Code taken from https://github.com/radekosmulski/whale/blob/master/oversample.ipynb

In [6]:
%matplotlib inline
import matplotlib.pyplot as plt
from fastai.vision import *
from fastai.metrics import accuracy
from fastai.basic_data import *
from skimage.util import montage
import pandas as pd
from torch import optim
import re

from utils import *

In [7]:
df = pd.read_csv('../data/train.csv')
df.head()

Unnamed: 0,Image,Id
0,0000e88ab.jpg,w_f48451c
1,0001f9222.jpg,w_c3d896a
2,00029d126.jpg,w_20df2c5
3,00050a15a.jpg,new_whale
4,0005c1ef8.jpg,new_whale


In [8]:
im_count = df[df.Id != 'new_whale'].Id.value_counts()
im_count.name = 'sighting_count'
df = df.join(im_count, on='Id')
val_fns = set(df.sample(frac=1)[(df.Id != 'new_whale') & (df.sighting_count > 1)].groupby('Id').first().Image)

  after removing the cwd from sys.path.


In [9]:
pd.to_pickle(val_fns, '../data/10_val_fns')
val_fns = pd.read_pickle('../data/10_val_fns')

In [12]:
fn2label = {row[1].Image: row[1].Id for row in df.iterrows()}

In [13]:
SZ = 224
BS = 64
NUM_WORKERS = 16
SEED=0

In [14]:
path2fn = lambda path: re.search('\w*\.jpg$', path).group(0)

In [15]:
df = df[df.Id != 'new_whale']

In [16]:
df.shape

(15697, 3)

In [17]:
df.sighting_count.max()

73.0

In [18]:
df_val = df[df.Image.isin(val_fns)]
df_train = df[~df.Image.isin(val_fns)]
df_train_with_val = df

In [19]:
df_val.shape, df_train.shape, df_train_with_val.shape

((2931, 3), (12766, 3), (15697, 3))

In [20]:
%%time

res = None
sample_to = 15

for grp in df_train.groupby('Id'):
    n = grp[1].shape[0]
    additional_rows = grp[1].sample(0 if sample_to < n  else sample_to - n, replace=True)
    rows = pd.concat((grp[1], additional_rows))
    
    if res is None: res = rows
    else: res = pd.concat((res, rows))

CPU times: user 14.5 s, sys: 155 µs, total: 14.5 s
Wall time: 14.5 s


In [21]:
%%time

res_with_val = None
sample_to = 15

for grp in df_train_with_val.groupby('Id'):
    n = grp[1].shape[0]
    additional_rows = grp[1].sample(0 if sample_to < n  else sample_to - n, replace=True)
    rows = pd.concat((grp[1], additional_rows))
    
    if res_with_val is None: res_with_val = rows
    else: res_with_val = pd.concat((res_with_val, rows))

CPU times: user 14.4 s, sys: 27.7 ms, total: 14.4 s
Wall time: 14.4 s


In [22]:
res.shape, res_with_val.shape

((76174, 3), (76287, 3))

In [23]:
pd.concat((res, df_val))[['Image', 'Id']].to_csv('../data/10_oversampled_train.csv', index=False)
res_with_val[['Image', 'Id']].to_csv('../data/10_oversampled_train_and_val.csv', index=False)

In [25]:
df = pd.read_csv('../data/10_oversampled_train.csv')

In [30]:
data = (
    ImageItemList
        .from_df(df[df.Id != 'new_whale'], '../data/train', cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in val_fns)
        .label_from_func(lambda path: fn2label[path2fn(path)])
        .add_test(ImageItemList.from_folder('../data/test'))
        .transform(get_transforms(do_flip=False, max_zoom=1, max_warp=0, max_rotate=2), size=SZ, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=BS, num_workers=NUM_WORKERS, path='data')
        .normalize(imagenet_stats)
)

In [31]:
data

ImageDataBunch;

Train: LabelList (76174 items)
x: ImageItemList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: CategoryList
w_0003639,w_0003639,w_0003639,w_0003639,w_0003639
Path: ../data/train;

Valid: LabelList (2931 items)
x: ImageItemList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: CategoryList
w_cb622a2,w_8dddbee,w_8a6a8d5,w_3881f28,w_cee684e
Path: ../data/train;

Test: LabelList (7960 items)
x: ImageItemList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: EmptyLabelList
,,,,
Path: ../data/train

In [None]:
val_fns = pd.read_pickle('../data/10_val_fns')