In [1]:
from datasets.download_data import download_data
from datasets.consts import Dataset, DatasetType
from models.protonet import ProtoNet
from models.feature_extractor import get_pretrained_model, get_transform
from utils import divide_into_query_and_support, get_accuracy_from_logits, count_learnable_params, count_non_learnable_params
import learn2learn as l2l
import time
import torch
import torchvision
import tqdm
import wandb

ImportError: cannot import name 'ProtoNet' from partially initialized module 'models.protonet' (most likely due to a circular import) (/home/lszarejko/MGR/few-shot-image-classification/src/models/protonet.py)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
# device = "cpu" # Due to CUDA error

device(type='cuda', index=0)

In [None]:
# wandb.init(project="few-shot-learning", config={"architecture": "PMF", "dataset": "mini-imagenet"})

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mszaryvip[0m ([33mmgr-few-shot[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
train_transform, test_transform = get_transform("timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k")
# train_transform = torchvision.transforms.Compose(train_transform.transforms[:-2])
# test_transform = torchvision.transforms.Compose(test_transform.transforms[:-2])
train = download_data(Dataset.MINI_IMAGENET, DatasetType.TRAIN, transform=train_transform)
valid = download_data(Dataset.MINI_IMAGENET, DatasetType.VAL, transform=test_transform)
# test = download_data(Dataset.MINI_IMAGENET, DatasetType.TEST, transform=test_transform)

In [None]:
fe = get_pretrained_model("timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k")
model = ProtoNet(backbone=fe).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
learnable_params = count_learnable_params(model)
non_learnable_params = count_non_learnable_params(model)
print(f"Learnable parameters: {learnable_params}")
print(f"Non-learnable parameters: {non_learnable_params}")

Learnable parameters: 85799426
Non-learnable parameters: 0


: 

In [None]:
way = 5
shot = 3
epochs = 5

In [None]:
train_dataset = l2l.data.MetaDataset(train)
transforms = [
    l2l.data.transforms.FusedNWaysKShots(train_dataset, n=way, k=shot+1),
    l2l.data.transforms.LoadData(train_dataset),
    l2l.data.transforms.RemapLabels(train_dataset),
]

train_tasksets = l2l.data.TaskDataset(train_dataset, task_transforms=transforms, num_tasks=10)
train_loader = torch.utils.data.DataLoader(train_tasksets)

In [None]:
valid_dataset = l2l.data.MetaDataset(valid)
transforms = [
    l2l.data.transforms.FusedNWaysKShots(valid_dataset, n=way, k=shot+1),
    l2l.data.transforms.LoadData(valid_dataset),
    l2l.data.transforms.RemapLabels(valid_dataset),
]
valid_tasksets = l2l.data.Taskset(valid_dataset, task_transforms=transforms, num_tasks=10)
valid_loader = torch.utils.data.DataLoader(valid_tasksets, shuffle=True)

In [None]:
# test_dataset = l2l.data.MetaDataset(test)
# transforms = [
#     l2l.data.transforms.FusedNWaysKShots(test_dataset, n=5, k=1),
#     l2l.data.transforms.LoadData(test_dataset),
#     l2l.data.transforms.RemapLabels(test_dataset),
# ]
# test_tasksets = l2l.data.Taskset(test_dataset, task_transforms=transforms, num_tasks=100)
# test_loader = torch.utils.data.DataLoader(test_tasksets, shuffle=True)

In [None]:
best_val_acc = 0
for epoch in tqdm.tqdm(range(epochs)):
    epoch_start = time.time()
    # Training
    model.train()
    avg_loss = 0.0
    avg_acc = 0.0
    for i, (X, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        X, labels = X.to(device), labels.to(device)
        suppX, queryX, suppY, queryY = divide_into_query_and_support(X, labels, way, shot)
        # print(f"Support labels {suppY.view(-1)}, Query labels {queryY.view(-1)}")
        # print(f"Support set shape: {suppX.shape}, Query set shape: {queryX.shape}")
        # print(f"Support labels shape: {suppY.shape}, Query labels shape: {queryY.shape}")
        logits = model(suppX, suppY, queryX)
        logits = logits.view(queryX.shape[0] * queryX.shape[1], -1)
        loss = criterion(logits, queryY.view(-1))
        loss.backward()
        optimizer.step()
        loss_value = loss.item()

        acc = get_accuracy_from_logits(logits, queryY.view(-1))
        avg_acc += acc
        avg_loss += loss_value
    train_epoch_time = time.time() - epoch_start
    avg_acc = avg_acc / (i + 1)
    avg_loss = avg_loss / (i + 1)
    print(f"Epoch {epoch} - Loss: {avg_loss}, Acc: {avg_acc}, Time: {train_epoch_time}")
    # wandb.log({"train_acc": acc, "train_loss": loss, "train_epoch_time": train_epoch_time})
    
    # torch.save(model.state_dict(), "model.pth")
    # wandb.save("model.pth")
    
    # Validation
    model.eval()
    avg_acc = 0.0
    avg_loss = 0.0
    with torch.no_grad():
        for i, (X, labels) in enumerate(valid_loader):
            X, labels = X.to(device), labels.to(device)
            suppX, queryX, suppY, queryY = divide_into_query_and_support(X, labels, way, shot)
            logits = model(suppX, suppY, queryX)
            logits = logits.view(queryX.shape[0] * queryX.shape[1], -1)
            loss = criterion(logits, queryY.view(-1))
            
            acc = get_accuracy_from_logits(logits, queryY.view(-1))
            avg_acc += acc
            avg_loss += loss_value
    full_epoch_time = time.time() - epoch_start
    avg_acc = avg_acc / (i + 1)
    avg_loss = avg_loss / (i + 1)
    print(f"Validation - Loss: {avg_loss}, Acc: {avg_acc}, Time: {full_epoch_time}")
    # wandb.log({"valid_acc": acc, "valid_loss": loss, "full_epoch_time": full_epoch_time})
    best_val_acc = max(best_val_acc, avg_acc)
    
    # if avg_acc >= best_val_acc:
    #     torch.save(model.state_dict(), "best_model.pth")
    #     wandb.save("best_model.pth")
    

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 0 - Loss: 1.5947038650512695, Acc: 28.000000417232513


 20%|██        | 1/5 [01:15<05:00, 75.15s/it]

Validation - Loss: 1.6094436645507812, Acc: 30.400000676512718
Epoch 1 - Loss: 1.609278154373169, Acc: 30.000000447034836


 40%|████      | 2/5 [02:29<03:44, 74.75s/it]

Validation - Loss: 1.6094324588775635, Acc: 29.200000688433647
Epoch 2 - Loss: 1.6089608073234558, Acc: 28.000000566244125


 60%|██████    | 3/5 [03:43<02:28, 74.46s/it]

Validation - Loss: 1.6094083786010742, Acc: 29.200000643730164
Epoch 3 - Loss: 1.6026477813720703, Acc: 26.00000038743019


 80%|████████  | 4/5 [04:57<01:14, 74.34s/it]

Validation - Loss: 1.6092277765274048, Acc: 28.800000593066216
Epoch 4 - Loss: 1.6024526000022887, Acc: 28.000000566244125


100%|██████████| 5/5 [06:11<00:00, 74.40s/it]

Validation - Loss: 1.605920433998108, Acc: 28.60000056028366





In [None]:
# wandb.finish(exit_code=0)

0,1
train_acc,▁▁███
train_loss,████▁
valid_acc,▅▅▁█▅
valid_loss,███▇▁

0,1
train_acc,20.0
train_loss,1.60592
valid_acc,20.0
valid_loss,1.60097
