# Advanced Computer Vision - Week_08 - Conditional GAN (dataset: faces)

In [5]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
import matplotlib.pyplot as plt
%matplotlib inline
import os
from tqdm import tqdm
import torch.nn.functional as F
from torchvision.utils import save_image
import pandas as pd

In [6]:
df_daces_metadata = pd.read_csv("/home/mh731nk/_data/experiments_tmp/data/celeba_dataset/list_attr_celeba.csv")

In [7]:
df_daces_metadata.columns

Index(['image_id', '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive',
       'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose',
       'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows',
       'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair',
       'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open',
       'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin',
       'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns',
       'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings',
       'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace',
       'Wearing_Necktie', 'Young'],
      dtype='object')

In [8]:
df_daces_metadata

Unnamed: 0,image_id,5_o_Clock_Shadow,Arched_Eyebrows,Attractive,Bags_Under_Eyes,Bald,Bangs,Big_Lips,Big_Nose,Black_Hair,...,Sideburns,Smiling,Straight_Hair,Wavy_Hair,Wearing_Earrings,Wearing_Hat,Wearing_Lipstick,Wearing_Necklace,Wearing_Necktie,Young
0,000001.jpg,-1,1,1,-1,-1,-1,-1,-1,-1,...,-1,1,1,-1,1,-1,1,-1,-1,1
1,000002.jpg,-1,-1,-1,1,-1,-1,-1,1,-1,...,-1,1,-1,-1,-1,-1,-1,-1,-1,1
2,000003.jpg,-1,-1,-1,-1,-1,-1,1,-1,-1,...,-1,-1,-1,1,-1,-1,-1,-1,-1,1
3,000004.jpg,-1,-1,1,-1,-1,-1,-1,-1,-1,...,-1,-1,1,-1,1,-1,1,1,-1,1
4,000005.jpg,-1,1,1,-1,-1,-1,1,-1,-1,...,-1,-1,-1,-1,-1,-1,1,-1,-1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
202594,202595.jpg,-1,-1,1,-1,-1,-1,1,-1,-1,...,-1,-1,-1,-1,-1,-1,1,-1,-1,1
202595,202596.jpg,-1,-1,-1,-1,-1,1,1,-1,-1,...,-1,1,1,-1,-1,-1,-1,-1,-1,1
202596,202597.jpg,-1,-1,-1,-1,-1,-1,-1,-1,1,...,-1,1,-1,-1,-1,-1,-1,-1,-1,1
202597,202598.jpg,-1,1,1,-1,-1,-1,1,-1,1,...,-1,1,-1,1,1,-1,1,-1,-1,1


In [15]:
df_daces_metadata.loc[df_daces_metadata["Male"]==1].shape

(84434, 41)

In [16]:
image_size = 64
batch_size = 265
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) # mean, std for normalize imagess
latent_size = 128
lr = 0.00025
epochs = 60

In [17]:
df_daces_metadata.loc[df_daces_metadata["Blond_Hair"]==1].shape

(29983, 41)

In [18]:
df_daces_metadata.loc[df_daces_metadata["Brown_Hair"]==1].shape

(41572, 41)

# Hyperparameters

In [19]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
device

device(type='cuda')

In [20]:
sample_dir = 'generated_lab_week08'
os.makedirs(sample_dir, exist_ok=True)

# Helpers

In [21]:
def to_device(data, device):
  """Move tensor(s) to chosen device"""
  if isinstance(data, (list,tuple)):
      return [to_device(x, device) for x in data]
  return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [22]:
def save_samples(index, latent_tensors, show=True):
  fake_images = generator(latent_tensors)
  fake_fname = 'generated=images-{0:0=4d}.png'.format(index)
  save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=8)
  print("Saving", fake_fname)

  if show:
    fig, ax = plt.subplots(figsize=(8,8))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))

In [23]:
def show_images(images, nmax=64):
  fig, ax = plt.subplots(figsize=(8,8))
  ax.set_xticks([]); ax.set_yticks([])
  ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))
  
def show_batch(dl, nmax=64):
  for images, _ in dl:
    show_images(images, nmax)
    break

# Dataset prepair

In [24]:
DATA_DIR = '/home/mh731nk/_data/experiments_tmp/data/celeba_dataset/img_align_celeba'

In [25]:
from PIL import Image
import torch
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
# Load CSV file with labels
csv_file = '/home/mh731nk/_data/experiments_tmp/data/celeba_dataset/list_attr_celeba.csv'
df = pd.read_csv(csv_file)

class CustomDataset(Dataset):
    def __init__(self, root_dir, csv_file, transform=None):
        self.image_folder = ImageFolder(root=root_dir, transform=None)  # Ensure ImageFolder returns PIL Images
        self.labels_df = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.image_folder)

    def __getitem__(self, idx):
        img_path, _ = self.image_folder.imgs[idx]
        img = Image.open(img_path).convert('RGB')  # Open image as PIL Image
        
        image_id = img_path.split('/')[-1]  # Extract image_id from image path
        labels = self.labels_df[self.labels_df['image_id'] == image_id].iloc[:, 1:].values.flatten()
        labels = torch.IntTensor(labels)

        if self.transform:
            img = self.transform(img)

        return img, labels

# Define transformations
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])

# Create custom dataset and dataloader
custom_dataset = CustomDataset(root_dir=DATA_DIR, csv_file=csv_file, transform=transform)
custom_dataloader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)


In [30]:
for images,labels in custom_dataloader:
    print(labels.shape)
    assert False



torch.Size([265, 40])


AssertionError: 

In [17]:
# train_ds = ImageFolder(root=DATA_DIR, 
#             transform=T.Compose([T.Resize(image_size),
#                                 T.CenterCrop(image_size), # pick central square crop of it
#                                 T.ToTensor(),
#                                 T.Normalize(*stats)        # normalize => -1 to 1                               
#                             ]))



In [18]:

# train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=3, pin_memory=True) # use multiple cores

In [31]:
# helper functions to denormalize the image tensors and display some sample images from a training batch. 
# In future for helping train could be calculated zero centered nomralization per channel :)
def denorm(img_tensors):
    "Denormalize image tensor with specified mean and std"
    return img_tensors * stats[1][0] + stats[0][0]

In [32]:
def show_images(images, nmax=64):
  fig, ax = plt.subplots(figsize=(8,8))
  ax.set_xticks([]); ax.set_yticks([])
  ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))
  
def show_batch(dl, nmax=64):
  for images, _ in dl:
    show_images(images, nmax)
    break

In [33]:
# show_batch(custom_dataloader) # original data looks "cudne pokrucene"

In [34]:
# move to GPU
train_dl = DeviceDataLoader(custom_dataloader, device)

# Models

## Generator

In [35]:
latent_size

128

In [36]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


In [61]:
# import torch
# import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, latent_size, number_of_class):
        super(Generator, self).__init__()
        self.latent_size = latent_size
        self.number_of_class = number_of_class
        self.embeding_size = self.latent_size

        self.label_embeddings = nn.Embedding(self.number_of_class,self.embeding_size)

        self.conv1 = nn.ConvTranspose2d(self.latent_size + self.embeding_size, 512, kernel_size=4, stride=1, padding=0, bias=False)
        self.conv2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False)
        self.conv3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False)
        self.conv4 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False)
        self.conv5 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False)

        self.bn1 = nn.BatchNorm2d(512)
        self.bn2 = nn.BatchNorm2d(256)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(64)
        
        self.relu = nn.ReLU(True)
        self.tanh = nn.Tanh()

    def forward(self, input_latent, input_label):
        print(input_latent.shape)
        print(input_label.shape)
        label_embeddings = self.label_embeddings(input_label).view(len(input_label), self.embeding_size, 1, 1)
        print(input_label.shape)
        input_cat = torch.cat([input_latent, label_embeddings], 1)
        x = self.relu(self.bn1(self.conv1(input_cat)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.relu(self.bn4(self.conv4(x)))
        x = self.tanh(self.conv5(x))
        return x


class Generator(nn.Module):
    def __init__(self, emb_size=128):
        super(Generator,self).__init__()
        self.emb_size = emb_size
        self.label_embeddings = nn.Embedding(40, self.emb_size)
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100+self.emb_size,64*8,4,1,0,bias=False),
            nn.BatchNorm2d(64*8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*8,64*4,4,2,1,bias=False),
            nn.BatchNorm2d(64*4),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*4,64*2,4,2,1,bias=False),
            nn.BatchNorm2d(64*2),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*2,64,4,2,1,bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64,3,4,2,1,bias=False),
            nn.Tanh()
        )
        self.apply(weights_init)


    def forward(self,input_noise,labels):
        label_embeddings = self.label_embeddings(labels).view(len(labels), self.emb_size, 1)
        
        input_x = torch.cat([input_noise, label_embeddings], 1)
        return self.model(input_x)


generator = Generator().to(device)



## Discriminator

In [62]:
import torch
import torch.nn as nn

# class Discriminator(nn.Module):
#     def __init__(self):
#         super(Discriminator, self).__init__()

#         self.label_embeddings = nn.Embedding(2, self.emb_size)
#         # Define the leaky ReLU activation
#         self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)

#         # Define the convolutional layers
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False)
#         self.bn1 = nn.BatchNorm2d(64)

#         self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False)
#         self.bn2 = nn.BatchNorm2d(128)

#         self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False)
#         self.bn3 = nn.BatchNorm2d(256)

#         self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False)
#         self.bn4 = nn.BatchNorm2d(512)

#         self.conv5 = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False)

#         self.flatten = nn.Flatten()
#         self.sigmoid = nn.Sigmoid()

#     def forward(self, input):
#         # Define the forward pass
#         x = self.leaky_relu(self.bn1(self.conv1(input)))
#         x = self.leaky_relu(self.bn2(self.conv2(x)))
#         x = self.leaky_relu(self.bn3(self.conv3(x)))
#         x = self.leaky_relu(self.bn4(self.conv4(x)))
#         x = self.conv5(x)
#         x = self.flatten(x)
#         x = self.sigmoid(x)
#         return x
class Discriminator(nn.Module):
    def __init__(self, number_of_class,embeding_size=128):
        super(Discriminator, self).__init__()
        
        self.number_of_class = number_of_class
        self.embeding_size = embeding_size

        self.label_embeddings = nn.Embedding(40, self.embeding_size )
        
        self.model = nn.Sequential(
            nn.Conv2d(3,64,4,2,1,bias=False),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64,64*2,4,2,1,bias=False),
            nn.BatchNorm2d(64*2),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*2,64*4,4,2,1,bias=False),
            nn.BatchNorm2d(64*4),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*4,64*8,4,2,1,bias=False),
            nn.BatchNorm2d(64*8),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*8,64,4,2,1,bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Flatten()
        )
        self.model2 = nn.Sequential(
            nn.Linear(288,100),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(100,1),
            nn.Sigmoid()
        )
        self.apply(weights_init)


    def forward(self, img, labels):
        x = self.model(img)
        y = self.label_embeddings(labels)
        print(x.shape)
        print(y.shape)
        z = torch.cat([x, y], 1)
        final_output = self.model2(z)
        return final_output


discriminator = Discriminator(40,128).to(device)

In [63]:
# # generator = Generator(latent_size, 41)
# generator = Generator()
# # self.apply(weights_init)
# # generator = weights_init(generator)
# discriminator = Discriminator(41)

Weight Initialization
We define the weight initialization so that we do not have a widespread variation across randomly initialized weight values


In [64]:
# from torchsummary import summary
# summary(generator, (128,1,1))
# print(generator)

## Check out

In [65]:
torch.ones((1,41))

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1.]])

In [66]:
def noise(size):
    n = torch.randn(size, 100, 1, 1, device=device)
    return n.to(device)


In [67]:
def discriminator_train_step(real_data, real_labels, fake_data, fake_labels):
    d_optimizer.zero_grad()

    prediction_real = discriminator(real_data, real_labels)
    error_real = loss(prediction_real, torch.ones(len(real_data), 1).to(device))
    error_real.backward()

    prediction_fake = discriminator(fake_data, fake_labels)
    error_fake = loss(prediction_fake, torch.zeros(len(fake_data), 1).to(device))
    error_fake.backward()
    
    d_optimizer.step()    
    return error_real + error_fake


In [68]:
def generator_train_step(fake_data, fake_labels):
    g_optimizer.zero_grad()

    prediction = discriminator(fake_data, fake_labels)

    error = loss(prediction, torch.ones(len(fake_data), 1).to(device))
    
    error.backward()
    g_optimizer.step()
    return error


In [69]:
# !pip install  torch_snippets

In [70]:
from torch_snippets import *
import torch
from torchvision.utils import make_grid
from torch_snippets import *
from PIL import Image
import torchvision
from torchvision import transforms
import torchvision.utils as vutils
from tqdm import tqdm


In [71]:
print(device)

# generator = Generator().to(device)

In [72]:
loss = nn.BCELoss()



d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))


# fixed_noise will be used to generate images from random noise
fixed_noise = torch.randn(64, 100, 1, 1, device=device)


# Half labels correspond to class 0 and remaining to class 1
fixed_fake_labels = torch.LongTensor([0]*(len(fixed_noise)//2) + [1]*(len(fixed_noise)//2)).to(device)


n_epochs = 25
img_list = []


In [73]:
type(fixed_noise)

torch.Tensor

In [74]:
torch.LongTensor(np.random.randint(0, 2, (256,41,1,1))).shape

torch.Size([256, 41, 1, 1])

In [75]:
# Train the model for 25 epochs
for epoch in tqdm(range(n_epochs), total = n_epochs):
    N = len(custom_dataloader)


    for bx, (images, labels) in enumerate(custom_dataloader):
	    # Obtain the data
        real_data, real_labels = images.to(device), labels.to(device)

        fake_labels = torch.LongTensor(np.random.randint(0, 2, (len(real_data), 40))).to(device)

        print(device)
        print("real_data " , type(real_data))
        print(real_data.shape)

        print("-----")
        print("real_labels " ,type(real_labels))
        print(real_labels.shape)

        print("fake_labels " ,type(fake_labels))
        print(fake_labels.shape)


        fake_data = generator(noise(len(real_data)), fake_labels)
        # fake_data = fake_data.detach()


        # Train discriminator
       
        print("fake_data " ,type(fake_data))
        print(fake_data.shape)


        # d_loss = discriminator_train_step(real_data, real_labels.to(torch.int64), fake_data, fake_labels)

        d_loss = discriminator_train_step(real_data, real_labels, fake_data, fake_labels)
        
        
        fake_labels = torch.LongTensor(np.random.randint(0, 2, len(real_data))).to(device)

        # Train generator
        fake_data = generator(noise(len(real_data)), fake_labels).to(device)
        
        g_loss = generator_train_step(fake_data, fake_labels)

	# Log to wandb
        # wandb.log(
        #     {
        #         'd_loss':d_loss.detach(),
        #         'g_loss':g_loss.detach()
        #     }
        # )
	
    # Inference
    with torch.no_grad():
        fake = generator(fixed_noise, fixed_fake_labels).detach().cpu()
        imgs = vutils.make_grid(fake, padding=2, normalize=True).permute(1,2,0)
        img_list.append(imgs)
        
        show(imgs, sz=10)


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:01<?, ?it/s]


RuntimeError: shape '[265, 128, 1]' is invalid for input of size 1356800

In [None]:
import torch
import torch.nn as nn

# class Discriminator(nn.Module):
#     def __init__(self):
#         super(Discriminator, self).__init__()

#         self.label_embeddings = nn.Embedding(2, self.emb_size)
#         # Define the leaky ReLU activation
#         self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)

#         # Define the convolutional layers
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False)
#         self.bn1 = nn.BatchNorm2d(64)

#         self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False)
#         self.bn2 = nn.BatchNorm2d(128)

#         self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False)
#         self.bn3 = nn.BatchNorm2d(256)

#         self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False)
#         self.bn4 = nn.BatchNorm2d(512)

#         self.conv5 = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False)

#         self.flatten = nn.Flatten()
#         self.sigmoid = nn.Sigmoid()

#     def forward(self, input):
#         # Define the forward pass
#         x = self.leaky_relu(self.bn1(self.conv1(input)))
#         x = self.leaky_relu(self.bn2(self.conv2(x)))
#         x = self.leaky_relu(self.bn3(self.conv3(x)))
#         x = self.leaky_relu(self.bn4(self.conv4(x)))
#         x = self.conv5(x)
#         x = self.flatten(x)
#         x = self.sigmoid(x)
#         return x
class Discriminator(nn.Module):
    def __init__(self, number_of_class,embeding_size=128,):
        super(Discriminator, self).__init__()
        
        self.number_of_class = number_of_class
        self.embeding_size = embeding_size

        self.label_embeddings = nn.Embedding(41, self.embeding_size )
        
        self.model = nn.Sequential(
            nn.Conv2d(3,64,4,2,1,bias=False),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64,64*2,4,2,1,bias=False),
            nn.BatchNorm2d(64*2),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*2,64*4,4,2,1,bias=False),
            nn.BatchNorm2d(64*4),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*4,64*8,4,2,1,bias=False),
            nn.BatchNorm2d(64*8),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*8,64,4,2,1,bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Flatten()
        )
        self.model2 = nn.Sequential(
            nn.Linear(288,100),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(100,1),
            nn.Sigmoid()
        )
        self.apply(weights_init)


    def forward(self, input, labels):
        x = self.model(input)
        y = self.label_embeddings(labels)
        print(x.shape)
        print(y.shape)
        input = torch.cat([x, y], 1)
        final_output = self.model2(input)
        return final_output


discriminator = Discriminator(41).to(device)

In [None]:
assert False

AssertionError: 

In [None]:
import numpy as np
xb = torch.randn(batch_size, 100, 1, 1) # random latent tensors
fake_images = generator(xb, torch.ones((41,1,1), dtype=torch.int))
# print(fake_images.shape)
# show_images(fake_images)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 265 but got size 41 for tensor number 1 in the list.

In [None]:
generator = to_device(generator, device) 
discriminator = to_device(discriminator, device)
lr

0.00025

In [None]:

def train(epochs, lr, start_idx = 1):
  torch.cuda.empty_cache()

  # Losses & scores
  losses_g = []
  losses_d = []
  real_scores = []
  fake_scores = []

  # Create optimizers
  opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
  opt_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))

  for idx, epoch in enumerate(range(0,epochs)):
    print(f"epoch {idx} start")
    for real_images, _ in tqdm(train_dl):


        ### Train discriminator
        # Clear discriminator gradients
        opt_d.zero_grad()

        # Pass real images through  discriminator
        real_preds = discriminator(real_images)
        real_targets = torch.ones(real_images.size(0), 1, device=device)
        real_loss = F.binary_cross_entropy(real_preds, real_targets)
        real_score = torch.mean(real_preds).item()

        # Generate fake images
        latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
        fake_images = generator(latent)

        # Pass Fake images through discriminator
        fake_targets = torch.zeros(fake_images.size(0), 1, device=device)
        fake_preds = discriminator(fake_images)
        fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)
        fake_score = torch.mean(fake_preds).item()

        # Update discriminator weights
        loss_d = real_loss + fake_loss
        loss_d.backward()
        opt_d.step()

        # loss_d, real_score, fake_score = train_discriminator(real_images, opt_d)




        ### Train generator
        # Clear generator gradients
        opt_g.zero_grad()

        # Generate fake images
        latent = torch.randn(batch_size, latent_size, 1,1, device=device)
        fake_images = generator(latent)

        # Try to fool the discriminator
        preds = discriminator(fake_images)
        targets = torch.ones(batch_size, 1, device=device)
        loss_g = F.binary_cross_entropy(preds, targets)

        # Update generator 
        loss_g.backward()
        opt_g.step()



    # Record losses & scores
    losses_g.append(loss_g)
    losses_d.append(loss_d.item())
    real_scores.append(real_score)
    fake_scores.append(fake_score)

    # Log losses & scores (last batch)
    print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(epoch+1, epochs, loss_g, loss_d, real_score, fake_score))
    # Save generated images
    save_samples(epoch+start_idx, fixed_latent, show=False)

  return losses_g, losses_d, real_scores, fake_scores

## Test latent space
We'll use a fixed set of input vectors to the generator to see how the individual generated images evolve over time as we train the model. Let's save one set of images before we start training our model.

In [None]:
fixed_latent = torch.randn(64, latent_size, 1, 1, device=device)

In [None]:
history = train(epochs, lr)

epoch 0 start


 70%|██████▉   | 532/765 [00:29<00:12, 18.72it/s]