In [7]:
import os
from pathlib import Path

import torch.nn as nn
import torch.optim as optim
import torch
from torch.utils.data import TensorDataset, DataLoader, SubsetRandomSampler
import numpy as np

In [5]:
# NOTE: Change this to wherever the project is located in your google drive
project_root = Path("/home/ubuntu/semrep")
embeddings_path = project_root / "embeddings" / "encoders"

In [8]:
dataset_paths = list(embeddings_path.glob("*.npz"))

dataset_paths

[PosixPath('/home/ubuntu/semrep/embeddings/encoders/ResNet18.npz'),
 PosixPath('/home/ubuntu/semrep/embeddings/encoders/dinov2.npz'),
 PosixPath('/home/ubuntu/semrep/embeddings/encoders/ResNet50.npz'),
 PosixPath('/home/ubuntu/semrep/embeddings/encoders/VIP.npz'),
 PosixPath('/home/ubuntu/semrep/embeddings/encoders/ResNet34.npz'),
 PosixPath('/home/ubuntu/semrep/embeddings/encoders/MVP.npz')]

In [11]:
data = np.load(dataset_paths[0])

data.keys()

KeysView(NpzFile '/home/ubuntu/semrep/embeddings/encoders/ResNet18.npz' with keys: embeddings, labels, train_flag, dataset_flag)

In [None]:
# Assuming data is your loaded dataset
embeddings = torch.from_numpy(data["embeddings"])
domain_labels = torch.from_numpy(data["dataset_flag"])
train_flag = data["train_flag"]  # Binary array indicating train (1) or val (0)

# Create the full dataset
full_dataset = TensorDataset(embeddings, domain_labels)

# Create indices for train and validation splits
train_indices = np.where(train_flag == 1)[0]
val_indices = np.where(train_flag == 0)[0]

# Create samplers for train and validation splits
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

# Create DataLoaders
batch_size = 200
train_loader = DataLoader(full_dataset, batch_size=batch_size, sampler=train_sampler)
val_loader = DataLoader(full_dataset, batch_size=batch_size, sampler=val_sampler)

# Print some information about the splits
print(f"Total samples: {len(full_dataset)}")
print(f"Training samples: {len(train_indices)}")
print(f"Validation samples: {len(val_indices)}")

# If you need separate datasets for any reason, you can create them like this:
train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
val_dataset = torch.utils.data.Subset(full_dataset, val_indices)


class LinearProbe(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearProbe, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set input dimension from embeddings
input_dim = embeddings.shape[1]

domain_probe = LinearProbe(input_dim, 1).to(device)
domain_criterion = nn.BCEWithLogitsLoss()  # Binary classification loss
domain_optimizer = optim.Adam(domain_probe.parameters(), lr=0.001)


def train_probe(probe, criterion, optimizer, train_loader, val_loader, epochs):
    for epoch in range(epochs):
        # Training Phase
        probe.train()
        train_loss = 0.0
        for features, labels in train_loader:
            features, labels = features.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = probe(features)
            loss = criterion(outputs.squeeze(), labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        probe.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for features, labels in val_loader:
                features, labels = features.to(device), labels.to(device)
                outputs = probe(features)
                loss = criterion(outputs.squeeze(), labels)
                val_loss += loss.item()

                predicted = torch.round(torch.sigmoid(outputs)).squeeze()  # Binary prediction
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        val_accuracy = correct / total
        print(
            f"Epoch [{epoch+1}/{epochs}] | "
            f"Train Loss: {train_loss/len(train_loader):.4f} | "
            f"Val Loss: {val_loss/len(val_loader):.4f} | "
            f"Val Accuracy: {val_accuracy * 100:.2f}% "
        )


train_probe(domain_probe, domain_criterion, domain_optimizer, train_loader, val_loader, epochs=30)
# split train and val on the trajectory level