In [None]:
# !pip install torchsummary

import os, sys, glob, gc, copy
import time
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
# from torchsummary import summary

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import cv2
import matplotlib.pyplot as plt

from tqdm import tqdm

In [None]:
# Check GPU is available or not
is_cuda = torch.cuda.is_available()
if is_cuda:
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU not available, CPU used")

In [None]:
data_path = '../input/cassava-leaf-disease-classification/train_images/'
total_csv = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')

## Create Dataset

In [None]:
def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    #print(im_rgb)
    return im_rgb


class HDD_Dataset(Dataset):
    def __init__(self, df, data_root, device, transform=None):
        super().__init__()
        self.df = df.reset_index(drop=True).copy()
        self.data_root = data_root
        self.device = device
        self.transform = transform
        self.N_class = 5  # number of classes
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index: int):
        # read label and one-hot
        label = torch.tensor(self.df.iloc[index]['label'], device=device)
        label_onehot = nn.functional.one_hot(label, self.N_class)
        # prepare image
        path = "{}/{}".format(self.data_root, self.df.iloc[index]['image_id'])
        img = torch.tensor(get_img(path).copy()).float().to(device)/255  # normalize to 1
        img = img.permute(2,0,1)
        if self.transform:
            img = self.transform(img)
#         img = (img-0.5)*2  # normalize images to [-1, 1] and use tanh
        return {'image':img, 'label':label, 'label_onehot':label_onehot.float()}

## split data to `test` and `train`

In [None]:
def frac_train_val(N:int, val_frac:float):
    """N: dataset length e.g. 2100
        val_frac: validation fraction of total e.g. 0.2 
        return indices of train, validation"""
    perm = np.random.permutation(N)
    thrshld = int(N*val_frac)
    return perm[thrshld:].tolist(), perm[:thrshld].tolist()

inds_train, inds_test = frac_train_val(total_csv.shape[0], 0.1)
train_csv = total_csv.loc[inds_train]
test_csv = total_csv.loc[inds_test]
#create datasets
crop_transform = transforms.RandomCrop(512)
train_dataset = HDD_Dataset(train_csv, data_path, device, crop_transform)
test_dataset = HDD_Dataset(test_csv, data_path, device, crop_transform)
# gc.collect()

## Define CVAE architectur

In [None]:
class print_layer(nn.Module):
    def __init__(self):
        super(print_layer, self).__init__()
    def forward(self, x):
        print(x.shape)
        return x

class encoder(nn.Module):
    """encoder for CVAE
        `image` input shape:(N, 3, 512, 512)
        `c` shape: (N, 5)"""
    def __init__(self, zdim:int, ydim:int):
        super(encoder, self).__init__()
        self.x_conv = nn.Sequential( 
            nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=2), nn.LeakyReLU(),  # (256,256)
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32), nn.LeakyReLU(),
            nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32), nn.LeakyReLU(),
            nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2),  # (128,128)
            nn.BatchNorm2d(64), nn.LeakyReLU(),
            nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2), nn.LeakyReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),  # (64, 64)
            nn.BatchNorm2d(64), nn.LeakyReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # (32, 32)
            nn.BatchNorm2d(128), nn.LeakyReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128), nn.LeakyReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),  # (16, 16)
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),  # (8, 8)
            nn.BatchNorm2d(256), nn.LeakyReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),  # (4, 4)
            nn.BatchNorm2d(512), nn.LeakyReLU(), 
            nn.Conv2d(512, 1024, kernel_size=4, stride=1, padding=0), nn.LeakyReLU()
        )
        self.x_fc = nn.Sequential(
            nn.Linear(in_features=1024, out_features=1024),
            nn.ReLU(),
            nn.Dropout()
        )
        self.y_fc = nn.Sequential(
            nn.Linear(in_features=ydim, out_features=zdim), nn.ReLU(),
            nn.Linear(in_features=zdim, out_features=zdim),  #????????? maybe remove
        )
        self.mu_lin = nn.Sequential(nn.Linear(in_features=1024+zdim, out_features=zdim))
        self.var_lin = nn.Sequential(nn.Linear(in_features=1024+zdim, out_features=zdim))

    def forward(self, x, y):
        x = self.x_conv(x)
        x = x.view(-1, 1024)
        x = self.x_fc(x)
        y = self.y_fc(y)
        cat = torch.cat([x, y],dim=1)
        mean = self.mu_lin(cat)
        log_sigma = self.var_lin(cat) + 1e-6
        return mean, log_sigma


class decoder(nn.Module):
    def __init__(self, zdim:int, ydim:int):
        super(decoder, self).__init__()
        self.y_decode = nn.Sequential(
            nn.Linear(in_features=ydim, out_features=zdim*2),
            nn.ReLU()
        )
        self.linearMix = nn.Sequential(
            nn.Linear(in_features=zdim*3, out_features=256),
            nn.ReLU()
        )
        self.to_conv = nn.Linear(in_features=256, out_features=256*8*8)
        self.conv_decoder = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), # (8, 8)
            nn.BatchNorm2d(256), nn.ReLU(),
            nn.ConvTranspose2d(256, 256, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),  # (16, 16)
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),  # (32, 32)
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),  # (64, 64)
            nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64), nn.ReLU(),
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),  # (128, 128)
            nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32), nn.ReLU(),
            nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(),
            nn.Conv2d(32, 16, kernel_size=5, stride=1, padding=2),  # (256, 256)
            nn.BatchNorm2d(16), nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16), nn.ReLU(),
            nn.ConvTranspose2d(16, 16, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(),
            nn.Conv2d(16, 8, kernel_size=5, stride=1, padding=2), # (512, 512)
            nn.BatchNorm2d(8), nn.ReLU(),
            nn.Conv2d(8, 4, kernel_size=5, stride=1, padding=2), nn.ReLU(),
            nn.Conv2d(4, 3, kernel_size=5, stride=1, padding=2),
            nn.Sigmoid()
        )
    
    def forward(self, z, y):
        y = self.y_decode(y)
#         z = copy.deepcopy(z)    #?????????????
        z = torch.cat((y,z),dim=1)
        z = self.linearMix(z)
        z = self.to_conv(z)
        z = z.view(-1, 256,8,8)
        z = self.conv_decoder(z)
        return z


class CVAE(nn.Module):
    def __init__(self, zdim:int, ydim:int):
        super(CVAE, self).__init__()
        self.zdim = zdim
        self.ydim = ydim
        self.encoder = encoder(zdim, ydim)
        self.decoder = decoder(zdim, ydim)
        
    def sampler(self, mu, log_sigma):
        std = torch.exp(log_sigma / 2)
        eps = torch.randn_like(std)
        x_sample = eps.mul(std) + mu
        return x_sample

    def forward(self, x, y):
        mu, log_sigma = self.encoder(x, y)
        z = self.sampler(mu, log_sigma)
        out_img = self.decoder(z, y)
        return out_img, mu, log_sigma

### test model output shape

In [None]:
zdim = 100
ydim = 5
####
model = CVAE(zdim, ydim)

x = torch.randn(2,3,512,512)
y = F.one_hot(torch.randint(5,size=(2,)), 5).float()
out, mu, log_sigma = model(x, y)
print(out.shape)

## Define Loss

In [None]:
def calculate_loss(x, reconstructed_x, mean, log_var):
    # reconstruction loss
    Num = mean.shape[0]
    RCL = F.binary_cross_entropy(reconstructed_x, x, reduction='sum')
    # kl divergence loss
    KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return (RCL + KLD)/Num

## Define model and optimizers

In [None]:
# hyper parameters
zdim = 100
ydim = 5
# Create model instance
model = CVAE(zdim, ydim)
model = model.to(device, non_blocking=True)  # to GPU or CPU
# summary(model, input_size=(3, 512, 512))
# Define hyperparameters
lr = torch.tensor(0.00001).to(device)
lr_decay = torch.tensor(0.99).to(device)  # per epoch
lr_floor = torch.tensor(0.000005).to(device)
# Define Loss, Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
!nvidia-smi --gpu-reset

## Train loop

In [None]:
# Hyper parameters
batch_size = 48
n_epochs = 10
# dataloader
train_dataloader = DataLoader(train_dataset,batch_size=batch_size, shuffle=True)
tot_batch = len(train_dataloader)
# loss
train_loss = []
# Training RNN
for epoch in range(1, n_epochs + 1):
    epoch_start = time.time()
    batch_start = time.time()
    for num_batch, data_batch in enumerate(train_dataloader):
        
        optimizer.zero_grad()
        output, mu, log_sigma = model(data_batch['image']-1/2, data_batch['label_onehot'])  # zero-mean preprocessing ??????
        loss = calculate_loss(data_batch['image'], output, mu, log_sigma)
        train_loss.append(loss.item())
#         loss.backward(torch.tensor(1/100,device=device))
        loss.backward()
        optimizer.step()
        
        batch_stop = time.time()
        print(f"*****Epoch {epoch}, minibatch: {num_batch+1}/{tot_batch} elapsed {batch_stop-batch_start:0.1f}s, loss:{loss:0.4f}*****")  # write in replace
        batch_start = time.time()
    epoch_stop = time.time()
    print(f"Epoch {epoch}/{n_epochs} has finished in {epoch_stop-epoch_start:0.1f}s")
    lr *= lr_decay
    if lr < lr_floor:
        lr = lr_floor
    gc.collect()

In [None]:
torch.save(model, 'CVAE_model_v2.h5')

## Plot loss

In [None]:
fig, ax = plt.subplots(1,1)
ax.plot(range(1,1+len(train_loss)), train_loss)
ax.set(xlabel='minibatch')
ax.set_title('CVAE loss')
plt.show()

In [None]:
# model.eval()

N = 4
tmp_dec = model.decoder
z = torch.randn(N, zdim).to(device)
class_label = torch.randint(5,size=(N,))
y = F.one_hot(class_label, 5).float().to(device)
out = tmp_dec(z, y)

# model.train()

In [None]:
fig, ax =plt.subplots(1, N)
for i in range(N):
    iimg = out[i].to(torch.device('cpu')).permute(1,2,0)
    ax[i].imshow(iimg.detach().numpy())
    ax[i].axis('off')
    ax[i].set_title(f'cls:{class_label[i]}')

## Load model from last training

In [None]:
!ls -ltrh
model = torch.load('CVAE_model_v1.h5')