In [25]:
import csv
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader, get_eval_loader
import numpy as np

In [54]:

# Step 1: Read the CSV file and extract image filenames
device = torch.device('cuda') if not torch.cuda.is_available() else torch.device('cpu')
image_paths = []
labels = []
with open('train.csv', 'r') as csvfile:
    reader = csv.reader(csvfile)
    i = 0
    for row in reader:
        if i == 0:
            i += 1
            continue
        label = int(row[1][0:3])
        image_paths.append(("./data/waterbirds_v1.0/" + row[1], label))  # Assuming the image filenames are in the first column
        if i > 100:
            break
        i+=1

In [60]:
image_paths

[('./data/waterbirds_v1.0/001.Black_footed_Albatross/Black_Footed_Albatross_0009_34.jpg',
  1),
 ('./data/waterbirds_v1.0/001.Black_footed_Albatross/Black_Footed_Albatross_0074_59.jpg',
  1),
 ('./data/waterbirds_v1.0/001.Black_footed_Albatross/Black_Footed_Albatross_0014_89.jpg',
  1),
 ('./data/waterbirds_v1.0/001.Black_footed_Albatross/Black_Footed_Albatross_0031_100.jpg',
  1),
 ('./data/waterbirds_v1.0/001.Black_footed_Albatross/Black_Footed_Albatross_0010_796097.jpg',
  1),
 ('./data/waterbirds_v1.0/001.Black_footed_Albatross/Black_Footed_Albatross_0023_796059.jpg',
  1),
 ('./data/waterbirds_v1.0/001.Black_footed_Albatross/Black_Footed_Albatross_0040_796066.jpg',
  1),
 ('./data/waterbirds_v1.0/001.Black_footed_Albatross/Black_Footed_Albatross_0089_796069.jpg',
  1),
 ('./data/waterbirds_v1.0/001.Black_footed_Albatross/Black_Footed_Albatross_0067_170.jpg',
  1),
 ('./data/waterbirds_v1.0/001.Black_footed_Albatross/Black_Footed_Albatross_0060_796076.jpg',
  1),
 ('./data/waterbir

In [47]:
class CustomDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data  # Update the attribute name to 'data'
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path, label = self.data[idx]
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label, image_path  

In [55]:
# Step 3: Load the model and define transformations
num_classes = 2
model = getattr(models, 'resnet50')(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)




In [56]:

# Define the transformation to resize and normalize the input images
transform = transforms.Compose([
    transforms.Resize((448, 448)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])



In [57]:
# Step 4: Create the custom dataset and data loader
custom_dataset = CustomDataset(image_paths, transform=transform)
data_loader = DataLoader(custom_dataset, batch_size=16, shuffle=False)



In [58]:
# Step 5: Use forward hooks to get feature vectors
feature_vectors = []
labels = []

# a dict to store the activations
activation = {}
def getActivation(name):
    # the hook signature
    def hook(module, input, output):
        activation[name] = output.detach()
    return hook

# register forward hooks on the layers of choice
hook_handle = model.avgpool.register_forward_hook(getActivation('avgpool'))

model.eval()
with torch.no_grad():
    for batch_images, batch_labels, batch_image_paths in data_loader:  # Updated loop variable
        batch_images = batch_images.to(device)
        batch_labels = batch_labels.to(device)
        outputs = model(batch_images)
        labels.extend(batch_labels.cpu().numpy())

        # Append the extracted activation (feature vector) to the list
        feature_vectors.append(activation['avgpool'].view(activation['avgpool'].size(0), -1).cpu().numpy())
# Remove the forward hook after obtaining feature vectors
hook_handle.remove()

# Convert the list of feature vectors to a numpy array
feature_vectors = np.concatenate(feature_vectors, axis=0)



In [62]:
from matplotlib.backend_bases import PickEvent

def on_pick(event):
    # Get the index of the picked point
    index = event.ind[0]

    # Get the corresponding image path
    image_path = image_paths[index]

    # Now you have the image path for the selected point, and you can load and display the image if needed
    image = Image.open(image_path)
    image.show() 

In [66]:

# Enable the notebook backend for interactive plots in Jupyter Lab
%matplotlib widget

# Apply t-SNE to the feature vectors to obtain 2D embeddings
tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, random_state=42)
embeddings_2d = tsne.fit_transform(feature_vectors)

# Create a scatter plot of the 2D embeddings
# Assuming you have 5 classes (0 to 4)
num_classes = 5

# Use the 'tab10' colormap for 5 classes
colors = plt.cm.tab10.colors

plt.figure(figsize=(10, 8))
scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=labels, cmap='tab10', s=50, picker=True)
plt.colorbar(scatter, ticks=range(num_classes))
plt.title('t-SNE Clustering of Waterbird Images')
plt.xlabel('t-SNE Dimension 1')
plt.ylabel('t-SNE Dimension 2')

# Connect the pick event to the on_pick function
plt.gcf().canvas.mpl_connect('pick_event', on_pick)

plt.show()

ModuleNotFoundError: No module named 'ipympl'

In [65]:
embeddings_2d

array([[ 5.8830816e-01, -1.0662376e+00],
       [ 2.2068832e+00, -6.1704642e-01],
       [-3.9335397e-01,  2.4787757e+00],
       [ 3.4450307e+00, -1.5581744e+00],
       [-1.7193385e+00, -1.3585700e+00],
       [ 7.2822821e-01,  4.0388465e+00],
       [-7.0312536e-01,  3.6477158e+00],
       [ 4.1002402e+00,  7.9619505e-02],
       [-5.8757186e-01,  3.6768107e+00],
       [-1.9244497e+00,  1.2159057e+00],
       [-6.8993074e-01, -1.4854709e+00],
       [ 2.6903221e-01,  3.8569574e+00],
       [ 1.7391416e+00,  2.7635844e+00],
       [-5.3442464e+00,  2.1529372e+00],
       [ 1.3615773e+00,  3.3456284e-01],
       [-2.6908371e-01, -2.9079220e+00],
       [ 4.8582900e-01, -5.1400822e-01],
       [ 1.7898320e+00, -8.4523863e-01],
       [-3.4683785e+00,  4.5592375e+00],
       [ 1.0548605e+00,  4.2639432e+00],
       [ 3.3009071e+00,  2.1652486e+00],
       [-3.1946776e+00,  3.0727432e+00],
       [ 1.1345544e+00, -1.5015808e-01],
       [-2.3635225e+00,  1.3850421e+00],
       [ 2.59405