In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
import numpy as np
from torchvision.utils import save_image

In [2]:
class CVAE(nn.Module):

    def __init__(self, condi_dim, channels, latent_dim) -> None:
        super().__init__()

        # encoder
        pre_channel = 1
        modules = []
        img_length = 28
        
        for i in range(len(channels)-1):
            modules.append(
                nn.Sequential(
                    nn.Conv2d(pre_channel,
                              channels[i+1],
                              kernel_size=3,
                              stride=2,
                              padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU()
                )
            )
            pre_channel = channels[i+1]
            img_length = (img_length-1)//2+1
        
        self.encoder_projection = nn.Sequential(
                nn.Linear(pre_channel * img_length * img_length + condi_dim, pre_channel * img_length * img_length),
                nn.ReLU()
        )
        
        self.encoder = nn.Sequential(*modules)
        self.mean_linear = nn.Linear(pre_channel * img_length * img_length,
                                     latent_dim)
        self.var_linear = nn.Linear(pre_channel * img_length * img_length,
                                    latent_dim)
        self.latent_dim = latent_dim
        
        # decoder
        modules = []
        self.decoder_projection = nn.Linear(
            latent_dim + condi_dim, pre_channel * img_length * img_length)
        self.decoder_input_chw = (pre_channel, img_length, img_length)

        for i in range(len(channels)-1):
            if(i == 0):
                modules.append(
                    nn.Sequential(
                        nn.ConvTranspose2d(channels[len(channels)-i-1],
                                           channels[len(channels)-i-2],
                                           kernel_size=3,
                                           stride=2,
                                           padding=1,
                                           output_padding=0),
                        nn.BatchNorm2d(channels[len(channels)-i-2]),
                        nn.ReLU()
                    )
                )
            else:
                modules.append(
                    nn.Sequential(
                        nn.ConvTranspose2d(channels[len(channels)-i-1],
                                           channels[len(channels)-i-2],
                                           kernel_size=3,
                                           stride=2,
                                           padding=1,
                                           output_padding=1),
                        nn.BatchNorm2d(channels[len(channels)-i-2]),
                        nn.ReLU()
                    )
                )
        self.decoder_layers = nn.Sequential(*modules)
        
    def decoder(self, z, c):
        z = torch.cat([z, c], dim = 1)
        z = self.decoder_projection(z)
        z = torch.reshape(z, (-1, *self.decoder_input_chw))
        decoded = self.decoder_layers(z)
        return decoded

    def forward(self, x, c):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        x = torch.cat([x, c], dim = 1)
        encoded = self.encoder_projection(x)
        mean = self.mean_linear(encoded)
        logvar = self.var_linear(encoded)
        eps = torch.randn_like(logvar)
        std = torch.exp(logvar / 2)
        z = eps * std + mean
        decoded = self.decoder(z, c)
        return decoded, mean, logvar

In [29]:
model_i = ""
device = 'cpu'
model = torch.load('./model1' + str(model_i) + '.pth', map_location=device, weights_only=False)

In [31]:
latend_dim = 100
num_class = 10
diff_unit = 0.5

def one_hot_encode(labels, num_classes):
    one_hot_labels = np.zeros((len(labels), num_classes))
    one_hot_labels[np.arange(len(labels)), labels] = 1
    return one_hot_labels.tolist()

effect_num = 0
chosen_i_list = []
for chosen_i in range(latend_dim):
#for chosen_i in chosen_i_list:
    for i in range(10):
        for k in range(10):
            diff = (k-5)*diff_unit
            z0 = torch.zeros(latend_dim)
            z0[chosen_i] = diff
            if(i == 0 and k == 0):
                z = z0[None, :]
            else:
                z = torch.cat((z, z0[None, :]), 0)
    model.eval()
    with torch.no_grad():
        max_diff = -1
        c = [i for i in range(10) for _ in range(10)]
        c = one_hot_encode(c, num_classes = num_class)
        c = torch.tensor(c).to(device)
        noise = z
        generated_imgs = model.decoder(noise, c)
        t = 2
        for i in range(10):
            if(i == 0):
                x = generated_imgs[i*10:(i+1)*10-1]
            else:
                x = torch.cat((x, generated_imgs[i*10:(i+1)*10-1]), 0)
            diff_x = torch.abs((generated_imgs[i*10+t]-generated_imgs[(i+1)*10-1-t])[None, :])
            x = torch.cat((x, diff_x), 0)
            if(torch.sum(diff_x).item() > max_diff):
                max_diff = torch.sum(diff_x).item()
        print(chosen_i, max_diff)
        if(max_diff >= 5):
            effect_num += 1
            chosen_i_list.append(chosen_i)
        #resized_image = torchvision.transforms.Resize((50, 50))(x)
        #save_image(resized_image, f'./pictures2/latend_{chosen_i}.png', nrow=10)
print(effect_num)

0 128.77056884765625
1 0.4584888219833374
2 46.97119903564453
3 0.45331287384033203
4 0.5370331406593323
5 0.5602951645851135
6 0.5077562928199768
7 0.4779769778251648
8 0.6215984225273132
9 0.4455057382583618
10 0.520866870880127
11 19.317182540893555
12 0.39344024658203125
13 0.46297192573547363
14 0.4824103116989136
15 0.5133993625640869
16 0.45491018891334534
17 0.42163193225860596
18 0.5148807764053345
19 0.41963309049606323
20 0.49334728717803955
21 0.5261502265930176
22 20.828359603881836
23 0.44930776953697205
24 0.4930216670036316
25 0.45483753085136414
26 47.7014274597168
27 0.4196412265300751
28 0.5666756629943848
29 0.48159247636795044
30 0.4435326159000397
31 0.5242775082588196
32 0.5265655517578125
33 0.5647923946380615
34 0.438334584236145
35 0.474765807390213
36 0.45799362659454346
37 52.24833679199219
38 0.46859073638916016
39 0.5341638922691345
40 0.42603036761283875
41 0.4484359323978424
42 0.47025173902511597
43 0.5253233909606934
44 0.5021467208862305
45 0.55219435

In [32]:
device = 'cuda:0'
compare_model = torch.load('./model.pth', map_location=device, weights_only=False)
model = model.to(device)
base_dim = 20

In [33]:
def get_correlation_matrix(X, Y):
    combined_matrix = np.hstack((X, Y))
    correlation_matrix = np.corrcoef(combined_matrix, rowvar=False)[:X.shape[1], X.shape[1]:]
    return correlation_matrix

def get_strong_MCC(X, Y):
    train_length = int(len(X)/2)
    train_X = X[:train_length]
    train_Y = Y[:train_length]
    test_X = X[train_length:]
    test_Y = Y[train_length:]
    correlation_matrix = get_correlation_matrix(train_X, train_Y)
    row_ind, col_ind = linear_sum_assignment(-np.abs(correlation_matrix))
    P_fitted = np.zeros((X.shape[1], X.shape[1]), dtype=int)
    P_fitted[col_ind, row_ind] = 1
    return np.mean(np.diag(np.abs(get_correlation_matrix(test_X @ P_fitted.T, test_Y))))

def get_weak_MCC(X, Y):
    train_length = int(len(X)/2)
    train_X = X[:train_length]
    train_Y = Y[:train_length]
    test_X = X[train_length:]
    test_Y = Y[train_length:]
    cca = CCA(n_components=min(train_X.shape[1], train_Y.shape[1]))
    cca.fit(train_X, train_Y)
    canonical_coefficients = cca.coef_
    return np.mean(np.diag(np.abs(get_correlation_matrix(test_X @ canonical_coefficients.T, test_Y))))

In [34]:
from sklearn.cross_decomposition import CCA
from scipy.optimize import linear_sum_assignment

mean_strong_MCC = 0
mean_weak_MCC = 0

for class_type in range(10):
    
    # get data
    model.eval()
    compare_model.eval()
    with torch.no_grad():
        num = 5000
        t = 100
        num0 = int(num/t)
        for i in range(t):
            # get x from z
            noise0 = torch.randn(num0, base_dim).to(device)
            c = [class_type for _ in range(num0)]
            c = one_hot_encode(c, num_classes = num_class)
            c = torch.tensor(c).to(device)
            x0 = compare_model.decoder(noise0, c)
            # get z from x
            _, z0, _ = model(x0, c)
            compare_noise0 = z0[:, chosen_i_list]
            if(i == 0):
                noise = noise0
                compare_noise = compare_noise0
            else:
                noise = torch.cat((noise, noise0), 0)
                compare_noise = torch.cat((compare_noise, compare_noise0), 0)
           
    noise = np.array(noise.to("cpu"))
    compare_noise = np.array(compare_noise.to("cpu"))
    
    strong_MCC = get_strong_MCC(noise, compare_noise)
    weak_MCC = get_weak_MCC(noise, compare_noise)
    print(f"c = {class_type}, strong MCC = {strong_MCC}, weak MCC = {weak_MCC}")
    
    mean_strong_MCC += strong_MCC
    mean_weak_MCC += weak_MCC

print(f"mean strong MCC = {mean_strong_MCC/10}, mean weak MCC = {mean_weak_MCC/10}")

  noise = np.array(noise.to("cpu"))
  compare_noise = np.array(compare_noise.to("cpu"))


c = 0, strong MCC = 0.5635043596363752, weak MCC = 0.917024842793122


  noise = np.array(noise.to("cpu"))
  compare_noise = np.array(compare_noise.to("cpu"))


c = 1, strong MCC = 0.5463857643114413, weak MCC = 0.8940892614855184


  noise = np.array(noise.to("cpu"))
  compare_noise = np.array(compare_noise.to("cpu"))


c = 2, strong MCC = 0.5571462513277916, weak MCC = 0.9071189715754086


  noise = np.array(noise.to("cpu"))
  compare_noise = np.array(compare_noise.to("cpu"))


c = 3, strong MCC = 0.5539298958646283, weak MCC = 0.9080429439731551


  noise = np.array(noise.to("cpu"))
  compare_noise = np.array(compare_noise.to("cpu"))


c = 4, strong MCC = 0.5546250398107467, weak MCC = 0.90415324064854


  noise = np.array(noise.to("cpu"))
  compare_noise = np.array(compare_noise.to("cpu"))


c = 5, strong MCC = 0.5574132372126517, weak MCC = 0.9112164685520586


  noise = np.array(noise.to("cpu"))
  compare_noise = np.array(compare_noise.to("cpu"))


c = 6, strong MCC = 0.5591608858994486, weak MCC = 0.9057998640128038


  noise = np.array(noise.to("cpu"))
  compare_noise = np.array(compare_noise.to("cpu"))


c = 7, strong MCC = 0.5607725073415694, weak MCC = 0.9166571685801301


  noise = np.array(noise.to("cpu"))
  compare_noise = np.array(compare_noise.to("cpu"))


c = 8, strong MCC = 0.5651033481374456, weak MCC = 0.9168227113041141


  noise = np.array(noise.to("cpu"))
  compare_noise = np.array(compare_noise.to("cpu"))


c = 9, strong MCC = 0.5591768409539432, weak MCC = 0.9223280535852671
mean strong MCC = 0.5577218130496042, mean weak MCC = 0.910325352651012
