In [185]:
import torch
from pathlib import Path
import torch
import torchvision
from torch import nn
from torchvision import transforms
from dataclasses import dataclass
from torch.utils.tensorboard import SummaryWriter  

In [186]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'cpu'
device

'cuda'

In [187]:
@dataclass
class ModelArgs:
    latent_vector_size = 100
    device = 'cpu'
    batch_size = 128
    initial_lr = 0.1
    final_lr =1e-6
    decay_factor = 1.00004
    momentum_initial = 0.5
    final_momentum_value = 0.7
    dropout = 0.5
    num_classes = 10
    img_size = 64
    no_of_lables = 10
    no_of_channels = 1


In [188]:
ModelArgs.device = device

In [189]:
#Transforms for images
transforms = torchvision.transforms.Compose([
    transforms.Resize(size=(ModelArgs.img_size,ModelArgs.img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))

])

In [None]:
#Loading MNIST Dataset
import torchvision
from torch.utils.data import DataLoader
import os

data_path = Path('/home/cmi_10101/Documents/datasets/content/data/')

# train_dir = data_path / "train"
# test_dir = data_path / "test"

# Load the training set
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=ModelArgs.batch_size, shuffle=True)

# Load the test set
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms)
testloader = torch.utils.data.DataLoader(testset, batch_size=ModelArgs.batch_size, shuffle=False)

In [191]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)  #mean = 0, std = 0.02


In [192]:
# ModelArgs.img_size = 64


In [193]:
class Generator(nn.Module):
    def __init__(
        self,
        latent_vector_size = 100,
        no_of_channels = 1,
        kernel_size = (4,4),
        stride: int = 2,
        number_of_feature_maps: int = 64,
        padding: int = 1,
        z_out_dimensions: int = 200,
        y_out_dimensions: int = 1000,
        labels_size: int = 10,
        combined_hidden_layer_dimensions: int = 1200,
        img_size: int = 1

    ):

        super().__init__()

        # self.linear_z_out = nn.Linear(in_features=latent_vector_size, out_features=z_out_dimensions, device=ModelArgs.device)
        # self.linear_y_out = nn.Linear(in_features=ModelArgs.no_of_lables, out_features=y_out_dimensions, device=ModelArgs.device)
        # self.combined = z_out_dimensions + y_out_dimensions
        # self.combined_layer = nn.Linear(in_features=self.combined, out_features=combined_hidden_layer_dimensions, device=ModelArgs.device)
        self.dense = nn.Linear(in_features=latent_vector_size, out_features=img_size * img_size, device=ModelArgs.device)
        self.combined_hidden_layer_dimensions = latent_vector_size + ModelArgs.no_of_lables
        self.embedding = nn.Embedding(num_embeddings=ModelArgs.num_classes, embedding_dim=latent_vector_size, device=ModelArgs.device)
        # self.wi = weight_initialization()
        self.img_size = img_size
        self.main = nn.Sequential(


            # nn.Linear(in_features=latent_vector_size, out_features=z_out_dimensions, device=ModelArgs.device),
            # nn.Linear(in_features=labels_size, out_features=y_out_dimensions, device=ModelArgs.device),
            nn.ConvTranspose2d(ModelArgs.latent_vector_size + ModelArgs.latent_vector_size, number_of_feature_maps * 16 , kernel_size=kernel_size, stride=stride, padding=0, bias=False),
            nn.InstanceNorm2d(number_of_feature_maps * 16),
            nn.ReLU(),

            #shape = (...,1024, 4, 4)
            nn.ConvTranspose2d(number_of_feature_maps * 16, number_of_feature_maps * 8 , kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.InstanceNorm2d(number_of_feature_maps * 8),
            nn.ReLU(),

            #shape = (..., 512, 8, 8)
            nn.ConvTranspose2d(number_of_feature_maps * 8, number_of_feature_maps * 4 , kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.ReLU(),
            nn.InstanceNorm2d(number_of_feature_maps * 4),

             #shape = (..., 256, 16, 16)
            nn.ConvTranspose2d(number_of_feature_maps * 4, number_of_feature_maps * 2 , kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.ReLU(),
            nn.InstanceNorm2d(number_of_feature_maps * 2),

             #shape = (..., 128, 32, 32)
            nn.ConvTranspose2d(number_of_feature_maps * 2, no_of_channels , kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.Tanh()
            #shape = (..., 3, 64, 64)
        )

    def forward(self, x, y):
        labels = self.embedding(y)

        labels = labels.unsqueeze(2).unsqueeze(3).view(x.shape[0], ModelArgs.latent_vector_size, 1,1)

        combined = torch.cat([x, labels], dim=1)
        # print(combined.shape)
        out = self.main(combined)
        return out

In [194]:
#Intializing the Generator instance
generator = Generator().to(ModelArgs.device)

#Applying the weights transformation
generator.apply(weights_init)

#Printing the structure
print(generator)

Generator(
  (dense): Linear(in_features=100, out_features=1, bias=True)
  (embedding): Embedding(10, 100)
  (main): Sequential(
    (0): ConvTranspose2d(200, 1024, kernel_size=(4, 4), stride=(2, 2), bias=False)
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): ReLU()
    (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): ReLU()
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2

In [167]:
torch.randint(0, 10, (128, 1), dtype=torch.long, device=ModelArgs.device).shape

torch.Size([128, 1])

In [195]:
from torchinfo import summary

random_data = torch.randn(ModelArgs.batch_size, ModelArgs.latent_vector_size, 1, 1, device=ModelArgs.device)
# labels =
labels = torch.randint(0, 10, (128,), dtype=torch.long, device=ModelArgs.device)
random_data = random_data.to(ModelArgs.device)
summary(model=generator,
        
        # input_size=(128, 100, 10, 1, 1),
        input_data=(random_data, labels),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
Generator (Generator)                    [128, 100, 1, 1]     [128, 1, 64, 64]     101                  True
├─Embedding (embedding)                  [128]                [128, 100]           1,000                True
├─Sequential (main)                      [128, 200, 1, 1]     [128, 1, 64, 64]     --                   True
│    └─ConvTranspose2d (0)               [128, 200, 1, 1]     [128, 1024, 4, 4]    3,276,800            True
│    └─BatchNorm2d (1)                   [128, 1024, 4, 4]    [128, 1024, 4, 4]    2,048                True
│    └─ReLU (2)                          [128, 1024, 4, 4]    [128, 1024, 4, 4]    --                   --
│    └─ConvTranspose2d (3)               [128, 1024, 4, 4]    [128, 512, 8, 8]     8,388,608            True
│    └─BatchNorm2d (4)                   [128, 512, 8, 8]     [128, 512, 8, 8]     1,024                True
│    └─ReLU (5) 

In [13]:
random_data.shape, labels.shape

(torch.Size([128, 100]), torch.Size([128]))

In [118]:
def one_hot_encode(labels, num_classes=10):
    # Create a tensor of zeros with shape (number_of_labels, num_classes)
    one_hot = torch.zeros(labels.size(0), num_classes, device=labels.device)

    # Scatter the labels to create one-hot encoding
    one_hot.scatter_(1, labels.unsqueeze(1), 1)

    return one_hot

for images, labels in trainloader:
    # Convert labels to one-hot encoded vectors
    # print(images)
    # print(labels)
    one_hot_labels = one_hot_encode(labels)
    break

    # Now you can use `one_hot_labels` as needed
    # print(one_hot_labels.shape)
    # break  # Break after the first batch for demonstration


In [196]:
class Discriminator(nn.Module):
    def __init__(
        self,
        no_of_channels = 1,
        kernel_size = (4,4),
        stride: int = 2,
        number_of_feature_maps: int = 64,
        padding: int = 1,
        lr_slope=0.2,
        num_classes: int = ModelArgs.num_classes,
        z_out_dimensions: int = 200,
        y_out_dimensions: int = 1000,
        labels_size: int = 10,
        combined_hidden_layer_dimensions: int = 1200,
        latent_vector_size: int = ModelArgs.latent_vector_size,
        img_size: int = ModelArgs.img_size
    ):

        super().__init__()

        self.embedding = nn.Embedding(num_embeddings=ModelArgs.num_classes, embedding_dim=ModelArgs.img_size * ModelArgs.img_size, device=ModelArgs.device)
        self.sig = nn.Sigmoid()


        self.main = nn.Sequential(
            nn.Conv2d(no_of_channels + 1, number_of_feature_maps * 2 , kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.InstanceNorm2d(number_of_feature_maps * 2),
            nn.LeakyReLU(negative_slope=lr_slope),

                #shape = (...,1024, 32, 32)
            nn.Conv2d(number_of_feature_maps * 2, number_of_feature_maps * 4 , kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.InstanceNorm2d(number_of_feature_maps * 4),
            nn.LeakyReLU(negative_slope=lr_slope),

                #shape = (..., 512, 16, 16)
            nn.Conv2d(number_of_feature_maps * 4, number_of_feature_maps * 8 , kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.InstanceNorm2d(number_of_feature_maps * 8),
            nn.LeakyReLU(negative_slope=lr_slope),

                #shape = (..., 256, 8, 8)
            nn.Conv2d(number_of_feature_maps * 8, number_of_feature_maps * 16 , kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.InstanceNorm2d(number_of_feature_maps * 16),
            nn.LeakyReLU(negative_slope=lr_slope),
            #  shape = (..., 128, 4, 4)

            nn.Conv2d(number_of_feature_maps * 16, 1 , kernel_size=kernel_size, stride=4, padding=padding, bias=False),
            #shape = (...,1,1)
            nn.Flatten(),
            nn.Sigmoid(),
         )

    def forward(self, x, y):
        # (2800x28 and 200x100)
        y = self.embedding(y)
        B,E = y.shape

        combined = torch.concat([x, y.unsqueeze(2).unsqueeze(3).view(x.shape[0], ModelArgs.no_of_channels, ModelArgs.img_size, ModelArgs.img_size)], dim=1)

        x = self.main(combined)

        return x

In [130]:
# x.view(B, C, IMG_SIZE*IMG_SIZE)

In [131]:
# ModelArgs.device

In [197]:
#Intializing the Discriminator instance
discriminator = Discriminator().to(ModelArgs.device)
#Apply the wieght intilization function layer by layer
discriminator = discriminator.apply(weights_init)
#Printing the structure
print(discriminator)

Discriminator(
  (linear_z_out): Linear(in_features=1, out_features=200, bias=True)
  (final_linear_layer): Linear(in_features=1200, out_features=1, bias=True)
  (linear_y_out): Linear(in_features=10, out_features=1000, bias=True)
  (combined_layer): Linear(in_features=1200, out_features=1200, bias=True)
  (embedding): Embedding(10, 4096)
  (sig): Sigmoid()
  (main): Sequential(
    (0): Conv2d(2, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
    (3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (

In [199]:
from torchinfo import summary

images = torch.randn(128, 1, 64, 64)
labels = torch.randint(0, 10, (128,), dtype=torch.long)

summary(model=discriminator,
        # input_size=(100, 1, 64, 64, 10),
        input_data=(images.to(ModelArgs.device), labels.to(ModelArgs.device)),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
Discriminator (Discriminator)            [128, 1, 64, 64]     [128, 1]             1,453,801            True
├─Embedding (embedding)                  [128]                [128, 4096]          40,960               True
├─Sequential (main)                      [128, 2, 64, 64]     [128, 1]             --                   True
│    └─Conv2d (0)                        [128, 2, 64, 64]     [128, 128, 32, 32]   4,096                True
│    └─BatchNorm2d (1)                   [128, 128, 32, 32]   [128, 128, 32, 32]   256                  True
│    └─LeakyReLU (2)                     [128, 128, 32, 32]   [128, 128, 32, 32]   --                   --
│    └─Conv2d (3)                        [128, 128, 32, 32]   [128, 256, 16, 16]   524,288              True
│    └─BatchNorm2d (4)                   [128, 256, 16, 16]   [128, 256, 16, 16]   512                  True
│    └─LeakyReLU

In [20]:
images.shape, labels.shape

(torch.Size([128, 1, 64, 64]), torch.Size([128]))

In [21]:
# labels = labels.float()
labels = one_hot_encode(labels)
labels = labels.to(ModelArgs.device)

In [22]:
labels

tensor([[0., 0., 0.,  ..., 0., 1., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

In [146]:
# epochs = 5 #30
# # beta_1 = 0.5
# # lr_optimizer = 0.0002
# loss_fn = nn.BCELoss()  #BCELoss function

generator = Generator().to(ModelArgs.device).apply(weights_init)
discriminator = Discriminator().to(ModelArgs.device).apply(weights_init)

# optimizerD = torch.optim.SGD(params=discriminator.parameters(), momentum=ModelArgs.momentum_initial, lr=ModelArgs.initial_lr) #For discriminator
# optimizerG = torch.optim.SGD(params=generator.parameters(), momentum=ModelArgs.momentum_initial, lr=ModelArgs.initial_lr) #For generator


epochs = 30 #30
beta_1 = 0.5
lr_optimizer = 0.0002
loss_fn = nn.BCELoss()  #BCELoss function


optimizerD = torch.optim.Adam(params=discriminator.parameters(), betas=(beta_1, 0.999), lr=lr_optimizer) #For discriminator
optimizerG = torch.optim.Adam(params=generator.parameters(), betas=(beta_1, 0.999), lr=lr_optimizer) #For generator



real_label = 1
fake_label = 0


loss_g = []
loss_d = []
img_list = []

# Fixed noise for generating the images
fixed_noise = torch.randn((ModelArgs.batch_size, ModelArgs.latent_vector_size, 1, 1), dtype=torch.float32, device=ModelArgs.device)

In [114]:
import shutil
save_images = Path('output_images/MNIST')


In [115]:
for images, labels in trainloader:
  print(images.shape)
  print(labels.shape)
  break

torch.Size([128, 1, 64, 64])
torch.Size([128])


In [162]:
one_hot_encode(torch.randint(low=0, high=ModelArgs.num_classes, size=(current_batch_size,), device=ModelArgs.device))

tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [149]:
#Training loop

generator.train()
discriminator.train()
iters = 0

writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")

for epoch in range(epochs):
    for X, y in trainloader:

        X = X.to(ModelArgs.device)
        y = y.to(ModelArgs.device)
        #Train the discriminator (with real data)

        ############################
        # (1) Update D network: maximize: log(1 - D(G(z)))
        ###########################

        current_batch_size = X.shape[0]  #Getting the current batch size
        
        real_data = torch.ones((current_batch_size,), device=device, dtype=torch.float32)

        # 1. Forward pass
        y_pred = discriminator(X, y).view(-1)

        # 2. Calculate  and accumulate loss
        loss_real = loss_fn(y_pred, real_data)


        # 3. Optimizer zero grad
        optimizerD.zero_grad()

        # 4. Loss backward
        loss_real.backward(retain_graph=True)


        #Train the discriminator (with fake data)

        noise = torch.randn((current_batch_size, ModelArgs.latent_vector_size, 1, 1), device=device)
        fake_data = torch.zeros((current_batch_size,), device=device, dtype=torch.float32)
        fake_label = torch.randint(0, ModelArgs.no_of_lables, (current_batch_size, ), device=ModelArgs.device)
        noise_generated_by_generator = generator(noise, fake_label)

        #1. Forward pass
        y_pred = discriminator(noise_generated_by_generator, fake_label).view(-1)

        # 2. Calculate  and accumulate loss
        loss_fake = loss_fn(y_pred, fake_data)


        # 4. Loss backward
        loss_fake.backward(retain_graph=True)

        # 5. Optimizer step
        optimizerD.step()

        #Accumulating total discriminator loss
        discriminator_combined_loss = loss_real + loss_fake
        loss_d.append(discriminator_combined_loss.item())



        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################

        labels = torch.ones((current_batch_size,), device=device, dtype=torch.float32)

        #1. Forward pass
        noise_generated_by_generator = generator(noise, y)
        y_pred = discriminator(noise_generated_by_generator, y).view(-1)
        # y_pred = torch.argmax(probs, dim=1).type(torch.float32)


        #2. Calculate and accumulate loss
        loss_gen = loss_fn(y_pred, labels)


        # 3. Optimizer zero grad
        optimizerG.zero_grad()

        # 4. Loss backward
        loss_gen.backward()

        # 5. Optimizer step
        optimizerG.step()

        loss_g.append(loss_gen.item())

        if iters % 300 == 0:
          print("Iterations: ", iters, "Epoch: ", epoch, "Generator loss: ", loss_gen.item(), "Discriminator loss: ", discriminator_combined_loss.item())

        #save the output
        with torch.no_grad():
          if iters % 500 == 0:
            print('saving the output')
            torchvision.utils.save_image(X,'{}/real_images_steps_{}.png'.format(save_images, iters),normalize=True)
            fake = generator(fixed_noise, y)
            torchvision.utils.save_image(fake,'{}/fake_images_steps_{}.png'.format(save_images, iters),normalize=True)

            img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
            img_grid_real = torchvision.utils.make_grid(X, normalize=True)
                
            writer_fake.add_image(
                      "Mnist Fake Images", img_grid_fake, global_step=iters
                  )
            writer_real.add_image(
                      "Mnist Real Images", img_grid_real, global_step=iters
                  )
                  

            # Check pointing for every epoch
            # torch.save(generator.state_dict(), 'weights/CelebA/generator_steps_%d.pth' % (iters))
            # torch.save(discriminator.state_dict(), 'weights/CelebA/discriminator_steps_%d.pth' % (iters))


          iters += 1


Iterations:  0 Epoch:  0 Generator loss:  3.579348564147949 Discriminator loss:  2.062045097351074
saving the output
Iterations:  300 Epoch:  0 Generator loss:  6.042276382446289 Discriminator loss:  0.8542827367782593
saving the output
Iterations:  600 Epoch:  1 Generator loss:  0.7866476774215698 Discriminator loss:  0.5937706828117371
Iterations:  900 Epoch:  1 Generator loss:  5.631103515625 Discriminator loss:  0.006849944591522217
saving the output
Iterations:  1200 Epoch:  2 Generator loss:  6.262673854827881 Discriminator loss:  0.025467010214924812
Iterations:  1500 Epoch:  3 Generator loss:  4.356583595275879 Discriminator loss:  0.05104273557662964
saving the output
Iterations:  1800 Epoch:  3 Generator loss:  8.941598892211914 Discriminator loss:  0.0033288539852946997
saving the output
Iterations:  2100 Epoch:  4 Generator loss:  7.465685844421387 Discriminator loss:  0.0012019629357382655
Iterations:  2400 Epoch:  5 Generator loss:  7.324221611022949 Discriminator loss:  