In [None]:
import torch
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import utils
from voc_dataset import VOCDataset
from train_q2 import ResNet


MODEL_PATH = rf"E:\courses\16824\hw1\checkpoints\q2_1e-4_0.3_6_50_aug_0.7929\checkpoint-model-epoch32.pth"
NUM_IMAGES = 1000
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model
model = torch.load(MODEL_PATH, weights_only=False)
model.to(DEVICE)
model.eval()

# Load data
test_loader = utils.get_data_loader('voc', train=False, batch_size=32, split='test', inp_size=224)

# Generate color map
colors = plt.cm.get_cmap('tab20', len(VOCDataset.CLASS_NAMES)).colors
CLASS_COLOR_MAP = {i: colors[i] for i in range(len(VOCDataset.CLASS_NAMES))}

In [None]:
# Define feature extraction hook
features = []
def hook_fn(module, input, output):
    # .squeeze() to remove unnecessary dimensions
    features.append(output.squeeze().detach().cpu().numpy())

handle = model.resnet.avgpool.register_forward_hook(hook_fn)

In [None]:
all_labels = []
with torch.no_grad():
    for data, target, wgt in test_loader:
        all_labels.append(target.numpy())
        
        data = data.to(DEVICE)
        model(data)
        
        if len(features) * test_loader.batch_size >= NUM_IMAGES:
            break

handle.remove()

# Collect to nd.array
all_features_np = np.concatenate(features, axis=0)[:NUM_IMAGES]
all_labels_np = np.concatenate(all_labels, axis=0)[:NUM_IMAGES]

In [None]:
tsne = TSNE(n_components=2, verbose=1, perplexity=40, max_iter=300)
tsne_results = tsne.fit_transform(all_features_np)

plot_colors = []
for labels in all_labels_np:
    active_class_indices = np.where(labels == 1)[0]
    
    if len(active_class_indices) == 0:
        plot_colors.append(np.array([0.5, 0.5, 0.5]))
    else:
        active_colors = [CLASS_COLOR_MAP[i] for i in active_class_indices]
        mean_color = np.mean(active_colors, axis=0)
        plot_colors.append(mean_color)

In [None]:
plt.figure(figsize=(14, 10))
plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=plot_colors, alpha=0.7, s=15)

legend_patches = []
for i, class_name in enumerate(VOCDataset.CLASS_NAMES):
    patch = mpatches.Patch(color=CLASS_COLOR_MAP[i], label=class_name)
    legend_patches.append(patch)

plt.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

plt.title('t-SNE Visualization of PASCAL VOC Features')
plt.xlabel('t-SNE Component 1')
plt.ylabel('t-SNE Component 2')
plt.grid(True)
plt.tight_layout()
plt.show()