In [None]:
dataset = DeepLesion(mode='test')

spl = dataset[0]
print(spl['y'])
plt.imshow(spl['x'])

root_dir='/home/Surya/Meta-Learning-Deep-Leision/'
mode = 'test'

with open( root_dir + 'X_' + mode + '.p', 'rb' ) as f:
    X = pickle.load(f)
    X = np.transpose(X, (0,3,1,2))
    
print(X.shape)

In [13]:
import os
import argparse
import pickle
import gc
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import learn2learn as l2l
from learn2learn.data import MetaDataset, TaskDataset
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels

In [2]:
class DeepLesion(Dataset):

    def __init__(self, mode, root_dir='/home/Surya/Meta-Learning-Deep-Leision/', transform=None):
        """
        Args:
            pickle_file (string): Path to the file with datasets.
            mode (string): Train or Val or Test.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        x = 'X_train.p'
        y = 'Y_train.p'
        with open( root_dir + 'X_' + mode + '.p', 'rb' ) as f:
            self.X = pickle.load(f)
        with open( root_dir + 'Y_' + mode + '.p', 'rb' ) as f:
            self.Y = pickle.load(f)
        self.X = np.float32(self.X)
        self.X = np.transpose(self.X, (0,3,1,2))
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.Y)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        x = self.X[idx]
        y = self.Y[idx]

        if self.transform:
            sample = self.transform(x)

        return np.float32(x), int(y)

In [3]:
def pairwise_distances_logits(a, b):
    n = a.shape[0]
    m = b.shape[0]
    logits = -((a.unsqueeze(1).expand(n, m, -1) -
                b.unsqueeze(0).expand(n, m, -1))**2).sum(dim=2)
    return logits


def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


class Convnet(nn.Module):

    def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
        super().__init__()
        self.encoder = l2l.vision.models.ConvBase(output_size=z_dim,
                                                  hidden=hid_dim,
                                                  channels=x_dim)
        self.out_channels = 1600

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)


def fast_adapt(model, batch, ways, shot, query_num, metric=None, device=None):
    if metric is None:
        metric = pairwise_distances_logits
    if device is None:
        device = model.device()
    data, labels = batch
    data = data.to(device)
    labels = labels.to(device)
    n_items = shot * ways

    # Sort data samples by labels
    # TODO: Can this be replaced by ConsecutiveLabels ?
    sort = torch.sort(labels)
    data = data.squeeze(0)[sort.indices].squeeze(0)
    labels = labels.squeeze(0)[sort.indices].squeeze(0)

    # Compute support and query embeddings
    embeddings = model(data)
    support_indices = np.zeros(data.size(0), dtype=bool)
    selection = np.arange(ways) * (shot + query_num)
    for offset in range(shot):
        support_indices[selection + offset] = True
    query_indices = torch.from_numpy(~support_indices)
    support_indices = torch.from_numpy(support_indices)
    support = embeddings[support_indices]
    support = support.reshape(ways, shot, -1).mean(dim=1)
    query = embeddings[query_indices]
    labels = labels[query_indices].long()

    logits = pairwise_distances_logits(query, support)
    loss = F.cross_entropy(logits, labels)
    acc = accuracy(logits, labels)
    return loss, acc

In [14]:
gc.collect()

parser = argparse.ArgumentParser()
parser.add_argument('--max-epoch', type=int, default=10)
parser.add_argument('--shot', type=int, default=5)
parser.add_argument('--test-way', type=int, default=5)
parser.add_argument('--test-shot', type=int, default=5)
parser.add_argument('--test-query', type=int, default=1)
parser.add_argument('--train-query', type=int, default=15)
parser.add_argument('--train-way', type=int, default=5)
parser.add_argument('--gpu', default=1)
args = parser.parse_args('')

In [16]:
# device = torch.device('cpu')
# if args.gpu and torch.cuda.device_count():
#     print("Using gpu")
#     torch.cuda.manual_seed(43)
#     device = torch.device('cuda')

# model = Convnet()
# model.to(device)

# # print('Loading Training Data')
# # train_dataset = DeepLesion(mode='train')
# # print('Loading Validation Data')
# # valid_dataset = DeepLesion(mode='val')
# # print('Loading Test Data')
# # test_dataset = DeepLesion(mode='test')

# train_dataset = l2l.data.MetaDataset(train_dataset)
# train_transforms = [
#     NWays(train_dataset, args.train_way),
#     KShots(train_dataset, args.train_query + args.shot),
#     LoadData(train_dataset),
#     RemapLabels(train_dataset),
# ]
# train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms)
# train_loader = DataLoader(train_tasks, pin_memory=True, shuffle=True)

# valid_dataset = l2l.data.MetaDataset(valid_dataset)
# valid_transforms = [
#     NWays(valid_dataset, args.test_way),
#     KShots(valid_dataset, args.test_query + args.test_shot),
#     LoadData(valid_dataset),
#     RemapLabels(valid_dataset),
# ]

# num_tasks_val = math.floor(len(valid_dataset)/(args.test_way*(args.test_query + args.test_shot)))
# valid_tasks = l2l.data.TaskDataset(valid_dataset,
#                                    task_transforms=valid_transforms,
#                                    num_tasks= num_tasks_val)
# valid_loader = DataLoader(valid_tasks, pin_memory=True, shuffle=True)

# test_dataset = l2l.data.MetaDataset(test_dataset)
# test_transforms = [
#     NWays(test_dataset, args.test_way),
#     KShots(test_dataset, args.test_query + args.test_shot),
#     LoadData(test_dataset),
#     RemapLabels(test_dataset),
# ]

num_tasks_test = math.floor(len(test_dataset)/(args.test_way*(args.test_query + args.test_shot)))
test_tasks = l2l.data.TaskDataset(test_dataset,
                                  task_transforms=test_transforms,
                                  num_tasks=num_tasks_test)
test_loader = DataLoader(test_tasks, pin_memory=True, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=20, gamma=0.5)

for epoch in range(1, args.max_epoch + 1):
    model.train()

    loss_ctr = 0
    n_loss = 0
    n_acc = 0

    for i in range(100):
        batch = next(iter(train_loader))

        loss, acc = fast_adapt(model,
                               batch,
                               args.train_way,
                               args.shot,
                               args.train_query,
                               metric=pairwise_distances_logits,
                               device=device)

        loss_ctr += 1
        n_loss += loss.item()
        n_acc += acc

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    lr_scheduler.step()

    print('epoch {}, train, loss={:.4f} acc={:.4f}'.format(
        epoch, n_loss/loss_ctr, n_acc/loss_ctr))

    model.eval()

    loss_ctr = 0
    n_loss = 0
    n_acc = 0
    for i, batch in enumerate(valid_loader):
        loss, acc = fast_adapt(model,
                               batch,
                               args.test_way,
                               args.test_shot,
                               args.test_query,
                               metric=pairwise_distances_logits,
                               device=device)

        loss_ctr += 1
        n_loss += loss.item()
        n_acc += acc

    print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(
        epoch, n_loss/loss_ctr, n_acc/loss_ctr))

loss_ctr = 0
n_acc = 0

for i, batch in enumerate(test_loader, 1):
    loss, acc = fast_adapt(model,
                           batch,
                           args.test_way,
                           args.test_shot,
                           args.test_query,
                           metric=pairwise_distances_logits,
                           device=device)
    loss_ctr += 1
    n_acc += acc
    print('batch {}: {:.2f}({:.2f})'.format(
        i, n_acc/loss_ctr * 100, acc * 100))

epoch 1, train, loss=1.4211 acc=0.3955
epoch 1, val, loss=1.2869 acc=0.4750
epoch 2, train, loss=1.3834 acc=0.4039
epoch 2, val, loss=1.1972 acc=0.5500
epoch 3, train, loss=1.3224 acc=0.4223
epoch 3, val, loss=1.2364 acc=0.4750
epoch 4, train, loss=1.3185 acc=0.4275
epoch 4, val, loss=1.2403 acc=0.4500
epoch 5, train, loss=1.2715 acc=0.4368
epoch 5, val, loss=1.3038 acc=0.5000
epoch 6, train, loss=1.2731 acc=0.4472
epoch 6, val, loss=1.1967 acc=0.4750
epoch 7, train, loss=1.2945 acc=0.4205
epoch 7, val, loss=1.2858 acc=0.3750
epoch 8, train, loss=1.2712 acc=0.4373
epoch 8, val, loss=1.2726 acc=0.4000
epoch 9, train, loss=1.2715 acc=0.4513
epoch 9, val, loss=1.1791 acc=0.4500
epoch 10, train, loss=1.2905 acc=0.4372
epoch 10, val, loss=1.1582 acc=0.4750
batch 1: 40.00(40.00)
batch 2: 40.00(40.00)
batch 3: 40.00(40.00)
batch 4: 30.00(0.00)
batch 5: 36.00(60.00)
batch 6: 33.33(20.00)
batch 7: 34.29(40.00)
batch 8: 32.50(20.00)
batch 9: 33.33(40.00)


In [12]:
print(len(train_dataset))
print(len(valid_dataset))
print(len(test_dataset))


2994
268
277
