In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm

In [2]:
class CustomImageDataset(Dataset):
    def __init__(self, data, labels=None, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx] if self.labels is not None else -1
        if self.transform:
            image = self.transform(image)
        return image, label


transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load ResNet-152 with weights, suppress warnings
device = 'cuda' if torch.cuda.is_available() else 'cpu'
feature_extractor = models.resnet152(weights=models.ResNet152_Weights.DEFAULT)
feature_extractor = nn.Sequential(*list(feature_extractor.children())[:-1])  # Remove final layer
feature_extractor.eval().to(device)  # Move model to device

Downloading: "https://download.pytorch.org/models/resnet152-f82ba261.pth" to /root/.cache/torch/hub/checkpoints/resnet152-f82ba261.pth
100%|██████████| 230M/230M [00:01<00:00, 170MB/s]


Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [3]:
def extract_features(model, dataloader, device='cpu'):
    """Extracts features for images using the feature extractor."""
    features, labels = [], []
    model.to(device)
    with torch.no_grad():
        for images, lbls in dataloader:
            images = images.to(device)
            feats = model(images).squeeze()  # Extract features
            features.append(feats.cpu())  # Store features on CPU to save GPU memory
            labels.append(lbls)
    return torch.cat(features, dim=0), torch.cat(labels, dim=0)

In [4]:
def compute_prototypes(features, labels, num_classes=10):
    """Computes class prototypes as the mean of feature vectors for each class."""
    prototypes = []
    for cls in range(num_classes):
        class_features = features[labels == cls]
        prototype = class_features.mean(dim=0)
        prototypes.append(prototype)
    return torch.stack(prototypes)

In [5]:
def classify_with_prototypes(prototypes, query_features):
    """Classifies query features by finding the nearest prototype."""
    distances = torch.cdist(query_features, prototypes)  # Pairwise distances
    return distances.argmin(dim=1)  # Nearest prototype index

In [7]:
import zipfile

# Path to the zip file
zip_file_path = '/content/eval_data.zip'

# Destination folder to extract the files
extract_to = '/content/eval_data'

# Unzipping the file
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

print(f"Extracted all files to: {extract_to}")

Extracted all files to: /content/eval_data


In [8]:
def train_sequential_models():
    """Trains sequential models from D1 to D10 and evaluates on held-out datasets."""
    model_prototypes = []
    accuracies = []

    # Process datasets D1 to D10
    prototypes = None
    for i in tqdm(range(1, 11), desc="Processing Training Datasets"):
        data_path = f'train_data/train_data/{i}_train_data.tar.pth'
        dataset = torch.load(data_path)
        images, labels = dataset['data'], dataset.get('targets', None)

        # Feature extraction
        dataset = CustomImageDataset(images, labels, transform=transform)
        loader = DataLoader(dataset, batch_size=32, shuffle=False)
        features, lbls = extract_features(feature_extractor, loader, device)

        # Update prototypes
        if labels is not None:  # D1 is labeled
            prototypes = compute_prototypes(features, lbls)
        else:  # Unlabeled datasets
            predicted_labels = classify_with_prototypes(prototypes, features)
            prototypes = compute_prototypes(features, predicted_labels)

        model_prototypes.append(prototypes)

        # Evaluate on all held-out datasets up to current model
        model_accuracies = []
        for eval_idx in tqdm(range(1, i + 1), desc=f"Evaluating Model f{i}"):
            eval_data = torch.load(f'eval_data/eval_data/{eval_idx}_eval_data.tar.pth')
            eval_images, eval_labels = eval_data['data'], eval_data['targets']
            eval_dataset = CustomImageDataset(eval_images, eval_labels, transform=transform)
            eval_loader = DataLoader(eval_dataset, batch_size=32, shuffle=False)

            eval_features, eval_lbls = extract_features(feature_extractor, eval_loader, device)
            predicted_labels = classify_with_prototypes(prototypes, eval_features)

            accuracy = (predicted_labels == eval_lbls).float().mean().item()
            model_accuracies.append(accuracy)
            print(f"Model f{i} accuracy on D̂{eval_idx}: {accuracy * 100:.2f}%")

        accuracies.append(model_accuracies)

    # Save the prototypes for f10
    prototype_save_path = "prototypes_f10.pth"
    torch.save(prototypes, prototype_save_path)
    print(f"Prototypes for f10 saved as {prototype_save_path}")

    return accuracies

In [9]:
# Main execution
if __name__ == "__main__":
    accuracy_matrix = train_sequential_models()
    print("\nAccuracy Matrix:")
    for row in accuracy_matrix:
        print(row)

  dataset = torch.load(data_path)

  eval_data = torch.load(f'eval_data/eval_data/{eval_idx}_eval_data.tar.pth')

Evaluating Model f1: 100%|██████████| 1/1 [00:20<00:00, 20.66s/it]
Processing Training Datasets:  10%|█         | 1/10 [00:42<06:18, 42.02s/it]

Model f1 accuracy on D̂1: 89.88%



Evaluating Model f2:   0%|          | 0/2 [00:00<?, ?it/s][A
Evaluating Model f2:  50%|█████     | 1/2 [00:20<00:20, 20.33s/it][A

Model f2 accuracy on D̂1: 88.28%



Evaluating Model f2: 100%|██████████| 2/2 [00:40<00:00, 20.31s/it]
Processing Training Datasets:  20%|██        | 2/10 [01:42<07:05, 53.15s/it]

Model f2 accuracy on D̂2: 89.72%



Evaluating Model f3:   0%|          | 0/3 [00:00<?, ?it/s][A
Evaluating Model f3:  33%|███▎      | 1/3 [00:20<00:40, 20.40s/it][A

Model f3 accuracy on D̂1: 87.72%



Evaluating Model f3:  67%|██████▋   | 2/3 [00:40<00:20, 20.35s/it][A

Model f3 accuracy on D̂2: 89.40%



Evaluating Model f3: 100%|██████████| 3/3 [01:01<00:00, 20.35s/it]
Processing Training Datasets:  30%|███       | 3/10 [03:04<07:42, 66.08s/it]

Model f3 accuracy on D̂3: 88.28%



Evaluating Model f4:   0%|          | 0/4 [00:00<?, ?it/s][A
Evaluating Model f4:  25%|██▌       | 1/4 [00:20<01:00, 20.32s/it][A

Model f4 accuracy on D̂1: 87.32%



Evaluating Model f4:  50%|█████     | 2/4 [00:40<00:40, 20.31s/it][A

Model f4 accuracy on D̂2: 89.28%



Evaluating Model f4:  75%|███████▌  | 3/4 [01:01<00:20, 20.35s/it][A

Model f4 accuracy on D̂3: 87.96%



Evaluating Model f4: 100%|██████████| 4/4 [01:21<00:00, 20.35s/it]
Processing Training Datasets:  40%|████      | 4/10 [04:46<08:01, 80.21s/it]

Model f4 accuracy on D̂4: 88.24%



Evaluating Model f5:   0%|          | 0/5 [00:00<?, ?it/s][A
Evaluating Model f5:  20%|██        | 1/5 [00:20<01:21, 20.43s/it][A

Model f5 accuracy on D̂1: 87.36%



Evaluating Model f5:  40%|████      | 2/5 [00:40<01:01, 20.43s/it][A

Model f5 accuracy on D̂2: 89.28%



Evaluating Model f5:  60%|██████    | 3/5 [01:01<00:40, 20.34s/it][A

Model f5 accuracy on D̂3: 87.76%



Evaluating Model f5:  80%|████████  | 4/5 [01:21<00:20, 20.37s/it][A

Model f5 accuracy on D̂4: 87.96%



Evaluating Model f5: 100%|██████████| 5/5 [01:41<00:00, 20.39s/it]
Processing Training Datasets:  50%|█████     | 5/10 [06:48<07:56, 95.37s/it]

Model f5 accuracy on D̂5: 88.72%



Evaluating Model f6:   0%|          | 0/6 [00:00<?, ?it/s][A
Evaluating Model f6:  17%|█▋        | 1/6 [00:20<01:42, 20.43s/it][A

Model f6 accuracy on D̂1: 87.20%



Evaluating Model f6:  33%|███▎      | 2/6 [00:40<01:21, 20.45s/it][A

Model f6 accuracy on D̂2: 89.12%



Evaluating Model f6:  50%|█████     | 3/6 [01:01<01:01, 20.37s/it][A

Model f6 accuracy on D̂3: 87.56%



Evaluating Model f6:  67%|██████▋   | 4/6 [01:21<00:40, 20.40s/it][A

Model f6 accuracy on D̂4: 87.76%



Evaluating Model f6:  83%|████████▎ | 5/6 [01:42<00:20, 20.40s/it][A

Model f6 accuracy on D̂5: 88.32%



Evaluating Model f6: 100%|██████████| 6/6 [02:02<00:00, 20.38s/it]
Processing Training Datasets:  60%|██████    | 6/10 [09:11<07:25, 111.39s/it]

Model f6 accuracy on D̂6: 88.12%



Evaluating Model f7:   0%|          | 0/7 [00:00<?, ?it/s][A
Evaluating Model f7:  14%|█▍        | 1/7 [00:20<02:02, 20.44s/it][A

Model f7 accuracy on D̂1: 87.44%



Evaluating Model f7:  29%|██▊       | 2/7 [00:40<01:41, 20.32s/it][A

Model f7 accuracy on D̂2: 89.00%



Evaluating Model f7:  43%|████▎     | 3/7 [01:01<01:21, 20.36s/it][A

Model f7 accuracy on D̂3: 87.68%



Evaluating Model f7:  57%|█████▋    | 4/7 [01:21<01:01, 20.39s/it][A

Model f7 accuracy on D̂4: 87.56%



Evaluating Model f7:  71%|███████▏  | 5/7 [01:41<00:40, 20.36s/it][A

Model f7 accuracy on D̂5: 88.28%



Evaluating Model f7:  86%|████████▌ | 6/7 [02:02<00:20, 20.36s/it][A

Model f7 accuracy on D̂6: 88.12%



Evaluating Model f7: 100%|██████████| 7/7 [02:22<00:00, 20.38s/it]
Processing Training Datasets:  70%|███████   | 7/10 [11:54<06:24, 128.28s/it]

Model f7 accuracy on D̂7: 87.84%



Evaluating Model f8:   0%|          | 0/8 [00:00<?, ?it/s][A
Evaluating Model f8:  12%|█▎        | 1/8 [00:20<02:22, 20.42s/it][A

Model f8 accuracy on D̂1: 86.48%



Evaluating Model f8:  25%|██▌       | 2/8 [00:40<02:02, 20.46s/it][A

Model f8 accuracy on D̂2: 88.24%



Evaluating Model f8:  38%|███▊      | 3/8 [01:01<01:42, 20.43s/it][A

Model f8 accuracy on D̂3: 87.16%



Evaluating Model f8:  50%|█████     | 4/8 [01:21<01:21, 20.36s/it][A

Model f8 accuracy on D̂4: 87.20%



Evaluating Model f8:  62%|██████▎   | 5/8 [01:41<01:01, 20.39s/it][A

Model f8 accuracy on D̂5: 87.44%



Evaluating Model f8:  75%|███████▌  | 6/8 [02:02<00:40, 20.42s/it][A

Model f8 accuracy on D̂6: 87.72%



Evaluating Model f8:  88%|████████▊ | 7/8 [02:22<00:20, 20.42s/it][A

Model f8 accuracy on D̂7: 86.68%



Evaluating Model f8: 100%|██████████| 8/8 [02:43<00:00, 20.43s/it]
Processing Training Datasets:  80%|████████  | 8/10 [14:57<04:51, 145.97s/it]

Model f8 accuracy on D̂8: 87.36%



Evaluating Model f9:   0%|          | 0/9 [00:00<?, ?it/s][A
Evaluating Model f9:  11%|█         | 1/9 [00:20<02:42, 20.36s/it][A

Model f9 accuracy on D̂1: 86.00%



Evaluating Model f9:  22%|██▏       | 2/9 [00:40<02:22, 20.42s/it][A

Model f9 accuracy on D̂2: 88.08%



Evaluating Model f9:  33%|███▎      | 3/9 [01:01<02:02, 20.44s/it][A

Model f9 accuracy on D̂3: 86.72%



Evaluating Model f9:  44%|████▍     | 4/9 [01:21<01:41, 20.38s/it][A

Model f9 accuracy on D̂4: 87.04%



Evaluating Model f9:  56%|█████▌    | 5/9 [01:42<01:21, 20.43s/it][A

Model f9 accuracy on D̂5: 87.36%



Evaluating Model f9:  67%|██████▋   | 6/9 [02:02<01:01, 20.43s/it][A

Model f9 accuracy on D̂6: 87.24%



Evaluating Model f9:  78%|███████▊  | 7/9 [02:22<00:40, 20.38s/it][A

Model f9 accuracy on D̂7: 86.68%



Evaluating Model f9:  89%|████████▉ | 8/9 [02:43<00:20, 20.40s/it][A

Model f9 accuracy on D̂8: 87.40%



Evaluating Model f9: 100%|██████████| 9/9 [03:03<00:00, 20.41s/it]
Processing Training Datasets:  90%|█████████ | 9/10 [18:22<02:44, 164.15s/it]

Model f9 accuracy on D̂9: 86.76%



Evaluating Model f10:   0%|          | 0/10 [00:00<?, ?it/s][A
Evaluating Model f10:  10%|█         | 1/10 [00:20<03:03, 20.39s/it][A

Model f10 accuracy on D̂1: 86.32%



Evaluating Model f10:  20%|██        | 2/10 [00:40<02:43, 20.41s/it][A

Model f10 accuracy on D̂2: 88.52%



Evaluating Model f10:  30%|███       | 3/10 [01:01<02:22, 20.36s/it][A

Model f10 accuracy on D̂3: 86.92%



Evaluating Model f10:  40%|████      | 4/10 [01:21<02:02, 20.40s/it][A

Model f10 accuracy on D̂4: 87.00%



Evaluating Model f10:  50%|█████     | 5/10 [01:42<01:42, 20.43s/it][A

Model f10 accuracy on D̂5: 87.68%



Evaluating Model f10:  60%|██████    | 6/10 [02:02<01:21, 20.39s/it][A

Model f10 accuracy on D̂6: 87.64%



Evaluating Model f10:  70%|███████   | 7/10 [02:22<01:01, 20.41s/it][A

Model f10 accuracy on D̂7: 86.84%



Evaluating Model f10:  80%|████████  | 8/10 [02:43<00:40, 20.43s/it][A

Model f10 accuracy on D̂8: 87.56%



Evaluating Model f10:  90%|█████████ | 9/10 [03:03<00:20, 20.43s/it][A

Model f10 accuracy on D̂9: 86.72%



Evaluating Model f10: 100%|██████████| 10/10 [03:24<00:00, 20.41s/it]
Processing Training Datasets: 100%|██████████| 10/10 [22:06<00:00, 132.65s/it]

Model f10 accuracy on D̂10: 86.84%
Prototypes for f10 saved as prototypes_f10.pth

Accuracy Matrix:
[0.8988000154495239]
[0.8827999830245972, 0.8971999883651733]
[0.8772000074386597, 0.8939999938011169, 0.8827999830245972]
[0.873199999332428, 0.892799973487854, 0.8795999884605408, 0.8823999762535095]
[0.8736000061035156, 0.892799973487854, 0.8776000142097473, 0.8795999884605408, 0.8871999979019165]
[0.871999979019165, 0.8912000060081482, 0.8755999803543091, 0.8776000142097473, 0.8831999897956848, 0.8812000155448914]
[0.8744000196456909, 0.8899999856948853, 0.876800000667572, 0.8755999803543091, 0.8827999830245972, 0.8812000155448914, 0.8784000277519226]
[0.864799976348877, 0.8823999762535095, 0.8715999722480774, 0.871999979019165, 0.8744000196456909, 0.8772000074386597, 0.8668000102043152, 0.8736000061035156]
[0.8600000143051147, 0.8808000087738037, 0.8672000169754028, 0.8704000115394592, 0.8736000061035156, 0.8723999857902527, 0.8668000102043152, 0.8740000128746033, 0.8676000237464905


