In [1]:
import torch.nn as nn
import torch.optim as optim
from Generator import *
from Discriminator import * 
from load_datasets import *
import torchvision.datasets as dset
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir='../GAN/TensorboardLogs/')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
latent_dim = 100
image_channels = 3
total_filters = 128
epochs = 20


#dl_train,_,_ = load_ImageNet100(validate=True)


real_label = 1.0
fake_label = 0.0
g_loss = []
d_loss = []




In [2]:
model_Generator = Generator(latent_dim=latent_dim,image_channels=image_channels,total_filters=total_filters).to(device=device)
model_Discriminator = Discriminator(latent_dim=latent_dim,image_channels=image_channels,total_filters=total_filters).to(device=device)
#print(model_Generator)
#print(model_Discriminator)

In [3]:

dataset = dset.CIFAR10(root="../GAN/data/", download=True,
                           transform=transforms.Compose([
                               transforms.Resize(64),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dl_train = torch.utils.data.DataLoader(dataset, batch_size=128,
                                         shuffle=True, num_workers=2)


criterion = nn.BCELoss()
generator_lr = 0.0001
discriminator_lr = 0.00001


optimizerD = optim.Adam(model_Discriminator.parameters(), lr=discriminator_lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(model_Generator.parameters(), lr=generator_lr, betas=(0.5, 0.999))

for epoch in range(epochs):
    for i, data in enumerate(dl_train, 0):
        
        model_Discriminator.zero_grad()
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label, device=device)

        output = model_Discriminator(real_cpu)
        #print(output.shape)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        
        noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
        fake = model_Generator(noise)
        label.fill_(fake_label)
        output = model_Discriminator(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

       
        model_Generator.zero_grad()
        label.fill_(real_label)  
        output = model_Discriminator(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()
        
        #print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, epochs, i, len(dl_train), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        if i%1 == 0 :
            fake = model_Generator(noise).to("cpu")
            vutils.save_image(fake,f'../GAN/Samples/Fakes/fakes_samples_{epoch}.png',normalize=True)
            vutils.save_image(fake,f'../GAN/Samples/Fakes/fakes_samples.png',normalize=True)
            #plt.imshow(np.transpose(fake.detach()[0],(1,2,0)))
        
        writer.add_scalar("Loss/train", D_x, D_G_z1, D_G_z2 , epoch)
        

Files already downloaded and verified
