In [None]:
import gc
import os
import cv2
import sys
import json
import time
import timm
import torch
import tqdm
import random
import sklearn.metrics

from PIL import Image
from pathlib import Path
from functools import partial
from contextlib import contextmanager
from scipy.special import softmax

import numpy as np
import pandas as pd
import torch.nn as nn

from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset
from albumentations import Compose, Normalize, Resize
from albumentations.pytorch import ToTensorV2
from sklearn.preprocessing import LabelBinarizer

os.environ["CUDA_VISIBLE_DEVICES"]="1"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

# Loading Metadata

In [2]:
metadata = pd.read_csv("../metadata/SnakeCLEF2021_train_metadata_PROD.csv")
labels_species = metadata['binomial']
lb_species = LabelBinarizer()
lb_species.fit(np.asarray(labels_species))

print(len(metadata), len(metadata['binomial'].unique()))

386006 772


In [None]:
test_metadata = pd.read_csv("../metadata/SnakeCLEF2021_TEST_METADATA-PRIVATE.csv")
test_metadata["image_path"] = test_metadata['file_path'].apply(lambda x: "/local/nahouby/Datasets/SnakeCLEF2021-test/test-inat-post-4.2/" + x)
test_metadata.fillna('unknown', inplace=True)

test_metadata.head(5)

In [None]:
country_relevance =  pd.read_csv("../metadata/Species_Relevance.csv", sep=';')
country_relevance = country_relevance.rename(columns={'Unnamed: 0': 'binomial'})
KO_species = set(country_relevance['binomial']) - set(metadata['binomial'])
country_relevance = country_relevance[~country_relevance.binomial.isin(KO_species)]
country_relevance = country_relevance.reset_index().drop(columns=['index'])
country_relevance['class_id'] = country_relevance.apply(lambda row: np.where(lb_species.classes_ == row['binomial'])[0][0], axis=1)
country_relevance = country_relevance.sort_values(by=['class_id'])

# Loading Model

In [5]:
def getModel(architecture_name, target_size, pretrained = False):
    net = timm.create_model(architecture_name, pretrained=pretrained)
    net_cfg = net.default_cfg
    last_layer = net_cfg['classifier']
    num_ftrs = getattr(net, last_layer).in_features
    setattr(net, last_layer, nn.Linear(num_ftrs, target_size))
    return net

In [7]:
# %%
N_CLASSES = 772
MODEL_NAME = 'vit_large_patch16_384'
model = getModel(MODEL_NAME, N_CLASSES, pretrained=True)
model_mean = list(model.default_cfg['mean'])
model_std = list(model.default_cfg['std'])

model.load_state_dict(torch.load('../../SnakeCLEF2021/CKPTS/SnakeCLEF2021-ViT_large_patch16-384-FT-FL-OCLR-20E.pth'))

model.to(device)
model.eval()
print('Done.')

Done.


# Load Data etc.

In [10]:
def seed_torch(seed=777):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

    
class CustomDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_path = self.df['image_path'].values[idx]
        continent = self.df['continent'].values[idx]
        country = self.df['country'].values[idx]
        class_id = self.df['class_id'].values[idx]
        image = cv2.imread(file_path)
        try:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        except:
            print(file_path)
                    
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
                
        return image, file_path, class_id, country, continent
    
SEED = 777
seed_torch(SEED)

In [11]:
HEIGHT = 384
WIDTH = 384

from albumentations import RandomGridShuffle, CenterCrop, HueSaturationValue, RandomCrop, HorizontalFlip, VerticalFlip, RandomBrightnessContrast, CenterCrop, PadIfNeeded, RandomResizedCrop, ShiftScaleRotate, Blur, JpegCompression, RandomShadow

def get_transforms(*, data):
    assert data in ('train', 'test1', 'test2', 'test3')

    if data == 'train':
        return Compose([
            RandomResizedCrop(WIDTH, HEIGHT, scale=(0.7, 1.0)),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.25, rotate_limit=45, p=.75),
            JpegCompression(quality_lower=50, quality_upper=100),
            #RandomShadow(),
            Blur(blur_limit=2),
            RandomBrightnessContrast(p=0.3),
            HueSaturationValue(p=0.2),
            Normalize(
                mean=model_mean,
                std=model_std,
            ),
            ToTensorV2(),
        ])

    elif data == 'test1':
        return Compose([
            Resize(WIDTH, HEIGHT),
            Normalize(mean = model_mean, std = model_std),
            ToTensorV2(),
        ])
    elif data == 'test2':
        return Compose([
            PadIfNeeded(WIDTH, HEIGHT),
            Resize(int(WIDTH*1.2), int(HEIGHT*1.2)),
            CenterCrop(WIDTH, HEIGHT),
            Normalize(mean = model_mean, std = model_std),
            ToTensorV2(),
        ])
    elif data == 'test3':
        return Compose([
            PadIfNeeded(WIDTH, HEIGHT),
            Resize(int(WIDTH*1.5), int(HEIGHT*1.5)),
            CenterCrop(WIDTH, HEIGHT),
            Normalize(mean = model_mean, std = model_std),
            ToTensorV2(),
        ])

# Testing

In [12]:
BATCH_SIZE = 16
WORKERS = 8

test_dataset_1 = CustomDataset(test_metadata, transform=get_transforms(data='test1'))
test_loader_1 = DataLoader(test_dataset_1, batch_size=BATCH_SIZE, shuffle=False, num_workers=WORKERS)

test_dataset_2 = CustomDataset(test_metadata, transform=get_transforms(data='test2'))
test_loader_2 = DataLoader(test_dataset_2, batch_size=BATCH_SIZE, shuffle=False, num_workers=WORKERS)

test_dataset_3 = CustomDataset(test_metadata, transform=get_transforms(data='test3'))
test_loader_3 = DataLoader(test_dataset_3, batch_size=BATCH_SIZE, shuffle=False, num_workers=WORKERS)

### Test with Augmentations v1

In [None]:
preds = np.zeros((len(test_metadata)), dtype=np.int64)
GT_lbls = []
image_paths = []
preds_raw = []
countries = []
continents = []

for i, (images, paths, labels, counts, conts) in enumerate(tqdm.tqdm(test_loader_1, total=len(test_loader_1))):

    images = images.to(device)
    labels = labels.to(device)
    
    with torch.no_grad():
        y_preds = model(images)
        
    preds[i * BATCH_SIZE: (i+1) * BATCH_SIZE] = y_preds.argmax(1).to('cpu').numpy()
    GT_lbls.extend(labels.to('cpu').numpy())
    preds_raw.extend(y_preds.to('cpu').numpy())
    image_paths.extend(paths)
    countries.extend(counts)
    continents.extend(conts)

In [15]:
test_metadata['logits_t1'] = preds_raw
test_metadata['preds_t1'] =  [np.argmax(p) for p in preds_raw]

In [17]:
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score, classification_report

vanilla_f1 = f1_score(test_metadata['class_id'], test_metadata['preds_t1'], average='macro')
vanilla_accuracy = accuracy_score(test_metadata['class_id'], test_metadata['preds_t1'])

print('Test Augmentation 1:', np.round(vanilla_f1 * 100, 2), np.round(vanilla_accuracy * 100, 2))

Test Augmentation 1: 88.77 94.13


In [18]:
output = pd.DataFrame(zip(test_metadata.UUID, preds), columns =['UUID', 'prediction'])
output.to_csv('vanilla.csv', index=False)

### Test with Augmentations v2

In [None]:
preds = np.zeros((len(test_metadata)))
GT_lbls = []
image_paths = []
preds_raw = []
countries = []
continents = []

for i, (images, paths, labels, counts, conts) in enumerate(tqdm.tqdm(test_loader_2, total=len(test_loader_2))):

    images = images.to(device)
    labels = labels.to(device)
    
    with torch.no_grad():
        y_preds = model(images)
        
    preds[i * BATCH_SIZE: (i+1) * BATCH_SIZE] = y_preds.argmax(1).to('cpu').numpy()
    GT_lbls.extend(labels.to('cpu').numpy())
    preds_raw.extend(y_preds.to('cpu').numpy())
    image_paths.extend(paths)
    countries.extend(counts)
    continents.extend(conts)

In [21]:
test_metadata['logits_t2'] = preds_raw
test_metadata['preds_t2'] =  [np.argmax(p) for p in preds_raw]

In [22]:
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score, classification_report

vanilla_f1 = f1_score(test_metadata['class_id'], test_metadata['preds_t2'], average='macro')
vanilla_accuracy = accuracy_score(test_metadata['class_id'], test_metadata['preds_t2'])

print('Test Augmentation 2:', np.round(vanilla_f1 * 100, 2), np.round(vanilla_accuracy * 100, 2))

Test Augmentation 2: 88.68 94.55


### Test with Augmentations v3

In [None]:
preds = np.zeros((len(test_metadata)))
GT_lbls = []
image_paths = []
preds_raw = []
countries = []
continents = []

for i, (images, paths, labels, counts, conts) in enumerate(tqdm.tqdm(test_loader_3, total=len(test_loader_3))):

    images = images.to(device)
    labels = labels.to(device)
    
    with torch.no_grad():
        y_preds = model(images)
        
    preds[i * BATCH_SIZE: (i+1) * BATCH_SIZE] = y_preds.argmax(1).to('cpu').numpy()
    GT_lbls.extend(labels.to('cpu').numpy())
    preds_raw.extend(y_preds.to('cpu').numpy())
    image_paths.extend(paths)
    countries.extend(counts)
    continents.extend(conts)

In [24]:
test_metadata['logits_t3'] = preds_raw
test_metadata['preds_t3'] =  [np.argmax(p) for p in preds_raw]

In [25]:
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score, classification_report

vanilla_f1 = f1_score(test_metadata['class_id'], test_metadata['preds_t3'], average='macro')
vanilla_accuracy = accuracy_score(test_metadata['class_id'], test_metadata['preds_t3'])

print('Test Augmentation 3:', np.round(vanilla_f1 * 100, 2), np.round(vanilla_accuracy * 100, 2))

Test Augmentation 3: 88.58 93.92


### Prediction Mean

In [None]:
test_metadata['mean_softmax'] = 0

for index, row in tqdm.tqdm(test_metadata.iterrows(), total=len(test_metadata)):
    max_index =  np.argmax(sum((softmax(row.logits_t1), softmax(row.logits_t2), softmax(row.logits_t3))))
    test_metadata.at[index, 'mean_softmax'] = max_index

In [28]:
vanilla_f1 = f1_score(test_metadata['class_id'], test_metadata['mean_softmax'], average='macro')
vanilla_accuracy = accuracy_score(test_metadata['class_id'], test_metadata['mean_softmax'])

print('Mean softmax:', np.round(vanilla_f1 * 100, 2), np.round(vanilla_accuracy * 100, 2))

Mean softmax: 89.13 95.15


In [29]:
output = pd.DataFrame(zip(test_metadata.UUID, test_metadata.mean_softmax), columns =['UUID', 'prediction'])
output.to_csv('mean_softmax.csv', index=False)

# Metadata extraction

### Extracting Species distribution

In [30]:
class_priors = np.ones(len(metadata['class_id'].unique()))
for species in metadata['class_id'].unique():
    class_priors[species] = len(metadata[metadata['class_id'] == species])

class_priors = class_priors / sum(class_priors)

In [None]:
class_2_genus = {}
for _, row in tqdm.tqdm(metadata.iterrows(), total=len(metadata)):
    if row.class_id not in class_2_genus:
        class_2_genus[row.class_id] = row.genus
        
class_2_family = {}
for _, row in tqdm.tqdm(metadata.iterrows(), total=len(metadata)):
    if row.class_id not in class_2_family:
        class_2_family[row.class_id] = row.family

### Country Distribution

In [None]:
country_distributions = {}

for _, observation in tqdm.tqdm(metadata.iterrows(), total=len(metadata)):
    country = str(observation.country)
    class_id = observation.class_id
    if country not in country_distributions:        
        country_distributions[country] = np.ones(len(metadata['class_id'].unique()))
    else:
        country_distributions[country][class_id] += 1

for key, value in country_distributions.items():
    country_distributions[key] = value / sum(value)

### Continent Distribution

In [None]:
continent_distributions = {}

for _, observation in tqdm.tqdm(metadata.iterrows(), total=len(metadata)):
    continent = str(observation.continent)
    class_id = observation.class_id
    if continent not in continent_distributions:        
        continent_distributions[continent] = np.ones(len(metadata['class_id'].unique()))
    else:
        continent_distributions[continent][class_id] += 1

for key, value in continent_distributions.items():
    continent_distributions[key] = value / sum(value)

# Prior Weighting

## Binary filtration

In [35]:
test_metadata['binary_filtration'] = 0

for index, row in tqdm.tqdm(test_metadata.iterrows(), total=len(test_metadata)):

    country = row.country
    
    preds =  sum((softmax(row.logits_t1), softmax(row.logits_t2), softmax(row.logits_t3))) / 3

    if country == 'Republic of Congo':
        country = 'republic of the congo'

    if country in ['unknown', 'West Bank', 'Macau S.A.R', 'US Naval Base Guantanamo Bay', 'United States Virgin Islands', 'Guam', 'Turks and Caicos Islands', 'Cyprus No Mans Area']:
        preds = preds
    else:
        preds = preds * np.array(country_relevance[country.lower()], dtype=np.int32)

    max_index = np.argmax(preds)   
    
    test_metadata.at[index, 'binary_filtration'] = max_index

f1 = f1_score(test_metadata['class_id'], test_metadata['binary_filtration'], average='macro')
accuracy = accuracy_score(test_metadata['class_id'], test_metadata['binary_filtration'])
print('Binary Masking:')
print('F1:', np.round(f1 * 100, 2), 'Acc:', np.round(accuracy * 100, 2))
print('F1 dif:', np.round((f1-vanilla_f1) * 100, 2), 'Acc dif:', np.round((accuracy-vanilla_accuracy) * 100, 2))    

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23673/23673 [00:04<00:00, 4931.83it/s]


Binary Masking:
F1: 92.24 Acc: 96.0
F1 dif: 3.11 Acc dif: 0.84


In [36]:
output = pd.DataFrame(zip(test_metadata.UUID, test_metadata.binary_filtration), columns =['UUID', 'prediction'])
output.to_csv('masking.csv', index=False)

## Country Weighting

In [37]:
test_metadata['country_weighting'] = 0


for index, row in tqdm.tqdm(test_metadata.iterrows(), total=len(test_metadata)):

    country = row.country
    preds =  sum((softmax(row.logits_t1), softmax(row.logits_t2), softmax(row.logits_t3))) / 3
    
    if country not in country_distributions:
        country = 'unknown'
    country_dist = country_distributions[country]
    
    p_countries = (preds * country_dist) / sum(preds * country_dist)
    prior_ratio = p_countries / class_priors
    max_index = np.argmax(prior_ratio * preds)        
    
    test_metadata.at[index, 'country_weighting'] = max_index
    
f1 = f1_score(test_metadata['class_id'], test_metadata['country_weighting'], average='macro')
accuracy = accuracy_score(test_metadata['class_id'], test_metadata['country_weighting'])
print('Binary Masking:')
print('F1:', np.round(f1 * 100, 2), 'Acc:', np.round(accuracy * 100, 2))
print('F1 dif:', np.round((f1-vanilla_f1) * 100, 2), 'Acc dif:', np.round((accuracy-vanilla_accuracy) * 100, 2))    

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23673/23673 [00:06<00:00, 3553.41it/s]

Binary Masking:
F1: 92.27 Acc: 95.92
F1 dif: 3.14 Acc dif: 0.77





In [None]:
output = pd.DataFrame(zip(test_metadata.UUID, test_metadata.country_weighting), columns =['UUID', 'prediction'])
output.to_csv('country_weighting.csv', index=False)

## Continent Weighting

In [38]:
test_metadata['continent_weighting'] = 0


for index, row in tqdm.tqdm(test_metadata.iterrows(), total=len(test_metadata)):

    continent = row.continent
    preds =  sum((softmax(row.logits_t1), softmax(row.logits_t2), softmax(row.logits_t3))) / 3
    
    continent_dist = continent_distributions[continent]
    
    p_continent = (preds * continent_dist) / sum(preds * continent_dist)
    prior_ratio = p_continent / class_priors
    max_index = np.argmax(prior_ratio * preds)      
    
    test_metadata.at[index, 'continent_weighting'] = max_index
    
f1 = f1_score(test_metadata['class_id'], test_metadata['continent_weighting'], average='macro')
accuracy = accuracy_score(test_metadata['class_id'], test_metadata['continent_weighting'])
print('Binary Masking:')
print('F1:', np.round(f1 * 100, 2), 'Acc:', np.round(accuracy * 100, 2))
print('F1 dif:', np.round((f1-vanilla_f1) * 100, 2), 'Acc dif:', np.round((accuracy-vanilla_accuracy) * 100, 2))    

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23673/23673 [00:06<00:00, 3609.67it/s]

Binary Masking:
F1: 91.16 Acc: 95.15
F1 dif: 2.03 Acc dif: 0.0





In [39]:
output = pd.DataFrame(zip(test_metadata.UUID, test_metadata.continent_weighting), columns =['UUID', 'prediction'])
output.to_csv('continent_weighting.csv', index=False)

## Continent Weighting + Binary Masking

In [40]:
test_metadata['continent_binary'] = 0


for index, row in tqdm.tqdm(test_metadata.iterrows(), total=len(test_metadata)):
    country = row.country
    continent = row.continent
    preds =  sum((softmax(row.logits_t1), softmax(row.logits_t2), softmax(row.logits_t3))) / 3
    
    
    if country == 'Republic of Congo':
        country = 'republic of the congo'

    if country in ['unknown', 'West Bank', 'Macau S.A.R', 'US Naval Base Guantanamo Bay', 'United States Virgin Islands', 'Guam', 'Turks and Caicos Islands', 'Cyprus No Mans Area']:
        preds = preds
    else:
        preds = preds * np.array(country_relevance[country.lower()], dtype=np.int32)


    continent_dist = continent_distributions[continent]
    
    p_continent = (preds * continent_dist) / sum(preds * continent_dist)
    prior_ratio = p_continent / class_priors
    max_index = np.argmax(prior_ratio * preds)      
    
    test_metadata.at[index, 'continent_binary'] = max_index
    
f1 = f1_score(test_metadata['class_id'], test_metadata['continent_binary'], average='macro')
accuracy = accuracy_score(test_metadata['class_id'], test_metadata['continent_binary'])

print('Binary Masking + Continent:')
print('F1:', np.round(f1 * 100, 2), 'Acc:', np.round(accuracy * 100, 2))
print('F1 dif:', np.round((f1-vanilla_f1) * 100, 2), 'Acc dif:', np.round((accuracy-vanilla_accuracy) * 100, 2))    

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23673/23673 [00:07<00:00, 3374.14it/s]


Binary Masking + Continent:
F1: 92.08 Acc: 95.18
F1 dif: 2.95 Acc dif: 0.03


In [41]:
output = pd.DataFrame(zip(test_metadata.UUID, test_metadata.continent_binary), columns =['UUID', 'prediction'])
output.to_csv('continent_binary.csv', index=False)