In [1]:
from datasets.download_data import download_data
from datasets.consts import Dataset, DatasetType
from datasets.get_data_loader import get_data_loader
from models.protonet import ProtoNet,ProtoNet_Finetune
from models.feature_extractor import get_pretrained_model, get_transform
from utils import count_learnable_params, count_non_learnable_params
from evaluate import eval_func
from scheduler import WarmupCosineDecayScheduler
from train import train_epoch
import time
import torch
import tqdm
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
# device = "cpu" # Due to CUDA error

device(type='cuda', index=0)

In [3]:
# wandb.init(project="few-shot-learning", config={"architecture": "PMF", "dataset": "mini-imagenet"})

In [4]:
fe_train_transform, fe_test_transform = get_transform("timm/vit_small_patch16_224.dino")
train = download_data(Dataset.MINI_IMAGENET, DatasetType.TRAIN, transform=fe_train_transform)
# valid = download_data(Dataset.MINI_IMAGENET, DatasetType.VAL, transform=fe_test_transform)
test = download_data(Dataset.MINI_IMAGENET, DatasetType.TEST, transform=fe_test_transform)

In [5]:
way = 5
shot = 3
epochs = 5
warmup_epochs = 2

In [6]:



fe = get_pretrained_model("timm/vit_small_patch16_224.dino")
model = ProtoNet(backbone=fe).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

eta_min = 1e-6
scheduler = WarmupCosineDecayScheduler(optimizer, warmup_epochs, epochs, eta_min)

In [7]:
learnable_params = count_learnable_params(model)
non_learnable_params = count_non_learnable_params(model)
print(f"Learnable parameters: {learnable_params}")
print(f"Non-learnable parameters: {non_learnable_params}")

Learnable parameters: 21665666
Non-learnable parameters: 0


In [8]:
train_loader = get_data_loader(train, way, shot, 15, True)
# valid_loader = get_data_loader(valid, way, shot, 15, False)
test_laoder = get_data_loader(test, way, shot, 15, False)

In [9]:
best_val_acc = 0
for epoch in tqdm.tqdm(range(epochs)):
    epoch_start = time.time()
    avg_loss, avg_acc = train_epoch(model, train_loader, optimizer, scheduler, criterion, device, way, shot)
    train_epoch_time = time.time() - epoch_start
    print(f"Epoch {epoch} - Loss: {avg_loss}, Acc: {avg_acc}, Time: {train_epoch_time}")
    # wandb.log({"train_acc": acc, "train_loss": loss, "train_epoch_time": train_epoch_time})
    
    # torch.save(model.state_dict(), "model.pth")
    # wandb.save("model.pth")
    
    # avg_loss, avg_acc = eval_func(model, valid_loader, criterion, device, way, shot)
    # full_epoch_time = time.time() - epoch_start
    # print(f"Validation - Loss: {avg_loss}, Acc: {avg_acc}, Time: {full_epoch_time}")
    # wandb.log({"valid_acc": acc, "valid_loss": loss, "full_epoch_time": full_epoch_time})
    
    # best_val_acc = max(best_val_acc, avg_acc)
    
    # if avg_acc >= best_val_acc:
    #     torch.save(model.state_dict(), "best_model.pth")
    #     wandb.save("best_model.pth")
    

 20%|██        | 1/5 [00:08<00:35,  8.82s/it]

Epoch 0 - Loss: 0.38, Acc: 90.667, Time: 8.823339223861694


 40%|████      | 2/5 [00:16<00:25,  8.44s/it]

Epoch 1 - Loss: 0.084, Acc: 100.0, Time: 8.165836334228516


 60%|██████    | 3/5 [00:25<00:16,  8.34s/it]

Epoch 2 - Loss: 0.029, Acc: 100.0, Time: 8.215368270874023


 80%|████████  | 4/5 [00:33<00:08,  8.31s/it]

Epoch 3 - Loss: 0.018, Acc: 100.0, Time: 8.274571657180786


100%|██████████| 5/5 [00:41<00:00,  8.36s/it]

Epoch 4 - Loss: 0.013, Acc: 100.0, Time: 8.312508821487427





In [None]:
model_with_fn = ProtoNet_Finetune(backbone=fe, lr=1e-5).to(device) # problem with learning rate
model_with_fn.load_state_dict(model.state_dict())

In [14]:
avg_loss, avg_acc = eval_func(model_with_fn, test_laoder, criterion, device, way, shot)
print(f"Test - Loss: {avg_loss}, Acc: {avg_acc}")

Test - Loss: 1.615, Acc: 26.667


In [12]:
# wandb.finish(exit_code=0)