In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd 
import random 
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader,random_split
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from sklearn.manifold import TSNE
import plotly.express as px
from vae import VariationalAutoencoder
from modules import train_epoch, test_epoch

!pip install nbformat 

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')

In [None]:
mnist_trainset = torchvision.datasets.MNIST(root="", train=True, download=True, transform=None)
mnist_testset = torchvision.datasets.MNIST(root="", train=False, download=True, transform=None)

batch_size=256
m=len(mnist_trainset)

train_transform = transforms.Compose([
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
])

mnist_trainset.transform = train_transform
mnist_testset.transform = test_transform

mnist_trainset, val_data = random_split(mnist_trainset, [int(m-m*0.2), int(m*0.2)])

train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=batch_size,shuffle=True)

In [None]:
d = 4
lr = 1e-3 

vae = VariationalAutoencoder(latent_dims=d, device=device)
vae.to(device)

optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=1e-5)

In [None]:
def plot_ae_outputs(encoder,decoder,n=10):
    plt.figure(figsize=(16,4.5))
    targets = mnist_testset.targets.numpy()
    t_idx = {i:np.where(targets==i)[0][0] for i in range(n)}
    for i in range(n):
      ax = plt.subplot(2,n,i+1)
      img = mnist_testset[t_idx[i]][0].unsqueeze(0).to(device)
      encoder.eval()
      decoder.eval()
      with torch.no_grad():
         rec_img  = decoder(encoder(img))
      plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(2, n, i + 1 + n)
      plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')  
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
         ax.set_title('Reconstructed images')
    plt.show()  

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
   train_loss = train_epoch(vae,device,train_loader,optim)
   val_loss = test_epoch(vae,device,valid_loader)
   print('\n EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}'.format(epoch + 1, num_epochs,train_loss,val_loss))
   # plot_ae_outputs(vae.encoder,vae.decoder,n=10)

In [None]:
torch.save(vae.state_dict(), "model.pt")

In [None]:
encoded_samples = []
true_labels = []
imgs = []

m=len(mnist_testset)
mnist_testset_label, mnist_testset_unlabel = random_split(mnist_testset, [int(m*0.2), int(m*0.8)], torch.Generator().manual_seed(42))
print("Len of labeled: ", len(mnist_testset_label), " Len of unlabeled: ", len(mnist_testset_unlabel))

vae.eval()
for sample in tqdm(mnist_testset_label):
    img = sample[0].unsqueeze(0).to(device)
    imgs.append({"image": sample[0]})
    label = sample[1]
    with torch.no_grad():
        encoded_img  = vae.encoder(img)
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = label
    encoded_samples.append(encoded_sample)
    true_labels.append(label)

for sample in tqdm(mnist_testset_unlabel):
    img = sample[0].unsqueeze(0).to(device)
    imgs.append({"image": sample[0]})
    true_labels.append(sample[1])
    label = -1
    with torch.no_grad():
        encoded_img  = vae.encoder(img)
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = label
    encoded_samples.append(encoded_sample)

encoded_samples = pd.DataFrame(encoded_samples)

# px.scatter(encoded_samples, x='Enc. Variable 0', y='Enc. Variable 1', color=encoded_samples.label.astype(str), opacity=0.7)

In [None]:
tsne = TSNE(n_components=2)
tsne_results = tsne.fit_transform(encoded_samples.drop(['label'],axis=1))

# fig = px.scatter(tsne_results, x=0, y=1, color=encoded_samples.label.astype(str),labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'},
#                  color_discrete_map= {'-1': "black"})
# fig.show()

In [None]:
from sklearn.semi_supervised import LabelSpreading
label_prop_model = LabelSpreading()
label_prop_model.fit(tsne_results, encoded_samples["label"])
label_prop_model.get_params()

In [None]:
labels = label_prop_model.predict(tsne_results)

In [None]:
# fig = px.scatter(tsne_results, x=0, y=1, color=labels.astype(str),labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'},
#                  color_discrete_map= {'-1': "black"})
# fig.show()

In [None]:
from sklearn.metrics import accuracy_score
accuracy_score(true_labels, labels)

In [None]:
tsne_results = pd.concat([pd.DataFrame(tsne_results), pd.DataFrame(imgs)], axis=1)

# fig = px.scatter(tsne_results, x=0, y=1, color=labels.astype(str),labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'},
#                  color_discrete_map= {'-1': "black"}, custom_data=["image"])
# fig.update_layout(clickmode='event+select')