In [1]:
import torch
import inspect
import torch.optim as optim
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.modules.activation as A
from tqdm import tqdm
%pylab inline
import torchvision
import torchvision.transforms as transforms

Populating the interactive namespace from numpy and matplotlib


In [15]:
mnist_data = torchvision.datasets.MNIST('./datasets/mnist', download=True, train=True,
                                           transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))
mnist_loader = torch.utils.data.DataLoader(mnist_data, batch_size=32, shuffle=True)
cifar_data = torchvision.datasets.CIFAR10('./datasets/cifar10', download=True, train=True, \
                                          transform=transforms.Compose([
                                              transforms.ToTensor(), 
                                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
cifar_loader = torch.utils.data.DataLoader(cifar_data, batch_size=32, shuffle=True)

Files already downloaded and verified


In [3]:
list(mnist_loader)[3][0].shape

torch.Size([32, 1, 28, 28])

In [5]:
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

# cifar10_path = "./cifar-10-batches-py/"
# batch1 = unpickle(cifar10_path + 'data_batch_1')
# import matplotlib.pyplot as plt
# b = batch1[b'data'][4].reshape(3, 32, 32).transpose(1,2,0).astype("uint8")
# # plt.imshow(b)
# b1_reshaped = batch1[b'data'].reshape(batch1[b'data'].shape[0], 3, 32, 32)
# b1_reshaped[0].shape
# plt.imshow(b1_reshaped[1000].transpose(1, 2, 0))
# b1_reshaped[0]

In [31]:
def makeChannels(max_channel, reverse=False, num_channels=3):
    assert max_channel % 2==0
    channels = [max_channel] # Stores channels for layers 1->4
    for i in range(1,num_channels):
        channels.append(channels[i-1]//2)
    if reverse:
        channels.reverse()
        return channels
    return channels

class GeneratorNet(nn.Module):
    def __init__(self, cz=100, gf_dim=64, img_channels=3):
        super(GeneratorNet, self).__init__()
        channels = makeChannels(4*gf_dim, reverse=True)
        self.model = nn.Sequential(
            # Project and reshape
            nn.ConvTranspose2d(cz, channels[0], 4, 1, 0, bias=False),
            A.ReLU(),
            nn.BatchNorm2d(channels[0]),
            # Conv1. (?, channels[0], 4, 4) -> (?, channels[1], 8, 8)
            nn.ConvTranspose2d(channels[0], channels[1], 4, 2, 1, bias=False),
            A.ReLU(),
            nn.BatchNorm2d(channels[1]),
            # Conv2. (?, channels[1], 8, 8) -> (?, channels[2], 16, 16)
            nn.ConvTranspose2d(channels[1], channels[2], 4, 2, 1, bias=False),
            A.ReLU(),
            nn.BatchNorm2d(channels[2]),
            ### MNIST (?, channels[2], 16, 16) -> (?, img_channels=1, 28, 28) ### 
            nn.ConvTranspose2d(channels[2], img_channels, 4, 2, 3, bias=False),
            A.Tanh()
            ### MNIST ### 

            ### CIFAR (?, channels[2], 16, 16) -> (?, img_channels=3, 32, 32)###
#             nn.ConvTranspose2d(channels[2], img_channels, 4, 2, 1, bias=True),            
#             A.Tanh()
            ### CIFAR ###
            # Conv3. (?, channels[2], 16, 16) -> (?, channels[3], 32, 32)
#             nn.ConvTranspose2d(channels[2], channels[3], 4, 2, 1, bias=True),
#             A.ReLU(),
#             nn.BatchNorm2d(channels[3]),
            # Conv4. (?, channels[3], 32, 32) -> (?, 3, 64, 64)
#             nn.ConvTranspose2d(channels[2], img_channels, 4, 2, 1, bias=True),
#             A.Tanh()
        )
        
    def forward(self, z):
        """
        Input: 
            - z: minibatch of 1D noise vectors/"codes"
        Shapes:
            - z: (N, C) where N is batch_size and C is channels of z.
        """
        return self.model(z)

class DiscriminatorNet(nn.Module):
    def __init__(self, gf_dim=64, img_channels=3):
        super(DiscriminatorNet, self).__init__()
        channels = makeChannels(4*gf_dim)
        self.model = nn.Sequential(
            # Project image 
            nn.Conv2d(img_channels, channels[0], 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2),
            nn.Conv2d(channels[0], channels[1], 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(channels[1]),
            nn.Conv2d(channels[1], channels[2], 4, 2, 1, bias=False),
            nn.BatchNorm2d(channels[2]),
            ### MNIST ###
            nn.Conv2d(channels[2], 1, 3, 1, 0, bias=False),
            nn.Sigmoid()
            ### MNIST ### 
            
            ### CIFAR10 ###
#             nn.Conv2d(channels[2], 1, 4, 1, 0, bias=False),
#             nn.Sigmoid()
            ### CIFAR10 ###
#             nn.Conv2d(channels[2], channels[3], 4, 1, 0, bias=True),
#             nn.Conv2d(channels[2], 1, 4, 1, 0, bias=True),
#             nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.model(x)

In [32]:
dataset = torch.utils.data.TensorDataset(torch.from_numpy(b1_reshaped))
cifar_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

NameError: name 'b1_reshaped' is not defined

In [33]:
batch_size=64
img_channels=1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
G=GeneratorNet(cz=100, gf_dim=64, img_channels=img_channels)
D=DiscriminatorNet(img_channels=img_channels)
gopt = optim.Adam(G.parameters(), lr=0.0002, betas=(.5, .999))
dopt = optim.Adam(D.parameters(), lr=0.0002, betas=(.5, .999))

criterion = nn.BCELoss(reduction='elementwise_mean')

In [43]:
mbatch_mod = 10
cz=100
epoch_mod = 2

### LOAD CHECKPOINT ###
checkG = torch.load("./chkpts/G_epoch0_mbatch60.pth")
G.load_state_dict(checkG)
checkD = torch.load("./chkpts/D_epoch0_mbatch60.pth")
D.load_state_dict(checkD)

for epoch in range(5):
    for mbatch_idx, x in tqdm(enumerate(mnist_loader)):
        ##############################
        # Update discriminator 
        ##############################
        x = x[0]
        x = x.to(device, dtype=torch.float32)
        real_loss = criterion(D(x), torch.ones([x.shape[0], 1, 1, 1]))
        real_loss.backward()
        
        z = torch.zeros([x.shape[0], 100, 1, 1]).uniform_(0, 1)
        generated_loss = criterion(D(G(z)), torch.zeros([x.shape[0], 1, 1, 1]))
        generated_loss.backward()

        dloss = real_loss.item() + generated_loss.item()

        dopt.step()
        dopt.zero_grad()
        
        ##############################
        # Update generator
        ##############################
        z = torch.zeros([x.shape[0], cz, 1, 1]).uniform_(0, 1) 
#         gloss = torch.Tensor([-1.0]) * criterion(D(G(z)), torch.zeros([x.shape[0], 1, 1, 1]))   
        gloss = criterion(D(G(z)), torch.ones([x.shape[0], 1, 1, 1]))
        gloss.backward()
        
        gopt.step()
        gopt.zero_grad()
#         import ipdb; ipdb.set_trace()
        if (mbatch_idx % mbatch_mod == 0 and mbatch_idx > 0 and epoch % epoch_mod == 0):
            print(F"Epoch: {epoch}, minibatch: {mbatch_idx}, DLoss: {dloss}, GLoss: {gloss}")
            checkpt_dir = "./chkpts"
            torch.save(G.state_dict(), F"{checkpt_dir}/G_epoch{epoch}_mbatch{mbatch_idx}.pth")
            torch.save(D.state_dict(), F"{checkpt_dir}/D_epoch{epoch}_mbatch{mbatch_idx}.pth")


0it [00:00, ?it/s][A
1it [00:01,  1.16s/it][A
2it [00:02,  1.12s/it][A
[A

KeyboardInterrupt: 

In [None]:
testimg = list(mnist_loader)[0][0]

In [None]:
z = torch.zeros([32, 100, 1, 1]).uniform_(0, 1)
G(z)