In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import pickle
from torch.utils.data import Dataset
import numpy as np
import torchvision
from torchvision.utils import save_image
from sklearn.cross_decomposition import CCA
from scipy.optimize import linear_sum_assignment

In [2]:
class CVAE(nn.Module):

    def __init__(self, condi_dim, channels, latent_dim) -> None:
        super().__init__()

        # encoder
        pre_channel = 1
        modules = []
        img_length = 28
        
        for i in range(len(channels)-1):
            modules.append(
                nn.Sequential(
                    nn.Conv2d(pre_channel,
                              channels[i+1],
                              kernel_size=3,
                              stride=2,
                              padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU()
                )
            )
            pre_channel = channels[i+1]
            img_length = (img_length-1)//2+1
        
        self.encoder_projection = nn.Sequential(
                nn.Linear(pre_channel * img_length * img_length + condi_dim, pre_channel * img_length * img_length),
                nn.ReLU()
        )
        
        self.encoder = nn.Sequential(*modules)
        self.mean_linear = nn.Linear(pre_channel * img_length * img_length,
                                     latent_dim)
        self.var_linear = nn.Linear(pre_channel * img_length * img_length,
                                    latent_dim)
        self.latent_dim = latent_dim
        
        # decoder
        modules = []
        self.decoder_projection = nn.Linear(
            latent_dim + condi_dim, pre_channel * img_length * img_length)
        self.decoder_input_chw = (pre_channel, img_length, img_length)

        for i in range(len(channels)-1):
            if(i == 0):
                modules.append(
                    nn.Sequential(
                        nn.ConvTranspose2d(channels[len(channels)-i-1],
                                           channels[len(channels)-i-2],
                                           kernel_size=3,
                                           stride=2,
                                           padding=1,
                                           output_padding=0),
                        nn.BatchNorm2d(channels[len(channels)-i-2]),
                        nn.ReLU()
                    )
                )
            else:
                modules.append(
                    nn.Sequential(
                        nn.ConvTranspose2d(channels[len(channels)-i-1],
                                           channels[len(channels)-i-2],
                                           kernel_size=3,
                                           stride=2,
                                           padding=1,
                                           output_padding=1),
                        nn.BatchNorm2d(channels[len(channels)-i-2]),
                        nn.ReLU()
                    )
                )
        self.decoder_layers = nn.Sequential(*modules)
        
    def decoder(self, z, c):
        z = torch.cat([z, c], dim = 1)
        z = self.decoder_projection(z)
        z = torch.reshape(z, (-1, *self.decoder_input_chw))
        decoded = self.decoder_layers(z)
        return decoded

    def forward(self, x, c):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        x = torch.cat([x, c], dim = 1)
        encoded = self.encoder_projection(x)
        mean = self.mean_linear(encoded)
        logvar = self.var_linear(encoded)
        eps = torch.randn_like(logvar)
        std = torch.exp(logvar / 2)
        z = eps * std + mean
        decoded = self.decoder(z, c)
        return decoded, mean, logvar

In [26]:
model_i = "1"
device = 'cpu'
model = torch.load('./model' + str(model_i) + '.pth', map_location=device)

In [27]:
latend_dim = 10
num_class = 10
diff_unit = 0.4

def one_hot_encode(labels, num_classes):
    one_hot_labels = np.zeros((len(labels), num_classes))
    one_hot_labels[np.arange(len(labels)), labels] = 1
    return one_hot_labels.tolist()

effect_num = 0
chosen_i_list = []
for chosen_i in range(latend_dim):
#for chosen_i in chosen_i_list:
    for i in range(10):
        for k in range(10):
            diff = (k-5)*diff_unit
            z0 = torch.zeros(latend_dim)
            z0[chosen_i] = diff
            if(i == 0 and k == 0):
                z = z0[None, :]
            else:
                z = torch.cat((z, z0[None, :]), 0)
    model.eval()
    with torch.no_grad():
        max_diff = -1
        c = [i for i in range(10) for _ in range(10)]
        c = one_hot_encode(c, num_classes = num_class)
        c = torch.tensor(c).to(device)
        noise = z
        generated_imgs = model.decoder(noise, c)
        t = 2
        for i in range(10):
            if(i == 0):
                x = generated_imgs[i*10:(i+1)*10-1]
            else:
                x = torch.cat((x, generated_imgs[i*10:(i+1)*10-1]), 0)
            diff_x = torch.abs((generated_imgs[i*10+t]-generated_imgs[(i+1)*10-1-t])[None, :])
            x = torch.cat((x, diff_x), 0)
            if(torch.sum(diff_x).item() > max_diff):
                max_diff = torch.sum(diff_x).item()
        print(chosen_i, max_diff)
        if(max_diff >= 2):
            effect_num += 1
            chosen_i_list.append(chosen_i)
        #resized_image = torchvision.transforms.Resize((50, 50))(x)
        #save_image(resized_image, f'./pictures2/latend_{chosen_i}.png', nrow=10)
print(effect_num)

0 0.9632492661476135
1 131.75804138183594
2 70.04931640625
3 0.994452953338623
4 111.01948547363281
5 72.21029663085938
6 51.83504104614258
7 0.8544827699661255
8 0.7677341103553772
9 1.007795810699463
5


In [28]:
print(chosen_i_list)
no_effect_list = [i for i in range(latend_dim) if i not in chosen_i_list]

[1, 2, 4, 5, 6]


In [29]:
device = 'cuda:0'
model = model.to(device)

In [7]:
# calculate FID effective

model.eval()
with torch.no_grad():
    num = 30000
    t = 100
    num0 = int(num/t)
    for i in range(t):
        noise = torch.randn(num0, latend_dim).to(device)
        noise[:, no_effect_list] = 0
        label0 = torch.randint(low=0, high=10, size=(num0,), dtype=torch.int32)
        c = label0.tolist()
        c = one_hot_encode(c, num_classes = num_class)
        c = torch.tensor(c).to(device)
        x0 = model.decoder(noise, c)
        if(i == 0):
            x = x0
        else:
            x = torch.cat((x, x0), 0)
            
import random
from torchvision.transforms import ToPILImage
import torch_fidelity
import numpy as np
import os
import shutil

save_dir = './samples/generated' + str(model_i)
if os.path.exists(save_dir) and os.path.isdir(save_dir):
    shutil.rmtree(save_dir)
os.makedirs(save_dir, exist_ok=True)
to_pil = ToPILImage()
for i in range(x.size(0)):
    image_tensor = x[i]
    image = to_pil(image_tensor)
    save_path = os.path.join(save_dir, f'image_{i}.png')
    image.save(save_path)

# IS and FID
metrics_dict = torch_fidelity.calculate_metrics(
    input1= './samples/generated' + str(model_i),
    input2= './samples/base',
    cuda=True,
    isc=True,
    fid=True
)

Creating feature extractor "inception-v3-compat" with features ['logits_unbiased', '2048']
Extracting features from input1
Looking for samples non-recursivelty in "./samples/generated1" with extensions png,jpg,jpeg
Found 30000 samples
Processing samples                                                              
Extracting features from input2
Looking for samples non-recursivelty in "./samples/base" with extensions png,jpg,jpeg
Found 30000 samples
Processing samples                                                              
Inception Score: 2.4646159540931833 ± 0.02533158709602798
Frechet Inception Distance: 14.120685147728295


In [59]:
# view generated pictures

import torchvision
from torchvision.utils import save_image
# Generation
import random
with torch.no_grad():
    noise = torch.randn(100, latend_dim).to(device)
    c = [i for i in range(10) for _ in range(10)]
    c = one_hot_encode(c, num_classes = num_class)
    c = torch.tensor(c).to(device)
    generated_imgs = model.decoder(noise, c)
    resized_image = torchvision.transforms.Resize((50, 50))(generated_imgs)
    save_image(resized_image, './pictures/genera.png', nrow=10)

In [30]:
device = 'cuda:0'
compare_model = torch.load('./model.pth', map_location=device)
base_dim = 5

In [31]:
def get_correlation_matrix(X, Y):
    combined_matrix = np.hstack((X, Y))
    correlation_matrix = np.corrcoef(combined_matrix, rowvar=False)[:X.shape[1], X.shape[1]:]
    return correlation_matrix

def get_strong_MCC(X, Y):
    train_length = int(len(X)/2)
    train_X = X[:train_length]
    train_Y = Y[:train_length]
    test_X = X[train_length:]
    test_Y = Y[train_length:]
    correlation_matrix = get_correlation_matrix(train_X, train_Y)
    row_ind, col_ind = linear_sum_assignment(-np.abs(correlation_matrix))
    P_fitted = np.zeros((X.shape[1], X.shape[1]), dtype=int)
    P_fitted[col_ind, row_ind] = 1
    return np.mean(np.diag(np.abs(get_correlation_matrix(test_X @ P_fitted.T, test_Y))))

def get_weak_MCC(X, Y):
    train_length = int(len(X)/2)
    train_X = X[:train_length]
    train_Y = Y[:train_length]
    test_X = X[train_length:]
    test_Y = Y[train_length:]
    cca = CCA(n_components=min(train_X.shape[1], train_Y.shape[1]))
    cca.fit(train_X, train_Y)
    canonical_coefficients = cca.coef_
    return np.mean(np.diag(np.abs(get_correlation_matrix(test_X @ canonical_coefficients.T, test_Y))))

In [32]:
mean_strong_MCC = 0
mean_weak_MCC = 0

for class_type in range(10):
    
    # get data
    model.eval()
    compare_model.eval()
    with torch.no_grad():
        num = 5000
        t = 100
        num0 = int(num/t)
        for i in range(t):
            # get x from z
            noise0 = torch.randn(num0, base_dim).to(device)
            c = [class_type for _ in range(num0)]
            c = one_hot_encode(c, num_classes = num_class)
            c = torch.tensor(c).to(device)
            x0 = compare_model.decoder(noise0, c)
            # get z from x
            _, z0, _ = model(x0, c)
            compare_noise0 = z0[:, chosen_i_list]
            if(i == 0):
                noise = noise0
                compare_noise = compare_noise0
            else:
                noise = torch.cat((noise, noise0), 0)
                compare_noise = torch.cat((compare_noise, compare_noise0), 0)
           
    noise = np.array(noise.to("cpu"))
    compare_noise = np.array(compare_noise.to("cpu"))
    
    strong_MCC = get_strong_MCC(noise, compare_noise)
    weak_MCC = get_weak_MCC(noise, compare_noise)
    print(f"c = {class_type}, strong MCC = {strong_MCC}, weak MCC = {weak_MCC}")
    
    mean_strong_MCC += strong_MCC
    mean_weak_MCC += weak_MCC

print(f"mean strong MCC = {mean_strong_MCC/10}, mean weak MCC = {mean_weak_MCC/10}")

c = 0, strong MCC = 0.7896775780907309, weak MCC = 0.9936773818170099
c = 1, strong MCC = 0.5964680031881946, weak MCC = 0.6680308770102208
c = 2, strong MCC = 0.9466711581890486, weak MCC = 0.9939553466955185
c = 3, strong MCC = 0.8789764344490534, weak MCC = 0.9914380410618862
c = 4, strong MCC = 0.9317557360348726, weak MCC = 0.9918375226572309
c = 5, strong MCC = 0.8660216882774942, weak MCC = 0.9804627638645226
c = 6, strong MCC = 0.925603593902807, weak MCC = 0.987941306901936
c = 7, strong MCC = 0.9131378189801378, weak MCC = 0.9663520779372297
c = 8, strong MCC = 0.8965656196082558, weak MCC = 0.9909589151567445
c = 9, strong MCC = 0.8835648813817137, weak MCC = 0.973312193013442
mean strong MCC = 0.8628442512102309, mean weak MCC = 0.953796642611574
