Code taken from https://github.com/radekosmulski/whale/blob/master/oversample.ipynb

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
from fastai.vision import *
from fastai.metrics import accuracy
from fastai.basic_data import *
from skimage.util import montage

import pandas as pd
from torch import optim
import re

from utils import *

In [2]:
import fastai
from fastprogress import force_console_behavior
import fastprogress
fastprogress.fastprogress.NO_BAR = True
master_bar, progress_bar = force_console_behavior()
fastai.basic_train.master_bar, fastai.basic_train.progress_bar = master_bar, progress_bar

In [3]:
MODEL_PATH = "../model/"

In [5]:
df = pd.read_csv('../data/train.csv')
df.head()

Unnamed: 0,Image,Id
0,0000e88ab.jpg,w_f48451c
1,0001f9222.jpg,w_c3d896a
2,00029d126.jpg,w_20df2c5
3,00050a15a.jpg,new_whale
4,0005c1ef8.jpg,new_whale


In [6]:
fn2label = {row[1].Image: row[1].Id for row in df.iterrows()}
path2fn = lambda path: re.search('\w*\.jpg$', path).group(0)

In [7]:
val_fns = pd.read_pickle('../data/10_val_fns')

In [8]:
SZ = 224
BS = 64
NUM_WORKERS = 16
SEED=0

In [9]:
name = '11-res50-full-train'

In [10]:
SZ = 224 * 2
BS = 64 // 4
NUM_WORKERS = 16
SEED=0

In [11]:
# with oversampling
df = pd.read_csv('../data/10_oversampled_train_and_val.csv')

In [12]:
data = (
    ImageItemList
        .from_df(df, '../data/train-extracted-448', cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in val_fns)
        .label_from_func(lambda path: fn2label[path2fn(path)])
        .add_test(ImageItemList.from_folder('../data/test-extracted-448'))
        .transform(get_transforms(do_flip=False), size=SZ, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=BS, num_workers=NUM_WORKERS, path='../data')
        .normalize(imagenet_stats)
)

In [13]:
data

ImageDataBunch;

Train: LabelList (61171 items)
x: ImageItemList
Image (3, 448, 448),Image (3, 448, 448),Image (3, 448, 448),Image (3, 448, 448),Image (3, 448, 448)
y: CategoryList
w_0003639,w_0003639,w_0003639,w_0003639,w_0003639
Path: ../data/train-extracted-448;

Valid: LabelList (15116 items)
x: ImageItemList
Image (3, 448, 448),Image (3, 448, 448),Image (3, 448, 448),Image (3, 448, 448),Image (3, 448, 448)
y: CategoryList
w_0027efa,w_00289b1,w_00289b1,w_00289b1,w_00289b1
Path: ../data/train-extracted-448;

Test: LabelList (7960 items)
x: ImageItemList
Image (3, 448, 448),Image (3, 448, 448),Image (3, 448, 448),Image (3, 448, 448),Image (3, 448, 448)
y: EmptyLabelList
,,,,
Path: ../data/train-extracted-448

In [16]:
%%time

learn = create_cnn(data, models.resnet50, lin_ftrs=[2048], model_dir=MODEL_PATH, metrics=[accuracy, map5])
learn.clip_grad();
learn.load(f'{name}-stage-6')
learn.freeze_to(-1)

CPU times: user 3.75 s, sys: 1.35 s, total: 5.1 s
Wall time: 4.37 s


In [17]:
preds, _ = learn.get_preds(DatasetType.Test)

In [18]:
preds = torch.cat((preds, torch.ones_like(preds[:, :1])), 1)

In [19]:
preds[:, 5004] = 0.06

In [20]:
classes = learn.data.classes + ['new_whale']

In [25]:
def create_submission(preds, data, name, classes=None):
    if not classes: classes = data.classes
    sub = pd.DataFrame({'Image': [path.name for path in data.test_ds.x.items]})
    sub['Id'] = top_5_pred_labels(preds, classes)
    sub.to_csv(f'../subs/{name}.csv.gz', index=False, compression='gzip')

In [26]:
create_submission(preds, learn.data, name, classes)

In [27]:
pd.read_csv(f'../subs/{name}.csv.gz').head()

Unnamed: 0,Image,Id
0,ef60d186c.jpg,w_5d81be1 w_2b50adf w_aaf5ab5 w_921cdae w_9b89c88
1,e141fd305.jpg,w_6bab2bd w_5841fb9 w_76a45de w_7e0659e new_whale
2,25045eeda.jpg,w_5d5c6a6 new_whale w_45e277d w_67a56b4 w_580ba51
3,d11ed8266.jpg,w_f765256 new_whale w_1133530 w_b6e4761 w_0815d2c
4,98e1ea193.jpg,w_a7806ad w_82fac51 w_463f219 w_51e7506 w_685b8e1


In [28]:
pd.read_csv(f'../subs/{name}.csv.gz').Id.str.split().apply(lambda x: x[0] == 'new_whale').mean()

0.0001256281407035176

In [30]:
!kaggle competitions submit -c humpback-whale-identification -f ../subs/{name}.csv.gz -m "{name}"

100%|████████████████████████████████████████| 185k/185k [00:08<00:00, 21.3kB/s]
Successfully submitted to Humpback Whale Identification