# **CGAN**

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import torch
import numpy as np
print(torch.__version__)
from torch import cuda
# print(cuda.get_device_name(cuda.current_device()))
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

1.10.0+cu111


### LOAD DATA

In [None]:
import tensorflow as tf

train_kwargs = {'batch_size': 128, 'shuffle': True}
test_kwargs = {'batch_size': 128, 'shuffle': False}

In [None]:
from torchvision import datasets, transforms

transform=transforms.Compose([
        # Pad images with 0s
        transforms.Pad((0,4,4,0), fill=0, padding_mode='constant'),
    
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        ])
dataset_train = datasets.MNIST('../data', train=True, download=True,
                   transform=transform)

dataset_test = datasets.MNIST('../data', train=False,
                   transform=transform)

train_loader = torch.utils.data.DataLoader(dataset_train,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset_test, **test_kwargs)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



### ARCHITECTURE

In [None]:
import pdb
import random
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F
import time

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Layer 0: Transpose Convolution. Noise: Input = 1x1x100. Output = 4x4x256. 
        #                                 Label: Input = 1x1x10, Output = 4x4x256. 
        #          Batchnorm
        trans_conv_noise = nn.ConvTranspose2d(100, 256, (4, 4), stride = 1)
        trans_conv_label = nn.ConvTranspose2d(10, 256, (4, 4), stride = 1)
        batch_norm = nn.BatchNorm2d(256)

        # Layer 1: Transpose Convolution. Input = 4x4x512. Output = 8x8x256. BatchNorm
        trans_conv1 = nn.ConvTranspose2d(512, 256, (4,4), stride = 2, padding = 1)
        batch_norm1 = nn.BatchNorm2d(256)
        # Layer 2: Transpose Convolution. Output = 16x16x128. BatchNorm
        trans_conv2 = nn.ConvTranspose2d(256, 128, (4, 4), stride = 2, padding = 1)
        batch_norm2 = nn.BatchNorm2d(128)
        # Layer 3: Transpose Convolution. Output = 32x32x1. tanH
        trans_conv3 = nn.ConvTranspose2d(128, 1, (4, 4), stride = 2, padding = 1)
        tanh = nn.Tanh()

        relu = nn.ReLU()
        
        self.project_noise = nn.Sequential(trans_conv_noise, batch_norm, relu)
        self.project_label = nn.Sequential(trans_conv_label, batch_norm, relu)
        self.model = nn.Sequential(trans_conv1, batch_norm1, relu, trans_conv2, batch_norm2, relu, trans_conv3, tanh)

    def forward(self, noise, label):
        # Project noise and label
        bsize = noise.size(0)
        noise = noise.view(bsize, 100, 1, 1)
        label = label.view(bsize, 10, 1, 1)
        noise = self.project_noise(noise)
        label = self.project_label(label)
        # Concat noise and label
        x = torch.cat((noise, label), 1)
        x = self.model(x)

        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Layer 0: Convolution. Image: Input = 32x32x1. Output = 16x16x64
        #                       Label: Input = 1x1x10. Output = 16x16x64
        #          BatchNorm. LeakyRelu
        conv_image = nn.Conv2d(1, 64, (4, 4), stride = 2, padding = 1)
        # conv_label = nn.Conv2d(10, 64, (2, 2), stride = 1, padding = 8)
        conv_label = nn.Conv2d(10, 64, (4, 4), stride = 2, padding = 1)
        
        batch_norm = nn.BatchNorm2d(64)
        
        # Layer 1: Convolution. Input = 16x16x128. Output = 8x8x256. BatchNorm
        conv1 = nn.Conv2d(128, 256, (4, 4), stride =2, padding = 1)
        batch_norm1 = nn.BatchNorm2d(256)

        # Layer 2: Convolution. Output = 4x4x512. BatchNorm
        conv2 = nn.Conv2d(256, 512, (4, 4), stride =2, padding = 1)
        batch_norm2 = nn.BatchNorm2d(512)

        # Layer 3: Transpose Convolution. Output = 1x1. Sigmoid
        trans_conv1 = nn.ConvTranspose2d(512, 1, (4, 4), stride = 1, padding = 3)
        # trans_conv1 = nn.Conv2d(512, 1, (4, 4), stride = 2)
        sigmoid = nn.Sigmoid()

        lrelu = nn.LeakyReLU(0.01)
        
        self.project_image = nn.Sequential(conv_image, batch_norm, lrelu)
        self.project_label = nn.Sequential(conv_label, batch_norm, lrelu)
        self.model = nn.Sequential(conv1, batch_norm1, lrelu, conv2, batch_norm2, lrelu, trans_conv1, sigmoid)

    def forward(self, image, label):
        bsize = image.size(0)
        label = label.view(bsize, 10, 1, 1)
        label = torch.tile(label, (1, 1, 32, 32))
        image = self.project_image(image)
        label = self.project_label(label)
        # Concat image and label
        x = torch.cat((image, label), 1)

        x = self.model(x)

        return x

### TRAINING PIPELINE

In [None]:
loss = nn.BCELoss()
generator = torch.load('generator_epoch14.pt', map_location= device)
discriminator = torch.load('discriminator_epoch14.pt', map_location= device)
generator.load_state_dict(torch.load('generator_epoch14_state_dict.pt', map_location= device))
discriminator.load_state_dict(torch.load('discriminator_epoch14_state_dict.pt', map_location= device))

optimizerD = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))


checkpoint = torch.load('cgan_epoch14.tar', map_location = device)
generator.load_state_dict(checkpoint['generator'])
discriminator.load_state_dict(checkpoint['discriminator'])
optimizerD.load_state_dict(checkpoint['optD'])
optimizerG.load_state_dict(checkpoint['optG'])

In [None]:
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # Layer 1: Convolutional. Input = 32x32x1. Output = 28x28x6.
        self.conv1 = nn.Conv2d(1, 6, (5,5))
        # Layer 2: Convolutional. Output = 10x10x16.
        self.conv2 = nn.Conv2d(6, 16, (5,5))
        # Layer 3: Fully Connected. Input = 400. Output = 120.
        self.fc1   = nn.Linear(400, 120)
        # Layer 4: Fully Connected. Input = 120. Output = 84.
        self.fc2   = nn.Linear(120, 84)
        # Layer 5: Fully Connected. Input = 84. Output = 10.
        self.fc3   = nn.Linear(84, 10)
    def forward(self, x):
        # Activation. # Pooling. Input = 28x28x6. Output = 14x14x6.
        x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
         # Activation. # Pooling. Input = 10x10x16. Output = 5x5x16.
        x = F.max_pool2d(F.relu(self.conv2(x)), (2,2))
        # Flatten. Input = 5x5x16. Output = 400.
        x = x.flatten(start_dim=1)
        # Activation.
        x = F.relu(self.fc1(x))
        # Activation.
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features 

In [None]:
def evaluate(target_loader, target_dataset, net):
    predictions = []
    net.eval()
    total_correct = 0
    avg_loss = 0.0
    for i, (images, labels) in enumerate(target_loader):
        output = net(images)
        avg_loss += criterion(output, labels).sum()
        pred = output.detach().max(1)[1]
        total_correct += pred.eq(labels.view_as(pred)).sum()
        predictions.append(pred)

    avg_loss /= len(target_dataset)
    #print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test)))
    accuracy    = float(total_correct) / len(target_dataset)
    return accuracy, np.array(torch.cat(predictions))
    #or if you are in latest Pytorch world
    #return accuracy, np.array(torch.vstack(predictions))

In [None]:
net = torch.load('classifier.pt')
net.load_state_dict(torch.load('classifier_state_dict.pt'))

criterion = nn.CrossEntropyLoss()
optimizerC = optim.SGD(net.parameters(), lr=0.001)

classifier_w = torch.load('classifier_tot.tar', map_location = device)
net.load_state_dict(classifier_w['net'])
optimizerC.load_state_dict(classifier_w['opt'])

In [None]:
from torchvision.utils import save_image,make_grid
import os

fixed_noise = torch.randn(400, 100).to(device)
fixed_label = torch.randint(0, 10, (400, ))
fixed_one_hot = F.one_hot(fixed_label.to(torch.int64), num_classes = 10).type(torch.FloatTensor).to(device)

with torch.no_grad():
    img_fake = generator(fixed_noise, fixed_one_hot).detach().cpu()
    # img_grid = make_grid(img_fake)
    # plt.imshow(img_grid)
    # plt.axis('off')
    # plt.show()

img_fake = img_fake.detach().numpy()
torch.save(img_fake, 'gen_imgs')

In [None]:
dataset_gen = torch.load('gen_imgs')
dataset_gen = torch.Tensor(dataset_gen)

dataset_gen = torch.utils.data.TensorDataset(dataset_gen, fixed_label)
gen_loader = torch.utils.data.DataLoader(dataset_gen)

gen_accuracy, gen_predictions = evaluate(gen_loader, dataset_gen, net)

print("Test Accuracy = {:.3f}".format(gen_accuracy))



Test Accuracy = 0.935
