In [1]:
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from src.transformations import EqualizeTransform, CenterCrop, Translate
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from PIL import Image
import json
import os
from argparse import ArgumentParser

In [2]:
with open('original_shuffled_map.json') as fp:
    original_shuffled_map = json.load(fp)
list(original_shuffled_map.keys())[list(original_shuffled_map.values()).index("SHUF00000")]

'DEV13781.jpg'

In [3]:
def _get_class_labels_df(cls_labels_file):  # e.g. 'data/dev_labels.csv'
    cls_labels = pd.read_csv(cls_labels_file)
    cls_labels['id'] = cls_labels['aimi_id']
    cls_labels['label_num'] = (cls_labels['class'] == 'RG').astype(int)
    cls_labels = cls_labels.drop(columns=['aimi_id', 'class'])
    cls_labels = cls_labels.set_index('id')
    return cls_labels


class ClassifierDataset(Dataset):
    def __init__(self, data_dir, df_labels, cache_all=False, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.cache_all = cache_all
        self.data = {}
        self.df_labels = df_labels.reset_index().sort_values(by='filename', ascending=True)
        self.filenames = self.df_labels.filename.to_numpy()

        if self.cache_all:
            for i, fn in enumerate(self.filenames):
                filepath = os.path.join(self.data_dir, fn)
                img = Image.open(filepath).convert('RGB')  # convert forces the image to load in the main process
                # ...other, since PIL uses lazy loading, would cause the image to be loaded in the dataloader
                # worker process but PIL has issues with multiprocessing
                self.data[i] = img

    def __len__(self):
        return len(self.filenames)

    def _get_data(self, idx):
        if self.cache_all:
            return self.data[idx]
        else:
            fn = self.filenames[idx]
            filepath = os.path.join(self.data_dir, fn)
            img = Image.open(filepath).convert('RGB')
            return img

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img = self._get_data(idx)

        # apply transforms
        if self.transform:
            img = self.transform(img)

        label = int(self.df_labels.iloc[idx].label_num)
        return img, label


class MyDataModule(pl.LightningDataModule):
    @staticmethod
    def add_argparse_args(parent_parser):
        parser = parent_parser.add_argument_group(
            title="MyDataModule", description="Class to organize and manage the data"
        )
        parser.add_argument('--data_dir', type=str, default='./data/cfp_od_crop_OD_f2.0')
        parser.add_argument('--cls_label_file', default='./data/dev_labels.csv')
        parser.add_argument("--equalize", choices=["no", "yes", "IgnoreBlack"], default="no")
        parser.add_argument("--batch_size", default=8, type=int)
        parser.add_argument("--use_validation_set_for_test", action="store_true")
        parser.add_argument("--od_crop_factor", default=1.0, type=float)
        parser.add_argument("--aug_degrees", default=45, type=float)
        parser.add_argument("--aug_translate", default=0, type=float)
        parser.add_argument("--aug_scale", default=0.0, type=float)
        parser.add_argument("--split_num_val_folds", default=5, type=int)
        parser.add_argument("--split_val_fold_idx", default=4, type=int)
        parser.add_argument("--split_test_prop", default=0.0, type=float)
        parser.add_argument("--DATA_MAX_OD_DIAMETER_PROP",
                            default=2.0,
                            help="needs to match data generated with lossless_od_crops_using_yolo_predictions.ipynb")
        parser.add_argument("--DATA_CROP_ENLARGMENT_FACTOR",
                            default=2**0.5 * 1.01,  # so that a rotation of the area of interest will not show an artificial border
                            help="needs to match data generated with  lossless_od_crops_using_yolo_predictions.ipynb")
        return parent_parser

    def __init__(self, args, backbone_transform, backbone_resize):
        super().__init__()

        max_od_diameter_prop = args.DATA_MAX_OD_DIAMETER_PROP
        crop_enlargment_factor = args.DATA_CROP_ENLARGMENT_FACTOR
        crop_factor = args.od_crop_factor / (max_od_diameter_prop * crop_enlargment_factor)
        aug_translate = args.aug_translate / (max_od_diameter_prop * crop_enlargment_factor)

        df_labels = _get_class_labels_df(args.cls_label_file).sort_index(ascending=True)

        equalizer = EqualizeTransform(args)
        self.train_transforms = transforms.Compose([
            equalizer,
            transforms.ToTensor(),
            transforms.RandomApply([
                transforms.RandomAffine(degrees=args.aug_degrees,
                                        translate=(aug_translate/2, aug_translate/2),
                                        scale=None)],  # random scaling is done through the cropping
                p=1.0),
            CenterCrop(crop_factor, args.aug_scale/2),
            backbone_resize,
            transforms.RandomHorizontalFlip(),
            backbone_transform,
        ])

        self.test_transforms = transforms.Compose([
            equalizer,
            transforms.ToTensor(),
            CenterCrop(crop_factor),
            backbone_resize,
            backbone_transform,
        ])

        datadir_filenames = [fn for fn in os.listdir(args.data_dir) if fn.endswith('.png')]
        fn_df = pd.DataFrame(columns=['filename'], data=datadir_filenames)
        fn_df['id'] = fn_df['filename'].map(lambda x: x[:-4])
        fn_df = fn_df.set_index('id')
        df_labels = pd.merge(fn_df, df_labels, how='left', left_index=True, right_index=True)  # note the left join
        df_labels = df_labels.sort_index(ascending=True)

        devset_len = int(len(df_labels) * (1-args.split_test_prop))
        test_mask = np.zeros(len(df_labels), dtype=bool)
        test_mask[devset_len:-1] = True

        fold_val = np.zeros(len(df_labels), dtype=bool)
        val_cnt = devset_len // args.split_num_val_folds
        if args.split_val_fold_idx < args.split_num_val_folds:
            fold_val[args.split_val_fold_idx * val_cnt:(args.split_val_fold_idx+1) * val_cnt] = True
        else:
            fold_val[args.split_val_fold_idx * val_cnt:devset_len] = True
        fold_train = (~fold_val) & (~test_mask)

        self.train_ds = ClassifierDataset(args.data_dir, df_labels=df_labels[fold_train],
                                          cache_all=False,
                                          transform=self.train_transforms)
        self.val_ds = ClassifierDataset(args.data_dir, df_labels=df_labels[fold_val],
                                        cache_all=False,
                                        transform=self.test_transforms)
        if args.split_test_prop == 0:
            self.test_ds = self.val_ds
        else:
            self.test_ds = ClassifierDataset(args.data_dir, df_labels=df_labels[test_mask],
                                             cache_all=False,
                                             transform=self.test_transforms)
        self.batch_size = args.batch_size

    def train_dataloader(self):
        return DataLoader(self.train_ds, shuffle=True, batch_size=self.batch_size,
                          num_workers=6, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size,
                          num_workers=6, persistent_workers=True)

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.batch_size,
                          num_workers=6, persistent_workers=True)


In [4]:
parser = ArgumentParser()
parser = MyDataModule.add_argparse_args(parser)
args = parser.parse_args({})

data = MyDataModule(args,
                    backbone_transform=transforms.Lambda(lambda x: x), # a no-op
                    backbone_resize=transforms.Resize((224, 224)))

In [5]:
batch = next(iter(data.train_dataloader()))
tr = transforms.ToPILImage()
for i in range(8):
    im = batch[0][i, :, :, :]
    pil_im = tr(im)
    pil_im.show()