In [69]:
import torch
import torch.nn as nn
from torchvision.transforms import transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torchvision
from datetime import datetime
import numpy as np

now = datetime.now()

In [70]:
class Discriminator(nn.Module):
    def __init__(self,channels_inp,input_features):
        super(Discriminator,self).__init__()
        self.disc=nn.Sequential(
            nn.Conv2d(
                channels_inp,input_features,kernel_size=4,stride=2,padding=1
            ),
            nn.LeakyReLU(0.2),
            self._block(input_features,input_features*2,4,2,1),
            self._block(input_features*2,input_features*4,4,2,1),
            self._block(input_features*4,input_features*8,4,2,1),
            nn.Conv2d(input_features*8,1,kernel_size=4,stride=2,padding=0),
            nn.Sigmoid()    
        )
    def _block(self,in_channels,out_channels,kernalsize,stride,padding):
        return nn.Sequential(
        nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernalsize,
            stride=stride,
            padding=padding,
            bias=False
        ),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2)
        )
    def forward(self,x):
        return self.disc(x)
        

In [71]:
class Generator(nn.Module):
    def __init__(self,z_dim,channels_img,input_features):
        super(Generator,self).__init__()
        self.net = nn.Sequential(   
            self._block(z_dim,input_features*16,4,1,0),
            self._block(input_features*16,input_features*8,4,2,1),
            self._block(input_features*8,input_features*4,4,2,1),
            self._block(input_features*4,input_features*2,4,2,1),
            nn.ConvTranspose2d(
                input_features*2,channels_img,kernel_size=4,stride=2,padding=1
            ),
            nn.Tanh()

        )   
    def _block(self,in_channels,out_channels,kernalsize,stride,padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernalsize,
                stride,
                padding,
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    def forward(self,x):
        return self.net(x)
        

In [72]:
def init_weights(model):
    for m in model.modules():
        if isinstance(m,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data,0.0,0.02)



In [74]:
DEVICE="cuda" if torch.cuda.is_available() else "cpu"
lr=1e-4
z_dim=100
img_size=64
channels_dim=3
batch_size=128
num_epochs=500

features_disc=64
features_gen=64


disc=Discriminator(channels_dim,features_disc).to(DEVICE) 
gen=Generator(z_dim,channels_dim,features_gen).to(DEVICE)
init_weights(disc)
init_weights(gen)

fixed_noise=torch.randn(32,z_dim,1,1).to(DEVICE)
transforms_img= transforms.Compose(
    [
        transforms.Resize((64,64)), #transforms.Resize(IMAGE_SIZE) resizes propotionally
        transforms.ToTensor(),
        transforms.Normalize([0.5 for _ in range(channels_dim)], [0.5 for _ in range(channels_dim)])
    ]
)
dataset = datasets.ImageFolder(root="./images",transform=transforms_img)
loader=DataLoader(dataset, batch_size=batch_size,shuffle=True)
opt_disc=optim.Adam(disc.parameters(),lr=lr,betas=(0.5,0.999))    
opt_gen=optim.Adam(gen.parameters(),lr=lr,betas=(0.5,0.999))
critereon=nn.BCELoss()
print("TIME: ",now.strftime("%Y%m%d-%H%M%S"))
writer_fake=SummaryWriter(f"runs/DCGAN/fake/"+ now.strftime("%Y%m%d-%H%M%S") + "/")
writer_real=SummaryWriter(f"runs/DCGAN/real/"+ now.strftime("%Y%m%d-%H%M%S") + "/")

step=0
gen.train()
disc.train()
for epoch in range(num_epochs):
    for batch_index, (real,_) in enumerate(loader):
        real=real.to(DEVICE)
        noise=torch.randn((batch_size,z_dim,1,1)).to(DEVICE)

        #Discriminator loss: max(log(D(real)) + log(1 - D(G(z))) )
        fake_img=gen(noise)
        
        disc_real=disc(real).reshape(-1)
        lossD_real=critereon(disc_real,torch.ones_like(disc_real))
        disc_fake=disc(fake_img).reshape(-1)
        lossD_fake=critereon(disc_fake,torch.zeros_like(disc_fake))
        lossD = (lossD_fake+lossD_real)/2

        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        #Train Generator
        output=disc(fake_img).reshape(-1)
        lossG=critereon(output,torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if(batch_index%100==0):
            print(f'[{epoch}/{num_epochs}--Loss(D):{lossD:.4f}--Loss(G):{lossG:.4f}')

        #start training
        with torch.no_grad():
            fake = gen(fixed_noise)
            img_grid_fake=torchvision.utils.make_grid(fake[:32],normalize=True)
            img_grid_real=torchvision.utils.make_grid(real[:32],normalize=True)

            writer_fake.add_image(
                "Fake img1",img_grid_fake,global_step=step
            )
            writer_real.add_image(
                "Real img1",img_grid_real,global_step=step
            )

        step+=1



TIME:  20230204-030223
[0/500--Loss(D):0.6914--Loss(G):0.7604
[1/500--Loss(D):0.5183--Loss(G):0.9999
[2/500--Loss(D):0.3929--Loss(G):1.1910
[3/500--Loss(D):0.3096--Loss(G):1.3733
[4/500--Loss(D):0.2458--Loss(G):1.5566
[5/500--Loss(D):0.1991--Loss(G):1.7266
[6/500--Loss(D):0.1652--Loss(G):1.8647
[7/500--Loss(D):0.1394--Loss(G):2.0066
[8/500--Loss(D):0.1180--Loss(G):2.1484
[9/500--Loss(D):0.1010--Loss(G):2.2731
[10/500--Loss(D):0.0895--Loss(G):2.3950
[11/500--Loss(D):0.0766--Loss(G):2.5123
[12/500--Loss(D):0.0680--Loss(G):2.6272
[13/500--Loss(D):0.0601--Loss(G):2.7371
[14/500--Loss(D):0.0549--Loss(G):2.8259
[15/500--Loss(D):0.0483--Loss(G):2.9482
[16/500--Loss(D):0.0441--Loss(G):3.0481
[17/500--Loss(D):0.0388--Loss(G):3.1436
[18/500--Loss(D):0.0360--Loss(G):3.2353
[19/500--Loss(D):0.0323--Loss(G):3.3219
[20/500--Loss(D):0.0293--Loss(G):3.4113
[21/500--Loss(D):0.0273--Loss(G):3.4950
[22/500--Loss(D):0.0258--Loss(G):3.5761
[23/500--Loss(D):0.0234--Loss(G):3.6572
[24/500--Loss(D):0.0217--Lo

In [75]:
torch.save(gen,'./models/gen.pt')
torch.save(disc,'./models/disc.pt')