# Import modules

In [None]:
pip install timm

In [None]:
datset_count_limit = None # None for all, debug option (whole dataset is 4h per epoch)

In [None]:
import os.path
import json
import codecs
from collections import Counter
import random
from datetime import datetime
import math

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import warnings
import random
import numpy as np

from tqdm import trange
from tqdm import tqdm as tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as D
import torch.optim as optim

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

import timm

In [None]:
CUDA = "cuda:0"
CPU = "cpu"
use_cuda = torch.cuda.is_available()
device = torch.device(CUDA if use_cuda else CPU)

## Load files

In [None]:
TRAIN_PATH = "../input/herbarium-2020-fgvc7/nybg2020/train/"
TRAIN_META_PATH = "../input/herbarium-2020-fgvc7/nybg2020/train/metadata.json"

TEST_PATH = "../input/herbarium-2020-fgvc7/nybg2020/test/"
TEST_META_PATH = "../input/herbarium-2020-fgvc7/nybg2020/test/metadata.json"

SUBMISSION_PATH = '../input/herbarium-2020-fgvc7/sample_submission.csv'

WEIGHTS_PATH = '../input/fc2ep3/h_2fc_ce_ep3.pth'


OUTPUT_PATH =  '/kaggle/working'

with codecs.open(TRAIN_META_PATH, 'r', encoding='utf-8', errors='ignore') as f:
    train_meta = json.load(f)
    
with codecs.open(TEST_META_PATH, 'r', encoding='utf-8', errors='ignore') as f:
    test_meta = json.load(f)

## Merge training data

In [None]:
train_df = pd.DataFrame(train_meta['annotations'])

train_cat = pd.DataFrame(train_meta['categories'])
train_cat.columns = ['family', 'genus', 'category_id', 'category_name']

train_img = pd.DataFrame(train_meta['images'])
train_img.columns = ['file_name', 'height', 'image_id', 'license', 'width']

train_reg = pd.DataFrame(train_meta['regions'])
train_reg.columns = ['region_id', 'region_name']

train_df = train_df.merge(train_cat, on='category_id', how='outer')
train_df = train_df.merge(train_img, on='image_id', how='outer')
train_df = train_df.merge(train_reg, on='region_id', how='outer')

sample_sub = pd.read_csv(SUBMISSION_PATH)

display(train_df)

In [None]:
family_index = dict([(v, i) for i, v in enumerate(sorted(list(set(train_df['family']))))])
# print('Family count: {}'.format(len(family_index)))

cat_id_family_map = dict()
distinct_cat_df = train_df.drop_duplicates('category_id')
for n in trange(len(distinct_cat_df)):
    item = distinct_cat_df.iloc[n]
    cat_id = item['category_id']
    family_name = item['family']
    family_id = family_index[family_name]
    cat_id_family_map[cat_id] = family_id
cat_family_index = [cat_id_family_map[n] for n in sorted(list(pd.unique(train_df['category_id'])))]



cat_id_count = [0] * len(pd.unique(train_df['category_id']))
cat_id_items =[list() for i in range(len(pd.unique(train_df['category_id'])))]

fam_id_count = [0] * len(family_index)
for n in trange(len(train_df['category_id'])):
    cat_id = train_df['category_id'].values[n]
    cat_id_count[cat_id] = cat_id_count[cat_id] + 1
    
    cat_id_items[cat_id].append(n)
    
#     family_name = train_df['family'].values[n]
#     family_id = family_index[family_name]
#     fam_id_count[family_id] = fam_id_count[family_id] + 1

train_ids = []
threshold = 20
for n in trange(len(cat_id_items)):
    item_list = cat_id_items[n]
    random.shuffle(item_list)
    if len(item_list) > threshold:
        item_list = item_list[:threshold]
    train_ids.extend(item_list)
random.shuffle(train_ids)
print('train_ids', len(train_ids))


# cat_id_weights = [1 / count for count in cat_id_count]
# fam_id_weights = [1 / count for count in fam_id_count]

# print(len(cat_family_index))

## Prepare torch dataset

In [None]:
import albumentations

def make_albs(*args):
    alb_list = []
    [alb_list.extend(a) for a in args]
    transforms = albumentations.Compose(alb_list)

    def f(im):
        sample = transforms(image=im)
        return sample['image']

    return f

alb_f = make_albs(
    [
                        albumentations.HorizontalFlip(p=0.5),
                        albumentations.VerticalFlip(p=0.3),
                        albumentations.RandomBrightnessContrast(brightness_limit=0.05, contrast_limit=0.05),
                        albumentations.Rotate(limit=25, border_mode=cv2.BORDER_REFLECT),
    ])
    

DATASET_SUCCESS_STATUS = -1
DATASET_ERROR_INDEX = 0

DEFAULT_RESIZE = (480, 320)
DEFAULT_CROP = (0.1, 0.1)

class HerbariumDataset(D.Dataset):
    success = DATASET_SUCCESS_STATUS
    
    def __init__(self, data, path, crop=DEFAULT_CROP, resize=DEFAULT_RESIZE, ids=None):
        self.data = data
        self.path = path
        self.crop = crop
        self.resize = resize
        self.ids = ids
        
        self.error_index = DATASET_ERROR_INDEX
        
        self.output_labels = True
        self.output_fam = True

    def __len__(self):
        if self.ids is not None:
            return len(self.ids)
        return len(self.data)
    
    def get_image_path(self, i):
        fname = str(self.data['file_name'].values[i])
        fpath = os.path.join(self.path, fname)
        return fpath

    def __getitem__(self, i):
        try:
            if self.ids is not None:
                i = self.ids[i]
            # Load image
            fpath = self.get_image_path(i)
            image = cv2.imread(fpath)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            # Crop
            if self.crop:
                h, w, _ = image.shape
                crop_h, crop_w = self.crop
                image = image[int(crop_h*h):int(-1 * crop_h * h), int(crop_w*w):int(-1 * crop_w * w)]

            # Resize
            resize_h, resize_w = self.resize
            image = cv2.resize(image, (resize_w, resize_h))
            
            if self.ids is not None:
                image = alb_f(image)

            # Normalize
            min_value = 0.0 # np.min(image)
            max_value = np.max(image)
            image = (image - min_value) / max_value

            # Convert to Channel First
            image = np.rollaxis(image, 2, 0)

            # Get label
            label = np.array([0])
            if self.output_labels:
                label = self.data['category_id'].values[i]
            
            # Get family
            fam_id = np.array([0])
            if self.output_fam:
                fam_id = cat_id_family_map[label]

            return (torch.tensor(self.success, dtype=torch.long),  # Status
                    torch.tensor(image.copy(), dtype=torch.float), # Image
                    torch.tensor(label.copy(), dtype=torch.long),  # Cat
                    torch.tensor(fam_id, dtype=torch.long)) # Family

        except Exception as e:
            _, image, label, fam = self.__getitem__(self.error_index)
            return (torch.tensor(i, dtype=torch.long),
                    image,
                    label,
                    fam)


class HerbariumDictDataset(HerbariumDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.output_labels = False
        self.output_fam = False

    def get_image_path(self, i):
        fname = str(self.data[i])
        fpath = os.path.join(self.path, fname)
        return fpath


In [None]:
# if datset_count_limit:
#     train_data, test_data = train_test_split(train_df[:datset_count_limit])
# else:
#     train_data, test_data = train_test_split(train_df, test_size=0.125)  # To be close to submission count (139k)

# train_dataset = HerbariumDataset(train_data, TRAIN_PATH)
# test_dataset = HerbariumDataset(test_data, TRAIN_PATH)  # There should be train path, it is correct

train_data = train_df
if datset_count_limit:
    train_data = train_df[:datset_count_limit]
train_dataset = HerbariumDataset(train_data, TRAIN_PATH, ids=train_ids)

In [None]:
_, img, label, _ = train_dataset[random.randint(0, len(train_dataset))]
print(img.shape)
print(label)
plt.imshow(img.transpose(0, 2))

In [None]:
# _, img, label, _ = test_dataset[random.randint(0, len(test_dataset))]
# print(img.shape)
# print(label)
# plt.imshow(img.transpose(0, 2))

## Declare model

In [None]:
# To override exprired sertificate
import ssl
ssl._create_default_https_context = ssl._create_unverified_context


num_classes = 32094
conv_net = net = timm.create_model('resnet34', pretrained=True, num_classes=num_classes)


class TwoStageNet(torch.nn.Module):
    FAM2CAT = cat_family_index
    def __init__(self):
        super(TwoStageNet, self).__init__()
        self.fam_cat_map = torch.tensor([self.FAM2CAT])
        
        self.conv_net = conv_net
        # Remove classifer - last fc layer (make it transparent)
        del self.conv_net.fc
        self.conv_net.fc = lambda x: x
        
        self.fc_fam = nn.Linear(in_features=512,
                             out_features=310,
                             bias=True)
        
        self.fc_cat = nn.Linear(in_features=512,
                             out_features=32094,
                             bias=True)
        
        
        
    def forward(self, x):
        x = self.conv_net.forward(x)

        fx = self.fc_fam(x)
        fam_id = fx
 
        cx = self.fc_cat(x)
        cat_id = cx
        
        gather_index = torch.cat([self.fam_cat_map for i in range(fx.shape[0])]).to(device)
        fam_id_as_cat_id = torch.gather(F.softmax(fx), dim=1, index=gather_index)
        result = F.softmax(F.softmax(cat_id) * fam_id_as_cat_id)
          
        return cat_id, fam_id, result

## Training

In [None]:
################################################################################
# Metrics configuration
################################################################################
def accuracy(x, y):
    x = x.detach().cpu().numpy()
    y = y.detach().cpu().numpy()
    correct = np.sum(x == y)
    total = y.shape[0]
    return correct / total


def f1(x, y):
    x = x.detach().cpu().numpy()
    y = y.detach().cpu().numpy()
    return f1_score(y, x, average = 'macro')

In [None]:
################################################################################
# Model configuration
################################################################################
num_classes = 32094

net = TwoStageNet()
net.to(device)
net.load_state_dict(torch.load(WEIGHTS_PATH))
################################################################################
# Training configuration
################################################################################
timestamp = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")

batch_size = 64
epoch_count = 1

learning_rate = 0.00003 #0.001
decay = 1e-6
optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=decay)
# scheduler = ReduceLROnPlateau(optimizer,
#                               'min',
#                               patience=10,
#                               factor=0.8,
#                               min_lr=1e-8)


_cat_cross_entropy_loss = nn.CrossEntropyLoss() # weight=torch.tensor(cat_id_weights).to(device))
cat_criterion = lambda x, y: _cat_cross_entropy_loss(x, y.long())

_fam_cross_entropy_loss = nn.CrossEntropyLoss() # weight=torch.tensor(fam_id_weights).to(device))
fam_criterion = lambda x, y: _fam_cross_entropy_loss(x, y.long())
# criterion = torch.nn.BCEWithLogitsLoss()

################################################################################
# Data configuration
################################################################################
train_dataset = HerbariumDataset(train_data, TRAIN_PATH, ids=train_ids)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=3,
                                           shuffle=True)
################################################################################
# Run
################################################################################
warnings.filterwarnings("ignore")
# Set constants
EPOCH_KEY = 'EPOCH'
BATCH_KEY = 'BATCH'

LOSS_KEY = 'LOSS'
ACC_KEY = 'ACC'
F1_KEY = 'F1'

VAL_LOSS_KEY = 'VAL_LOSS'
VAL_ACC_KEY = 'VAL_ACC'
VAL_F1_KEY = 'VAL_F1'

LOSS_PRINT_STEP = 1
best_loss = float('inf')
history = []

train_data_errors = []

optimizer.zero_grad()
# Run training traversion throught epoches
for epoch in range(1, epoch_count+1):

    # Traverse throught batches
    train_count = len(train_loader)
    with tqdm(train_loader, position=0) as train_data_iterator:
        net.train()
        for batch_n, data in enumerate(train_data_iterator, 1):
            progress_update = False

            # Get the inputs; data is a list of [inputs, labels]
            status, inputs, labels, fam_labels = data
            if not all(status == DATASET_SUCCESS_STATUS):
                train_data_errors.append((epoch ,batch_n, status))
            inputs, labels, fam_labels = inputs.to(device), labels.to(device), fam_labels.to(device)
            
            # Zero the parameter gradients
            

            # Forward -> backward -> optimize
            cat, fam, res = net(inputs)
            
            cat_loss = cat_criterion(cat, labels)
            fam_loss = fam_criterion(fam, fam_labels)
            
            loss = cat_loss + fam_loss
            loss.backward()
            if (batch_n+1)%4 == 0:
                optimizer.step()
                optimizer.zero_grad()

            # Get statistics
            history_item = {EPOCH_KEY: epoch, BATCH_KEY: batch_n, LOSS_KEY: loss.item()}
            
            # Set for stat update if there lable-update batch
            if batch_n == 1 or batch_n % LOSS_PRINT_STEP == 0:
                progress_update = True
                
                x = torch.argmax(cat, dim=1)
                history_item['CAT_ACC'] = accuracy(x, labels).item()
                history_item['CAT_F1'] = f1(x, labels).item()
                
                x = torch.argmax(fam, dim=1)
                history_item['FAM_ACC'] = accuracy(x, fam_labels).item()
                history_item['FAM_F1'] = f1(x, fam_labels).item()
                
                x = torch.argmax(res, dim=1)
                history_item[ACC_KEY] = accuracy(x, labels).item()
                history_item[F1_KEY] = f1(x, labels).item()
            
            # On epoch end
            if batch_n == train_count:
                progress_update = True
                # Save weights
                w_path_fname = '/_weights_{}_{}.pth'.format(str(epoch), str(timestamp))
                w_path = OUTPUT_PATH + w_path_fname
                torch.save(net.state_dict(), w_path)
                # Store history to history list
                history_path = OUTPUT_PATH + '/hisotry' + timestamp + '.json'
                with open(history_path, 'w') as json_file:
                    json.dump(history, json_file)

            # Print changes to tqdm bar
            if progress_update:
                train_data_iterator.set_postfix(history_item)
            # Store history to history list
            history.append(history_item)


print('Train dataset errors: ', str(train_data_errors))
print('Training fininshed!')

In [None]:
def smooth(x,window_len=11,window='hanning'):
    if x.ndim != 1:
        raise ValueError("smooth only accepts 1 dimension arrays.")
    if x.size < window_len:
        raise ValueError("Input vector needs to be bigger than window size.")
    if window_len<3:
        return x
    if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']:
        raise ValueError("Window is on of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'")
    s=np.r_[x[window_len-1:0:-1],x,x[-2:-window_len-1:-1]]
    #print(len(s))
    if window == 'flat': #moving average
        w=np.ones(window_len,'d')
    else:
        w=eval('np.'+window+'(window_len)')
    y=np.convolve(w/w.sum(),s,mode='valid')
    return y

history_f1 = [i['F1'] for i in history]
history_loss = [i['LOSS'] for i in history]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

ax1.set_title('f1')
ax1.plot(history_f1)
ax1.plot(smooth(np.array(history_f1), window_len=51))

ax2.set_title('loss')
ax2.plot(history_loss)
ax2.plot(smooth(np.array(history_loss), window_len=51))

## Validation

In [None]:
# valid_dataset = HerbariumDataset(test_data, TRAIN_PATH)  # There should be train path, it is correct

# valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, num_workers=2,
#                                            shuffle=False)

# valid_data_errors = []
# total_valid_loss = 0.0
# xs = torch.tensor([]).to(device).long()
# ys = torch.tensor([]).to(device).long()
# with torch.no_grad():
#     net.eval()
#     with tqdm(valid_loader, position=0) as valid_data_iterator:
#         net.train()
#         for valid_batch, (valid_status, valid_indputs, valid_labels, fam_labels) in enumerate(valid_data_iterator, 1):
#             if not all(valid_status == DATASET_SUCCESS_STATUS):
#                 valid_data_errors.append((epoch , status))
#             valid_indputs, valid_labels = valid_indputs.to(device), valid_labels.to(device)

#             # Evaluate batch
#             _, _, valid_pred = net(valid_indputs)

#             # Summing loss
#             total_valid_loss += criterion(valid_pred, valid_labels)

#             # Softmax and argmax results to reduce space consumption
#             x = valid_pred
#             # x = F.softmax(x, dim=1) # In case of lenet5 softmax is inside NN
#             x = torch.argmax(x, dim=1)
#             # Cat batch eval results and gt labels
#             xs = torch.cat((xs, x.long()), dim=0)
#             ys = torch.cat((ys, valid_labels.long()), dim=0)

# # Average validation metrics and store to history)
# avg_valid_loss = total_valid_loss / len(valid_loader)
# print('Valid loss', float(avg_valid_loss))
# print('Valid acc.', accuracy(xs, ys).item())
# print('Valid f1.', f1(xs, ys).item())

## Submission

In [None]:
test_df = pd.DataFrame(test_meta['images'])
test_path_dict = dict([(i, j) for i, j in test_df[['id', 'file_name']].values])
submission_dataset = HerbariumDictDataset(test_path_dict, TEST_PATH)
# print('Test images count', len(test_path_dict))

with torch.no_grad():
    net.eval()
    for n in trange(len(submission_dataset), desc='Submission'):
        i = sample_sub.xs(n)['Id']
        _, x, _, _ = submission_dataset[i]
        x = torch.unsqueeze(x, 0)
        x = x.to(device)
        _, _, y = net(x)
        y = torch.argmax(y, dim=1)
        sample_sub.xs(n)['Predicted'] = y.item()

sample_sub.to_csv(OUTPUT_PATH + 'my_submission.csv', index=False)

In [None]:
sample_sub.to_csv(OUTPUT_PATH + '/my_submission.csv', index=False)