In [None]:
# This is required due to this error: https://www.kaggle.com/product-feedback/279990
!pip install --user torch==1.9.0 > /dev/null 2>&1

import pandas as pd
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm
from mpl_toolkits.axes_grid1 import ImageGrid
from PIL import Image, ImageStat

from fastai.vision.all import *

# Intro

In this notebook, I train a basic species classifier on the train set and then predict the species on the test set. The classifier is using resnet34 arch with images resized to 224x224.

At the end I perform anaylsis to compare the train species distribution with test. I also look at a couple of examples of the top species across both datasets.

My hope is this will go someways to understanding why the test set appears to be quite different to the train set. See [Adversarial Validation](https://www.kaggle.com/lextoumbourou/happywhale-adversarial-validation).

# Params

In [None]:
SEED = 420
IMG_SIZE = 224
BS = 64
ARCH = resnet34
IMG_PATH_BASE = '../input/happy-whale-512'

# Prepare Data

In [None]:
train_df = pd.read_csv('../input/happy-whale-dolphin-q-a-style-eda/train_stats.csv')
test_df = pd.read_csv('../input/happy-whale-dolphin-q-a-style-eda/test_stats.csv')

In [None]:
def remove_corrupt_examples(df, dataset):
    valid_rows = []
    num = 0
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        try:
            Image.open(Path(IMG_PATH_BASE)/f'{dataset}_images'/row.image)
            valid_rows.append(row)
        except Exception as e:
            num += 1
            continue

    print(f'Found {num} corrupt examples')
    
    return pd.DataFrame(valid_rows)

In [None]:
train_df = remove_corrupt_examples(train_df, 'train')
test_df = remove_corrupt_examples(test_df, 'test')

In [None]:
train_df['image_path'] = IMG_PATH_BASE + '/train_images/' + train_df.image
test_df['image_path'] = IMG_PATH_BASE + '/test_images/' + test_df.image

In [None]:
train_df.species.value_counts()

In [None]:
datablock = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    getters=[
        ColReader('image_path'), ColReader('species')
    ],
    splitter=RandomSplitter(seed=SEED),
    item_tfms=Resize(IMG_SIZE),
    batch_tfms=aug_transforms(size=IMG_SIZE, max_rotate=30., min_scale=0.75, flip_vert=True, do_flip=True)
)

In [None]:
dls = datablock.dataloaders(source=train_df, bs=BS)

In [None]:
dls.show_batch()

# Train Model

In [None]:
def get_learner(dls, lr=1e-3):
    opt_func = partial(Adam, lr=lr, wd=0.01, eps=1e-8)

    learn = cnn_learner(
        dls, ARCH, opt_func=opt_func, metrics=[accuracy]).to_fp16()

    return learn

In [None]:
learn = get_learner(dls)

In [None]:
learn.fit_one_cycle(1)

In [None]:
learn.unfreeze()
learn.fit_one_cycle(4, slice(1e-4, 1e-3))

In [None]:
learn.save('species')

In [None]:
loss, accuracy = learn.validate()

In [None]:
print(accuracy)

In [None]:
learn.show_results(max_n=9)

# Save Test Set Predictions

In [None]:
test_dl = dls.test_dl(test_df)

In [None]:
test_preds, _ = learn.get_preds(dl=test_dl)

In [None]:
test_df['species_pred'] = [dls.vocab[i] for i in torch.argmax(test_preds, 1)]
test_df['species_prob'] = torch.max(test_preds, 1).values

test_df = test_df[['image', 'species_pred', 'species_prob']]
test_df.head()

In [None]:
test_df.to_csv('test_species.csv', index=False)

# Results

Let's compare the distribution of species predictions in the test set to train.

## Distribution

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 5))

plt.title('Train vs test species distribution')
train_df.species.value_counts().plot(kind='bar', title='Train species', ax=axes[0])
test_df.species_pred.value_counts().plot(kind='bar', title='Test species', ax=axes[1])
axes[0].bar_label(axes[0].containers[0], padding=5, rotation=45)
axes[1].bar_label(axes[1].containers[0], padding=5, rotation=45)
plt.show()

## Species Counts in Test

In [None]:
test_df.species_pred.value_counts()

## Visualise Species Across Datasets

Let's look at some example of the most common species across datasets.

In [None]:
def image_grid(images, nrows_ncols, title=None, figsize=(16, 5)):
    fig = plt.figure(figsize=figsize)
    if title:
        plt.title(title)

    grid = ImageGrid(fig, 111, nrows_ncols=nrows_ncols, axes_pad=0.1)

    for ax, im in zip(grid, images):
        ax.imshow(im)

    plt.show()


def load_images(image_ids, dataset, resize=(128, 128)):
    output = []
    for i in image_ids:
        img = Image.open(Path(f'../input/happy-whale-and-dolphin/{dataset}_images')/i)
        if resize:
            img = img.resize(resize)
            
        output.append(img)
        
    return output

### Humpback Whale

In [None]:
img_ids = list(train_df.query('species == "humpback_whale"').sample(10).image)
images = load_images(img_ids, 'train')
image_grid(images, nrows_ncols=(2, 5), figsize=(18, 8), title='Humpback Whale in Train')

In [None]:
img_ids = list(test_df.query('species_pred == "humpback_whale"').sample(10).image)
images = load_images(img_ids, 'test')
image_grid(images, nrows_ncols=(2, 5), figsize=(18, 8), title='Humpback Whale in Test')

### Bottlenose Dolphin

In [None]:
img_ids = list(train_df.query('species == "bottlenose_dolphin"').sample(10).image)
images = load_images(img_ids, 'train')
image_grid(images, nrows_ncols=(2, 5), figsize=(18, 8), title='Bottlenose Dolphin in Train')

In [None]:
img_ids = list(test_df.query('species_pred == "bottlenose_dolphin"').sample(10).image)
images = load_images(img_ids, 'test')
image_grid(images, nrows_ncols=(2, 5), figsize=(18, 8), title='Bottlenose Dolphin in Test')

### Beluga Whale

In [None]:
img_ids = list(train_df.query('species == "beluga"').sample(10).image)
images = load_images(img_ids, 'train')
image_grid(images, nrows_ncols=(2, 5), figsize=(18, 8), title='Beluga in Train')

In [None]:
img_ids = list(test_df.query('species_pred == "beluga"').sample(10).image)
images = load_images(img_ids, 'test')
image_grid(images, nrows_ncols=(2, 5), figsize=(18, 8), title='Beluga in Test')

### Blue Whale

In [None]:
img_ids = list(train_df.query('species == "blue_whale"').sample(10).image)
images = load_images(img_ids, 'train')
image_grid(images, nrows_ncols=(2, 5), figsize=(18, 8), title='Blue Whale in Train')

In [None]:
img_ids = list(test_df.query('species_pred == "blue_whale"').sample(10).image)
images = load_images(img_ids, 'test')
image_grid(images, nrows_ncols=(2, 5), figsize=(18, 8), title='Blue Whale in Test')

### False Killer Whale

In [None]:
img_ids = list(train_df.query('species == "false_killer_whale"').sample(10).image)
images = load_images(img_ids, 'train')
image_grid(images, nrows_ncols=(2, 5), figsize=(18, 8), title='False Killer Whale in Train')

In [None]:
img_ids = list(test_df.query('species_pred == "false_killer_whale"').sample(10).image)
images = load_images(img_ids, 'test')
image_grid(images, nrows_ncols=(2, 5), figsize=(18, 8), title='False Killer Whale in Test')