In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import timm
import numpy as np
from torchvision import transforms
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

Using device: cuda


In [2]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.1):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class CPEA(nn.Module):
    def __init__(self, in_dim=384):
        super().__init__()
        self.fc1 = Mlp(in_features=in_dim, hidden_features=int(in_dim/4), out_features=in_dim)
        self.fc_norm1 = nn.LayerNorm(in_dim)
        self.fc2 = Mlp(in_features=196**2, hidden_features=256, out_features=1)
        self.class_agnostic_embedding = nn.Parameter(torch.randn(1, 1, in_dim))

    def forward(self, feat_query, feat_shot, shot):
        B, N, C = feat_query.size()

        # Generate class-aware embeddings
        feat_query = self.fc1(torch.mean(feat_query, dim=1, keepdim=True)) + feat_query
        feat_shot = self.fc1(torch.mean(feat_shot, dim=1, keepdim=True)) + feat_shot

        feat_query = self.fc_norm1(feat_query)
        feat_shot = self.fc_norm1(feat_shot)

        # Split class token and patch embeddings
        query_class = feat_query[:, 0, :].unsqueeze(1)
        query_patches = feat_query[:, 1:, :]

        support_class = feat_shot[:, 0, :].unsqueeze(1)
        support_patches = feat_shot[:, 1:, :]

        # Class-aware patch embedding adaptation
        feat_query = query_patches + 2.0 * query_class
        feat_shot = support_patches + 2.0 * support_class

        # Normalization
        feat_query = F.normalize(feat_query, p=2, dim=2)
        feat_query = feat_query - torch.mean(feat_query, dim=2, keepdim=True)

        # Reshape support features
        feat_shot = feat_shot.contiguous().reshape(shot, -1, N-1, C)
        feat_shot = feat_shot.mean(dim=0)
        feat_shot = F.normalize(feat_shot, p=2, dim=2)
        feat_shot = feat_shot - torch.mean(feat_shot, dim=2, keepdim=True)

        # Compute similarity scores
        results = []
        for idx in range(feat_query.size(0)):
            tmp_query = feat_query[idx].unsqueeze(0)
            out = torch.matmul(feat_shot, tmp_query.transpose(1, 2))
            out = out.flatten(1)
            out = self.fc2(out.pow(2))
            out = out.transpose(0, 1)
            results.append(out)

        return torch.cat(results, dim=0)

class CPEAVisionTransformer(nn.Module):
    def __init__(self, pretrained_model='vit_small_patch16_224'):
        super().__init__()
        self.vit = timm.create_model(pretrained_model, pretrained=True, num_classes=0)
        self.embed_dim = self.vit.embed_dim
        self.cpea = CPEA(in_dim=self.embed_dim)

    def forward_features(self, x):
        x = self.vit.patch_embed(x)
        cls_token = self.vit.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = self.vit.pos_drop(x + self.vit.pos_embed)

        for blk in self.vit.blocks:
            x = blk(x)
        x = self.vit.norm(x)
        return x

    def forward(self, support_images, query_images, n_way, n_support, n_query):
        support_features = self.forward_features(support_images.view(-1, *support_images.shape[2:]))
        query_features = self.forward_features(query_images.view(-1, *query_images.shape[2:]))

        logits = self.cpea(query_features, support_features, n_support)
        return logits

In [3]:
def create_cifar_fs():
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    cifar100 = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)

    # Split classes for CIFAR-FS
    base_classes = np.random.choice(100, 64, replace=False)
    cifar_fs_indices = np.where(np.isin(cifar100.targets, base_classes))[0]
    cifar_fs = torch.utils.data.Subset(cifar100, cifar_fs_indices)
    cifar_fs_targets = np.array(cifar100.targets)[cifar_fs_indices]

    return cifar_fs, cifar_fs_targets

def create_fewshot_task(dataset, targets, n_way, n_support, n_query):
    classes = np.random.choice(np.unique(targets), n_way, replace=False)
    support_set = []
    query_set = []

    for cls in classes:
        cls_indices = np.where(targets == cls)[0]
        support_indices = np.random.choice(cls_indices, n_support, replace=False)
        query_indices = np.random.choice(np.setdiff1d(cls_indices, support_indices), n_query, replace=False)

        support_set.append(torch.stack([dataset[i][0] for i in support_indices]))
        query_set.append(torch.stack([dataset[i][0] for i in query_indices]))

    return torch.stack(support_set), torch.stack(query_set)

In [4]:
def train_episode(model, optimizer, dataset, targets, n_way, n_support, n_query, device):
    model.train()
    optimizer.zero_grad()

    support_images, query_images = create_fewshot_task(dataset, targets, n_way, n_support, n_query)
    support_images, query_images = support_images.to(device), query_images.to(device)

    logits = model(support_images, query_images, n_way, n_support, n_query)
    loss = F.cross_entropy(logits, torch.arange(n_way).repeat_interleave(n_query).to(device))

    loss.backward()
    optimizer.step()

    return loss.item()

def train(model, dataset, targets, n_episodes, n_way, n_support, n_query, device):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)

    for episode in tqdm(range(n_episodes)):
        loss = train_episode(model, optimizer, dataset, targets, n_way, n_support, n_query, device)
        scheduler.step()

        if (episode + 1) % 100 == 0:
            print(f"Episode {episode + 1}, Loss: {loss:.4f}")

    return model

def evaluate(model, dataset, targets, n_episodes, n_way, n_support, n_query, device):
    model.eval()
    accuracies = []

    for _ in tqdm(range(n_episodes)):
        support_images, query_images = create_fewshot_task(dataset, targets, n_way, n_support, n_query)
        support_images, query_images = support_images.to(device), query_images.to(device)

        with torch.no_grad():
            logits = model(support_images, query_images, n_way, n_support, n_query)
            _, predicted = torch.max(logits, 1)
            correct = (predicted == torch.arange(n_way).repeat_interleave(n_query).to(device)).sum().item()
            accuracies.append(correct / (n_way * n_query))

    return np.mean(accuracies)

In [5]:
# Create dataset
cifar_fs, cifar_fs_targets = create_cifar_fs()

# Initialize model
model = CPEAVisionTransformer().to(device)

# Training parameters (as per paper)
n_episodes = 10000  # Paper uses 10000 episodes
n_way = 5
n_support = 5
n_query = 15

# Train the model
model = train(model, cifar_fs, cifar_fs_targets, n_episodes, n_way, n_support, n_query, device)

# Save the model
torch.save({
    'model_state_dict': model.state_dict(),
    'n_way': n_way,
    'n_support': n_support,
    'n_query': n_query
}, 'cpea_model.pth')

# Evaluate
eval_episodes = 600  # Paper evaluates on 600 episodes
accuracy = evaluate(model, cifar_fs, cifar_fs_targets, eval_episodes, n_way, n_support, n_query, device)
print(f"5-way 5-shot accuracy: {accuracy:.4f}")

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169M/169M [00:02<00:00, 78.5MB/s]


Extracting ./data/cifar-100-python.tar.gz to ./data


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/88.2M [00:00<?, ?B/s]

  1%|          | 100/10000 [02:25<3:57:56,  1.44s/it]

Episode 100, Loss: 1.6058


  2%|▏         | 200/10000 [04:49<3:53:57,  1.43s/it]

Episode 200, Loss: 1.6096


  3%|▎         | 300/10000 [07:12<3:50:31,  1.43s/it]

Episode 300, Loss: 1.6096


  4%|▍         | 400/10000 [09:36<3:49:22,  1.43s/it]

Episode 400, Loss: 1.6090


  5%|▌         | 500/10000 [11:59<3:45:55,  1.43s/it]

Episode 500, Loss: 1.6093


  6%|▌         | 600/10000 [14:22<3:44:23,  1.43s/it]

Episode 600, Loss: 1.6093


  7%|▋         | 700/10000 [16:45<3:43:02,  1.44s/it]

Episode 700, Loss: 1.6096


  8%|▊         | 800/10000 [19:07<3:37:21,  1.42s/it]

Episode 800, Loss: 1.6070


  9%|▉         | 900/10000 [21:30<3:35:29,  1.42s/it]

Episode 900, Loss: 1.6100


 10%|█         | 1000/10000 [23:53<3:33:43,  1.42s/it]

Episode 1000, Loss: 1.6093


 11%|█         | 1100/10000 [26:15<3:32:49,  1.43s/it]

Episode 1100, Loss: 1.6092


 12%|█▏        | 1200/10000 [28:38<3:29:46,  1.43s/it]

Episode 1200, Loss: 1.6095


 13%|█▎        | 1300/10000 [31:02<3:27:09,  1.43s/it]

Episode 1300, Loss: 1.6094


 14%|█▍        | 1400/10000 [33:25<3:25:24,  1.43s/it]

Episode 1400, Loss: 1.6093


 15%|█▌        | 1500/10000 [35:47<3:21:54,  1.43s/it]

Episode 1500, Loss: 1.6093


 16%|█▌        | 1600/10000 [38:10<3:20:33,  1.43s/it]

Episode 1600, Loss: 1.6096


 17%|█▋        | 1700/10000 [40:33<3:18:17,  1.43s/it]

Episode 1700, Loss: 1.6094


 18%|█▊        | 1800/10000 [42:56<3:15:49,  1.43s/it]

Episode 1800, Loss: 1.6094


 19%|█▉        | 1900/10000 [45:18<3:11:44,  1.42s/it]

Episode 1900, Loss: 1.6095


 20%|██        | 2000/10000 [47:40<3:11:13,  1.43s/it]

Episode 2000, Loss: 1.6096


 21%|██        | 2100/10000 [50:02<3:06:15,  1.41s/it]

Episode 2100, Loss: 1.6094


 22%|██▏       | 2200/10000 [52:24<3:03:01,  1.41s/it]

Episode 2200, Loss: 1.6094


 23%|██▎       | 2300/10000 [54:45<3:02:35,  1.42s/it]

Episode 2300, Loss: 1.6094


 24%|██▍       | 2400/10000 [57:07<2:59:45,  1.42s/it]

Episode 2400, Loss: 1.6095


 25%|██▌       | 2500/10000 [59:29<2:59:07,  1.43s/it]

Episode 2500, Loss: 1.6094


 26%|██▌       | 2600/10000 [1:01:51<2:54:36,  1.42s/it]

Episode 2600, Loss: 1.6095


 27%|██▋       | 2700/10000 [1:04:13<2:52:19,  1.42s/it]

Episode 2700, Loss: 1.6094


 28%|██▊       | 2800/10000 [1:06:35<2:49:18,  1.41s/it]

Episode 2800, Loss: 1.6095


 29%|██▉       | 2900/10000 [1:08:56<2:46:52,  1.41s/it]

Episode 2900, Loss: 1.6096


 30%|███       | 3000/10000 [1:11:18<2:45:48,  1.42s/it]

Episode 3000, Loss: 1.6095


 31%|███       | 3100/10000 [1:13:39<2:42:27,  1.41s/it]

Episode 3100, Loss: 1.6095


 32%|███▏      | 3200/10000 [1:16:01<2:39:50,  1.41s/it]

Episode 3200, Loss: 1.6095


 33%|███▎      | 3300/10000 [1:18:23<2:39:09,  1.43s/it]

Episode 3300, Loss: 1.6094


 34%|███▍      | 3400/10000 [1:20:44<2:35:53,  1.42s/it]

Episode 3400, Loss: 1.6094


 35%|███▌      | 3500/10000 [1:23:06<2:35:50,  1.44s/it]

Episode 3500, Loss: 1.6095


 36%|███▌      | 3600/10000 [1:25:28<2:30:40,  1.41s/it]

Episode 3600, Loss: 1.6094


 37%|███▋      | 3700/10000 [1:27:50<2:28:10,  1.41s/it]

Episode 3700, Loss: 1.6095


 38%|███▊      | 3800/10000 [1:30:12<2:25:35,  1.41s/it]

Episode 3800, Loss: 1.6094


 39%|███▉      | 3900/10000 [1:32:33<2:23:34,  1.41s/it]

Episode 3900, Loss: 1.6094


 40%|████      | 4000/10000 [1:34:55<2:23:06,  1.43s/it]

Episode 4000, Loss: 1.6094


 41%|████      | 4100/10000 [1:37:17<2:19:21,  1.42s/it]

Episode 4100, Loss: 1.6095


 42%|████▏     | 4200/10000 [1:39:38<2:16:52,  1.42s/it]

Episode 4200, Loss: 1.6095


 43%|████▎     | 4300/10000 [1:42:00<2:14:40,  1.42s/it]

Episode 4300, Loss: 1.6094


 44%|████▍     | 4400/10000 [1:44:22<2:12:36,  1.42s/it]

Episode 4400, Loss: 1.6094


 45%|████▌     | 4500/10000 [1:46:44<2:11:00,  1.43s/it]

Episode 4500, Loss: 1.6094


 46%|████▌     | 4600/10000 [1:49:05<2:07:43,  1.42s/it]

Episode 4600, Loss: 1.6095


 47%|████▋     | 4700/10000 [1:51:27<2:05:32,  1.42s/it]

Episode 4700, Loss: 1.6094


 48%|████▊     | 4800/10000 [1:53:49<2:02:37,  1.41s/it]

Episode 4800, Loss: 1.6094


 49%|████▉     | 4900/10000 [1:56:11<2:00:48,  1.42s/it]

Episode 4900, Loss: 1.6095


 50%|█████     | 5000/10000 [1:58:33<2:00:21,  1.44s/it]

Episode 5000, Loss: 1.6094


 51%|█████     | 5100/10000 [2:00:55<1:55:17,  1.41s/it]

Episode 5100, Loss: 1.6095


 52%|█████▏    | 5200/10000 [2:03:16<1:53:01,  1.41s/it]

Episode 5200, Loss: 1.6094


 53%|█████▎    | 5300/10000 [2:05:38<1:50:39,  1.41s/it]

Episode 5300, Loss: 1.6095


 54%|█████▍    | 5400/10000 [2:08:00<1:49:07,  1.42s/it]

Episode 5400, Loss: 1.6094


 55%|█████▌    | 5500/10000 [2:10:22<1:47:44,  1.44s/it]

Episode 5500, Loss: 1.6094


 56%|█████▌    | 5600/10000 [2:12:45<1:44:51,  1.43s/it]

Episode 5600, Loss: 1.6094


 57%|█████▋    | 5700/10000 [2:15:07<1:42:03,  1.42s/it]

Episode 5700, Loss: 1.6094


 58%|█████▊    | 5800/10000 [2:17:30<1:39:36,  1.42s/it]

Episode 5800, Loss: 1.6095


 59%|█████▉    | 5900/10000 [2:19:52<1:36:50,  1.42s/it]

Episode 5900, Loss: 1.6095


 60%|██████    | 6000/10000 [2:22:15<1:34:22,  1.42s/it]

Episode 6000, Loss: 1.6093


 61%|██████    | 6100/10000 [2:24:37<1:32:58,  1.43s/it]

Episode 6100, Loss: 1.6094


 62%|██████▏   | 6200/10000 [2:27:00<1:30:21,  1.43s/it]

Episode 6200, Loss: 1.6094


 63%|██████▎   | 6300/10000 [2:29:22<1:27:25,  1.42s/it]

Episode 6300, Loss: 1.6094


 64%|██████▍   | 6400/10000 [2:31:45<1:25:32,  1.43s/it]

Episode 6400, Loss: 1.6094


 65%|██████▌   | 6500/10000 [2:34:06<1:22:23,  1.41s/it]

Episode 6500, Loss: 1.6093


 66%|██████▌   | 6600/10000 [2:36:29<1:20:48,  1.43s/it]

Episode 6600, Loss: 1.6095


 67%|██████▋   | 6700/10000 [2:38:51<1:18:30,  1.43s/it]

Episode 6700, Loss: 1.6095


 68%|██████▊   | 6800/10000 [2:41:14<1:15:42,  1.42s/it]

Episode 6800, Loss: 1.6094


 69%|██████▉   | 6900/10000 [2:43:36<1:13:04,  1.41s/it]

Episode 6900, Loss: 1.6094


 70%|███████   | 7000/10000 [2:45:58<1:11:01,  1.42s/it]

Episode 7000, Loss: 1.6094


 71%|███████   | 7100/10000 [2:48:20<1:08:40,  1.42s/it]

Episode 7100, Loss: 1.6095


 72%|███████▏  | 7200/10000 [2:50:42<1:06:10,  1.42s/it]

Episode 7200, Loss: 1.6094


 73%|███████▎  | 7300/10000 [2:53:05<1:04:52,  1.44s/it]

Episode 7300, Loss: 1.6095


 74%|███████▍  | 7400/10000 [2:55:27<1:02:10,  1.43s/it]

Episode 7400, Loss: 1.6095


 75%|███████▌  | 7500/10000 [2:57:50<59:18,  1.42s/it]

Episode 7500, Loss: 1.6095


 76%|███████▌  | 7600/10000 [3:00:12<57:34,  1.44s/it]

Episode 7600, Loss: 1.6094


 77%|███████▋  | 7700/10000 [3:02:34<54:13,  1.41s/it]

Episode 7700, Loss: 1.6095


 78%|███████▊  | 7800/10000 [3:04:57<52:39,  1.44s/it]

Episode 7800, Loss: 1.6095


 79%|███████▉  | 7900/10000 [3:07:19<49:47,  1.42s/it]

Episode 7900, Loss: 1.6094


 80%|████████  | 8000/10000 [3:09:41<47:05,  1.41s/it]

Episode 8000, Loss: 1.6094


 81%|████████  | 8100/10000 [3:12:04<45:08,  1.43s/it]

Episode 8100, Loss: 1.6094


 82%|████████▏ | 8200/10000 [3:14:26<42:27,  1.42s/it]

Episode 8200, Loss: 1.6095


 83%|████████▎ | 8300/10000 [3:16:48<40:25,  1.43s/it]

Episode 8300, Loss: 1.6094


 84%|████████▍ | 8400/10000 [3:19:11<38:09,  1.43s/it]

Episode 8400, Loss: 1.6094


 85%|████████▌ | 8500/10000 [3:21:33<35:23,  1.42s/it]

Episode 8500, Loss: 1.6094


 86%|████████▌ | 8600/10000 [3:23:55<33:14,  1.42s/it]

Episode 8600, Loss: 1.6095


 87%|████████▋ | 8700/10000 [3:26:17<30:36,  1.41s/it]

Episode 8700, Loss: 1.6094


 88%|████████▊ | 8800/10000 [3:28:40<28:21,  1.42s/it]

Episode 8800, Loss: 1.6095


 89%|████████▉ | 8900/10000 [3:31:02<25:52,  1.41s/it]

Episode 8900, Loss: 1.6095


 90%|█████████ | 9000/10000 [3:33:24<23:41,  1.42s/it]

Episode 9000, Loss: 1.6093


 91%|█████████ | 9100/10000 [3:35:47<21:09,  1.41s/it]

Episode 9100, Loss: 1.6094


 92%|█████████▏| 9200/10000 [3:38:09<19:00,  1.43s/it]

Episode 9200, Loss: 1.6095


 93%|█████████▎| 9300/10000 [3:40:31<16:38,  1.43s/it]

Episode 9300, Loss: 1.6094


 94%|█████████▍| 9400/10000 [3:42:53<14:10,  1.42s/it]

Episode 9400, Loss: 1.6094


 95%|█████████▌| 9500/10000 [3:45:15<11:49,  1.42s/it]

Episode 9500, Loss: 1.6095


 96%|█████████▌| 9600/10000 [3:47:38<09:35,  1.44s/it]

Episode 9600, Loss: 1.6094


 97%|█████████▋| 9700/10000 [3:50:00<07:04,  1.41s/it]

Episode 9700, Loss: 1.6095


 98%|█████████▊| 9800/10000 [3:52:22<04:46,  1.43s/it]

Episode 9800, Loss: 1.6094


 99%|█████████▉| 9900/10000 [3:54:44<02:22,  1.42s/it]

Episode 9900, Loss: 1.6094


100%|██████████| 10000/10000 [3:57:07<00:00,  1.42s/it]

Episode 10000, Loss: 1.6094



100%|██████████| 600/600 [06:21<00:00,  1.57it/s]

5-way 5-shot accuracy: 0.1988





In [7]:
def load_model(path):
    checkpoint = torch.load(path)
    model = CPEAVisionTransformer().to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    return model, checkpoint['n_way'], checkpoint['n_support'], checkpoint['n_query']

# Load and evaluate
loaded_model, n_way, n_support, n_query = load_model('cpea_model.pth')
accuracy = evaluate(loaded_model, cifar_fs, cifar_fs_targets, eval_episodes, n_way, n_support, n_query, device)
print(f"Loaded model accuracy: {accuracy:.4f}")

  checkpoint = torch.load(path)
100%|██████████| 600/600 [06:24<00:00,  1.56it/s]

Loaded model accuracy: 0.1986



