In [5]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import torchvision
import matplotlib.pyplot as plt
from torchbearer import Callback

# import cv2
import torchvision
import torchvision.transforms as transforms

# Albumentations for augmentations

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [2]:
# !install_package_python310.sh add albumentations
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [3]:
# import cv2
import torchvision
import torch
import torchvision.transforms as transforms

# Albumentations for augmentations

import albumentations as A
from albumentations.pytorch import ToTensorV2

# cv2.setNumThreads(0)
# cv2.ocl.setUseOpenCL(False)


class Cifar10SearchDataset(torchvision.datasets.CIFAR10):
    def __init__(self, root="~/data/cifar10", train=True, download=True, transform=None):
      super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index):
        image, label = self.data[index], self.targets[index]
        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]
        return image, label


train_transforms = A.Compose(
    [
      # A.RandomCrop(width=16, height=16),
      A.HorizontalFlip(p=0.5),
      A.CoarseDropout(max_holes = 1, max_height=16, max_width=16, min_holes = 1, min_height=16, min_width=16, fill_value=(0.5, 0.5, 0.5), mask_fill_value = None),
      A.ShiftScaleRotate(),
      # A.RandomBrightnessContrast(p=0.2),
      A.Normalize((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)),
      ToTensorV2(),
    ],
    p=1.0,
)

test_transforms = A.Compose([
    A.Normalize((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)),
      ToTensorV2(),
], p=1.0,
)


class args:
    def __init__(self, device="cpu", use_cuda=False) -> None:
        self.batch_size = 64
        self.device = device
        self.use_cuda = use_cuda
        self.kwargs = {"num_workers": 1, "pin_memory": True} if self.use_cuda else {}

trainset = Cifar10SearchDataset(
    root="./data", train=True, download=True, transform=train_transforms
)


trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=args().batch_size, shuffle=True, **args().kwargs
)

testset = Cifar10SearchDataset(
    root="./data", train=False, download=True, transform=test_transforms
)


testloader = torch.utils.data.DataLoader(
    testset, batch_size=args().batch_size, shuffle=True, **args().kwargs
)


classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [6]:
dropout_value = 0.0

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        self.convblock_0 = nn.Sequential(
                       nn.Conv2d(in_channels=3,out_channels=16,kernel_size=(3,3),dilation=1,stride=1,padding=1,bias=False,),
                       nn.ReLU(),
                       nn.BatchNorm2d(16),
                       nn.Dropout(dropout_value), # Input - 32x32x3 | Output - 32X32X16 | RF=3

                       nn.Conv2d(in_channels=16,out_channels=32,kernel_size=(3,3),dilation=1,stride=1,padding=1,bias=False,),
                       nn.ReLU(),
                       nn.BatchNorm2d(32),
                       nn.Dropout(dropout_value), # Input - 32X32X16 | Output - 32X32x32 |RF=5

                       nn.Conv2d(in_channels=32,out_channels=32,kernel_size=(3,3),dilation=1,stride=1,padding=1,bias=False,),
                       nn.ReLU(),
                       nn.BatchNorm2d(32),
                       nn.Dropout(dropout_value), # Input - 32X32X32 | Output - 32X32X64 |RF= 7

                       nn.Conv2d(in_channels=32,out_channels=32,kernel_size=(3,3),dilation=1,stride=1,padding=1,bias=False,),
                       nn.ReLU(),
                       nn.BatchNorm2d(32),
                       nn.Dropout(dropout_value), # Input - 32X32X64 | Output - 32X32X64 |RF= 9
                      )
        
        # depthwise seperable Convolution 1
        self.convblock_1 = nn.Sequential(
        
                       nn.Conv2d(in_channels=32,out_channels=64,kernel_size=(3,3),stride=(2,2),dilation=1,padding=1,bias=False,),# maxpool added after RF >11
                       nn.ReLU(),
                       nn.BatchNorm2d(64),
                       nn.Dropout(dropout_value), # Input - 32X32X64 | Output - 16X16X64 |RF=11

                       nn.Conv2d(in_channels=64,out_channels=64,groups=64,kernel_size=(3,3),stride=(1,1),dilation=1,padding=1,bias=False,),
                       # Input - 16X16X64 | Output - 16X16X64 | RF=15
                       nn.Conv2d(in_channels=64,out_channels=128,kernel_size=(1,1),stride=(1,1),padding=0,bias=False,),
                       # Input - 16X16X64 | Output - 16X16X64 | RF=15
                       nn.ReLU(),
                       nn.BatchNorm2d(128), 
                       nn.Dropout(dropout_value), # 16X16X64 | RF=21                                       
                       # pointwise   

                       nn.Conv2d(in_channels=128,out_channels=128,groups=128,kernel_size=(3,3),dilation=1,stride=(1,1),padding=1,bias=False,),
                       # Input - 16X16X64 | Output - 16X16X64 | RF=29
                       nn.Conv2d(in_channels=128,out_channels=64,kernel_size=(1,1),padding=0,bias=False,),
                       nn.ReLU(),
                       nn.BatchNorm2d(64),   
                       nn.Dropout(dropout_value), 
                       # Input - 16X16X64 | Output - 16X16X32 | RF=29
                       
                      #nn.Conv2d(in_channels=64,out_channels=64,kernel_size=(3,3),stride=(1,1),dilation=2,padding=1,bias=False,),
                      # #  nn.Conv2d(in_channels=64,out_channels=64,groups=64,kernel_size=(3,3),stride=(1,1),padding=1,bias=False,),
                      # #  nn.Conv2d(in_channels=64,out_channels=64,kernel_size=(1,1),padding=0,bias=False,),
                      #  nn.ReLU(),
                      #  nn.BatchNorm2d(64),   
                      #  nn.Dropout(dropout_value) , # 16X16X64 | RF=29                                                         
                       )
        # depthwise seperable Convolution 2
        self.convblock_2 = nn.Sequential(
        
                       nn.Conv2d(in_channels=64,out_channels=32,kernel_size=(3,3),stride=(2,2),dilation=1,padding=1,bias=False,),
                       nn.ReLU(),
                       nn.BatchNorm2d(32),   
                       nn.Dropout(dropout_value), 
                      # # Input - 16X16X32 | Output - 8X8X32 | RF=37

                       nn.Conv2d(in_channels=32,out_channels=32,groups=32,kernel_size=(3,3),stride=(1,1),padding=1,bias=False,),
                       # # Input - 8X8X32 | Output - 8X8X32 | RF=45
                       nn.Conv2d(in_channels=32,out_channels=64,kernel_size=(1,1),stride=(1,1),padding=0,bias=False,),
                       # # Input - 8X8X32 | Output - 8X8X64 | RF=45
                       nn.ReLU(),
                       nn.BatchNorm2d(64),  
                       nn.Dropout(dropout_value),
                      # pointwise   

                       nn.Conv2d(in_channels=64,out_channels=64,groups=64,kernel_size=(3,3),stride=(1,1),padding=1,bias=False,),
                       # # Input - 8X8X64 | Output - 8X8X128 | RF=53
                       nn.Conv2d(in_channels=64,out_channels=64,kernel_size=(1,1),stride=(1,1),padding=0,bias=False,),
                       nn.ReLU(),
                       nn.BatchNorm2d(64),  # pointwise 
                       nn.Dropout(dropout_value) 
                       # # Input - 8X8X64 | Output - 8X8X64 | RF=53

                      )
        # depthwise seperable Convolution 2
        self.convblock_3 = nn.Sequential(
        
                       #Maxpooling
                       nn.Conv2d(in_channels=64,out_channels=64,kernel_size=(3,3),dilation=1,stride=(2,2),padding=1,bias=False),
                       nn.ReLU(),
                       nn.BatchNorm2d(64),  
                       nn.Dropout(dropout_value),
                      # # Input - 8X8X64 | Output - 4X4X64 | RF=69

                       nn.Conv2d(in_channels=64,out_channels=128,groups=64,kernel_size=(3,3),stride=(1,1),padding=1,bias=False,),
                       # # Input - 4X4X64 | Output - 4X4X128 | RF=85
                       nn.Conv2d(in_channels=128,out_channels=192,kernel_size=(1,1),stride=(1,1),padding=0,bias=False,),
                       nn.ReLU(),
                       nn.BatchNorm2d(192),
                       nn.Dropout(dropout_value),
                      #  # Input - 4X4X128 | Output - 4X4X192 | RF=85
                     
#                        nn.Conv2d(in_channels=64,out_channels=32,kernel_size=(3,3),dilation=2,stride=(1,1),padding=2,bias=False),
#                        nn.ReLU(),
#                        nn.BatchNorm2d(32),  
#                        nn.Dropout(dropout_value),  
#                        #  # Input - 4X4X128 | Output - 4X4X64 | RF=117

#                        nn.Conv2d(in_channels=32, out_channels=10, kernel_size=(1, 1), padding=0, bias=False)
#                        # Input - 4X4X32 | Output - 4X4X10 | RF=117

                       )
        # 4X4X10 | RF=121
        self.gap = nn.Sequential(nn.AvgPool2d(kernel_size=1))
        self.fc = nn.Linear(in_features = 10, out_features = 20)
        
        self.fc1 = nn.Linear(3072, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc30 = nn.Linear(256, 20)
        self.fc31 = nn.Linear(256, 20)
        self.fc32 = nn.Linear(256, 20)
        self.fc4 = nn.Linear(20, 256)
        self.fc5 = nn.Linear(256, 1024)
        self.fc6 = nn.Linear(1024, 3072)
        

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        return self.fc31(h2), self.fc32(h2)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5*logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        h3 = F.relu(self.fc4(z))
        h4 = F.relu(self.fc5(h3))
        return torch.sigmoid(self.fc6(h4))

    def forward(self, x, y):
        x = self.convblock_0(x)
        x = self.convblock_1(x)
        x = self.convblock_2(x)
        x = self.convblock_3(x)
        x = self.gap(x)
#         print(f"gap output {x.shape}")
        x = x.view(-1,3072)
#         print(f"reshaped output {x.shape}")
#         x = self.fc1(x)
#         x = self.fc2(x)
#         x = self.fc30(x)
#         x = self.fc4(x)
#         x = self.fc5(x)
#         x = self.fc6(x)
#         x = x.view(-1, 3072)
#         print(f"1st {torch.isnan(sum(x))}")        
        y = torch.nn.functional.one_hot(y, num_classes = 10) # One hot encoding of the label
        y = y.type(torch.cuda.FloatTensor)
        y = self.fc(y)
        y = self.fc4(y)
        y = self.fc5(y)
        y = self.fc6(y)
#         print(f"x.view(-1, 3072) shape {x.view(-1, 3072).shape}")
#         print(f"y shape {y.shape}")
        x = torch.add(x.view(-1, 3072), y)
#         print(f"x added view(-1, 3072) shape {x.view(-1, 3072).shape}")
        mu, logvar = self.encode(x.view(-1, 3072))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


def bce_loss(y_pred, y_true):
    BCE = F.binary_cross_entropy(y_pred, y_true.view(-1, 1024), size_average=False)
    return BCE


class AddKLDLoss(Callback):
    def on_criterion(self, state):
        super().on_criterion(state)
        KLD = self.KLD_Loss(state['mu'], state['logvar'])
        state[torchbearer.LOSS] = state[torchbearer.LOSS] + KLD

    def KLD_Loss(self, mu, logvar):
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return KLD


class SaveReconstruction(Callback):
    def __init__(self, num_images=8, folder='results/'):
        super().__init__()
        self.num_images = num_images
        self.folder = folder

    def on_step_validation(self, state):
        super().on_step_validation(state)
        if state[torchbearer.BATCH] == 0:
            data = state[torchbearer.X]
            recon_batch = state[torchbearer.Y_PRED]
            comparison = torch.cat([data[:self.num_images],
                                    recon_batch.view(128, 1, 28, 28)[:self.num_images]])
            save_image(comparison.cpu(),
                       str(self.folder) + 'reconstruction_' + str(state[torchbearer.EPOCH]) + '.png', nrow=self.num_images)


model = VAE()

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
loss = bce_loss

In [7]:
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [8]:
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 3072), size_average=False)

    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

In [9]:
def train(model, device, traingen, optimizer, epoch, train_losses):
    model.train()
    pbar = tqdm(traingen)
    for batch_idx, (images, _) in enumerate(pbar):
        x = images
        x = x.to(device)
        y = _
        y = y.to(device)
        model = model.to(device)
        output_1, mu_1, logvar_1 = model(x, y)
        loss = loss_function(output_1, x, mu_1, logvar_1)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pbar.set_description(desc= f'Loss={loss.item()} epoch={epoch}')
        train_losses.append(loss.item())
    

In [10]:
def test_model(model, device, testgen, optimizer, epoch, test_losses):
    model.eval()
    pbar = tqdm(testgen)
    test_loss = 0
    additional_train_loader_dataset = 10000
    with torch.no_grad():
        for batch_idx, (images, _) in enumerate(pbar):
            x = images
            x = x.to(device)
            y = _
            y = y.to(device)
            model = model.to(device)
            output_1, mu_1, logvar_1 = model(x, y)
            output_1 = torch.nan_to_num(output_1)
            loss = loss_function(output_1, x, mu_1, logvar_1)
            pbar.set_description(desc= f'Loss={loss.item()} epoch={epoch}')
        test_loss = loss.item() / additional_train_loader_dataset
        test_losses.append(test_loss)
        print('\nTest set: Avg. loss: {:.4f}\n'.format(test_loss))
    

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
print(device)
model = VAE()
model.to(device)

cuda


VAE(
  (convblock_0): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): ReLU()
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.0, inplace=False)
    (4): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (5): ReLU()
    (6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Dropout(p=0.0, inplace=False)
    (8): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (9): ReLU()
    (10): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Dropout(p=0.0, inplace=False)
    (12): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (13): ReLU()
    (14): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (15): Dropout(p=0.0, inplace=False)
  )
  (convblock_1): Sequential(
    

In [12]:
model = VAE()
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = optim.SGD(model.parameters(), lr=0.1,  momentum=0.9)
model = model.to(device)
test_losses = []
train_losses = []

In [13]:
from tqdm import tqdm
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    print(f"Epoch: {epoch}")
    train(model, device, trainloader, optimizer, epoch, train_losses)
    test_model(model, device, testloader, optimizer, epoch, test_losses)

Epoch: 0


Loss=nan epoch=0: 100%|██████████| 782/782 [00:17<00:00, 43.65it/s]        
Loss=nan epoch=0: 100%|██████████| 157/157 [00:01<00:00, 116.03it/s]



Test set: Avg. loss: nan

Epoch: 1


Loss=nan epoch=1: 100%|██████████| 782/782 [00:17<00:00, 45.46it/s]
Loss=nan epoch=1: 100%|██████████| 157/157 [00:01<00:00, 117.83it/s]



Test set: Avg. loss: nan

Epoch: 2


Loss=nan epoch=2:  16%|█▌        | 125/782 [00:02<00:14, 45.37it/s]


KeyboardInterrupt: 

In [None]:
pbar = tqdm(testloader)
for batch_idx, (images, _) in enumerate(pbar):
    x = images
    x = x.to(device)
    y = _
    y = y.to(device)
    model = model.to(device)
    output_1, mu_1, logvar_1 = model(x, y)
    break

In [None]:
test.shape

In [None]:
import matplotlib.pyplot as plt
test = output_1.reshape(-1,3,32,32)
# plt.imshow(test[0].cpu().detach().numpy().squeeze(), cmap='gray_r')

In [None]:
figure = plt.figure(figsize=(8, 8))
cols, rows = 10, 10
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(test), size=(1,)).item()
    img = test[sample_idx].cpu().detach().squeeze()
    img = img.permute(1, 2, 0)
    figure.add_subplot(rows, cols, i)
#     plt.title(label)
    plt.axis("off")
    plt.imshow(img)
figure.tight_layout()
plt.show()

In [None]:
pbar = tqdm(testloader)
for batch_idx, (images, _) in enumerate(pbar):
    x = images
    x = x.to(device)
#     print(x.shape)
    y = torch.randint(0, 9, (128, 1)).squeeze()
    y = y.to(device)
#     print(y.shape)
    model = model.to(device)
    output_1, mu_1, logvar_1 = model(x, y)
    break

In [None]:
test_new = output_1.reshape(-1,1,28,28)

In [None]:
figure = plt.figure(figsize=(8, 8))
cols, rows = 10, 10
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(test_new), size=(1,)).item()
    img = test_new[sample_idx].detach().numpy().squeeze()
    figure.add_subplot(rows, cols, i)
#     plt.title(label)
    plt.axis("off")
    plt.imshow(img, cmap="gray")
figure.tight_layout()
plt.show()

In [None]:
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
y = classes[0]
y

In [None]:
pbar = tqdm(testloader)
for batch_idx, (images, _) in enumerate(pbar):
    x = images
    x = x.to(device)
    y = _
    print(x.shape)
#     y = classes.index(_)
    print(y.shape)
    print(y)
    break

In [None]:
sample_idx = torch.randint(len(test), size=(1,)).item()
sample_idx
len(test)

In [None]:
torch.randint(len(test)).item()

In [None]:
test[sample_idx].detach().numpy().squeeze().shape