In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import pickle
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, x, label):
        self.x = x
        self.label = label
        self.n = x.shape[0]

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.x[idx], self.label[idx]

bs = 100
# Generated Dataset
train_file = './data/model/train_dataset.pkl'
test_file = './data/model/test_dataset.pkl'
with open(train_file, 'rb') as f:
    train_dataset = pickle.load(f)
with open(test_file, 'rb') as f:
    test_dataset = pickle.load(f)
    
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [2]:
class VAE(nn.Module):

    def __init__(self, channels, latent_dim) -> None:
        super().__init__()

        # encoder
        pre_channel = 1
        modules = []
        img_length = 32
        
        for i in range(len(channels)-1):
            modules.append(
                nn.Sequential(
                    nn.Conv2d(pre_channel,
                              channels[i+1],
                              kernel_size=3,
                              stride=2,
                              padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU()
                )
            )
            pre_channel = channels[i+1]
            img_length = (img_length-1)//2+1
        
        self.encoder_projection = nn.Sequential(
                nn.Linear(pre_channel * img_length * img_length, pre_channel * img_length * img_length),
                nn.ReLU()
        )
        
        self.encoder = nn.Sequential(*modules)
        self.mean_linear = nn.Linear(pre_channel * img_length * img_length,
                                     latent_dim)
        self.var_linear = nn.Linear(pre_channel * img_length * img_length,
                                    latent_dim)
        self.latent_dim = latent_dim
        
        # decoder
        modules = []
        self.decoder_projection = nn.Linear(
            latent_dim, pre_channel * img_length * img_length)
        self.decoder_input_chw = (pre_channel, img_length, img_length)

        for i in range(len(channels)-1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(channels[len(channels)-i-1],
                                       channels[len(channels)-i-2],
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(channels[len(channels)-i-2]),
                    nn.ReLU()
                )
            )
        self.decoder_layers = nn.Sequential(*modules)
        
    def decoder(self, z):
        z = self.decoder_projection(z)
        z = torch.reshape(z, (-1, *self.decoder_input_chw))
        decoded = self.decoder_layers(z)
        return decoded

    def forward(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        encoded = self.encoder_projection(x)
        mean = self.mean_linear(encoded)
        logvar = self.var_linear(encoded)
        eps = torch.randn_like(logvar)
        std = torch.exp(logvar / 2)
        z = eps * std + mean
        decoded = self.decoder(z)
        return decoded, mean, logvar

In [3]:
from time import time
import numpy as np

n_epochs = 200
learning_rate = 1e-2
kl_weight = 1e-7

def elbo_loss(x, x_hat, mean, logvar):
    recons_loss = F.mse_loss(x_hat, x)
    kl_loss = -0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar))
    loss = (recons_loss + kl_loss * kl_weight)/len(x)
    return loss

def reconst_loss(x, x_hat):
    return F.mse_loss(x_hat, x) / len(x)

def kl_loss(mean, logvar):
    return -0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar)) / len(mean)

def train(device, model):
    optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)
    begin_time = time()
    # train
    with open('./train_loss1.txt', 'w') as file:
        for i in range(n_epochs):
            for batch_idx, (x, label) in enumerate(train_loader):
                x = x.to(torch.float32).to(device)
                x_hat, mean, logvar = model(x)
                loss = elbo_loss(x, x_hat, mean, logvar)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            # estimate loss
            model.eval()
            with torch.no_grad():
                each_epoch = 10
                n_samples = 100
                #train
                indices = torch.randperm(len(train_dataset))[:n_samples]
                x = torch.stack([train_dataset[i][0] for i in indices]).to(torch.float32).to(device)
                x_hat, mean, logvar = model(x)
                loss = elbo_loss(x, x_hat, mean, logvar)
                loss_recons = reconst_loss(x, x_hat)
                loss_kl = kl_loss(mean, logvar)
                if(i % each_epoch == 0):
                    print('====> Epoch: {} elbo loss: {:.7f}'.format(i, loss))
                    print('====> Epoch: {} reconst loss: {:.7f}'.format(i, loss_recons))
                    print('====> Epoch: {} kl loss: {:.7f}'.format(i, loss_kl))
                file.write(str(loss.item()) + ' ' + str(loss_recons.item()) + ' ' + str(loss_kl.item()) + ' ')
                #test
                indices = torch.randperm(len(test_dataset))[:n_samples]
                x = torch.stack([test_dataset[i][0] for i in indices]).to(torch.float32).to(device)
                x_hat, mean, logvar = model(x)
                loss = elbo_loss(x, x_hat, mean, logvar)
                loss_recons = reconst_loss(x, x_hat)
                loss_kl = kl_loss(mean, logvar)
                if(i % each_epoch == 0):
                    print('====> Epoch: {} elbo loss: {:.7f}'.format(i, loss))
                    print('====> Epoch: {} reconst loss: {:.7f}'.format(i, loss_recons))
                    print('====> Epoch: {} kl loss: {:.7f}'.format(i, loss_kl))
                file.write(str(loss.item()) + ' ' + str(loss_recons.item()) + ' ' + str(loss_kl.item()) + '\n')
            
            #time
            if(i % each_epoch == 0):
                training_time = time() - begin_time
                minute = int(training_time // 60)
                second = int(training_time % 60)
                print(f'time loss {minute}:{second}')
        
        torch.save(model, './train_model1.pth')
        
    tot_training_time = time() - begin_time
    minute = int(tot_training_time // 60)
    second = int(tot_training_time % 60)
    print(f'total time loss {minute}:{second}')

def initialize_parameters(model):
    for param in model.parameters():
        param.data.normal_(mean=0, std=0.01)

In [4]:
latent_dim = 2

def main():
    device = 'cuda:0'

    model = VAE(channels = [1, 200, 200, 200, 200], latent_dim = latent_dim).to(device)
    initialize_parameters(model)

    train(device, model)

if __name__ == '__main__':
    main()

====> Epoch: 0 elbo loss: 0.0000627
====> Epoch: 0 reconst loss: 0.0000620
====> Epoch: 0 kl loss: 6.4906788
====> Epoch: 0 elbo loss: 0.0000663
====> Epoch: 0 reconst loss: 0.0000656
====> Epoch: 0 kl loss: 6.6931314
time loss 0:0
====> Epoch: 10 elbo loss: 0.0000180
====> Epoch: 10 reconst loss: 0.0000171
====> Epoch: 10 kl loss: 8.5132704
====> Epoch: 10 elbo loss: 0.0000159
====> Epoch: 10 reconst loss: 0.0000151
====> Epoch: 10 kl loss: 8.4548302
time loss 0:5
====> Epoch: 20 elbo loss: 0.0000157
====> Epoch: 20 reconst loss: 0.0000149
====> Epoch: 20 kl loss: 7.5242629
====> Epoch: 20 elbo loss: 0.0000159
====> Epoch: 20 reconst loss: 0.0000152
====> Epoch: 20 kl loss: 7.5390611
time loss 0:10
====> Epoch: 30 elbo loss: 0.0000154
====> Epoch: 30 reconst loss: 0.0000147
====> Epoch: 30 kl loss: 7.4307337
====> Epoch: 30 elbo loss: 0.0000129
====> Epoch: 30 reconst loss: 0.0000122
====> Epoch: 30 kl loss: 7.3939872
time loss 0:15
====> Epoch: 40 elbo loss: 0.0000147
====> Epoch: 40

In [5]:
device = 'cpu'
model = torch.load('./train_model1.pth', map_location=device)

In [6]:
import torchvision
from torchvision.utils import save_image
# Generation
import random
with torch.no_grad():    
    
    noise = torch.randn(100, latent_dim).to(device)
    generated_imgs = model.decoder(noise)
    resized_image = torchvision.transforms.Resize((50, 50))(generated_imgs)
    save_image(resized_image, './pictures/genera.png', nrow=10)

In [7]:
trains_recon = []
trains_kl = []
tests_recon = []
tests_kl = []
with open('./train_loss1.txt', 'r') as file:
    for line in file:
        parts = line.split()
        trains_recon.append(float(parts[1]))
        trains_kl.append(float(parts[2]))
        tests_recon.append(float(parts[4]))
        tests_kl.append(float(parts[5]))

mean_num = 10
trains_recon = np.mean(np.array(trains_recon[(len(trains_recon)-mean_num):])).item()
trains_kl = np.mean(np.array(trains_kl[(len(trains_kl)-mean_num):])).item()
tests_recon = np.mean(np.array(tests_recon[(len(tests_recon)-mean_num):])).item()
tests_kl = np.mean(np.array(tests_kl[(len(tests_kl)-mean_num):])).item()
print("train recon:", trains_recon, "train kl:", trains_kl, "test recon:", tests_recon, "test kl:", tests_kl)

train recon: 1.2410323233780218e-05 train kl: 7.271074199676514 test recon: 1.306250396737596e-05 test kl: 7.282270765304565


In [8]:
device = 'cuda:0'
model = torch.load('./train_model1.pth', map_location=device)

In [9]:
from torch.utils.data import Dataset
class CustomDataset(Dataset):
    def __init__(self, x, label):
        self.x = x
        self.label = label
        self.n = x.shape[0]

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.x[idx], self.label[idx]

In [10]:
n = 10000
model.eval()
with torch.no_grad():
    num = n
    t = 100
    num0 = int(num/t)
    for i in range(t):
        noise = torch.randn(num0, latent_dim).to(device)
        x0 = model.decoder(noise)
        if(i == 0):
            x = x0
        else:
            x = torch.cat((x, x0), 0)

In [11]:
label = torch.full((n, ), 0)

In [12]:
split_idx = 8000
train_x, train_label = x[:split_idx], label[:split_idx]
test_x, test_label = x[split_idx:], label[split_idx:]
train_dataset = CustomDataset(train_x, train_label)
test_dataset = CustomDataset(test_x, test_label)

import pickle
train_file = './data/model1/train_dataset.pkl'
test_file = './data/model1/test_dataset.pkl'

with open(train_file, 'wb') as f:
    pickle.dump(train_dataset, f)
with open(test_file, 'wb') as f:
    pickle.dump(test_dataset, f)

In [13]:
import pickle
train_file = './data/model1/train_dataset.pkl'
test_file = './data/model1/test_dataset.pkl'

with open(train_file, 'rb') as f:
    train_dataset = pickle.load(f)
with open(test_file, 'rb') as f:
    test_dataset = pickle.load(f)

bs = 100
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

print(len(train_dataset), len(test_dataset))

8000 2000


In [14]:
import os
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

num_samples = 8000
indices = np.random.choice(len(train_dataset), num_samples, replace=False)
subset_train_dataset = Subset(train_dataset, indices)

print("start deleting")
import shutil
save_dir = './samples/base1'
if os.path.exists(save_dir) and os.path.isdir(save_dir):
    shutil.rmtree(save_dir)
else:
    print(f"Directory does not exist.")
print("done deleting")

os.makedirs(save_dir, exist_ok=True)

print("start saving")
def save_images(dataset, save_dir):
    for idx, (image, label) in enumerate(dataset):
        image = transforms.ToPILImage()(image)
        image.save(os.path.join(save_dir, f'image_{idx}.png'))
save_images(subset_train_dataset, save_dir)
print("done saving")

start deleting
done deleting
start saving
done saving


In [16]:
# IS and FID
import torch_fidelity

metrics_dict = torch_fidelity.calculate_metrics(
    input1= './samples/base1',
    input2= './samples/base',
    cuda=True,
    isc=True,
    fid=True
)

Creating feature extractor "inception-v3-compat" with features ['logits_unbiased', '2048']
Extracting features from input1
Looking for samples non-recursivelty in "./samples/base1" with extensions png,jpg,jpeg
Found 8000 samples
Processing samples                                                            
Extracting features from input2
Looking for samples non-recursivelty in "./samples/base" with extensions png,jpg,jpeg
Found 8000 samples
Processing samples                                                            
Inception Score: 1.1152017462009967 ± 0.001677010470989183
Frechet Inception Distance: 30.995100357331054
