In [1]:
from datasets.download_data import download_data
from datasets.consts import Dataset, DatasetType
from models.CAML import CAML
from models.feature_extractor import get_pretrained_model, get_transform
from utils import divide_into_query_and_support, get_accuracy_from_logits, count_learnable_params, count_non_learnable_params
import learn2learn as l2l
import time
import torch
import torchvision
import tqdm
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
# device = "cpu" # Due to CUDA error

device(type='cuda', index=0)

In [None]:
wandb.init(project="few-shot-learning", config={"architecture": "CAML", "dataset": "mini-imagenet"})

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mszaryvip[0m ([33mmgr-few-shot[0m). Use [1m`wandb login --relogin`[0m to force relogin




In [4]:
train_transform, test_transform = get_transform("timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k")
train_transform = torchvision.transforms.Compose(train_transform.transforms[:-2])
test_transform = torchvision.transforms.Compose(test_transform.transforms[:-2])
train = download_data(Dataset.MINI_IMAGENET, DatasetType.TRAIN, transform=train_transform)
valid = download_data(Dataset.MINI_IMAGENET, DatasetType.VAL, transform=test_transform)
# test = download_data(Dataset.MINI_IMAGENET, DatasetType.TEST, transform=test_transform)

In [5]:
fe = get_pretrained_model("timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k")
model = CAML(feature_extractor=fe, fe_dim=768, fe_dtype=torch.float32, train_fe=False, encoder_size='tiny', device=device).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [6]:
learnable_params = count_learnable_params(model)
non_learnable_params = count_non_learnable_params(model)
print(f"Learnable parameters: {learnable_params}")
print(f"Non-learnable parameters: {non_learnable_params}")

Learnable parameters: 25214209
Non-learnable parameters: 85800704


In [7]:
way=5
shot=3
epochs=5

In [8]:
train_dataset = l2l.data.MetaDataset(train)
transforms = [
    l2l.data.transforms.FusedNWaysKShots(train_dataset, n=way, k=shot+1),
    l2l.data.transforms.LoadData(train_dataset),
    l2l.data.transforms.RemapLabels(train_dataset),
]

train_tasksets = l2l.data.Taskset(train_dataset, task_transforms=transforms, num_tasks=10)
train_loader = torch.utils.data.DataLoader(train_tasksets, shuffle=True)

In [9]:
valid_dataset = l2l.data.MetaDataset(valid)
transforms = [
    l2l.data.transforms.FusedNWaysKShots(valid_dataset, n=way, k=shot+1),
    l2l.data.transforms.LoadData(valid_dataset),
    l2l.data.transforms.RemapLabels(valid_dataset),
]
valid_tasksets = l2l.data.Taskset(valid_dataset, task_transforms=transforms, num_tasks=100)
valid_loader = torch.utils.data.DataLoader(valid_tasksets, shuffle=True)

In [10]:
# test_dataset = l2l.data.MetaDataset(test)
# transforms = [
#     l2l.data.transforms.FusedNWaysKShots(test_dataset, n=way, k=shot+1),
#     l2l.data.transforms.LoadData(test_dataset),
#     l2l.data.transforms.RemapLabels(test_dataset),
# ]
# test_tasksets = l2l.data.Taskset(test_dataset, task_transforms=transforms, num_tasks=100)
# test_loader = torch.utils.data.DataLoader(test_tasksets, shuffle=True)

In [13]:
best_val_acc = 0
for epoch in tqdm.tqdm(range(epochs)):
    epoch_start = time.time()
    # Training
    model.train()
    avg_loss = 0.0
    avg_acc = 0.0
    for i, (X, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        X, labels = X.to(device), labels.to(device)
        suppX, queryX, suppY, queryY = divide_into_query_and_support(X, labels, way, shot)
        X = torch.cat([suppX, queryX], dim=1).squeeze(0)
        suppY = suppY.squeeze(0)
        queryY = queryY.squeeze(0)
        
        # CAML model divide the input into support and query set, labels should only be support labels
        # inp.shape = [way*(shot+1), channels, height, width] where last way-number of images are query images
        logits = model(X, suppY, way=way, shot=shot)
        loss = criterion(logits, queryY)
        loss.backward()
        optimizer.step()
        
        loss_value = loss.item()
        acc = get_accuracy_from_logits(logits, queryY)
        avg_acc += acc
        avg_loss += loss_value
    train_epoch_time = time.time() - epoch_start
    avg_acc = avg_acc / (i + 1)
    avg_loss = avg_loss / (i + 1)
    print(f"Epoch {epoch} - Loss: {avg_loss}, Acc: {avg_acc}, Time: {train_epoch_time}")
    wandb.log({"train_acc": acc, "train_loss": loss, "train_epoch_time": train_epoch_time})
    
    # torch.save(model.state_dict(), "model.pth")
    # wandb.save("model.pth")
    
    # Validation
    avg_acc = 0.0
    avg_loss = 0.0
    with torch.no_grad():
        for i, (X, labels) in enumerate(valid_loader):
            X, labels = X.to(device), labels.to(device)
            suppX, queryX, suppY, queryY = divide_into_query_and_support(X, labels, way, shot)
            X = torch.cat([suppX, queryX], dim=1).squeeze(0)
            suppY = suppY.squeeze(0)
            queryY = queryY.squeeze(0)
            
            logits = model(X, suppY, way=way, shot=shot)
            loss = criterion(logits, queryY)
            
            loss_value = loss.item()
            acc = get_accuracy_from_logits(logits, queryY)
            avg_acc += acc
            avg_loss += loss_value
    full_epoch_time = time.time() - epoch_start
    avg_acc = avg_acc / (i + 1)
    avg_loss = avg_loss / (i + 1)
    print(f"Validation - Loss: {avg_loss}, Acc: {avg_acc}, Time: {full_epoch_time}")
    wandb.log({"valid_acc": acc, "valid_loss": loss, "full_epoch_time": full_epoch_time})
    best_val_acc = max(best_val_acc, avg_acc)
    
    # if avg_acc >= best_val_acc:
    #     torch.save(model.state_dict(), "best_model.pth")
    #     wandb.save("best_model.pth")
    

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 0 - Loss: 3.9555023431777956, Acc: 18.0, Time: 6.9784934520721436


 20%|██        | 1/5 [01:13<04:54, 73.72s/it]

Validation - Loss: 2.4787175178527834, Acc: 20.0, Time: 73.71824431419373
Epoch 1 - Loss: 1.889948844909668, Acc: 22.0, Time: 6.891533613204956


 40%|████      | 2/5 [02:27<03:40, 73.57s/it]

Validation - Loss: 1.6772475862503051, Acc: 20.0, Time: 73.47093033790588
Epoch 2 - Loss: 1.64488285779953, Acc: 26.0, Time: 6.909821033477783


 60%|██████    | 3/5 [03:40<02:26, 73.49s/it]

Validation - Loss: 1.6138462042808532, Acc: 19.0, Time: 73.38893342018127
Epoch 3 - Loss: 1.57606098651886, Acc: 34.0, Time: 6.867777109146118


 80%|████████  | 4/5 [04:54<01:13, 73.50s/it]

Validation - Loss: 1.61329726934433, Acc: 20.6, Time: 73.51720476150513
Epoch 4 - Loss: 1.5278473258018495, Acc: 34.0, Time: 7.0167436599731445


100%|██████████| 5/5 [06:07<00:00, 73.52s/it]

Validation - Loss: 1.6220751667022706, Acc: 20.4, Time: 73.48469114303589





In [None]:
wandb.finish()

0,1
acc,▁
loss,▁

0,1
acc,0.0
loss,1.41545
