In [1]:
from datasets.download_data import download_data
from datasets.consts import Dataset, DatasetType
from datasets.get_data_loader import get_data_loader
from models.protonet import ProtoNet
from models.feature_extractor import get_pretrained_model, get_transform
from utils import count_learnable_params, count_non_learnable_params
from evaluate import eval_func
from train import train_epoch
import time
import torch
import tqdm
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
# device = "cpu" # Due to CUDA error

device(type='cuda', index=0)

In [3]:
# wandb.init(project="few-shot-learning", config={"architecture": "PMF", "dataset": "mini-imagenet"})

In [4]:
fe_train_transform, fe_test_transform = get_transform("timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k")
train = download_data(Dataset.MINI_IMAGENET, DatasetType.TRAIN, transform=fe_train_transform)
valid = download_data(Dataset.MINI_IMAGENET, DatasetType.VAL, transform=fe_test_transform)
# test = download_data(Dataset.MINI_IMAGENET, DatasetType.TEST, transform=test_transform)

In [5]:
fe = get_pretrained_model("timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k")
model = ProtoNet(backbone=fe).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [6]:
learnable_params = count_learnable_params(model)
non_learnable_params = count_non_learnable_params(model)
print(f"Learnable parameters: {learnable_params}")
print(f"Non-learnable parameters: {non_learnable_params}")

Learnable parameters: 85799426
Non-learnable parameters: 0


In [7]:
way = 5
shot = 3
epochs = 5

In [8]:
train_loader = get_data_loader(train, way, shot, 15, True)
valid_loader = get_data_loader(valid, way, shot, 15, False)
# test_laoder = get_data_loader(test, way, shot, 15, False)

In [9]:
best_val_acc = 0
for epoch in tqdm.tqdm(range(epochs)):
    epoch_start = time.time()
    avg_loss, avg_acc = train_epoch(model, train_loader, optimizer, criterion, device, way, shot)
    train_epoch_time = time.time() - epoch_start
    print(f"Epoch {epoch} - Loss: {avg_loss}, Acc: {avg_acc}, Time: {train_epoch_time}")
    # wandb.log({"train_acc": acc, "train_loss": loss, "train_epoch_time": train_epoch_time})
    
    # torch.save(model.state_dict(), "model.pth")
    # wandb.save("model.pth")
    
    eval_func(model, valid_loader, criterion, device, way, shot)
    full_epoch_time = time.time() - epoch_start
    print(f"Validation - Loss: {avg_loss}, Acc: {avg_acc}, Time: {full_epoch_time}")
    # wandb.log({"valid_acc": acc, "valid_loss": loss, "full_epoch_time": full_epoch_time})
    
    best_val_acc = max(best_val_acc, avg_acc)
    
    # if avg_acc >= best_val_acc:
    #     torch.save(model.state_dict(), "best_model.pth")
    #     wandb.save("best_model.pth")
    

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 0 - Loss: 1.678, Acc: 33.333, Time: 24.75922465324402


 20%|██        | 1/5 [00:33<02:14, 33.51s/it]

Validation - Loss: 1.678, Acc: 33.333, Time: 33.508912801742554
Epoch 1 - Loss: 1.618, Acc: 26.667, Time: 24.509443521499634


 40%|████      | 2/5 [01:06<01:40, 33.39s/it]

Validation - Loss: 1.618, Acc: 26.667, Time: 33.30507445335388
Epoch 2 - Loss: 1.609, Acc: 33.333, Time: 24.49422597885132


 60%|██████    | 3/5 [01:40<01:06, 33.34s/it]

Validation - Loss: 1.609, Acc: 33.333, Time: 33.273425817489624
Epoch 3 - Loss: 1.609, Acc: 26.667, Time: 24.479862451553345


 80%|████████  | 4/5 [02:13<00:33, 33.31s/it]

Validation - Loss: 1.609, Acc: 26.667, Time: 33.25819659233093
Epoch 4 - Loss: 1.609, Acc: 33.333, Time: 24.493269681930542


100%|██████████| 5/5 [02:46<00:00, 33.33s/it]

Validation - Loss: 1.609, Acc: 33.333, Time: 33.275880098342896





In [None]:
# wandb.finish(exit_code=0)

: 