In [1]:
import os
import numpy as np
import pandas as pd
# from chexpert_specific.model.dataloader import ChexpertDataset
import learn2learn as l2l
from pathlib import Path
import torch
from torch import nn, optim
import PIL.Image as Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torchvision.models as models

In [2]:
class ChexpertDataset(Dataset):
    def __init__(self, csv_path: Path, split: str) -> None:
        super(ChexpertDataset, self).__init__()
        self.data_path = Path(csv_path).parent
        self.annotations = pd.read_csv(csv_path).fillna(0)
        self.train_annotations = None
        self.split = split
        self.transforms = None
        self.height, self.width = 224, 224
        self.transforms = transforms.Compose([
                transforms.Resize((self.height, self.width)),
                transforms.ToTensor(),
                transforms.Normalize(128, 64),
                transforms.ToPILImage(),
                transforms.Lambda(lambda x: transforms.functional.equalize(x)),
                transforms.ToTensor(),
        ])
        if split == "train":
            # assert cfg.DATA.BATCH_SIZE <= cfg.DATA.LABELED_SIZE, "Batch size must be smaller than train size."
            self.annotations = self.annotations.sample(frac=1).reset_index(drop=True)
            # self.train_annotations = self.annotations[:cfg.DATA.LABELED_SIZE]
            self.train_annotations = self.annotations[:1000]
            self.transforms = transforms.Compose([
                self.transforms,
                transforms.RandomAffine(
                    degrees=(-15, 15),
                    translate=(0.05, 0.05),
                    scale=(0.95, 1.05)
                ),
            ])

    def __len__(self) -> int:
        return self.annotations.shape[0] if self.split != 'train' else self.train_annotations.shape[0]

    def __getitem__(self, index: int) -> None:
        annotations = self.annotations if self.split != 'train' else self.train_annotations
        annotation = annotations.iloc[index]
        image = Image.open(self.data_path.parent / annotation['Path'])
        classes = annotation[['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']].values.astype("float32")
        classes = torch.sum(torch.pow(2, torch.arange(5)) * classes)
        data = self.transforms(image)
        return data.repeat(3, 1, 1), classes # torch.from_numpy(classes)

In [3]:
path_ = '/home/smodi9/CheXpert-v1.0-small/'
ds_path = Path(path_)
split = 'train'

dataset = ChexpertDataset(ds_path / f"{split}.csv", split)

In [4]:
type(dataset)

__main__.ChexpertDataset

In [5]:
dataset_l2l = l2l.data.MetaDataset(dataset)

In [6]:
ways=3
shots=1

train_tasks = l2l.data.TaskDataset(dataset_l2l,
                                   task_transforms=[
                                         l2l.data.transforms.NWays(dataset_l2l, ways),
                                         l2l.data.transforms.KShots(dataset_l2l, 2*shots),
                                         l2l.data.transforms.LoadData(dataset_l2l),
#                                          l2l.data.transforms.RemapLabels(dataset_l2l),
#                                          l2l.data.transforms.ConsecutiveLabels(dataset_l2l),
                                    ],
                                    num_tasks=32)

In [7]:
len(train_tasks)

32

In [8]:
train_task = train_tasks.sample()

In [9]:
train_task[0].shape

torch.Size([6, 3, 224, 224])

In [10]:
train_task[1].shape

torch.Size([6])

In [11]:
train_task[1]

tensor([11., 11.,  0.,  0., 20., 20.], dtype=torch.float64)

In [12]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.backbone = models.resnet18(pretrained=False)
        self.backbone.fc = nn.Linear(512, 5)

    def forward(self, x):
        return self.backbone(x)

In [13]:
model = Net()
meta_model = l2l.algorithms.MAML(model, lr=0.01)
opt = optim.Adam(meta_model.parameters(), lr=0.005)
loss_func = nn.NLLLoss(reduction='mean')

In [14]:
iterations=1000
device=torch.device("cpu")
tps=32
fas=5

In [15]:
def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1)
    acc = (predictions == targets).sum().float()
    acc /= len(targets)
    return acc.item()

In [16]:
for iteration in range(iterations):
    iteration_error = 0.0
    iteration_acc = 0.0
    for _ in range(tps):
        learner = meta_model.clone()
        try:
            train_task = train_tasks.sample()
        except ValueError:
            continue
        data, labels = train_task
        data = data.to(device)
        labels = labels.to(device)
        print(labels)

        # Separate data into adaptation/evalutation sets
        adaptation_indices = np.zeros(data.size(0), dtype=bool)
        adaptation_indices[np.arange(shots*ways) * 2] = True
        evaluation_indices = torch.from_numpy(~adaptation_indices)
        adaptation_indices = torch.from_numpy(adaptation_indices)
        adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
        evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

        # Fast Adaptation
        for step in range(fas):
            print(adaptation_data.dtype)
            print(adaptation_data.type(torch.LongTensor).dtype)
            print(learner(adaptation_data).dtype)
            print(adaptation_labels.dtype)
            learner_ = learner(adaptation_data.type(torch.float32))
            train_error = loss_func(learner_, adaptation_labels.long())
            learner.adapt(train_error)

        # Compute validation loss
        predictions = learner(evaluation_data)
        valid_error = loss_func(predictions, evaluation_labels)
        valid_error /= len(evaluation_data)
        valid_accuracy = accuracy(predictions, evaluation_labels)
        iteration_error += valid_error
        iteration_acc += valid_accuracy

    iteration_error /= tps
    iteration_acc /= tps
    print('Loss : {:.3f} Acc : {:.3f}'.format(iteration_error, iteration_acc))

    # Take the meta-learning step
    opt.zero_grad()
    iteration_error.backward()
    opt.step()

tensor([ 9.,  9., 11., 11., 19., 19.], dtype=torch.float64)
torch.float32
torch.int64
torch.float32
torch.float64


IndexError: Target 9 is out of bounds.