In [None]:
import numpy as np 
import pandas as pd 
import os
import matplotlib.pyplot as plt
import cv2

In [None]:
TRAIN_PATH = 'data/full/train/'
IGNORE = ['.DS_Store']

train_images = []
train_labels = []

for label in os.listdir(TRAIN_PATH):
    if label in IGNORE:
        continue
    full_path = f'{TRAIN_PATH}/{label}'

    for filename in os.listdir(full_path):
        img = cv2.imread(os.path.join(full_path, filename))
        train_images.append(img)
        train_labels.append(label)

# Some EDA

## Image Examples

In [None]:
np.random.seed(9)
a, b, c, d = np.round(np.random.uniform(0, len(train_images), 4))

fig, axs = plt.subplots(1, 2, figsize=(10, 10))
axs[0].imshow(np.flip(train_images[int(a)], axis=-1))
axs[0].set_title(train_labels[int(a)])
axs[1].imshow(np.flip(train_images[int(b)], axis=-1))
axs[1].set_title(train_labels[int(b)])
axs[0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
axs[1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

# axs[1,0].imshow(np.flip(train_images[int(c)], axis=-1))
# axs[1,0].set_title(train_labels[int(c)])
# axs[1,1].imshow(np.flip(train_images[int(d)], axis=-1))
# axs[1,1].set_title(train_labels[int(d)]);


fig.savefig('Example_of_images.png')

## Filter Examples

In [None]:
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('data/full/train/curling/') / '018.jpg')
torch.manual_seed(0)


def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False, figsize=(15, 15))
    for row_idx, row in enumerate(imgs):
        row = [orig_img] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original Image')
        axs[0, 0].title.set_size(15)
        axs[0, 1].set(title='HorizontalFlip')
        axs[0, 1].title.set_size(15)
        axs[0, 2].set(title='VerticalFlip')
        axs[0, 2].title.set_size(15)
        axs[0, 3].set(title='Equalize')
        axs[0, 3].title.set_size(15)
        axs[0, 4].set(title='Perspective')
        axs[0, 4].title.set_size(15)
        axs[0, 5].set(title='Autocontrast')
        axs[0, 5].title.set_size(15)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()
    fig.savefig('Augmentation.png')

In [None]:
eq = T.RandomEqualize(p=1)
h_flip = T.RandomHorizontalFlip(p=1)
v_flip = T.RandomVerticalFlip(p=1)
pers = T.RandomPerspective(p=1, distortion_scale=0.4)
contrast = T.RandomAutocontrast(p=1)

all_filters = [h_flip, v_flip, eq, pers, contrast]

images = [custom_filter(orig_img) for _, custom_filter in zip(range(6), all_filters)]
plot(images, row_title=None)

## Training Visualization

In [None]:
data = pd.read_csv('remote_project/stat254/summary.csv', header=None)
data.columns = ['date', 'name', 'training_time', 'accuracy']

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import glob
import os

def get_logs(path, cls):
    all_files = glob.glob(os.path.join(path, cls))
    frames = []
    for filename in all_files:
        df = pd.read_csv(filename, index_col=None, header=0)
        #df['model'] = filename.split('/')[-2]
        df['model'] = filename.split('/')[-2].split('_')[0]
        frames.append(df)
    return pd.concat(frames, axis=0, ignore_index=True)

In [None]:
path = 'remote_project/stat254/weights/'
cls = '*_native/*.csv'

data_native = get_logs(path, cls)

path = 'remote_project/stat254/weights/'
cls = '*_soft_aug/*.csv'

data_soft = get_logs(path, cls)

path = 'remote_project/stat254/weights/'
cls = '*_hard_aug/*.csv'

data_hard = get_logs(path, cls)

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15,5))

sns.lineplot(data=data_native, 
             x='epoch',
             y='valid_loss',
             hue='model',
             ax=ax[0])

sns.lineplot(data=data_soft, 
             x='epoch',
             y='valid_loss',
             hue='model',
             ax=ax[1])

sns.lineplot(data=data_hard, 
             x='epoch',
             y='valid_loss',
             hue='model',
             ax=ax[2])

ax[0].set_title('No augmentation')
ax[1].set_title('Two filters')
ax[2].set_title('Multiple filters augmentation');