In [1]:
from datasets.download_data import download_data
from datasets.consts import Dataset, DatasetType
from datasets.get_data_loader import get_data_loader
from models.CAML import CAML
from models.feature_extractor import get_pretrained_model, get_transform
from utils import count_learnable_params, count_non_learnable_params
from evaluate import eval_func
from train import train_epoch
import time
import torch
import tqdm
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
# device = "cpu" # Due to CUDA error

device(type='cuda', index=0)

In [3]:
# wandb.init(project="few-shot-learning", config={"architecture": "CAML", "dataset": "mini-imagenet"})

In [4]:
train_transform, test_transform = get_transform("timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k")
train = download_data(Dataset.MINI_IMAGENET, DatasetType.TRAIN, transform=train_transform)
valid = download_data(Dataset.MINI_IMAGENET, DatasetType.VAL, transform=test_transform)
# test = download_data(Dataset.MINI_IMAGENET, DatasetType.TEST, transform=test_transform)

In [5]:
fe = get_pretrained_model("timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k")
model = CAML(feature_extractor=fe, fe_dim=768, fe_dtype=torch.float32, train_fe=False, encoder_size='tiny', device=device).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [6]:
learnable_params = count_learnable_params(model)
non_learnable_params = count_non_learnable_params(model)
print(f"Learnable parameters: {learnable_params}")
print(f"Non-learnable parameters: {non_learnable_params}")

Learnable parameters: 25214209
Non-learnable parameters: 85800704


In [7]:
way=5
shot=3
epochs=5

In [8]:
train_loader = get_data_loader(train, way, shot, 15, True)
valid_loader = get_data_loader(valid, way, shot, 15, False)
# test_laoder = get_data_loader(test, way, shot, 15, False)

In [9]:
best_val_acc = 0
for epoch in tqdm.tqdm(range(epochs)):
    epoch_start = time.time()
    avg_loss, avg_acc = train_epoch(model, train_loader, optimizer, criterion, device, way, shot)
    train_epoch_time = time.time() - epoch_start
    print(f"Epoch {epoch} - Loss: {avg_loss}, Acc: {avg_acc}, Time: {train_epoch_time}")
    # wandb.log({"train_acc": acc, "train_loss": loss, "train_epoch_time": train_epoch_time})
    
    # torch.save(model.state_dict(), "model.pth")
    # wandb.save("model.pth")
    
    avg_loss, avg_acc = eval_func(model, valid_loader, criterion, device, way, shot)
    full_epoch_time = time.time() - epoch_start
    print(f"Validation - Loss: {avg_loss}, Acc: {avg_acc}, Time: {full_epoch_time}")
    # wandb.log({"valid_acc": acc, "valid_loss": loss, "full_epoch_time": full_epoch_time})
    
    best_val_acc = max(best_val_acc, avg_acc)
    
    # if avg_acc >= best_val_acc:
    #     torch.save(model.state_dict(), "best_model.pth")
    #     wandb.save("best_model.pth")
    

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 0 - Loss: 3.544, Acc: 20.0, Time: 10.669377326965332


 20%|██        | 1/5 [00:20<01:22, 20.64s/it]

Validation - Loss: 1.932, Acc: 20.0, Time: 20.64151430130005
Epoch 1 - Loss: 1.66, Acc: 33.333, Time: 10.323870658874512


 40%|████      | 2/5 [00:40<01:01, 20.45s/it]

Validation - Loss: 1.717, Acc: 16.0, Time: 20.32196068763733
Epoch 2 - Loss: 1.463, Acc: 45.333, Time: 10.30575942993164


 60%|██████    | 3/5 [01:01<00:40, 20.38s/it]

Validation - Loss: 1.625, Acc: 20.0, Time: 20.29905104637146
Epoch 3 - Loss: 1.019, Acc: 70.667, Time: 10.359575510025024


 80%|████████  | 4/5 [01:21<00:20, 20.37s/it]

Validation - Loss: 2.139, Acc: 17.333, Time: 20.35268545150757
Epoch 4 - Loss: 0.756, Acc: 74.667, Time: 10.302518844604492


100%|██████████| 5/5 [01:41<00:00, 20.39s/it]

Validation - Loss: 4.864, Acc: 20.0, Time: 20.30772113800049





In [10]:
# wandb.finish()