In [1]:
import torch
from datasets.download_data import download_data
from datasets.consts import Dataset, DatasetType
from models.feature_extractor import get_pretrained_model, get_transform
import torchvision
import torch
from models.protonet import ProtoNet
from timm.utils import accuracy
import learn2learn as l2l
import tqdm
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
# device = "cpu" # Due to CUDA error

device(type='cuda', index=0)

In [3]:
wandb.init(project="few-shot-learning", config={"architecture": "PMF", "dataset": "mini-imagenet"})

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mszaryvip[0m ([33mmgr-few-shot[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
train_transform, test_transform = get_transform("timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k")
train_transform = torchvision.transforms.Compose(train_transform.transforms[:-2])
test_transform = torchvision.transforms.Compose(test_transform.transforms[:-2])
train = download_data(Dataset.MINI_IMAGENET, DatasetType.TRAIN, transform=train_transform)
# valid = download_data(Dataset.MINI_IMAGENET, DatasetType.VAL, transform=test_transform)
# test = download_data(Dataset.MINI_IMAGENET, DatasetType.TEST, transform=test_transform)

In [5]:
fe = get_pretrained_model("timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k")
model = ProtoNet(backbone=fe).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [17]:
way = 5
shot = 1

In [None]:
train_dataset = l2l.data.MetaDataset(train)
transforms = [
    l2l.data.transforms.FusedNWaysKShots(train_dataset, n=way, k=shot+1),
    l2l.data.transforms.LoadData(train_dataset),
    l2l.data.transforms.RemapLabels(train_dataset),
]

train_tasksets = l2l.data.TaskDataset(train_dataset, task_transforms=transforms, num_tasks=10)
train_loader = torch.utils.data.DataLoader(train_tasksets)

In [14]:
for task in train_tasksets:
    X, y = task
    print("support: ", y[::2], "query: ", y[1::2])

support:  tensor([2, 0, 4, 3, 1]) query:  tensor([2, 0, 4, 3, 1])
support:  tensor([4, 2, 0, 3, 1]) query:  tensor([4, 2, 0, 3, 1])
support:  tensor([3, 1, 0, 4, 2]) query:  tensor([3, 1, 0, 4, 2])
support:  tensor([0, 1, 4, 2, 3]) query:  tensor([0, 1, 4, 2, 3])
support:  tensor([4, 1, 3, 2, 0]) query:  tensor([4, 1, 3, 2, 0])
support:  tensor([3, 4, 2, 0, 1]) query:  tensor([3, 4, 2, 0, 1])
support:  tensor([0, 3, 4, 2, 1]) query:  tensor([0, 3, 4, 2, 1])
support:  tensor([4, 3, 1, 2, 0]) query:  tensor([4, 3, 1, 2, 0])
support:  tensor([2, 4, 3, 1, 0]) query:  tensor([2, 4, 3, 1, 0])
support:  tensor([4, 3, 0, 1, 2]) query:  tensor([4, 3, 0, 1, 2])


In [7]:
# valid_dataset = l2l.data.MetaDataset(valid)
# transforms = [
#     l2l.data.transforms.FusedNWaysKShots(valid_dataset, n=5, k=1),
#     l2l.data.transforms.LoadData(valid_dataset),
#     l2l.data.transforms.RemapLabels(valid_dataset),
# ]
# valid_tasksets = l2l.data.Taskset(valid_dataset, task_transforms=transforms, num_tasks=100)
# valid_loader = torch.utils.data.DataLoader(valid_tasksets, shuffle=True)

In [8]:
# test_dataset = l2l.data.MetaDataset(test)
# transforms = [
#     l2l.data.transforms.FusedNWaysKShots(test_dataset, n=5, k=1),
#     l2l.data.transforms.LoadData(test_dataset),
#     l2l.data.transforms.RemapLabels(test_dataset),
# ]
# test_tasksets = l2l.data.Taskset(test_dataset, task_transforms=transforms, num_tasks=100)
# test_loader = torch.utils.data.DataLoader(test_tasksets, shuffle=True)

In [18]:
model.train()
epochs = 5
best_val_acc = 0
for epoch in tqdm.tqdm(range(epochs)):
    avg_loss = 0.0
    avg_acc = 0.0
    for i, batch in enumerate(train_loader):
        optimizer.zero_grad()
        X, labels = batch
        X, labels = X.to(device), labels.to(device)
        suppX, queryX, suppY, queryY = [], [], [], []
        for i in range(0, labels.shape[1], shot+1):
            suppX = X[:, i:i+shot]
            queryX = X[:, i+shot]
            suppY = labels[:, i:i+shot]
            queryY = labels[:, i+shot]
        print(f"Support labels {suppY.view(-1)}, Query labels {queryY.view(-1)}")
        print(f"Support set shape: {suppX.shape}, Query set shape: {queryX.shape}")
        print(f"Support labels shape: {suppY.shape}, Query labels shape: {queryY.shape}")
        logits = model(suppX, suppY, queryX)
        logits = logits.view(queryX.shape[0] * queryX.shape[1], -1)
        loss = criterion(logits, queryY.view(-1))
        loss.backward()
        optimizer.step()
        loss_value = loss.item()
        
        
        _, max_index = torch.max(logits, 1)
        print("max index: ", max_index, "target: ", queryY, "logits: ", logits)
        acc = 100 * torch.sum(torch.eq(max_index, queryY.view(-1))).item() / queryY.view(-1).shape[0]
        print(acc)
        avg_acc += acc
        avg_loss += loss_value
    avg_acc = avg_acc / (i + 1)
    avg_loss = avg_loss / (i + 1)
    print(f"Epoch {epoch} - Loss: {avg_loss}, Acc: {avg_acc}")
    wandb.log({"train_acc": acc, "train_loss": loss})
    
    torch.save(model.state_dict(), "model.pth")
    wandb.save("model.pth")
    
    # Validation
    # avg_acc = 0.0
    # avg_loss = 0.0
    # with torch.no_grad():
    #     for i, batch in enumerate(valid_loader):
    #         X, labels = batch
    #         shot=1
    #         batch_size, way, channels, height, width = X.shape
    #         X = X.view(batch_size * way, channels, height, width)
    #         labels = labels.view(-1)
    #         X, labels = X.to(device), labels.to(device)
    #         logits = model(X, labels[:(way-1)*shot], way=(way-1), shot=shot)
    #         loss = criterion(logits, labels[(way-1)*shot:])
    #         loss_value = loss.item()
    #         _, max_index = torch.max(logits, 1)
            
    #         acc = 100 * torch.sum(torch.eq(max_index, labels[(way-1)*shot:])).item() / labels[(way-1)*shot:].shape[0]
    #         avg_acc += acc
    #         avg_loss += loss_value
    # avg_acc = avg_acc / (i + 1)
    # avg_loss = avg_loss / (i + 1)
    # print(f"Validation - Loss: {avg_loss}, Acc: {avg_acc}")
    # wandb.log({"valid_acc": acc, "valid_loss": loss})
    # best_val_acc = max(best_val_acc, avg_acc)
    
    # if avg_acc == best_val_acc:
    #     torch.save(model.state_dict(), "best_model.pth")
    #     wandb.save("best_model.pth")
    

  0%|          | 0/5 [00:00<?, ?it/s]

Support labels tensor([1], device='cuda:0'), Query labels tensor([1], device='cuda:0')
Support set shape: torch.Size([1, 1, 3, 224, 224]), Query set shape: torch.Size([1, 3, 224, 224])
Support labels shape: torch.Size([1, 1]), Query labels shape: torch.Size([1])


  0%|          | 0/5 [00:00<?, ?it/s]


RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [1, 256] but got: [1, 768].

In [11]:
wandb.finish(exit_code=0)