In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import pickle
from torch.utils.data import Dataset
import numpy as np
import torchvision
from torchvision.utils import save_image

In [3]:
class CVAE(nn.Module):

    def __init__(self, condi_dim, channels, latent_dim) -> None:
        super().__init__()

        # encoder
        pre_channel = 1
        modules = []
        img_length = 28
        
        for i in range(len(channels)-1):
            modules.append(
                nn.Sequential(
                    nn.Conv2d(pre_channel,
                              channels[i+1],
                              kernel_size=3,
                              stride=2,
                              padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU()
                )
            )
            pre_channel = channels[i+1]
            img_length = (img_length-1)//2+1
        
        self.encoder_projection = nn.Sequential(
                nn.Linear(pre_channel * img_length * img_length + condi_dim, pre_channel * img_length * img_length),
                nn.ReLU()
        )
        
        self.encoder = nn.Sequential(*modules)
        self.mean_linear = nn.Linear(pre_channel * img_length * img_length,
                                     latent_dim)
        self.var_linear = nn.Linear(pre_channel * img_length * img_length,
                                    latent_dim)
        self.latent_dim = latent_dim
        
        # decoder
        modules = []
        self.decoder_projection = nn.Linear(
            latent_dim + condi_dim, pre_channel * img_length * img_length)
        self.decoder_input_chw = (pre_channel, img_length, img_length)

        for i in range(len(channels)-1):
            if(i == 0):
                modules.append(
                    nn.Sequential(
                        nn.ConvTranspose2d(channels[len(channels)-i-1],
                                           channels[len(channels)-i-2],
                                           kernel_size=3,
                                           stride=2,
                                           padding=1,
                                           output_padding=0),
                        nn.BatchNorm2d(channels[len(channels)-i-2]),
                        nn.ReLU()
                    )
                )
            else:
                modules.append(
                    nn.Sequential(
                        nn.ConvTranspose2d(channels[len(channels)-i-1],
                                           channels[len(channels)-i-2],
                                           kernel_size=3,
                                           stride=2,
                                           padding=1,
                                           output_padding=1),
                        nn.BatchNorm2d(channels[len(channels)-i-2]),
                        nn.ReLU()
                    )
                )
        self.decoder_layers = nn.Sequential(*modules)
        
    def decoder(self, z, c):
        z = torch.cat([z, c], dim = 1)
        z = self.decoder_projection(z)
        z = torch.reshape(z, (-1, *self.decoder_input_chw))
        decoded = self.decoder_layers(z)
        return decoded

    def forward(self, x, c):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        x = torch.cat([x, c], dim = 1)
        encoded = self.encoder_projection(x)
        mean = self.mean_linear(encoded)
        logvar = self.var_linear(encoded)
        eps = torch.randn_like(logvar)
        std = torch.exp(logvar / 2)
        z = eps * std + mean
        decoded = self.decoder(z, c)
        return decoded, mean, logvar

In [4]:
device = 'cpu'
model = torch.load('./model2.pth', map_location=device)

In [8]:
latend_dim = 200
num_class = 10
diff_unit = 0.4

def one_hot_encode(labels, num_classes):
    one_hot_labels = np.zeros((len(labels), num_classes))
    one_hot_labels[np.arange(len(labels)), labels] = 1
    return one_hot_labels.tolist()

effect_num = 0
#chosen_i_list = []
#for chosen_i in range(latend_dim):
for chosen_i in chosen_i_list:
    for i in range(10):
        for k in range(10):
            diff = (k-5)*diff_unit
            z0 = torch.zeros(latend_dim)
            z0[chosen_i] = diff
            if(i == 0 and k == 0):
                z = z0[None, :]
            else:
                z = torch.cat((z, z0[None, :]), 0)
    model.eval()
    with torch.no_grad():
        max_diff = -1
        c = [i for i in range(10) for _ in range(10)]
        c = one_hot_encode(c, num_classes = num_class)
        c = torch.tensor(c).to(device)
        noise = z
        generated_imgs = model.decoder(noise, c)
        t = 2
        for i in range(10):
            if(i == 0):
                x = generated_imgs[i*10:(i+1)*10-1]
            else:
                x = torch.cat((x, generated_imgs[i*10:(i+1)*10-1]), 0)
            diff_x = torch.abs((generated_imgs[i*10+t]-generated_imgs[(i+1)*10-1-t])[None, :])
            x = torch.cat((x, diff_x), 0)
            if(torch.sum(diff_x).item() > max_diff):
                max_diff = torch.sum(diff_x).item()
        print(chosen_i, max_diff)
        if(max_diff >= 2):
            effect_num += 1
            #chosen_i_list.append(chosen_i)
        resized_image = torchvision.transforms.Resize((50, 50))(x)
        save_image(resized_image, f'./pictures2/latend_{chosen_i}.png', nrow=10)
print(effect_num)

49 45.09355163574219
52 131.7681884765625
60 74.11412811279297
171 112.82151794433594
180 58.813621520996094
5


In [59]:
import torchvision
from torchvision.utils import save_image
# Generation
import random
latent_dim = 5
with torch.no_grad():    
    
    noise = torch.randn(100, latent_dim).to(device)
    c = [i for i in range(10) for _ in range(10)]
    c = one_hot_encode(c, num_classes = num_class)
    c = torch.tensor(c).to(device)
    generated_imgs = model.decoder(noise, c)
    resized_image = torchvision.transforms.Resize((50, 50))(generated_imgs)
    save_image(resized_image, './pictures/genera.png', nrow=10)