# Whale/Dolphin Identification with CNNS and Pytorch_lightning

## Intro

<p>
    Although this competition asks us to identify individual animals, there are 2 other coarse-grained layers of classes:

    Whale or Dolphin
    Species
</p>

<p>
In this notebook, I will attempt to categorize images into Whale/Dolphin classes and hopefully find out more about the dataset
    
Augmentations are taken from this notebook: https://www.kaggle.com/kohjiahng/whaleeda-aug
</p>

## Imports

In [None]:
!pip install https://github.com/ufoym/imbalanced-dataset-sampler/archive/master.zip

In [None]:
import numpy as np
import pandas as pd
import os

import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 15})

# DL
import torch
from torch import nn
import pytorch_lightning as pl
from torchvision import transforms

from torchsampler import ImbalancedDatasetSampler

# Image reading
import cv2

# Splitting data
from sklearn.model_selection import train_test_split

# Augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Progress bar
from tqdm import tqdm

import wandb
from pytorch_lightning.loggers import WandbLogger

In [None]:
# Global variables
TRAIN_IMAGE_PATH = '../input/w-d-224x224-fast-dataset/train_images'
TRAIN_CSV_PATH = '../input/w-d-224x224-fast-dataset/train.csv'

# Weights and Biases

In [None]:
try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    api_key = user_secrets.get_secret("wandb_api")
    wandb.login(key=api_key)
    anony = None
except:
    anony = "must"
    print('If you want to use your W&B account, go to Add-ons -> Secrets and provide your W&B access token. Use the Label name as wandb_api. \nGet your W&B access token from here: https://wandb.ai/authorize')

## Config

In [None]:
CONFIG = {
    "seed": 2021,
    "epochs": 10,
    "img_size": 224,
    "model_name": "resnet50",
    "batch_size": 128,
    'train_size': 0.8,
    "learning_rate": 1e-5,
    "optimizer": "adam",
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
}

## Seeding

In [None]:
def set_seed(seed=42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed(CONFIG['seed'])

## Utility Functions

In [None]:
def read_image(path)->np.array: #reads image as array of shape (H,W,3)
    arr = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) / 255
    return arr

## Data Loading and Cleaning

In [None]:
df = pd.read_csv(TRAIN_CSV_PATH)
df['path'] = TRAIN_IMAGE_PATH+'/'+df['image']

In [None]:
# Fixing misspellings
df['species'] = df['species'].replace({
    'kiler_whale': 'killer_whale',
    'bottlenose_dolpin': 'bottlenose_dolphin',
    'globis': 'short_finned_pilot_whale',
    'pilot_whale': 'short_finned_pilot_whale'
})

In [None]:
WHALE_SPECIES = [
    'blue_whale',
    'brydes_whale',
    'cuviers_beaked_whale',
    'fin_whale',
    'gray_whale',
    'humpback_whale',
    'killer_whale',
    'long_finned_pilot_whale',
    'melon_headed_whale',
    'minke_whale',
    'pygmy_killer_whale',
    'sei_whale',
    'short_finned_pilot_whale',
    'southern_right_whale',
    'beluga'
]
DOLPHIN_SPECIES = [
    'false_killer_whale',
    'bottlenose_dolphin',
    'commersons_dolphin',
    'common_dolphin',
    'dusky_dolphin',
    'frasiers_dolphin',
    'pantropic_spotted_dolphin',
    'rough_toothed_dolphin',
    'spinner_dolphin',
    'spotted_dolphin',
    'white_sided_dolphin'
]

In [None]:
def species2type(species):
    if species in WHALE_SPECIES:
        return 'whale'
    elif species in DOLPHIN_SPECIES:
        return 'dolphin'
    else:
        raise Exception(f'{species} not in whale or dolphin lists')

df['type'] = df['species'].apply(species2type)

## A bit of EDA

In [None]:
n_species_by_type = df.groupby('type')['species'].agg(lambda x: x.nunique()).loc[['whale', 'dolphin']]
n_species_by_type.plot(kind='bar', figsize=(20,10))

In [None]:
n_images_by_type = df['type'].value_counts().loc[['whale', 'dolphin']]

animal_types = df.groupby('individual_id')['type'].agg(lambda x: x.iloc[0])
n_animals_by_type = animal_types.value_counts().loc[['whale', 'dolphin']]

type_numbers_df = pd.DataFrame({
    'Image': n_images_by_type,
    'Animal': n_animals_by_type,
})

type_numbers_df.plot(kind='bar',figsize = (20, 10))

plt.title('Number of images/animals of each type')
plt.ylabel('Frequency')
plt.xlabel('Type')

<p>
    Whales are the majority, oversampling will be used later to balance the dataset
</p>

## Dataset

In [None]:
data_transforms = {
    'train': A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        A.HorizontalFlip(p=0.5),
        A.RandomBrightness(0.5),
        A.RandomFog(),
        A.Rotate((0, 45)),
        ToTensorV2()
    ]),
    'val': A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])
}

In [None]:
class WhaleDataset(torch.utils.data.Dataset):
    def __init__(self, df, transforms):
        super().__init__()
        self.df = df
        self.transforms = transforms

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image = read_image(self.df['path'].iloc[idx])
        
        aug_image = self.transforms(image=image)['image']
        
        label = self.df['token_type'].iloc[idx]
        return aug_image, label
    
    def get_labels(self): # For ImbalancedDatasetSampler
        return self.df['token_type']

In [None]:
class WhaleDataModule(pl.LightningDataModule):
    def __init__(self, df, batch_size):
        super().__init__()
        self.df = df.copy()
        self.batch_size = batch_size
        
    def prepare_data(self):
        self.df = self.df[['path', 'type']]
        self.df['token_type'] = (self.df['type'] == 'whale').astype(int)
        
    def setup(self, stage):
        self.train_df, self.val_df = train_test_split(self.df, train_size = CONFIG['train_size'], random_state=0)
        
        self.train_ds = WhaleDataset(self.train_df, data_transforms['train'])
        self.val_ds = WhaleDataset(self.val_df, data_transforms['val'])
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_ds, 
                                           batch_size=self.batch_size,
                                           sampler = ImbalancedDatasetSampler(self.train_ds),
                                           pin_memory=True)
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_ds, 
                                           batch_size=self.batch_size,
                                          shuffle=False,
                                          pin_memory=True)
whale_datamodule = WhaleDataModule(df, batch_size=CONFIG['batch_size'])

## Finetuning ResNet

In [None]:
class ResNet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.net = torch.hub.load('pytorch/vision:v0.10.0', CONFIG['model_name'], pretrained=True)
        self.net.fc = nn.Sequential(
            nn.Linear(self.net.fc.in_features, 1),
            nn.Sigmoid()
        )
        self.loss_fn = nn.BCELoss()
        
    def forward(self, X):
        return self.net(X).squeeze(-1)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y = y.float()
        out = self.forward(x)
        
        loss = self.loss_fn(out, y)
        acc = ((out.detach() > 0.5) == y).float().mean()
        self.log('loss', {'train': loss}, logger=True)
        self.log('acc', {'train': acc}, logger=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y = y.float()
        
        out = self.forward(x)
        
        loss = self.loss_fn(out, y)
        acc = ((out > 0.5) == y).float().mean()
        self.log('loss', {'val': loss}, logger=True)
        self.log('acc',{'val': acc}, logger=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = CONFIG['learning_rate'])
        return optimizer

In [None]:
model = ResNet()

In [None]:
run = wandb.init(project="HappyWhale-WhaleorDolphin",
                 config = CONFIG,
                 entity="jiahng")

In [None]:
wandb.watch(model, log_freq = 100)
logger = WandbLogger()
trainer = pl.Trainer(gpus=1,max_epochs=CONFIG['epochs'],profiler='simple',logger = logger, log_every_n_steps = 10)
trainer.fit(model, datamodule = whale_datamodule)

In [None]:
torch.save(model.state_dict(), 'model.pth')

final_model_artifact = wandb.Artifact('model', type='model')
final_model_artifact.add_file('model.pth')
run.log_artifact(final_model_artifact)

In [None]:
wandb.finish()

## Results

<p> 94% validated accuracy, 0.1557 validated BCE loss </p>
<p> Train accuracy and loss are worse than validated accuracy and loss, more training might be done </p>
<p> Full Logs and Final model: https://wandb.ai/jiahng/HappyWhale-WhaleorDolphin/runs/3p812n8p </p>