In [None]:
import os

import cv2
import numpy as np 
import pandas as pd 
import plotly.express as px
import matplotlib.pyplot as plt
%matplotlib inline

### Read data:

In [None]:
root_dir = '../input/happy-whale-and-dolphin/'
train_dir = os.path.join(root_dir, 'train_images')
test_dir = os.path.join(root_dir, 'test_images')
train_csv = os.path.join(root_dir, 'train.csv')
test_species_csv = '../input/happywhale-test-species/Happywhale_test_species.csv'

### Generate images paths and fix species columns:

In [None]:
test_df = pd.read_csv(test_species_csv)
test_df['path'] = test_df.image.apply(lambda x: os.path.join(test_dir, x))

train_df = pd.read_csv(train_csv)
train_df.species.replace({"globis": "short_finned_pilot_whale",
                          "pilot_whale": "short_finned_pilot_whale",
                          "kiler_whale": "killer_whale",
                          "bottlenose_dolpin": "bottlenose_dolphin"}, inplace=True)
train_df['path'] = train_df.image.apply(lambda x: os.path.join(train_dir, x))

### Predicted test species:

In [None]:
test_df.species.value_counts()

### Compare test and train distributions:

In [None]:
plt.figure(figsize=(12, 12))
plt.xticks(rotation='vertical')
plt.bar(train_df.species.value_counts().keys(), train_df.species.value_counts()/max(train_df.species.value_counts()), alpha=0.5, label='train')
plt.bar(test_df.species.value_counts().keys(), test_df.species.value_counts()/max(test_df.species.value_counts()), alpha=0.5, label='test')
_ = plt.legend()

### Some predicted images:

In [None]:
def show_images(images_paths: list, titles=None):
    fig = plt.figure(figsize=(12, len(images_paths)))
    columns = 4
    rows = len(images_paths)//4
    rows += 1 if len(images_paths) % 4 else 0
    for i, image_path in enumerate(images_paths):
        img = cv2.imread(image_path)
        if img is not None:
            img = img[...,::-1]
            fig.add_subplot(rows, columns, i+1)
            if titles is None:
                plt.title(image_path[-15:])
            else:
                plt.title(str(titles[i]))
            plt.imshow(img)

show_images(test_df.path[:20], test_df.species[:20])