In [1]:
from datasets.download_data import download_data
from datasets.consts import Dataset, DatasetType
from datasets.get_data_loader import get_data_loader
from models.CAML import CAML
from models.feature_extractor import get_pretrained_model, get_transform
from utils import count_learnable_params, count_non_learnable_params
from evaluate import eval_func
from scheduler import WarmupCosineDecayScheduler
from train import train_epoch
import time
import torch
import tqdm
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
# device = "cpu" # Due to CUDA error

device(type='cuda', index=0)

In [3]:
# wandb.init(project="few-shot-learning", config={"architecture": "CAML", "dataset": "mini-imagenet"})

In [3]:
train_transform, test_transform = get_transform("timm/vit_small_patch16_224.dino")
train = download_data(Dataset.MINI_IMAGENET, DatasetType.TRAIN, transform=train_transform)
valid = download_data(Dataset.MINI_IMAGENET, DatasetType.VAL, transform=test_transform)
# test = download_data(Dataset.MINI_IMAGENET, DatasetType.TEST, transform=test_transform)

In [4]:
way = 5
shot = 3
epochs = 5
warmup_epochs = 2

In [5]:
# fe = get_pretrained_model("timm/vit_small_patch16_224.dino")
fe = get_pretrained_model("timm/vit_base_patch16_clip_224.openai")
model = CAML(feature_extractor=fe, fe_dim=768, fe_dtype=torch.float32, train_fe=False, encoder_size='large', device=device).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

eta_min = 1e-6
scheduler = WarmupCosineDecayScheduler(optimizer, warmup_epochs, epochs, eta_min)

In [7]:
state_dict = torch.load("../caml_pretrained_models/CAML_CLIP/model.pth", map_location=device)
model.load_state_dict(state_dict, strict=False)

_IncompatibleKeys(missing_keys=['feature_extractor.cls_token', 'feature_extractor.pos_embed', 'feature_extractor.patch_embed.proj.weight', 'feature_extractor.norm_pre.weight', 'feature_extractor.norm_pre.bias', 'feature_extractor.blocks.0.norm1.weight', 'feature_extractor.blocks.0.norm1.bias', 'feature_extractor.blocks.0.attn.qkv.weight', 'feature_extractor.blocks.0.attn.qkv.bias', 'feature_extractor.blocks.0.attn.proj.weight', 'feature_extractor.blocks.0.attn.proj.bias', 'feature_extractor.blocks.0.norm2.weight', 'feature_extractor.blocks.0.norm2.bias', 'feature_extractor.blocks.0.mlp.fc1.weight', 'feature_extractor.blocks.0.mlp.fc1.bias', 'feature_extractor.blocks.0.mlp.fc2.weight', 'feature_extractor.blocks.0.mlp.fc2.bias', 'feature_extractor.blocks.1.norm1.weight', 'feature_extractor.blocks.1.norm1.bias', 'feature_extractor.blocks.1.attn.qkv.weight', 'feature_extractor.blocks.1.attn.qkv.bias', 'feature_extractor.blocks.1.attn.proj.weight', 'feature_extractor.blocks.1.attn.proj.bias

In [8]:
learnable_params = count_learnable_params(model)
non_learnable_params = count_non_learnable_params(model)
print(f"Learnable parameters: {learnable_params}")
print(f"Non-learnable parameters: {non_learnable_params}")

Learnable parameters: 302316801
Non-learnable parameters: 85800704


In [11]:
train_loader = get_data_loader(train, way, shot, 15, True)
valid_loader = get_data_loader(valid, way, shot, 15, False)
# test_laoder = get_data_loader(test, way, shot, 15, False)

In [12]:
best_val_acc = 0
for epoch in tqdm.tqdm(range(epochs)):
    epoch_start = time.time()
    # avg_loss, avg_acc = train_epoch(model, train_loader, optimizer, scheduler, criterion, device, way, shot)
    train_epoch_time = time.time() - epoch_start
    # print(f"Epoch {epoch} - Loss: {avg_loss}, Acc: {avg_acc}, Time: {train_epoch_time}")
    # wandb.log({"train_acc": acc, "train_loss": loss, "train_epoch_time": train_epoch_time})
    
    # torch.save(model.state_dict(), "model.pth")
    # wandb.save("model.pth")
    
    avg_loss, avg_acc = eval_func(model, valid_loader, criterion, device, way, shot)
    full_epoch_time = time.time() - epoch_start
    print(f"Validation - Loss: {avg_loss}, Acc: {avg_acc}, Time: {full_epoch_time}")
    # wandb.log({"valid_acc": acc, "valid_loss": loss, "full_epoch_time": full_epoch_time})
    
    best_val_acc = max(best_val_acc, avg_acc)
    
    # if avg_acc >= best_val_acc:
    #     torch.save(model.state_dict(), "best_model.pth")
    #     wandb.save("best_model.pth")
    

 20%|██        | 1/5 [02:07<08:30, 127.72s/it]

Validation - Loss: 0.113, Acc: 94.667, Time: 127.71152114868164


 40%|████      | 2/5 [04:17<06:26, 128.84s/it]

Validation - Loss: 0.113, Acc: 94.667, Time: 129.62123703956604


 60%|██████    | 3/5 [06:27<04:18, 129.33s/it]

Validation - Loss: 0.113, Acc: 94.667, Time: 129.902437210083


 80%|████████  | 4/5 [08:36<02:09, 129.33s/it]

Validation - Loss: 0.113, Acc: 94.667, Time: 129.32808804512024


100%|██████████| 5/5 [10:46<00:00, 129.33s/it]

Validation - Loss: 0.113, Acc: 94.667, Time: 130.04531955718994





In [10]:
# wandb.finish()