In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
import tqdm
from sklearn.manifold import TSNE

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.normalize(x)
        return x

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))])

In [None]:
mnist = datasets.MNIST('data', train=False, download=True, transform=transform)

In [None]:
model = Net()

In [None]:
model.load_state_dict(torch.load('result/model.pth', map_location=torch.device('cpu')))

In [None]:
x = mnist[0][0]
x = x.view(1, *x.shape)

In [None]:
feat = model(x)

In [None]:
data = []
targets = []
for m in tqdm.tqdm(mnist):
    target = m[1]
    targets.append(target)
    x = m[0]
    x = x.view(1, *x.shape)
    feat = model(x)
    data.append(feat.data.numpy()[0])

In [None]:
ret = TSNE(n_components=2, random_state=0).fit_transform(data)

In [None]:
plt.scatter(ret[:, 0], ret[:, 1], c=targets)
plt.colorbar()
plt.show()