In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torchvision.utils import make_grid, save_image

import os
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

%matplotlib inline

In [2]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print('Running on device:', device)
if use_cuda:
    print('Using GPU:',
          torch.cuda.get_device_name(torch.cuda.current_device()))

Running on device: cuda:0
Using GPU: NVIDIA TITAN RTX


In [3]:
root = '/home/therock/data2/devnagari_data/'

expr_name = 'devnagari_cnn_vae'
model_name = expr_name + '_PyTorch_model.pt'

In [4]:
batch_size = 256
# each image in dataset is 32x32 pixels
image_dim = 32
learning_rate = 0.001
num_epochs = 100

train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize(image_dim),
    transforms.ToTensor(),
])


train_data = datasets.ImageFolder(os.path.join(root, 'Train'),
                                  transform=train_transform)

train_data_len = len(train_data)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

In [5]:
decoded_out_dir = expr_name + '_decoded'
if not os.path.exists(decoded_out_dir):
    os.mkdir(decoded_out_dir)


def to_img(x):
    x = x.view(x.size(0), 1, 32, 32)
    return x

In [6]:
def calc_conv_out(n=1, p=1, f=1, s=1):
        return int(((n + 2 * p - f) / s) + 1)

def calc_deconv_out(n=1, p=1, f=1, s=1):
        return int(s * (n - 1) + f - 2 * p)

In [7]:
conv1 = calc_conv_out(n=32, f=3, s=1, p=0)
mp1 = calc_conv_out(n=conv1, f=2, s=2, p=0)
conv2 = calc_conv_out(n=mp1, f=5, s=1, p=0)
mp2 = calc_conv_out(n=conv2, f=2, s=2, p=0)
deconv1 = calc_deconv_out(n=mp2, f=3, s=1, p=0)
deconv2 = calc_deconv_out(n=deconv1, f=3, s=2, p=0)
deconv3 = calc_deconv_out(n=deconv2, f=3, s=2, p=0)


print('1:', conv1)
print('2:', mp1)
print('3:', conv2)
print('4:', mp2)
print('de 5:', deconv1)
print('de 6:', deconv2)
print('de 7:', deconv3)

1: 30
2: 15
3: 11
4: 5
de 5: 7
de 6: 15
de 7: 31


In [8]:
class VAEEncoder(nn.Module):
    def __init__(self, latent_dim=64):

        super(VAEEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=1, padding=0),  # b, 32, 30,30
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  # b, 32, 15, 15
            nn.Conv2d(16, 8, 5, stride=1, padding=0),  # b, 16, 11, 11
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2)  # b, 8, 5, 5
        )
        self.mu = nn.Linear(8*5*5, latent_dim)
        self.var = nn.Linear(8*5*5, latent_dim)

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.shape[0],8*5*5)
        latent_mu = self.mu(x)
        latent_var = self.var(x)
        return latent_mu, latent_var

class VAEDecoder(nn.Module):
    def __init__(self, latent_dim=10):

        super(VAEDecoder, self).__init__()
        self.l = nn.Linear(latent_dim, 8*5*5)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 4, 3, stride=1, padding=0),  # b, 12, 11,11
            nn.ReLU(True),
            nn.ConvTranspose2d(4, 2, 3, stride=2, padding=0),  # b, 6, 17, 17
            nn.ReLU(True),
            nn.ConvTranspose2d(2, 1, 3, stride=2, padding=0,
                               output_padding=1),  # b, 1, 32, 32
            nn.Tanh()
        )

    def forward(self, x):
        x = self.l(x)
        x = x.view(x.shape[0],8,5,5)
        x = self.decoder(x)
        pred = torch.sigmoid(x)
        return pred


class VAE(nn.Module):
    def __init__(self, encd, decd):
        super(VAE, self).__init__()

        self.encoder = encd
        self.decoder = decd

    def forward(self, x):
        latent_mu, latent_var = self.encoder(x)

        std = torch.exp(latent_var / 2)
        eps = torch.randn_like(std)
        x_sample = eps.mul(std).add_(latent_mu)

        # decode
        predicted = self.decoder(x_sample)
        return predicted, latent_mu, latent_var

In [9]:
in_dim = image_dim * image_dim
latent_dim = 64
encoder = VAEEncoder(latent_dim=latent_dim)
decoder = VAEDecoder(latent_dim=latent_dim)

model = VAE(encoder, decoder)
if use_cuda:
    model = model.to(device)
print(model)
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
lowest_loss = float('inf')

VAE(
  (encoder): VAEEncoder(
    (encoder): Sequential(
      (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(16, 8, kernel_size=(5, 5), stride=(1, 1))
      (4): ReLU(inplace=True)
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (mu): Linear(in_features=200, out_features=64, bias=True)
    (var): Linear(in_features=200, out_features=64, bias=True)
  )
  (decoder): VAEDecoder(
    (l): Linear(in_features=64, out_features=200, bias=True)
    (decoder): Sequential(
      (0): ConvTranspose2d(8, 4, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU(inplace=True)
      (2): ConvTranspose2d(4, 2, kernel_size=(3, 3), stride=(2, 2))
      (3): ReLU(inplace=True)
      (4): ConvTranspose2d(2, 1, kernel_size=(3, 3), stride=(2, 2), output_padding=(1, 1))
      (5): Tanh()
    )
  )
)


In [10]:
def count_parameters(model):
    params = [p.numel() for p in model.parameters() if p.requires_grad]
    for i, item in enumerate(params):
        print(f'{i:2} : {item:}')
    print(f'==========\n{sum(params):>6}')


count_parameters(model)

 0 : 144
 1 : 16
 2 : 3200
 3 : 8
 4 : 12800
 5 : 64
 6 : 12800
 7 : 64
 8 : 12800
 9 : 200
10 : 288
11 : 4
12 : 72
13 : 2
14 : 18
15 : 1
 42481


In [11]:
def train(image_dim, train_loader, model, e):
    global lowest_loss
    # set the train mode
    model.train()

    # loss of the epoch
    train_loss = 0

    for i, (x, _) in enumerate(train_loader):

        x = x.to(device)

        # forward pass
        x_sample, latent_mu, latent_var = model(x)

        if e % 10 == 0:
            pic = to_img(x_sample.cpu().data)
            save_image(pic, './{}/image_{}.png'.format(decoded_out_dir, e))

        # reconstruction loss
        recon_loss = F.binary_cross_entropy(x_sample, x, reduction='sum')
        # kl divergence loss
        kl_loss = 0.5 * torch.sum(
            torch.exp(latent_var) + latent_mu**2 - 1.0 - latent_var)

        # total loss
        loss = recon_loss + kl_loss

        train_loss += loss.item()
        if loss.item() < lowest_loss:
            lowest_loss = loss.item()
            torch.save(model.state_dict(), model_name)
            #print(f"saved model")

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    return train_loss

In [12]:
best_test_loss = float('inf')

for e in range(num_epochs):

    train_loss = train(image_dim, train_loader, model, e)
    train_loss /= len(train_data)
    print(f'Epoch {e}, Train Loss: {train_loss}')

torch.save(model.state_dict(), model_name)

Epoch 0, Train Loss: 629.9134830562659
Epoch 1, Train Loss: 567.1098889066496
Epoch 2, Train Loss: 557.3271369485294
Epoch 3, Train Loss: 552.5030888347187
Epoch 4, Train Loss: 548.8391469189578
Epoch 5, Train Loss: 545.4971893981777
Epoch 6, Train Loss: 542.7125213295236
Epoch 7, Train Loss: 539.8243871882993
Epoch 8, Train Loss: 535.8824558423913
Epoch 9, Train Loss: 530.3134697190696
Epoch 10, Train Loss: 526.1031888886669
Epoch 11, Train Loss: 500.0064392183504
Epoch 12, Train Loss: 489.4164639945652
Epoch 13, Train Loss: 484.5542238451087
Epoch 14, Train Loss: 481.6919627657449
Epoch 15, Train Loss: 479.42204433743603
Epoch 16, Train Loss: 477.01482731577687
Epoch 17, Train Loss: 475.24334119245526
Epoch 18, Train Loss: 473.7912414082481
Epoch 19, Train Loss: 472.84519516264385
Epoch 20, Train Loss: 472.0948720728101
Epoch 21, Train Loss: 471.2701434622762
Epoch 22, Train Loss: 470.48918762987535
Epoch 23, Train Loss: 469.9128294337436
Epoch 24, Train Loss: 469.3993243985774
Epoch

In [14]:
sample_batches = 20
decoded_data = torch.FloatTensor(batch_size,1,image_dim,image_dim)

for sb_ in range(sample_batches):
    for i in range(batch_size):
        z = torch.randn(1, latent_dim).to(device)
        reconstructed_img = model.decoder(z).to('cpu')
        #print(f"reconstructed_img {reconstructed_img.shape}")
        img = reconstructed_img.view(image_dim, image_dim).data

    pic = to_img(decoded_data)
    save_image(pic, './{}/image_decoded_{}.png'.format(decoded_out_dir,sb_))
    