In [None]:
import numpy as np 
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.utils.data as Dataset

In [None]:
SEED = 2021
def seedTorch(seed=SEED):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
seedTorch()

In [None]:
ROOT_PATH = "../input/recursion-cellular-image-classification"
TRAIN_PATH  = "../input/recursion-cellular-image-classification/train.csv"
TEST_PATH = "../input/recursion-cellular-image-classification/test.csv"
TRAIN_IMAGE_PATH = "../input/recursion-cellular-image-classification/train"
TEST_IMAGE_PATH = "../input/recursion-cellular-image-classification/test"

TEST_SIZE = 0.25

device = 'cuda'
BATCH_SIZE = 32
CHANNELS = [1,2,3,4,5,6]

In [None]:
train_val = pd.read_csv(TRAIN_PATH)
train, val = train_test_split(train_val, test_size = TEST_SIZE, random_state=SEED)
test = pd.read_csv(TEST_PATH)

sample image path : ../input/recursion-cellular-image-classification/test/HEPG2-08/Plate1/B02_s1_w1.png

In [None]:
class ImagesDS(Dataset.Dataset):
    def __init__(self, df, img_dir, mode='train'):
        self.records = df.to_records(index=False)
        self.channels = CHANNELS
        self.site = 1
        self.mode = mode
        self.img_dir = img_dir
        self.len = df.shape[0]
        
    @staticmethod
    def _load_img_as_tensor(file_name):
        with Image.open(file_name) as img:
            return T.ToTensor()(img)

    def _get_img_path(self, index, channel):
        experiment, well, plate = self.records[index].experiment, self.records[index].well, self.records[index].plate
        return '/'.join([self.img_dir,self.mode,experiment,f'Plate{plate}',f'{well}_s{self.site}_w{channel}.png'])
        
    def __getitem__(self, index):
        paths = [self._get_img_path(index, ch) for ch in self.channels]
        img = torch.cat([self._load_img_as_tensor(img_path) for img_path in paths])
        if self.mode == 'train':
            return img, int(self.records[index].sirna)
        else:
            return img, self.records[index].id_code

    def __len__(self):
        return self.len

# make image dataset

In [None]:
train_ds = ImagesDS(train, ROOT_PATH, mode='train')
val_ds = ImagesDS(val, ROOT_PATH, mode='train')
test_ds = ImagesDS(test, ROOT_PATH, mode='test')

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,\
           batch_sampler=None, num_workers=0, collate_fn=None,\
           pin_memory=False, drop_last=False, timeout=0,\
           worker_init_fn=None, *, prefetch_factor=2,\
           persistent_workers=False)

# make loader

In [None]:
train_loader = Dataset.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = Dataset.DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True)
test_loader = Dataset.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)