In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import torch.nn.functional as F

def check_cuda():
    return "CUDA ON" if torch.cuda.is_available() else "NO CUDA :("

check_cuda()

'CUDA ON'

In [2]:
train = pd.read_csv('/kaggle/input/digit-recognizer/train.csv')
test = pd.read_csv('/kaggle/input/digit-recognizer/test.csv')
train.shape, test.shape

((42000, 785), (28000, 784))

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

data = np.concatenate((np.array(train)[:,1:], np.array(test)), axis=0) / 255.0

class MNIST(Dataset):
    def __init__(self, data, transform):
        self.data = data
        self.transform = transform
        
    def __getitem__(self, index):
        if self.transform:
            return np.asarray(self.transform(self.data[index, ]))
        else:
            return np.asarray(self.data[index, ])
        
    def __len__(self):
        return self.data.shape[0]
    
dataset = MNIST(data.reshape(-1,28,28), transform=False)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0)

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.main_module = nn.Sequential(
            nn.ConvTranspose2d(in_channels=100, out_channels=1024, kernel_size=7, stride=1, padding=0),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
        
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
        )
        
        self.out = nn.Tanh()
        
    def forward(self, x):
        x = self.main_module(x)
        x = self.out(x)
        return x
    
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.main_module = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.out = nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=3, stride=1, padding=1)
        self.l = nn.Linear(4*4, 1)
        
    def forward(self, x):
        x = self.main_module(x)
        x = self.out(x)
        x = x.view(-1, 4*4)
        x = self.l(x)
        return x

In [5]:
G = Generator().cuda()
D = Discriminator().cuda()
G_optim = optim.RMSprop(G.parameters(), lr=0.00005)
D_optim = optim.RMSprop(D.parameters(), lr=0.00005)

In [6]:
import time
epochs = 100
critic_iter = 5
clamp = 0.01

for epoch in range(epochs):
    t0 = time.time()
    t = torch.FloatTensor([1]).cuda()
    f = t * (-1)
    
    for image in dataloader:
        image = image.view(-1, 1, 28, 28).float()
    # Train Critic 5 times per Generator train
    
        for p in D.parameters():
            p.requires_grad = True
            
        for d_iter in range(critic_iter):

            D.zero_grad()

            # Weight Clipping
            for p in D.parameters():
                p.data.clamp_(-clamp, clamp)


            z = torch.randn((64, 100, 1, 1)) # Batch_size 64
            image = Variable(image.cuda())
            z = Variable(z.cuda())

            d_loss_real = torch.sum(-D(image) * t)
#             d_loss_real = 
            d_loss_real.backward()

            fake_image = G(z)
            d_loss_fake = torch.sum(-D(fake_image) * f)
#             d_loss_fake = d_loss_fake.mean(0).view(1)
            d_loss_fake.backward()

            d_loss = d_loss_real + d_loss_fake
            wasserstein_d = d_loss_real - d_loss_fake
            D_optim.step()

            # Training Generator
        for p in D.parameters():
            p.requires_grad = False

        G.zero_grad()

        z = Variable(torch.randn((64, 100, 1, 1)).cuda())
        fake_image = G(z)
        g_loss = torch.sum(D(fake_image) * f)
#         g_loss = g_loss.mean().mean(0).view(1)
        g_loss.backward()
        g_cost = -g_loss
        G_optim.step()
        
    t1 = time.time()
    
    print("Current Epoch: {}".format(epoch+1))
    print("Time Spent: {:.1f}s".format(t1-t0))
    print("Critic Loss: {:.3f}".format(d_loss.item()))
    print("Generator Loss: {:.3f}".format(g_loss.item()))
    print()

Current Epoch: 1
Time Spent: 712.5s
Critic Loss: -8.331
Generator Loss: 4.093

Current Epoch: 2
Time Spent: 711.6s
Critic Loss: -8.080
Generator Loss: 4.150

Current Epoch: 3
Time Spent: 711.5s
Critic Loss: -8.089
Generator Loss: 4.098

Current Epoch: 4
Time Spent: 711.6s
Critic Loss: -8.231
Generator Loss: 4.044

Current Epoch: 5
Time Spent: 711.6s
Critic Loss: -8.086
Generator Loss: 4.038

Current Epoch: 6
Time Spent: 710.2s
Critic Loss: -8.108
Generator Loss: 4.142

Current Epoch: 7
Time Spent: 709.8s
Critic Loss: -8.235
Generator Loss: 4.150

Current Epoch: 8
Time Spent: 709.8s
Critic Loss: -8.122
Generator Loss: 4.169

Current Epoch: 9
Time Spent: 709.7s
Critic Loss: -8.263
Generator Loss: 4.168

Current Epoch: 10
Time Spent: 710.2s
Critic Loss: -8.012
Generator Loss: 4.088

Current Epoch: 11
Time Spent: 710.1s
Critic Loss: -8.107
Generator Loss: 4.146

Current Epoch: 12
Time Spent: 710.3s
Critic Loss: -8.226
Generator Loss: 4.118

Current Epoch: 13
Time Spent: 709.9s
Critic Loss:

KeyboardInterrupt: 