In [None]:
import torch
import numpy as np

train_on_gpu = torch.cuda.is_available()

In [None]:
import pickle
import numpy as np
import pandas as pd
import random
from skimage import io

from tqdm import tqdm, tqdm_notebook
from PIL import Image
from pathlib import Path

from torchvision import transforms
from multiprocessing.pool import ThreadPool
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

from matplotlib import colors, pyplot as plt
import seaborn as sns
%matplotlib inline

import warnings
warnings.filterwarnings(action='ignore', category=DeprecationWarning)

np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True

In [None]:
DATA_MODES = ['train', 'val', 'test']
RESCALE_SIZE = 224
DEVICE = torch.device("cuda")

In [None]:
class Picture_Dataset(Dataset):

    def __init__(self, files, mode):
        super().__init__()
        self.files = sorted(files)
        self.mode = mode

        if self.mode not in DATA_MODES:
            print(f"{self.mode} is not correct; correct modes: {DATA_MODES}")
            raise NameError

        self.len_ = len(self.files)
     
        self.label_encoder = LabelEncoder()

        if self.mode != 'test':
            self.labels = [path.parent.name for path in self.files]
            self.label_encoder.fit(self.labels)

            with open('label_encoder.pkl', 'wb') as le_dump_file:
                  pickle.dump(self.label_encoder, le_dump_file)
                      
    def __len__(self):
        return self.len_
      
    def load_sample(self, file):
        image = Image.open(file)
        image.load()
        return image
  
    def __getitem__(self, index):
        # Нормализации входа & тензоры
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
        ])
        x = self.load_sample(self.files[index])
        x = self._prepare_sample(x)
        x = np.array(x / 255, dtype='float32')
        x = transform(x)
        if self.mode == 'test':
            return x
        else:
            label = self.labels[index]
            label_id = self.label_encoder.transform([label])
            y = label_id.item()
            return x, y
        
    def _prepare_sample(self, image):
        image = image.resize((RESCALE_SIZE, RESCALE_SIZE))
        return np.array(image)

In [None]:
def imshow(inp, title=None, plt_ax=plt, default=False):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt_ax.imshow(inp)
    if title is not None:
        plt_ax.set_title(title)
    plt_ax.grid(False)

In [None]:
TRAIN_DIR = Path('/train')
TEST_DIR = Path('/test')

train_val_files = sorted(list(TRAIN_DIR.rglob('*.jpg')))
test_files = sorted(list(TEST_DIR.rglob('*.jpg')))

In [None]:
from sklearn.model_selection import train_test_split

train_val_labels = [path.parent.name for path in train_val_files]
train_files, val_files = train_test_split(train_val_files, test_size=0.20, stratify=train_val_labels, random_state=111)

val_dataset = Picture_Dataset(val_files, mode='val')

### Object classes

In [None]:
from collections import Counter

data_dict = dict(Counter([x.parent.name for x in train_val_files]))
data = pd.DataFrame(data = data_dict.values(), index=data_dict.keys(), columns=['count'])
plt.figure(figsize=(20,10))
sns.barplot(x = data.index, y = data['count']).set_xticklabels(data.index, rotation=90)
plt.show()

### Augmentation

In [None]:
augmenters = {
    'Crop_comb': transforms.Compose([
                                transforms.Resize(size=300, max_size=301),
                                transforms.CenterCrop(size=300),
                                transforms.RandomCrop(250)
                                ]),
    'Perspective': transforms.Compose([
        transforms.RandomPerspective(distortion_scale=0.3, p=1.0),
        transforms.Resize(size=300, max_size=301),
        transforms.CenterCrop(size=300),
        transforms.RandomCrop(250)
        ]),
                                       
    'Rotate': transforms.RandomRotation(degrees=(-25, 25)),
    'HFlip': transforms.RandomHorizontalFlip(p=1),
    'Comb1': transforms.Compose([transforms.RandomPerspective(distortion_scale=0.3, p=1.2),
                                  transforms.RandomHorizontalFlip(p=1),
                                 ]),
    'Comb2': transforms.Compose([transforms.RandomPerspective(distortion_scale=0.3, p=1.1),
                                 transforms.RandomHorizontalFlip(p=1),
                                 transforms.RandomRotation(degrees=(-25, 25)),
                                 transforms.Resize(size=300, max_size=301),
                                 transforms.CenterCrop(size=300),
                                 transforms.RandomCrop(250)
                                 ]),
    'Comb3': transforms.Compose([transforms.RandomPerspective(distortion_scale=0.2, p=1.2),
                                 transforms.RandomHorizontalFlip(p=1),
                                 transforms.RandomRotation(degrees=(-15, 15)),
                                 transforms.Resize(size=300, max_size=301),
                                 transforms.CenterCrop(size=300),
                                 transforms.RandomCrop(250),
                                 ]),
}

In [None]:
train_dataset = Picture_Dataset(train_val_files, mode='train')

fig, ax = plt.subplots(nrows=5, ncols=(len(augmenters) + 1),figsize=(10, 10))

for i in range(5):
    random_character = int(np.random.uniform(0, len(train_val_files)))
    img_orig = train_dataset.load_sample(train_val_files[random_character])
    img_label = train_val_files[random_character].parent.name
    
    ax[i][0].imshow(img_orig)
    ax[i][0].set_title(img_label)
    ax[i][0].axis('off')
        
    for j, (augmenter_name, augmenter) in enumerate(augmenters.items()):
        img_aug = augmenter(img_orig)
        ax[i][j + 1].imshow(img_aug)
        ax[i][j + 1].set_title(augmenter_name)
        ax[i][j + 1].axis('off')

### Change is_enght filter to valid value

In [None]:
is_enght = data['count'] < 1500
data.loc[is_enght, 'add'] = (1500 - data['count']).astype(int)
data.loc[~is_enght, 'add'] = 0
data['from_one_image'] = (np.ceil(data['add'] / data['count'])).astype(int)
data

In [None]:
import os

create_dir = Path('/sample')

if not os.path.isdir(create_dir):
    os.mkdir(create_dir)

proc_dataset = Picture_Dataset(train_files, mode='train')

for image_path in tqdm(train_files):
    path = image_path.parents[0]
    character = image_path.parent.name
    img = proc_dataset.load_sample(image_path)
    
    if data.loc[character]['add'] <= 0:
        continue
  
    if data.loc[character]['from_one_image'] > data.loc[character]['add']:
        iter_size = data.loc[character]['add']
    else:
        iter_size = data.loc[character]['from_one_image']
    data.loc[character]['add'] -= iter_size
    
    for i in range(int(iter_size)):
        
        parent_dir = Path('/sample')
        
        directory = character
        
        path = os.path.join(parent_dir, directory)
        
        if not os.path.isdir(path):
            os.mkdir(path)
        
        augmenter = random.choice(list(augmenters.values()))
        aug_img = augmenter(img)
        aug_img.save(f"{path}/{image_path.name.split('.')[0]}_{i}.jpg")

In [None]:
SAMPLE_DIR = Path('/sample')

In [None]:
train_sample_files = sorted(list(SAMPLE_DIR.rglob('*.jpg')))

sample_labels = [path.parent.name for path in train_val_files]

train_files.extend(train_sample_files)

In [None]:
data_dict = dict(Counter([x.parent.name for x in train_files]))
data = pd.DataFrame(data = data_dict.values(), index=data_dict.keys(), columns=['count'])
plt.figure(figsize=(20,10))
sns.barplot(x = data.index, y = data['count']).set_xticklabels(data.index, rotation=90)
plt.show()

In [None]:
train_dataset = Picture_Dataset(train_files, mode='train')