In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import pickle
from torch.utils.data import Dataset
import numpy as np
import torchvision
from torchvision.utils import save_image
from sklearn.cross_decomposition import CCA
from scipy.optimize import linear_sum_assignment

In [2]:
class VAE(nn.Module):

    def __init__(self, channels, latent_dim) -> None:
        super().__init__()

        # encoder
        pre_channel = 1
        modules = []
        img_length = 32
        
        for i in range(len(channels)-1):
            modules.append(
                nn.Sequential(
                    nn.Conv2d(pre_channel,
                              channels[i+1],
                              kernel_size=3,
                              stride=2,
                              padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU()
                )
            )
            pre_channel = channels[i+1]
            img_length = (img_length-1)//2+1
        
        self.encoder_projection = nn.Sequential(
                nn.Linear(pre_channel * img_length * img_length, pre_channel * img_length * img_length),
                nn.ReLU()
        )
        
        self.encoder = nn.Sequential(*modules)
        self.mean_linear = nn.Linear(pre_channel * img_length * img_length,
                                     latent_dim)
        self.var_linear = nn.Linear(pre_channel * img_length * img_length,
                                    latent_dim)
        self.latent_dim = latent_dim
        
        # decoder
        modules = []
        self.decoder_projection = nn.Linear(
            latent_dim, pre_channel * img_length * img_length)
        self.decoder_input_chw = (pre_channel, img_length, img_length)

        for i in range(len(channels)-1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(channels[len(channels)-i-1],
                                       channels[len(channels)-i-2],
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(channels[len(channels)-i-2]),
                    nn.ReLU()
                )
            )
        self.decoder_layers = nn.Sequential(*modules)
        
    def decoder(self, z):
        z = self.decoder_projection(z)
        z = torch.reshape(z, (-1, *self.decoder_input_chw))
        decoded = self.decoder_layers(z)
        return decoded

    def forward(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        encoded = self.encoder_projection(x)
        mean = self.mean_linear(encoded)
        logvar = self.var_linear(encoded)
        eps = torch.randn_like(logvar)
        std = torch.exp(logvar / 2)
        z = eps * std + mean
        decoded = self.decoder(z)
        return decoded, mean, logvar

In [5]:
model_i = "2"
device = 'cpu'
model = torch.load('./model' + str(model_i) + '.pth', map_location=device)

In [6]:
latend_dim = 2
diff_unit = 0.4

effect_num = 0
chosen_i_list = []
for chosen_i in range(latend_dim):
#for chosen_i in chosen_i_list:
    for k in range(10):
        diff = (k-5)*diff_unit
        z0 = torch.zeros(latend_dim)
        z0[chosen_i] = diff
        if(k == 0):
            z = z0[None, :]
        else:
            z = torch.cat((z, z0[None, :]), 0)
    model.eval()
    with torch.no_grad():
        max_diff = -1
        noise = z
        generated_imgs = model.decoder(noise)
        x = generated_imgs
        diff_x = torch.abs((generated_imgs[7]-generated_imgs[2])[None, :])
        max_diff = torch.sum(diff_x).item()
        print(chosen_i, max_diff)
        if(max_diff >= 2):
            effect_num += 1
            chosen_i_list.append(chosen_i)
        if chosen_i == 0:
            all_generated_imgs = x
        else:
            all_generated_imgs = torch.cat((all_generated_imgs, x), 0)
print(effect_num)
resized_image = torchvision.transforms.Resize((50, 50))(all_generated_imgs)
save_image(resized_image, f'./pictures1/latend_all' + model_i + '.png', nrow=10)

0 62.44047927856445
1 25.432859420776367
2


In [7]:
device = 'cuda:0'
model = model.to(device)

In [8]:
device = 'cuda:0'
compare_model = torch.load('./model1.pth', map_location=device)
base_dim = 2

In [9]:
def get_correlation_matrix(X, Y):
    combined_matrix = np.hstack((X, Y))
    correlation_matrix = np.corrcoef(combined_matrix, rowvar=False)[:X.shape[1], X.shape[1]:]
    return correlation_matrix

def get_strong_MCC(X, Y):
    train_length = int(len(X)/2)
    train_X = X[:train_length]
    train_Y = Y[:train_length]
    test_X = X[train_length:]
    test_Y = Y[train_length:]
    correlation_matrix = get_correlation_matrix(train_X, train_Y)
    row_ind, col_ind = linear_sum_assignment(-np.abs(correlation_matrix))
    P_fitted = np.zeros((X.shape[1], X.shape[1]), dtype=int)
    P_fitted[col_ind, row_ind] = 1
    return np.mean(np.diag(np.abs(get_correlation_matrix(test_X @ P_fitted.T, test_Y))))

def get_weak_MCC(X, Y):
    train_length = int(len(X)/2)
    train_X = X[:train_length]
    train_Y = Y[:train_length]
    test_X = X[train_length:]
    test_Y = Y[train_length:]
    cca = CCA(n_components=min(train_X.shape[1], train_Y.shape[1]))
    cca.fit(train_X, train_Y)
    canonical_coefficients = cca.coef_
    return np.mean(np.diag(np.abs(get_correlation_matrix(test_X @ canonical_coefficients.T, test_Y))))

In [10]:
mean_strong_MCC = 0
mean_weak_MCC = 0

# get data
model.eval()
compare_model.eval()
with torch.no_grad():
    num = 5000
    t = 100
    num0 = int(num/t)
    for i in range(t):
        # get x from z
        noise0 = torch.randn(num0, base_dim).to(device)
        x0 = compare_model.decoder(noise0)
        # get z from x
        _, z0, _ = model(x0)
        compare_noise0 = z0
        if(i == 0):
            noise = noise0
            compare_noise = compare_noise0
        else:
            noise = torch.cat((noise, noise0), 0)
            compare_noise = torch.cat((compare_noise, compare_noise0), 0)

noise = np.array(noise.to("cpu"))
compare_noise = np.array(compare_noise.to("cpu"))

strong_MCC = get_strong_MCC(noise, compare_noise)
weak_MCC = get_weak_MCC(noise, compare_noise)
print(f"strong MCC = {strong_MCC}, weak MCC = {weak_MCC}")

strong MCC = 0.6183719347023469, weak MCC = 0.6296616725844808
