In [57]:
import torch
import torchvision
import plotly.express as px
import numpy as np
import pandas as pd

from tqdm import tqdm
from torch.utils.data import random_split
from torchvision import transforms
from sklearn.semi_supervised import LabelSpreading
from sklearn.manifold import TSNE
from PIL import Image
from sklearn.model_selection import StratifiedShuffleSplit

from modules import np_image_to_base64
from vae import VariationalAutoencoder

In [58]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model = VariationalAutoencoder(4, device)
model.load_state_dict(torch.load("model.pt"))
model.to(device)

mnist_testset = torchvision.datasets.MNIST(root="", train=False, download=True, transform=None)
test_transform = transforms.Compose([
    transforms.ToTensor(),
])


# Stratified Split

In [59]:
encoded_samples = []
true_labels = []
imgs = []

m=0.005

features = mnist_testset.data
labels = mnist_testset.targets

stratified_split = StratifiedShuffleSplit(n_splits=1, test_size=m, random_state=42)

train_indices, test_indices = next(stratified_split.split(features, labels))

unlabeled_features = features[train_indices]
unlabeled_labels = labels[train_indices]
labeled_features = features[test_indices]
labeled_labels = labels[test_indices]


unlabeled_features.transform = test_transform
labeled_features.transform = test_transform
print("Len of labeled: ", len(labeled_features), " Len of unlabeled: ", len(unlabeled_features))


Len of labeled:  50  Len of unlabeled:  9950


In [60]:
hist1 = px.histogram(unlabeled_labels, title="Unlabeled")
hist1.show()
hist2 = px.histogram(labeled_labels, title="Labeled")
hist2.show()

In [61]:
model.eval()
transform = transforms.ToTensor()
for index, sample in tqdm(enumerate(labeled_features)):
    img = sample.unsqueeze(0).to(device)
    imgs.append({"image": sample})
    label = labeled_labels[index]
    with torch.no_grad():
        encoded_img  = model.encoder(img.unsqueeze(0).float())
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = label
    encoded_samples.append(encoded_sample)
    true_labels.append(label)

for index, sample in tqdm(enumerate(unlabeled_features)):
    img = sample.unsqueeze(0).to(device)
    imgs.append({"image": sample})
    true_labels.append(unlabeled_labels[index])
    label = -1
    with torch.no_grad():
        encoded_img  = model.encoder(img.unsqueeze(0).float())
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = label
    encoded_samples.append(encoded_sample)

encoded_samples = pd.DataFrame(encoded_samples)

50it [00:00, 1133.36it/s]
9950it [00:07, 1392.76it/s]


In [62]:
tsne = TSNE(n_components=2)
tsne_results = tsne.fit_transform(encoded_samples.drop(['label'],axis=1))

In [63]:
label_prop_model = LabelSpreading()
label_prop_model.fit(tsne_results, encoded_samples["label"].astype("int"))

labels = label_prop_model.predict(tsne_results)


invalid value encountered in divide



In [64]:
from sklearn.metrics import accuracy_score
accuracy_score(true_labels, labels)

0.4544

In [65]:
fig = px.scatter(tsne_results, x=0, y=1, color=labels.astype(str),labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'},
                 color_discrete_map= {'-1': "black"})
fig.show()

# Random Split

In [66]:
encoded_samples = []
true_labels = []
imgs = []

mnist_testset = torchvision.datasets.MNIST(root="", train=False, download=True, transform=None)
test_transform = transforms.Compose([
    transforms.ToTensor(),
])
mnist_testset.transform = test_transform

m=len(mnist_testset)
mnist_testset_label, mnist_testset_unlabel = random_split(mnist_testset, [int(m*0.005), int(m*0.995)], torch.Generator().manual_seed(42))
print("Len of labeled: ", len(mnist_testset_label), " Len of unlabeled: ", len(mnist_testset_unlabel))

Len of labeled:  50  Len of unlabeled:  9950


In [67]:
labels_label = [label for _, label in mnist_testset_label]
labels_unlabel = [label for _, label in mnist_testset_unlabel]
hist1 = px.histogram(labels_unlabel, title="Unlabeled")
hist1.show()
hist2 = px.histogram(labels_label, title="Labeled")
hist2.show()

In [68]:
model.eval()
for sample in tqdm(mnist_testset_label):
    img = sample[0].unsqueeze(0).to(device)
    imgs.append({"image": sample[0]})
    label = sample[1]
    with torch.no_grad():
        encoded_img  = model.encoder(img)
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = label
    encoded_samples.append(encoded_sample)
    true_labels.append(label)

for sample in tqdm(mnist_testset_unlabel):
    img = sample[0].unsqueeze(0).to(device)
    imgs.append({"image": sample[0]})
    true_labels.append(sample[1])
    label = -1
    with torch.no_grad():
        encoded_img  = model.encoder(img)
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = label
    encoded_samples.append(encoded_sample)

encoded_samples = pd.DataFrame(encoded_samples)

100%|██████████| 50/50 [00:00<00:00, 771.95it/s]
100%|██████████| 9950/9950 [00:08<00:00, 1185.69it/s]


In [69]:
tsne = TSNE(n_components=2)
tsne_results = tsne.fit_transform(encoded_samples.drop(['label'],axis=1))

In [70]:
label_prop_model = LabelSpreading()
label_prop_model.fit(tsne_results, encoded_samples["label"])
labels = label_prop_model.predict(tsne_results)


invalid value encountered in divide



In [71]:
accuracy_score(true_labels, labels)

0.4846

In [72]:
fig = px.scatter(tsne_results, x=0, y=1, color=labels.astype(str),labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'},
                 color_discrete_map= {'-1': "black"})
fig.show()

# TSNE trick from https://arxiv.org/pdf/1712.09005.pdf

In [82]:
tsne = TSNE(n_components=2, perplexity=50, early_exaggeration=30, n_iter=2000, random_state=42)
tsne_results = tsne.fit_transform(encoded_samples.drop(['label'],axis=1))

In [83]:
true_labels_str = [str(label) for label in true_labels]
fig = px.scatter(tsne_results, x=0, y=1, color=true_labels_str,labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'},
                 color_discrete_map= {'-1': "black"})
fig.show()

# Clustering and selection

In [95]:
from sklearn.cluster import DBSCAN

In [121]:
encoded_samples = []
true_labels = []

mnist_testset = torchvision.datasets.MNIST(root="", train=False, download=True, transform=None)
test_transform = transforms.Compose([
    transforms.ToTensor(),
])
mnist_testset.transform = test_transform

In [122]:
model.eval()
for sample in tqdm(mnist_testset):
    img = sample[0].unsqueeze(0).to(device)
    label = sample[1]
    with torch.no_grad():
        encoded_img  = model.encoder(img)
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_samples.append(encoded_sample)
    true_labels.append(label)

encoded_samples = pd.DataFrame(encoded_samples)

100%|██████████| 10000/10000 [00:07<00:00, 1307.96it/s]


In [123]:
tsne = TSNE(n_components=2)
tsne_results = tsne.fit_transform(encoded_samples)

In [124]:
ms = DBSCAN(eps=3, min_samples=20).fit(tsne_results)
labels = ms.labels_
labels_unique = np.unique(labels)
print(len(labels_unique))

56


In [125]:
cluster_labels_str = [str(label) for label in labels]
fig = px.scatter(tsne_results, x=0, y=1, color=cluster_labels_str,labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'},
                 color_discrete_map= {'-1': "black"})
fig.show()

In [126]:
import random
from collections import defaultdict
label_indices = defaultdict(list)
for idx, label in enumerate(labels):
    label_indices[label].append(idx)
    
num_samples_per_label = 10

selected_indices = []
for label in label_indices.keys():
    selected_indices.extend(random.sample(label_indices[label], num_samples_per_label))
print(selected_indices)

[4865, 5180, 4453, 3103, 9393, 554, 987, 6253, 9723, 6269, 5703, 9804, 437, 1746, 8780, 4311, 2991, 6381, 309, 1798, 4931, 6101, 2221, 2406, 377, 8650, 1643, 9241, 9978, 963, 4698, 2057, 2609, 4464, 6068, 7721, 2170, 8013, 3447, 2039, 3332, 27, 6367, 1300, 906, 9432, 379, 9701, 185, 1222, 7360, 2074, 7320, 7515, 5017, 8150, 6507, 1543, 6738, 6249, 1724, 9492, 2380, 2394, 3287, 2493, 432, 4665, 3174, 1873, 2793, 1695, 7423, 1532, 8247, 426, 7731, 2615, 9252, 7902, 8343, 2932, 1843, 8657, 3805, 5136, 6309, 8611, 5428, 7255, 8491, 6110, 8488, 5678, 2901, 768, 1834, 196, 4990, 3065, 5862, 877, 53, 1896, 3157, 4116, 4808, 2534, 1466, 3677, 8534, 8452, 6200, 4182, 3076, 4683, 3981, 3924, 4319, 4213, 9312, 6613, 5898, 9138, 6387, 4565, 5690, 9461, 6128, 2002, 8232, 6611, 508, 869, 6651, 1670, 1087, 3022, 5078, 8878, 1920, 562, 8577, 4987, 1767, 4391, 1903, 1516, 1512, 3288, 2407, 9451, 4593, 8734, 3184, 8401, 917, 9169, 1186, 6243, 2287, 3455, 4676, 3070, 3503, 1012, 900, 3454, 3765, 3601, 52

In [127]:
selected = []
for idx, label in enumerate(labels):
    if idx in selected_indices:
        selected.append(str(1))
    else:
        selected.append(str(-1))

In [128]:
fig = px.scatter(tsne_results, x=0, y=1, color=selected,labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'},
                 color_discrete_map= {'-1': "black"})
fig.show()

In [136]:
labels = []

for idx, label in enumerate(true_labels):
    if idx in selected_indices:
        labels.append(label)
    else:
        labels.append(-1)

In [138]:
labels_str = [str(label) for label in labels]
fig = px.scatter(tsne_results, x=0, y=1, color=labels_str,labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'},
                 color_discrete_map= {'-1': "black"})
fig.show()

In [139]:
label_prop_model = LabelSpreading()
label_prop_model.fit(tsne_results, labels)
labels = label_prop_model.predict(tsne_results)


invalid value encountered in divide



In [141]:
labels_str = [str(label) for label in labels]
fig = px.scatter(tsne_results, x=0, y=1, color=labels_str,labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'},
                 color_discrete_map= {'-1': "black"})
fig.show()

In [142]:
accuracy_score(true_labels, labels)

0.7127

# Other