<a href="https://colab.research.google.com/github/satvikk/ai_synthesize/blob/master/learn1_customGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

GAN on MNIST

In [0]:
import torch as t
import plotly.graph_objects as go
import torchvision.datasets as datasets
t.set_default_tensor_type(t.cuda.FloatTensor)
import torch.nn.functional as F
import tqdm

In [0]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)
def full_noise(noise_size = 1, dim = 28):
  return t.floor(t.rand(noise_size,dim,dim)*255)/255

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


In [0]:
def fixed_leaky_relu(x, dummy = None):
  return F.leaky_relu(x, 0.2)
class Apartheid(t.nn.Module):
  def __init__(self):
    super().__init__()
    self.conv0 = t.nn.Conv2d(1,1, kernel_size=(5,5), padding = 2)
    self.conv1 = t.nn.Conv2d(1,1, kernel_size=(5,5), padding = 2)
    self.lin0 = t.nn.Linear(7*7, 128)
    self.lin1 = t.nn.Linear(128, 2)
    self.dropout0 = t.nn.Dropout(0.1)
    self.dropout1 = t.nn.Dropout(0.1)
  def forward(self, x):
    x = x.unsqueeze(1)
    x = self.dropout0(fixed_leaky_relu(self.conv0(x), 0.04))
    x = F.max_pool2d(x, 2)
    x = self.dropout0(fixed_leaky_relu(self.conv1(x), 0.04))
    x = F.max_pool2d(x, 2)
    x = x.reshape(-1,1,7*7)
    x = self.dropout1(fixed_leaky_relu(self.lin0(x), 0.04))
    x = self.dropout1(self.lin1(x))
    x = F.log_softmax(x, 2).squeeze(1)
    return x

class Reproduction(t.nn.Module):
  def __init__(self):
    super().__init__()
    self.lin0 = t.nn.Linear(100,7*7*128)
    self.convt0 = t.nn.ConvTranspose2d(128,1,4,2,1)
    self.convt1 = t.nn.ConvTranspose2d(1,1,4,2,1)
    self.conv0 = t.nn.Conv2d(1,1,3,padding = 1)
    self.dropout = t.nn.Dropout(0.4)
  def forward(self,x):
    x = self.dropout(F.leaky_relu(self.lin0(x), 0.04))
    x = x.reshape(-1,128,7,7)
    x = self.dropout(fixed_leaky_relu(self.convt0(x), 0.04))
    x = self.dropout(fixed_leaky_relu(self.convt1(x), 0.04))
    x = t.sigmoid(self.conv0(x))
    return x

discriminator = Apartheid()
generator = Reproduction()
def generate(generator, noise_size):
  return generator(t.randn(noise_size,100)).squeeze(1)

In [0]:
class datamaker(t.utils.data.Dataset):
    def __init__(self, mnist, noise_size = 0, generator_model = None):
        self.mnist = mnist
        self.noise_size = noise_size
        self.noise_data = full_noise(noise_size, dim = 28)

    def __len__(self):
        return len(self.mnist) + self.noise_size
    
    def __getitem__(self, idx):
        if idx < len(self.mnist):
          return {'x': self.mnist.data[idx].cuda().float()/255, 'y': t.tensor(1)}
        else:
          return {'x': self.noise_data[idx-len(self.mnist),:,:], 'y': t.tensor(0)}
batch_size = 40
test_size = 5000
dataloader = t.utils.data.DataLoader(datamaker(mnist_trainset, noise_size = 0), batch_size=batch_size,shuffle=True,)
testloader = t.utils.data.DataLoader(datamaker(mnist_testset, noise_size = 0), batch_size=test_size,shuffle=True,)

In [0]:
n_epochs = 15
init_lr = 0.002
ema_loss = 2.4
gradients_d = []
gradients_g = []

optimizer_d = t.optim.Adam(discriminator.parameters(), lr = init_lr, betas=(0.9,0.999))
optimizer_g = t.optim.Adam(generator.parameters(), lr = init_lr, betas=(0.9,0.999))
for epochs in range(n_epochs):
  for i,dat in enumerate(dataloader):
    discriminator.train()
    generator.train()
    optimizer_d.zero_grad()
    inp_d = t.cat([dat['x'], generate(generator,batch_size)], dim = 0)
    output_d = discriminator(inp_d)
    loss_d = F.nll_loss(output_d, t.cat([t.ones(batch_size), t.zeros(batch_size)],dim=0).long())
    loss_d.backward()
    optimizer_d.step()
    # gradients_d.append(t.tensor([t.mean(i.grad**2) for i in discriminator.parameters()]))

    optimizer_g.zero_grad()
    output_g = generate(generator,batch_size)
    loss_g = F.nll_loss(discriminator(output_g), t.ones(batch_size).long())
    loss_g.backward()
    optimizer_g.step()
    # gradients_g.append(t.tensor([t.mean(i.grad**2) for i in generator.parameters()]))

    if i%int(60000/4/batch_size)==0:
      discriminator.eval()
      generator.eval()
      with t.no_grad():
        testset = next(iter(testloader))
        inp_d = t.cat([testset['x'], generate(generator,test_size)], dim = 0)
        output_d = discriminator(inp_d).argmax(1)
        accuracy_d = t.sum(t.cat([t.ones(test_size), t.zeros(test_size)],dim=0).long() == output_d).float()/(test_size*2)
        output_t =  discriminator(generate(generator,test_size)).argmax(1)
        accuracy_t = t.sum(t.zeros(test_size).long() == output_t).float()/(test_size)
        print(epochs," | ", i," | ",accuracy_d," | ",accuracy_t)

# gradients_d= t.sqrt(t.cat([i.unsqueeze(1) for i in gradients_d],dim = 1))
# gradients_g= t.sqrt(t.cat([i.unsqueeze(1) for i in gradients_g],dim = 1))

0  |  0  |  tensor(0.6402)  |  tensor(0.5124)
0  |  375  |  tensor(0.6600)  |  tensor(0.5106)
0  |  750  |  tensor(0.6342)  |  tensor(0.4148)
0  |  1125  |  tensor(0.6154)  |  tensor(0.3934)
1  |  0  |  tensor(0.6938)  |  tensor(0.5604)
1  |  375  |  tensor(0.6101)  |  tensor(0.3188)
1  |  750  |  tensor(0.6427)  |  tensor(0.4174)
1  |  1125  |  tensor(0.6419)  |  tensor(0.4588)
2  |  0  |  tensor(0.6142)  |  tensor(0.4032)
2  |  375  |  tensor(0.6667)  |  tensor(0.5120)
2  |  750  |  tensor(0.6606)  |  tensor(0.5182)
2  |  1125  |  tensor(0.6303)  |  tensor(0.4546)
3  |  0  |  tensor(0.6201)  |  tensor(0.4164)
3  |  375  |  tensor(0.6102)  |  tensor(0.3770)
3  |  750  |  tensor(0.5938)  |  tensor(0.2996)
3  |  1125  |  tensor(0.6734)  |  tensor(0.6254)
4  |  0  |  tensor(0.6272)  |  tensor(0.3688)
4  |  375  |  tensor(0.6102)  |  tensor(0.3554)
4  |  750  |  tensor(0.6002)  |  tensor(0.3012)
4  |  1125  |  tensor(0.6221)  |  tensor(0.4166)
5  |  0  |  tensor(0.6464)  |  tensor(0.4798)

In [0]:
generator.eval()
discriminator.eval()
image = generate(generator,1)
while discriminator(image).argmax(1) != 1:
  image = generate(generator,1)
print(discriminator(image).argmax(1))
grid20x = t.cat([t.linspace(0,27,28)]*28)
grid20y = t.cat([t.linspace(27,0,28).unsqueeze(1)]*28, dim = 1).flatten()
fig = go.Figure()
fig.add_scatter(
    x = grid20x.cpu(),
    y = grid20y.cpu(),
    mode = "markers",
    marker = dict(
        color = image[0,:,:].detach().cpu().flatten(),
        showscale=True,
        colorscale = "gray",
        symbol = "square",
        size = 15,
    )
)
fig.update_layout(
    yaxis = dict(
      scaleanchor = "x",
      scaleratio = 1,
    ),
)

fig.show()

tensor([1])


In [0]:
fig = go.Figure()
fig.add_scatter(
    x = t.linspace(0,gradients_g.shape[1]-1,gradients_g.shape[1]).cpu(),
    y = gradients_d[3,:].cpu().detach(),
)
fig.show()

In [0]:
irow = int((t.rand(1)*60000).detach())
grid20x = t.cat([t.linspace(0,27,28)]*28)
grid20y = t.cat([t.linspace(27,0,28).unsqueeze(1)]*28, dim = 1).flatten()
fig = go.Figure()
fig.add_scatter(
    x = grid20x.cpu(),
    y = grid20y.cpu(),
    mode = "markers",
    marker = dict(
        color = mnist_trainset.data[irow,:,:].flatten().cpu(),
        showscale=True,
        colorscale = "gray",
        symbol = "square",
        size = 15,
    )
)
fig.update_layout(
    yaxis = dict(
      scaleanchor = "x",
      scaleratio = 1,
    )
)
fig.show()