In [1]:
from typing import Dict, Iterable, Callable
from tqdm.notebook import tqdm
import torchvision.transforms.functional as T
from matplotlib.colors import ListedColormap
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from torch import nn
import torch
import collections
from collections import defaultdict
import torchvision
from torchvision import transforms
import numpy as np

In [2]:
class FeatureExtractor(nn.Module):
    def __init__(self, model, layer_names):
        super().__init__()
        self.model = model
        self.layer_names = layer_names
        self._features = defaultdict(list)

        layer_dict = dict([*self.model.named_modules()])
        for layer_name in layer_names:
            layer = layer_dict[layer_name]
            layer.register_forward_hook(self.save_outputs_hook(layer_name))

    def save_outputs_hook(self, layer_name):
        def fn(_, __, output):
            self._features[layer_name] = output
        return fn

    def forward(self, x):
        _ = self.model(x)
        return self._features

In [3]:
# loading model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
model = Net()

In [4]:
checkpoint = torch.load('fashion_mnist_cnn.pt')
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [5]:
layer_dict = dict([*model.named_modules()])
layer_dict

{'': Net(
   (fc1): Linear(in_features=784, out_features=512, bias=True)
   (fc2): Linear(in_features=512, out_features=10, bias=True)
 ),
 'fc1': Linear(in_features=784, out_features=512, bias=True),
 'fc2': Linear(in_features=512, out_features=10, bias=True)}

In [6]:
feature_extractor = FeatureExtractor(model, layer_names=['fc1'])

In [7]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
                                        download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                            shuffle=False)

In [8]:
features = []
labels = []

In [9]:
for batch in tqdm(testloader, total=len(testloader)):
    out = feature_extractor(batch[0])['fc1']
    features.append(out.detach().numpy())
    labels.append([batch[1].numpy()])

  0%|          | 0/157 [00:00<?, ?it/s]

In [10]:
len(features), len(labels)  

(157, 157)

In [21]:
np.concatenate(features).shape

(10000, 512)

In [11]:
np.concatenate(features).shape, np.concatenate(labels).shape

ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 64 and the array at index 156 has size 16

In [22]:
tnse = TSNE(n_components=3, perplexity=50, learning_rate=200, n_iter=500, random_state=0)

In [23]:
x_tsne = tnse.fit_transform(np.concatenate(features))