In [None]:
import pandas as pd
import numpy as np
import os
import multiprocessing as mp

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.transforms as T
import torchvision.transforms.functional as TF

from PIL import Image
from skimage import feature
from skimage import filters
from skimage import exposure
from skimage import img_as_float
from skimage import io

from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, Sampler

from datetime import datetime
from tqdm.notebook import tqdm as tqdm
from utils import mkdir, rmdir

import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams.update({'font.size': 10})
plt.style.use('ggplot')

In [None]:
NUM_WORKERS = 24
CLASS_NUM = 43

IMG_SIZE = (64, 64)

MAX_ROTATION = 40

MIN_BRIGHTNESS = 0.4
MAX_BRIGHTNESS = 2.0

MIN_CONTRAST = 0.4
MAX_CONTRAST = 2.0

MIN_SATURATION = 0.4
MAX_SATURATION = 1.6

DISTORTION = 0.4
MAX_TRANSITION = 0.25

MIN_SCALE = 0.7
MAX_SCALE = 1.5

NOISE_STD = 0.2

In [None]:
def RandomNoise(tensor):
    std = np.random.uniform(0, NOISE_STD)
    return torch.clamp(tensor + torch.FloatTensor(tensor.size()).normal_(0, std), 0, 1)

In [None]:
my_transforms = {
    'default': T.Compose([
        T.Resize(IMG_SIZE),
        T.CenterCrop(IMG_SIZE),
        T.ToTensor()
    ]),
    'train': T.Compose([
        T.RandomAffine(MAX_ROTATION, (MAX_TRANSITION, MAX_TRANSITION), (MIN_SCALE, MAX_SCALE)),
        T.RandomPerspective(DISTORTION),
        T.Resize(IMG_SIZE),
        T.CenterCrop(IMG_SIZE),
        T.ToTensor(),
        T.Lambda(RandomNoise),
    ]),
    'test': T.Compose([
        T.Resize(IMG_SIZE),
        T.CenterCrop(IMG_SIZE),
        T.ToTensor(),
    ])
}

In [None]:
train_path = 'dataset/train.csv'
test_path = 'dataset/test.csv'

image_path = 'dataset/images/'

In [None]:
def load_image(path, as_gray=True):
    return Image.fromarray(255 - io.imread(path, as_gray=as_gray))

In [None]:
class BengaliDataset(Dataset):
    def __init__(self, path, image_path, 
                 target='classify',
                 transform=my_transforms['default'],
                 is_train=True, verbose=True):        
        self.is_train = is_train
        if not isinstance(transform, tuple):
            self.transform = (transform, transform)
        else:
            self.transform = transform
            
        assert target in {'reconstruct', 'classify'}
        self.target = target
        
        # reading labels
        df = pd.read_csv(path)
        self.image_ids = df.image_id.unique()
        self.length = len(self.image_ids)
        
        if self.is_train:
            self.labels = df[['grapheme_root','vowel_diacritic','consonant_diacritic']].values
        else:
            self.labels = [None] * self.length
        
        # reading images
        iterator = [os.path.join(image_path, f'{x}.png') for x in self.image_ids]
        with mp.Pool(8) as pool:
            iterator = pool.imap(load_image, iterator)
            if verbose:
                iterator = tqdm(iterator, total=len(self.image_ids))
            self.images = list(iterator)
        
    def __getitem__(self, idx):
        if self.target == 'reconstruct':
            print(self.transform[0])
            print(self.transform[1])
            print(self.images[idx])
            
            img = self.images[idx]
            img1 = self.transform[0](img)
            img2 = self.transform[1](img)
            return img1, img2
        
        return self.transform(self.images[idx]), self.labels[idx]
    
    def __len__(self):
        return self.length
    
Autoencoder_Train = BengaliDataset(train_path, image_path, target='reconstruct',
    transform=(my_transforms['train'], my_transforms['default']),
    is_train=True, verbose=True)
Autoencoder_Test = BengaliDataset(test_path, image_path, target='reconstruct',
    transform=(my_transforms['test'], my_transforms['default']),
    is_train=False, verbose=True)

for idx in range(5):
    img, target = Autoencoder_Train[idx]
    print(target.size(), type(target))
    print(img.size(), type(img))
    
    plt.imshow(TF.to_pil_image(img), cmap='gray')
    plt.axis('off')
    plt.show()
    
    plt.imshow(TF.to_pil_image(target), cmap='gray')
    plt.axis('off')
    plt.show()