In [1]:
from datasets.download_data import download_data
from datasets.consts import Dataset, DatasetType
from datasets.get_data_loader import get_data_loader
from models.CAML import CAML
from models.feature_extractor import get_pretrained_model, get_transform
from utils import count_learnable_params, count_non_learnable_params
from evaluate import eval_func
from scheduler import WarmupCosineDecayScheduler
from train import train_epoch
import time
import torch
import tqdm
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
# device = "cpu" # Due to CUDA error

device(type='cuda', index=0)

In [3]:
# wandb.init(project="few-shot-learning", config={"architecture": "CAML", "dataset": "mini-imagenet"})

In [4]:
train_transform, test_transform = get_transform("timm/vit_small_patch16_224.dino")
train = download_data(Dataset.MINI_IMAGENET, DatasetType.TRAIN, transform=train_transform)
valid = download_data(Dataset.MINI_IMAGENET, DatasetType.VAL, transform=test_transform)
# test = download_data(Dataset.MINI_IMAGENET, DatasetType.TEST, transform=test_transform)

In [5]:
way = 5
shot = 3
epochs = 5
warmup_epochs = 2

In [6]:
fe = get_pretrained_model("timm/vit_small_patch16_224.dino")
model = CAML(feature_extractor=fe, fe_dim=384, fe_dtype=torch.float32, train_fe=False, encoder_size='tiny', device=device).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

eta_min = 1e-6
scheduler = WarmupCosineDecayScheduler(optimizer, warmup_epochs, epochs, eta_min)

In [7]:
learnable_params = count_learnable_params(model)
non_learnable_params = count_non_learnable_params(model)
print(f"Learnable parameters: {learnable_params}")
print(f"Non-learnable parameters: {non_learnable_params}")

Learnable parameters: 25214593
Non-learnable parameters: 21668864


In [8]:
train_loader = get_data_loader(train, way, shot, 15, True)
valid_loader = get_data_loader(valid, way, shot, 15, False)
# test_laoder = get_data_loader(test, way, shot, 15, False)

In [9]:
best_val_acc = 0
for epoch in tqdm.tqdm(range(epochs)):
    epoch_start = time.time()
    avg_loss, avg_acc = train_epoch(model, train_loader, optimizer, scheduler, criterion, device, way, shot)
    train_epoch_time = time.time() - epoch_start
    print(f"Epoch {epoch} - Loss: {avg_loss}, Acc: {avg_acc}, Time: {train_epoch_time}")
    # wandb.log({"train_acc": acc, "train_loss": loss, "train_epoch_time": train_epoch_time})
    
    # torch.save(model.state_dict(), "model.pth")
    # wandb.save("model.pth")
    
    avg_loss, avg_acc = eval_func(model, valid_loader, criterion, device, way, shot)
    full_epoch_time = time.time() - epoch_start
    print(f"Validation - Loss: {avg_loss}, Acc: {avg_acc}, Time: {full_epoch_time}")
    # wandb.log({"valid_acc": acc, "valid_loss": loss, "full_epoch_time": full_epoch_time})
    
    best_val_acc = max(best_val_acc, avg_acc)
    
    # if avg_acc >= best_val_acc:
    #     torch.save(model.state_dict(), "best_model.pth")
    #     wandb.save("best_model.pth")
    

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 0 - Loss: 1.6, Acc: 18.667, Time: 3.937694549560547


 20%|██        | 1/5 [00:06<00:27,  6.90s/it]

Validation - Loss: 1.875, Acc: 14.667, Time: 6.900905132293701
Epoch 1 - Loss: 1.599, Acc: 22.667, Time: 3.5077009201049805


 40%|████      | 2/5 [00:13<00:20,  6.67s/it]

Validation - Loss: 1.875, Acc: 14.667, Time: 6.514173746109009
Epoch 2 - Loss: 1.614, Acc: 21.333, Time: 3.395956516265869


 60%|██████    | 3/5 [00:19<00:13,  6.52s/it]

Validation - Loss: 1.874, Acc: 14.667, Time: 6.3451149463653564
Epoch 3 - Loss: 1.611, Acc: 18.667, Time: 3.3839399814605713


 80%|████████  | 4/5 [00:26<00:06,  6.45s/it]

Validation - Loss: 1.876, Acc: 14.667, Time: 6.341275215148926
Epoch 4 - Loss: 1.605, Acc: 24.0, Time: 3.4002678394317627


100%|██████████| 5/5 [00:32<00:00,  6.50s/it]

Validation - Loss: 1.877, Acc: 14.667, Time: 6.371903896331787





In [10]:
# wandb.finish()