In [6]:
import torch
from torchvision import transforms
from dataset import ShapesDataset
from prototypical_net import ConvNet
from learn2learn.data import TaskDataset
from learn2learn.data import MetaDataset
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels
import os

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_root = "images/augmented-images"
class_names = [d for d in os.listdir(train_root) if os.path.isdir(os.path.join(train_root, d))]
n_ways = len(class_names) # number of classes per task
print(f"Classes: {class_names}")
print(f"N-way automatically set to: {n_ways}")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

dataset = ShapesDataset(train_root, transform=transform)
meta_dataset = MetaDataset(dataset)

Classes: ['apple', 'kiwi']
N-way automatically set to: 2


In [8]:
taskset = TaskDataset(
    meta_dataset,
    task_transforms=[
        NWays(meta_dataset, n=n_ways),
        KShots(meta_dataset, k=10),
        LoadData(meta_dataset),
        RemapLabels(meta_dataset),
    ],
    num_tasks=1000
)

model = ConvNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

### Training

In [9]:
k_shot = 5  # support
k_query = 5

for iteration in range(1000):
    try:
        learner = model
        task = taskset.sample()
        data, labels = task
        data, labels = data.to(device), labels.to(device)

        embeddings = learner(data)

        support = []
        support_labels = []
        query = []
        query_labels = []

        for class_idx in range(n_ways):
            class_mask = labels == class_idx
            class_indices = torch.nonzero(class_mask).squeeze()

            if len(class_indices) < (k_shot + k_query):
                print(f"Not enough samples for class {class_idx}")
                continue

            # Take first k_shot for support, rest for query
            support_idx = class_indices[:k_shot]
            query_idx = class_indices[k_shot:k_shot + k_query]

            support.append(embeddings[support_idx])
            support_labels.append(labels[support_idx])

            query.append(embeddings[query_idx])
            query_labels.append(labels[query_idx])

        if len(support) < n_ways or len(query) < n_ways:
            print("Skipping task — not enough valid classes")
            continue

        support = torch.cat(support)
        support_labels = torch.cat(support_labels)
        query = torch.cat(query)
        query_labels = torch.cat(query_labels)

        # Compute prototypes
        prototypes = []
        for class_idx in range(n_ways):
            class_mask = support_labels == class_idx
            prototypes.append(support[class_mask].mean(0))
        prototypes = torch.stack(prototypes)

        dists = torch.cdist(query, prototypes)
        predictions = -dists
        loss = loss_fn(predictions, query_labels)

        if torch.isnan(loss):
            print("Loss is NaN — skipping iteration")
            continue

        opt.zero_grad()
        loss.backward()
        opt.step()

        acc = (predictions.argmax(1) == query_labels).float().mean()
        if iteration % 100 == 0:
            print(f"Iteration {iteration}: Loss={loss.item():.4f}, Accuracy={acc.item()*100:.4f}")

    except Exception as e:
        print(f"Error in iteration {iteration}: {e}")

Iteration 0: Loss=0.1466, Accuracy=100.0000
Iteration 100: Loss=0.0003, Accuracy=100.0000
Iteration 200: Loss=0.0002, Accuracy=100.0000
Iteration 300: Loss=0.0001, Accuracy=100.0000
Iteration 400: Loss=0.0001, Accuracy=100.0000
Iteration 500: Loss=0.0001, Accuracy=100.0000
Iteration 600: Loss=0.0001, Accuracy=100.0000
Iteration 700: Loss=0.0001, Accuracy=100.0000
Iteration 800: Loss=0.0000, Accuracy=100.0000
Iteration 900: Loss=0.0000, Accuracy=100.0000


### Save model parameters

In [10]:
torch.save(model.state_dict(), "saved_models/normal_model.pth")
print("Model saved to model.pth")

Model saved to model.pth
