In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import pickle
from torch.utils.data import Dataset
import math

class CustomDataset(Dataset):
    def __init__(self, x, label):
        self.x = x
        self.label = label
        self.n = x.shape[0]

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.x[idx], self.label[idx]

bs = 100
# Generated Dataset
train_file = './data/model/train_dataset.pkl'
test_file = './data/model/test_dataset.pkl'
with open(train_file, 'rb') as f:
    train_dataset = pickle.load(f)
with open(test_file, 'rb') as f:
    test_dataset = pickle.load(f)
    
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [2]:
class CVAE(nn.Module):

    def __init__(self, condi_dim, channels, latent_dim) -> None:
        super().__init__()

        # encoder
        pre_channel = 1
        modules = []
        img_length = 28
        
        for i in range(len(channels)-1):
            modules.append(
                nn.Sequential(
                    nn.Conv2d(pre_channel,
                              channels[i+1],
                              kernel_size=3,
                              stride=2,
                              padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU()
                )
            )
            pre_channel = channels[i+1]
            img_length = (img_length-1)//2+1
        
        self.encoder_projection = nn.Sequential(
                nn.Linear(pre_channel * img_length * img_length + condi_dim, pre_channel * img_length * img_length),
                nn.ReLU()
        )
        
        self.encoder = nn.Sequential(*modules)
        self.mean_linear = nn.Linear(pre_channel * img_length * img_length,
                                     latent_dim)
        self.var_linear = nn.Linear(pre_channel * img_length * img_length,
                                    latent_dim)
        self.latent_dim = latent_dim
        
        # decoder
        modules = []
        self.decoder_projection = nn.Linear(
            latent_dim + condi_dim, pre_channel * img_length * img_length)
        self.decoder_input_chw = (pre_channel, img_length, img_length)

        for i in range(len(channels)-1):
            if(i == 0):
                modules.append(
                    nn.Sequential(
                        nn.ConvTranspose2d(channels[len(channels)-i-1],
                                           channels[len(channels)-i-2],
                                           kernel_size=3,
                                           stride=2,
                                           padding=1,
                                           output_padding=0),
                        nn.BatchNorm2d(channels[len(channels)-i-2]),
                        nn.ReLU()
                    )
                )
            else:
                modules.append(
                    nn.Sequential(
                        nn.ConvTranspose2d(channels[len(channels)-i-1],
                                           channels[len(channels)-i-2],
                                           kernel_size=3,
                                           stride=2,
                                           padding=1,
                                           output_padding=1),
                        nn.BatchNorm2d(channels[len(channels)-i-2]),
                        nn.ReLU()
                    )
                )
        self.decoder_layers = nn.Sequential(*modules)
        
    def decoder(self, z, c):
        z = torch.cat([z, c], dim = 1)
        z = self.decoder_projection(z)
        z = torch.reshape(z, (-1, *self.decoder_input_chw))
        decoded = self.decoder_layers(z)
        return decoded

    def forward(self, x, c):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        x = torch.cat([x, c], dim = 1)
        encoded = self.encoder_projection(x)
        mean = self.mean_linear(encoded)
        logvar = self.var_linear(encoded)
        eps = torch.randn_like(logvar)
        std = torch.exp(logvar / 2)
        z = eps * std + mean
        decoded = self.decoder(z, c)
        return decoded, mean, logvar

In [3]:
from time import time
import numpy as np

n_epochs = 5
learning_rate = 1e-2
kl_weight = 2e-5
num_class = 10

def one_hot_encode(labels, num_classes):
    one_hot_labels = np.zeros((len(labels), num_classes))
    one_hot_labels[np.arange(len(labels)), labels] = 1
    return one_hot_labels.tolist()

def elbo_loss(x, x_hat, mean, logvar):
    recons_loss = F.mse_loss(x_hat, x)
    kl_loss = -0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar))
    loss = (recons_loss + kl_loss * kl_weight)/len(x)
    return loss

def reconst_loss(x, x_hat):
    return F.mse_loss(x_hat, x) / len(x)

def kl_loss(mean, logvar):
    return -0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar)) / len(mean)

def train(device, model):
    is_nan = 0
    is_kl_small = 0
    optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)
    begin_time = time()
    # train
    for i in range(n_epochs):
        for batch_idx, (x, label) in enumerate(train_loader):
            x = x.to(device)
            label = one_hot_encode(label, num_classes = num_class)
            label = torch.tensor(label)
            label = label.to(device)
            x_hat, mean, logvar = model(x, label)
            loss = elbo_loss(x, x_hat, mean, logvar)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # estimate loss
        model.eval()
        with torch.no_grad():
            n_samples = 100
            #test
            indices = torch.randperm(len(test_dataset))[:n_samples]
            x = torch.stack([test_dataset[i][0] for i in indices]).to(device)
            label = one_hot_encode([test_dataset[i][1] for i in indices], num_classes = num_class)
            label = torch.tensor(label).to(device)
            x_hat, mean, logvar = model(x, label)
            loss = elbo_loss(x, x_hat, mean, logvar)
            loss_recons = reconst_loss(x, x_hat)
            loss_kl = kl_loss(mean, logvar)
            #print(loss.item(), loss_recons.item(), loss_kl.item())
            if(math.isnan(loss.item())):
                is_nan = 1
                print("Get nan!")
                break
            if(loss_kl.item() < 0.5):
                is_kl_small = 1
                print("Get small kl")
                break
        
    tot_training_time = time() - begin_time
    minute = int(tot_training_time // 60)
    second = int(tot_training_time % 60)
    print(f'total time loss {minute}:{second}')
    
    return is_nan, is_kl_small

def initialize_parameters(model):
    for param in model.parameters():
        param.data.normal_(mean=0, std=0.01)

In [4]:
def main():
    device = 'cuda:0'
    num_all = 100
    num_nan = 0
    num_kl_small = 0
    for i in range(num_all):
        print("i:", i)
        torch.cuda.empty_cache()
        model = CVAE(condi_dim = num_class, channels = [1, 500, 500, 500], latent_dim = 100).to(device)
        initialize_parameters(model)
        is_nan, is_kl_small = train(device, model)
        if(is_nan == 1):
            num_nan += 1
        if(is_kl_small == 1):
            num_kl_small += 1
    print("num nan:", num_nan)
    print("num kl small:", num_kl_small)
    
if __name__ == '__main__':
    main()

i: 0
total time loss 1:20
i: 1
total time loss 1:20
i: 2
total time loss 1:20
i: 3
total time loss 1:20
i: 4
total time loss 1:20
i: 5
Get small kl
total time loss 0:16
i: 6
total time loss 1:20
i: 7
total time loss 1:20
i: 8
total time loss 1:20
i: 9
total time loss 1:20
i: 10
total time loss 1:20
i: 11
total time loss 1:20
i: 12
total time loss 1:20
i: 13
total time loss 1:20
i: 14
total time loss 1:20
i: 15
Get small kl
total time loss 0:15
i: 16
total time loss 1:20
i: 17
total time loss 1:20
i: 18
total time loss 1:20
i: 19
Get small kl
total time loss 0:16
i: 20
total time loss 1:20
i: 21
total time loss 1:20
i: 22
total time loss 1:20
i: 23
total time loss 1:20
i: 24
total time loss 1:20
i: 25
total time loss 1:20
i: 26
total time loss 1:20
i: 27
total time loss 1:20
i: 28
total time loss 1:20
i: 29
total time loss 1:20
i: 30
Get small kl
total time loss 0:15
i: 31
total time loss 1:20
i: 32
total time loss 1:20
i: 33
total time loss 1:20
i: 34
total time loss 1:20
i: 35
total t