### Lecture 2

In [1]:
import torch
import numpy as np
from matplotlib import pyplot as plt
% matplotlib inline

UsageError: Line magic function `%` not found.


In [None]:
torch.cuda.is_available()

In [None]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')


device

In [None]:
noise_dim = 100
label_dim = 10

#### Generator Model

In [None]:
class Generator(torch.nn.Module):
    
    def __init__(self):
        
        super(Generator, self).__init__()
        
        self.fcn = torch.nn.Sequential(
            # Fully Connected Layer 1
            torch.nn.Linear(
                in_features=noise_dim + label_dim,
                out_features=240,
                bias=True
            ),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            # Fully Connected Layer 2
            torch.nn.Linear(
                in_features=240,
                out_features=240,
                bias=True
            ),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            # Fully Connected Layer 3
            torch.nn.Linear(
                in_features=240,
                out_features=240,
                bias=True
            ),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            # Fully Connected Layer 4
            torch.nn.Linear(
                in_features=240,
                out_features=240,
                bias=True
            ),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            # Fully Connected Layer 5
            torch.nn.Linear(
                in_features=240,
                out_features=240,
                bias=True
            ),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            # Fully Connected Layer 6
            torch.nn.Linear(
                in_features=240,
                out_features=784,
                bias=True
            ),
            torch.nn.Sigmoid()
        )

    def forward(self, batch, labels):
        inputs = batch.view(batch.size(0), -1)
        ret = torch.cat((inputs, labels), dim=1)
        ret = self.fcn(ret)
        return ret

#### Maxout Activation

##### Source: https://github.com/pytorch/pytorch/issues/805

In [None]:
class Maxout(torch.nn.Module):

    def __init__(self, num_pieces):

        super(Maxout, self).__init__()

        self.num_pieces = num_pieces

    def forward(self, x):

        # x.shape = (batch_size? x 625)

        assert x.shape[1] % self.num_pieces == 0  # 625 % 5 = 0

        ret = x.view(
            *x.shape[:1],  # batch_size
            x.shape[1] // self.num_pieces,  # piece-wise linear
            self.num_pieces,  # num_pieces
            *x.shape[2:]  # remaining dimensions if any
        )
        
        # ret.shape = (batch_size? x 125 x 5)

        # https://pytorch.org/docs/stable/torch.html#torch.max        
        ret, _ = ret.max(dim=2)

        # ret.shape = (batch_size? x 125)

        return ret

#### Discriminator Model

In [None]:
class Discriminator(torch.nn.Module):
    
    def __init__(self):
        
        super(Discriminator, self).__init__()
        
        self.fcn = torch.nn.Sequential(
            # Fully Connected Layer 1
            torch.nn.Linear(
                in_features=784 + label_dim,
                out_features=240,
                bias=True
            ),
            Maxout(5),
            torch.nn.Dropout(0.5),
            # Fully Connected Layer 2
            torch.nn.Linear(
                in_features=48,
                out_features=240,
                bias=True
            ),
            Maxout(5),
            torch.nn.Dropout(0.5),
            # Fully Connected Layer 3
            torch.nn.Linear(
                in_features=48,
                out_features=240,
                bias=True
            ),
            Maxout(5),
            torch.nn.Dropout(0.5),
            # Fully Connected Layer 4
            torch.nn.Linear(
                in_features=48,
                out_features=1,
                bias=True
            ),
            torch.nn.Sigmoid()
        )

    def forward(self, batch, labels):
        ret = batch.view(batch.size(0), -1)
        ret = torch.cat((ret, labels), dim=1)
        ret = self.fcn(ret)
        return ret

#### MNIST Dataset

In [None]:
import torchvision

In [None]:
class FlattenTransform:
    
    def __call__(self, inputs):
        return inputs.view(inputs.shape[0], -1)
        

data_train = torchvision.datasets.MNIST(
    './data/mnist',
    train=True,
    download=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        FlattenTransform()
    ])
)

In [None]:
BATCH_SIZE = 64

train_loader = torch.utils.data.DataLoader(
    data_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
)

#### Optimizer

In [None]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

discriminator_optimizer = torch.optim.SGD(
    discriminator.parameters(),
    lr=0.001,
    momentum=0.5,
#     dampening=0.0001
)

generator_optimizer = torch.optim.SGD(
    generator.parameters(),
    lr=0.001,
    momentum=0.5,
#     dampening=0.0001
)

criterion = torch.nn.BCELoss()

#### Optimizer Scheduler

In [None]:
'''
discriminator_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer=discriminator_optimizer,
    step_size=1,
    gamma=0.99,
    last_epoch=-1
)

generator_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer=generator_optimizer,
    step_size=1,
    gamma=0.99,
    last_epoch=-1
)
'''

#### Lambda Learning Rate

In [None]:
'''
class DecayLR:
    
    def __init__(self, _lr, _step_size):
        
        self.lr = _lr
        self.step_size = _step_size
    
    def __call__(self, _epoch):

        if _epoch % self.step_size == 0:
            self.lr = self.lr * 0.1
        
        return self.lr


discriminator_scheduler = torch.optim.lr_scheduler.LambdaLR(
    discriminator_optimizer,
    DecayLR(
        _lr=0.9,
        _step_size=100
    )
)

generator_scheduler = torch.optim.lr_scheduler.LambdaLR(
    generator_optimizer,
    DecayLR(
        _lr=0.9,
        _step_size=100
    )
)
'''

#### Visualize Function

In [None]:
def visualizeGAN(tgt_pth, images, labels, epoch):

    fig, axes = plt.subplots(2, 5, figsize=(20, 18))
    
    fig.suptitle('Epoch {}'.format(str(epoch).zfill(4)))

    for row, axe in enumerate(axes):
        for col, cell in enumerate(axe):
            cell.imshow(
                images[row * 5 + col],
                cmap='gray'
            )
            
            cell.set_title('{}'.format(
                torch.argmax(labels[row * 5 + col])
            ))

            cell.axis("off")


    plt.axis("off")
    plt.tight_layout()

    fig.savefig(os.path.join(tgt_pth, '{}.jpg'.format(str(epoch).zfill(3))))
    
    plt.close()

#### Onehot Encoding

In [None]:
def encodeOneHot(labels):
    ret = torch.FloatTensor(labels.shape[0], label_dim)
    ret.zero_()
    ret.scatter_(dim=1, index=labels.view(-1, 1), value=1)
    return ret

#### Train GANs

In [None]:
real_labels = torch.ones(BATCH_SIZE, 1).to(device)
fake_labels = torch.zeros(BATCH_SIZE, 1).to(device)

test_z = (2 * torch.randn(10, noise_dim) - 1).to(device)
test_y = encodeOneHot(torch.tensor(np.arange(0, 10))).to(device)

num_epochs = 256
num_steps = len(train_loader) // BATCH_SIZE

In [None]:
import os

visuals_dir = 'visuals-section-3-lecture-2-c'

if not os.path.exists(visuals_dir):
    os.mkdir(visuals_dir)

In [None]:
d_loss_ls = []
g_loss_ls = []
d_lr_ls = []
g_lr_ls = []


for epoch in range(num_epochs):
    
    # Loss Log
    d_counter = 0
    g_counter = 0
    d_loss = 0
    g_loss = 0

    for i, (images, labels) in enumerate(train_loader):

        if i == num_steps:
            break

        # Train Discriminator
        for _ in range(4):
        
            real_images = images.to(device)
            real_conditions = encodeOneHot(labels).to(device)
            
            fake_conditions = encodeOneHot(
                torch.randint(0, 10, (BATCH_SIZE,))
            ).to(device)

            fake_images = generator(
                (2 * torch.randn(BATCH_SIZE, noise_dim) - 1)
                .to(device),
                fake_conditions
            )

            discriminator_optimizer.zero_grad()
            
            real_outputs = discriminator(
                real_images, real_conditions)
            fake_outputs = discriminator(
                fake_images, fake_conditions)
            
            d_x = criterion(real_outputs, real_labels)
            d_g_z = criterion(fake_outputs, fake_labels)

            d_x.backward()
            d_g_z.backward()

            discriminator_optimizer.step()
            
            # Loss Log
            d_counter += 1
            d_loss = d_x.item() + d_g_z.item()


        # Train Generator
        z = (2 * torch.randn(BATCH_SIZE, noise_dim) - 1).to(device)
        y = encodeOneHot(torch.randint(0, 10, (BATCH_SIZE,))).to(device)

        generator.zero_grad()

        outputs = discriminator(generator(z, y), y)

        loss = criterion(outputs, real_labels)

        loss.backward()

        generator_optimizer.step()
        
        # LR Decay
#         discriminator_scheduler.step()
#         generator_scheduler.step()
        
        # Loss Log
        g_counter += 1
        g_loss += loss.item()

    # Loss Log
    if epoch % 10 == 0:
        print(
            'e:{}, G:{:.3f}, D:{:.3f}'.format(
                epoch,
                g_loss / g_counter,
                d_loss / d_counter
#                 generator_scheduler.get_lr(),
#                 discriminator_scheduler.get_lr()
            )
        )
    
    # Loss Log for Plot
    g_loss_ls.append(g_loss / g_counter)
    d_loss_ls.append(d_loss / d_counter)
    
    # Learning Rate Decay Log
#     g_lr_ls.append(generator_scheduler.get_lr())
#     d_lr_ls.append(discriminator_scheduler.get_lr())


    # Visualize Results
    if epoch % 5 == 0:

        generated = generator(test_z, test_y).detach().cpu().view(-1, 28, 28)

        visualizeGAN(visuals_dir, generated, test_y, epoch)

In [None]:
# Visualize Results
generated = generator(test_z, test_y)
                .detach().cpu().view(-1, 28, 28)

visualizeGAN(visuals_dir, generated, test_y, epoch)

#### Visualize Loss

In [None]:
fig = plt.figure(figsize=(16, 10))
plt.plot(d_loss_ls, label='D Loss')
plt.plot(g_loss_ls, label='G Loss')
plt.legend()
plt.show();

#### Visualize Learning Rate Decay

In [None]:
fig = plt.figure(figsize=(16, 10))
plt.plot(d_lr_ls, label='D LR')
plt.plot(g_lr_ls, label='G LR')
plt.legend()
plt.show();

#### Visualize Outputs

In [None]:
# Visualize Results
test_z = (2 * torch.randn(10, noise_dim) - 1).to(device)

generated = generator(test_z, test_y).detach().cpu().view(-1, 1, 28, 28)

grid = torchvision.utils.make_grid(
    generated,
    nrow=5,
    padding=10,
    pad_value=1
)

img = np.transpose(
    grid.numpy(),
    (1, 2, 0)
)

fig = plt.figure(figsize=(16, 16))
plt.axis("off")
plt.imshow(img);

#### Google Collaboratory

Notebook: https://colab.research.google.com/drive/1O0Id95mJUZLsxu3AJy8xCNMm5phbeYh7