In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [None]:
class ResidualBlock(nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim, dim, 1)
        self.batchnorm2d = nn.BatchNorm2d(dim)

    def forward(self, x):
        tmp = self.conv1(x)
        tmp = self.batchnorm2d(tmp)
        tmp = self.relu(tmp)
        tmp = self.conv2(tmp)
        tmp = self.batchnorm2d(tmp)
        tmp = x + tmp
        tmp = self.relu(tmp)
        return tmp

class CVAE(nn.Module):

    def __init__(self, condi_dim, channels, latent_dim) -> None:
        super().__init__()

        # encoder
        pre_channel = 1
        modules = []
        img_length = 28
        
        for i in range(len(channels)-1):
            modules.append(
                nn.Sequential(
                    nn.Conv2d(pre_channel,
                              channels[i+1],
                              kernel_size=3,
                              stride=2,
                              padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU()
                )
            )
            pre_channel = channels[i+1]
            img_length = (img_length-1)//2+1
        
        self.encoder_projection = nn.Sequential(
                nn.Linear(pre_channel * img_length * img_length + condi_dim, pre_channel * img_length * img_length),
                nn.ReLU()
        )
        
        self.encoder = nn.Sequential(*modules)
        self.mean_linear = nn.Linear(pre_channel * img_length * img_length,
                                     latent_dim)
        self.var_linear = nn.Linear(pre_channel * img_length * img_length,
                                    latent_dim)
        self.latent_dim = latent_dim
        
        # decoder
        modules = []
        self.decoder_projection = nn.Linear(
            latent_dim + condi_dim, pre_channel * img_length * img_length)
        self.decoder_input_chw = (pre_channel, img_length, img_length)

        for i in range(len(channels)-1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(channels[len(channels)-i-1],
                                       channels[len(channels)-i-2],
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(channels[len(channels)-i-2]),
                    nn.ReLU()
                )
            )
        self.decoder_layers = nn.Sequential(*modules)
        
    def decoder(self, z, c):
        z = torch.cat([z, c], dim = 1)
        z = self.decoder_projection(z)
        z = torch.reshape(z, (-1, *self.decoder_input_chw))
        decoded = self.decoder_layers(z)
        return decoded

    def forward(self, x, c):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        x = torch.cat([x, c], dim = 1)
        encoded = self.encoder_projection(x)
        mean = self.mean_linear(encoded)
        logvar = self.var_linear(encoded)
        eps = torch.randn_like(logvar)
        std = torch.exp(logvar / 2)
        z = eps * std + mean
        decoded = self.decoder(z, c)
        return decoded, mean, logvar

In [None]:
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import os

root = '/root/data/MNIST/'
transform = transforms.Compose([
    transforms.ToTensor()
])
train_dataset = datasets.MNIST(root=root, train=True, transform=transforms.ToTensor(), download=False)

num_samples = 30000
indices = np.random.choice(len(train_dataset), num_samples, replace=False)
subset_train_dataset = Subset(train_dataset, indices)

print("start deleting")
import shutil
save_dir = './samples/original'
if os.path.exists(save_dir) and os.path.isdir(save_dir):
    shutil.rmtree(save_dir)
else:
    print(f"Directory does not exist.")
print("done deleting")

os.makedirs(save_dir, exist_ok=True)

print("start saving")
def save_images(dataset, save_dir):
    for idx, (image, label) in enumerate(dataset):
        if(idx%1000 == 0):
            print(idx)
        image = transforms.ToPILImage()(image)
        image.save(os.path.join(save_dir, f'image_{idx}.png'))
save_images(subset_train_dataset, save_dir)
print("done saving")
'''

In [None]:
latent_dim = 500

In [None]:
device = 'cuda:0'
model = torch.load('./model.pth', map_location=device)

In [None]:
def one_hot_encode(labels, num_classes):
    one_hot_labels = np.zeros((len(labels), num_classes))
    one_hot_labels[np.arange(len(labels)), labels] = 1
    return one_hot_labels.tolist()

In [None]:
import random
from torchvision.transforms import ToPILImage
import torch_fidelity
import numpy as np
import os
import shutil

w = 0
num_class = 11

model.eval()
with torch.no_grad():
    for k in range(100):
        num = 300
        noise = torch.randn(num, latent_dim).to(device)
        c = random.choices(list(range(1, 11)), k=num)
        c0 = [0]*num
        c = one_hot_encode(c, num_classes = num_class)
        c0 = one_hot_encode(c0, num_classes = num_class)
        c = torch.tensor(c).to(device)
        c0 = torch.tensor(c0).to(device)
        x0 = (1+w)*model.decoder(noise, c) - w*model.decoder(noise, c0)
        if(k ==0):
            x = x0
        else:
            x = torch.cat((x, x0), 0)
            
save_dir = './samples/generated'
if os.path.exists(save_dir) and os.path.isdir(save_dir):
    shutil.rmtree(save_dir)
os.makedirs(save_dir, exist_ok=True)
to_pil = ToPILImage()
for i in range(x.size(0)):
    image_tensor = x[i]
    image = to_pil(image_tensor)
    save_path = os.path.join(save_dir, f'image_{i}.png')
    image.save(save_path)

# IS and FID
metrics_dict = torch_fidelity.calculate_metrics(
    input1= './samples/generated',
    input2= './samples/original',
    cuda=True,
    isc=True,
    fid=True
)