## Introduction
This is notebook for submission. The discussion and introduction can be found on the first notebook <br>([Multi-head CNN for multi-label classification-pt1](https://www.kaggle.com/yunchenyang/multi-head-cnn-for-multi-label-classification-pt1))


## Source code
The codes are also available on [GitHub](https://github.com/yunchen-yang/artcv).

utils

In [None]:
from PIL import Image
import numpy as np
from torchvision.transforms import Compose, Resize, RandomResizedCrop, Normalize, ToTensor
import pandas as pd
import torch
from sklearn.metrics import fbeta_score


def label_indexer_fine(labels_dataframe):
    split_labels_dict = dict(attribute_id=[], attr_tier1=[], attr_tier2=[], attr_tier3=[])
    for i in range(labels_dataframe.shape[0]):
        tem = [item.strip() for split_list in [item_.split(';')
                                               for item_ in labels_dataframe['attribute_name'][i].split('::')]
               for item in split_list]
        split_labels_dict['attribute_id'].append(labels_dataframe['attribute_id'][i])
        split_labels_dict['attr_tier1'].append(tem[0])
        split_labels_dict['attr_tier2'].append(tem[1])
        try:
            split_labels_dict['attr_tier3'].append(tem[2])
        except:
            split_labels_dict['attr_tier3'].append('None')
    split_labels = pd.DataFrame(split_labels_dict,
                                columns=['attribute_id', 'attr_tier1', 'attr_tier2', 'attr_tier3'])

    tier1 = dict()
    tier2 = dict()
    counting_dict = dict()
    attr2indexing = dict()
    indexing2attr = dict()
    label_indexing_list = []
    for i_1, item1 in enumerate(sorted(list(set(list(split_labels['attr_tier1']))))):
        tier1[item1] = i_1
        tier2[item1] = dict()
        list_tem = sorted(list(set(list(split_labels['attr_tier2'][split_labels['attr_tier1'] == item1]))))
        for i_2, item2 in enumerate(list_tem):
            tier2[item1][item2] = i_2 + 1
        counting_dict[item1] = np.ones(len(list_tem), dtype='int')
    for idx in range(split_labels.shape[0]):
        tier1_idx = tier1[split_labels['attr_tier1'][idx]]
        tier2_idx = tier2[split_labels['attr_tier1'][idx]][split_labels['attr_tier2'][idx]]
        tier3_idx = counting_dict[split_labels['attr_tier1'][idx]][tier2_idx - 1]
        counting_dict[split_labels['attr_tier1'][idx]][tier2_idx - 1] += 1
        label_indexing_list.append([tier1_idx, tier2_idx, tier3_idx])
        attr2indexing[split_labels['attribute_id'][idx]] = [tier1_idx, tier2_idx, tier3_idx]
        indexing2attr[str([tier1_idx, tier2_idx, tier3_idx])] = split_labels['attribute_id'][idx]
    labels_indexing_df = split_labels.copy()
    labels_indexing_df['indexing'] = label_indexing_list
    return labels_indexing_df, attr2indexing, indexing2attr


def label_indexer_coarse(labels_dataframe):
    split_labels_dict = dict(attribute_id=[], attr_tier1=[], attr_tier2=[])
    for i in range(labels_dataframe.shape[0]):
        tem = [item.strip() for item in labels_dataframe['attribute_name'][i].split('::')]
        split_labels_dict['attribute_id'].append(labels_dataframe['attribute_id'][i])
        split_labels_dict['attr_tier1'].append(tem[0])
        split_labels_dict['attr_tier2'].append(tem[1])

    split_labels = pd.DataFrame(split_labels_dict,
                                columns=['attribute_id', 'attr_tier1', 'attr_tier2'])

    tier1 = dict()
    tier2 = dict()
    attr2indexing = dict()
    indexing2attr = dict()
    label_indexing_list = []
    for i_1, item1 in enumerate(sorted(list(set(list(split_labels['attr_tier1']))))):
        assert len(list(set(list(split_labels['attr_tier2'][split_labels['attr_tier1'] == item1])))) \
               == len(list(split_labels['attr_tier2'][split_labels['attr_tier1'] == item1]))
        tier1[item1] = i_1
        tier2[item1] = dict()
        list_tem = list(split_labels['attr_tier2'][split_labels['attr_tier1'] == item1])
        for i_2, item2 in enumerate(list_tem):
            tier2[item1][item2] = i_2
    for idx in range(split_labels.shape[0]):
        tier1_idx = tier1[split_labels['attr_tier1'][idx]]
        tier2_idx = tier2[split_labels['attr_tier1'][idx]][split_labels['attr_tier2'][idx]]
        label_indexing_list.append([tier1_idx, tier2_idx])
        attr2indexing[split_labels['attribute_id'][idx]] = [tier1_idx, tier2_idx]
        indexing2attr[str([tier1_idx, tier2_idx])] = split_labels['attribute_id'][idx]
    labels_indexing_df = split_labels.copy()
    labels_indexing_df['indexing'] = label_indexing_list
    return labels_indexing_df, attr2indexing, indexing2attr


def imgreader(img_id, ext, path, attr_ids, attr2indexing, length_list, dimension=256,
              task=('ml', 'ml', 'mc', 'ml', 'ml'), transform='val', grey_scale=False):
    file_path = f'{path}/{img_id}.{ext}'
    with open(file_path, 'rb') as f:
        img_ = Image.open(f)
        if grey_scale:
            img = img_.convert('L')
        else:
            img = img_.convert('RGB')

    transformer = {
        'train': Compose([RandomResizedCrop(size=(dimension, dimension)),
                          ToTensor(),
                          Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
                         ),
        'val': Compose([Resize(size=(dimension, dimension)),
                        ToTensor(),
                        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    }
    x = transformer[transform](img)
    y_list = [attr2indexing[int(attr_id)] for attr_id in attr_ids.split()]
    y_dict = labels_list2array(y_list, length_list, task)
    return x, tuple(y_dict.values())


def imgreader_test(file_path, dimension=256, transform='val', grey_scale=False):
    with open(file_path, 'rb') as f:
        img_ = Image.open(f)
        if grey_scale:
            img = img_.convert('L')
        else:
            img = img_.convert('RGB')

    transformer = {
        'train': Compose([RandomResizedCrop(size=(dimension, dimension)),
                          ToTensor(),
                          Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
                         ),
        'val': Compose([Resize(size=(dimension, dimension)),
                        ToTensor(),
                        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    }
    x = transformer[transform](img)

    return x


def counting_elements(labels_indexing_df):
    return [len(labels_indexing_df['indexing'][labels_indexing_df['attr_tier1']==catagory])
            for catagory in sorted(list(set(list(labels_indexing_df['attr_tier1']))))]


def labels_list2array(y_list, length_list, task):
    y_dict = dict()
    for i in range(len(task)):
        if task[i] != 'mc':
            y_dict[i] = torch.FloatTensor(np.zeros(length_list[i]))
        else:
            y_dict[i] = torch.LongTensor([0])

    for idx_list in y_list:
        if task[idx_list[0]] != 'mc':
            y_dict[idx_list[0]][idx_list[1]] = 1
        else:
            y_dict[idx_list[0]][0] = idx_list[1]+1

    return y_dict


def image_list_scan(data_info, indices):
    if indices is None:
        return list(data_info['id']), list(data_info['attribute_ids'])
    else:
        return list(data_info['id'][indices]), list(data_info['attribute_ids'][indices])


def f2score(ground_truth, pred, return_mean=True):
    f_beta = [fbeta_score(ground_truth[i,:], pred[i,:], beta=2) for i in range(ground_truth.shape[0])]
    if return_mean:
        return sum(f_beta)/len(f_beta)
    else:
        return f_beta

    
def regularized_pred(probs, thre, upper_bound=(3, 4, 17, 18), lower_bound=3,
                     boundary=([0, 100], [100, 781], [786, 2706], [2706, 3474])):
    thres_array = np.ones((probs.shape[1]), dtype='float')
    pred = dict()
    for i in range(len(boundary)):
        thres_array[boundary[i][0]: boundary[i][1]] = thre[i]
        probs_tem = probs[:, boundary[i][0]: boundary[i][1]]/thre[i]
        mask_tem = np.zeros(probs_tem.shape, dtype='float')
        max_args = probs_tem.argsort(axis=-1)[:,::-1][:, :upper_bound[i]]
        for i_ in range(mask_tem.shape[0]):
            mask_tem[i_, :][max_args[i_, :]] = 1
        probs_tem *= mask_tem
        probs_tem[probs_tem>=1] = 1
        probs_tem[probs_tem<1] = 0
        pred[i] = probs_tem  
    pred_array = np.concatenate((pred[0], pred[1], 
                                 probs[:, boundary[1][1]: boundary[2][0]], pred[2], pred[3]), axis=-1)
    no_label = np.where(pred_array.max(axis=-1)==0)[0]
    if no_label.shape[0] != 0:
        for idx in no_label:
            _max_args = (probs[idx, :]/thres_array).argsort(axis=-1)[::-1][:lower_bound]
            pred_array[idx, :][_max_args] = 1
    return pred_array

datatool

In [None]:
from torch.utils.data import Dataset
import sys
import math
from glob import glob


class ImgDataset(Dataset):
    def __init__(self, x, y, path, attr2indexing, length_list, task,
                 ext='png', dimension=256, transform='val', grey_scale=False):
        super().__init__()
        self.x = x
        self.y = y
        self.path = path
        self.attr2indexing =attr2indexing
        self.length_list = length_list
        self.task = task
        self.ext = ext
        self.dimension = dimension
        self.transform = transform
        self.grey_scale = grey_scale

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        img, ys = imgreader(self.x[index], self.ext, self.path, self.y[index],
                            self.attr2indexing, self.length_list,
                            self.dimension, self.task, self.transform, self.grey_scale)
        y0, y1, y2, y3, y4 = ys

        return img, y0, y1, y2, y3, y4


class ImgTestset(Dataset):
    def __init__(self, x, dimension=256, transform='val', grey_scale=False):
        super().__init__()
        self.x = x

        self.dimension = dimension
        self.transform = transform
        self.grey_scale = grey_scale

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        img = imgreader_test(self.x[index], self.dimension, self.transform, self.grey_scale)

        return img


class TrainValSet:
    def __init__(self, ext='png', path=None, indices=None, dimension=256, data_info_path=None, labels_info_path=None,
                 task=('ml', 'ml', 'mc', 'ml', 'ml'), train_transform='train', train_val_split=0.7, seed=0,
                 test_path=None, test_csv_path=None):
        super().__init__()
        self.ext = ext
        self.dimension = dimension
        self.indices = indices
        self.seed = seed
        self.train_transform = train_transform

        if path is None:
            self.path = f'{sys.path[0]}/train'
        else:
            self.path = path

        self.test_path = test_path

        if data_info_path is None:
            self.data_info_path = f'{sys.path[0]}/train.csv'
        else:
            self.data_info_path = data_info_path

        if labels_info_path is None:
            self.labels_info_path = f'{sys.path[0]}/labels.csv'
        else:
            self.labels_info_path = labels_info_path

        self.labels_info = pd.read_csv(self.labels_info_path)
        self.data_info = pd.read_csv(self.data_info_path)

        self.labels_indexing_df, self.attr2indexing, self.indexing2attr = label_indexer_coarse(self.labels_info)
        self.length_list = counting_elements(self.labels_indexing_df)
        self.X_all, self.Y_all = image_list_scan(self.data_info, indices=self.indices)
        self.task = task

        if self.test_path is not None:
            self.test_csv_path = test_csv_path
            if self.test_csv_path is not None:
                self.test_csv = pd.read_csv(self.test_csv_path)
                self.X_test = [f'{self.test_path}/{_filename}.{self.ext}' for _filename in list(self.test_csv['id'])]
            else:
                self.X_test = glob(f'{self.test_path}/*.{self.ext}', recursive=True)
            self.test = ImgTestset(self.X_test, dimension=256, transform='val', grey_scale=False)

        self.train_val_split = train_val_split

        if bool(self.train_val_split):
            self.all = ImgDataset(self.X_all, self.Y_all, self.path, self.attr2indexing, self.length_list,
                                  task=self.task, ext=self.ext, dimension=256, transform='val', grey_scale=False)
            assert(0 < self.train_val_split < 1)
            num_train = math.ceil(len(self.X_all) * self.train_val_split)
            np.random.seed(seed=self.seed)
            indices_array = np.random.permutation(len(self.X_all))
            self.X_train = [self.X_all[i] for i in indices_array[:num_train]]
            self.Y_train = [self.Y_all[i] for i in indices_array[:num_train]]
            self.train = ImgDataset(self.X_train, self.Y_train, self.path, self.attr2indexing, self.length_list,
                                    task=self.task, ext=self.ext, dimension=256,
                                    transform=self.train_transform, grey_scale=False)
            self.X_val = [self.X_all[i] for i in indices_array[num_train:]]
            self.Y_val = [self.Y_all[i] for i in indices_array[num_train:]]
            self.val = ImgDataset(self.X_val, self.Y_val, self.path, self.attr2indexing, self.length_list,
                                  task=self.task, ext=self.ext, dimension=256, transform='val', grey_scale=False)
        else:
            if self.test_path is not None:
                self.all = ImgDataset(self.X_all, self.Y_all, self.path, self.attr2indexing, self.length_list,
                                      task=self.task, ext=self.ext, dimension=256,
                                      transform=self.train_transform, grey_scale=False)
            else:
                self.all = ImgDataset(self.X_all, self.Y_all, self.path, self.attr2indexing, self.length_list,
                                      task=self.task, ext=self.ext, dimension=256, transform='val', grey_scale=False)
                

modules

In [None]:
from torchvision.models.resnet import ResNet
import torch
from torch import nn as nn
import collections
import torch.nn.functional as F


class ResNet_CNN(ResNet):
    def __init__(self, block, layers, weight_path, freeze_layers, **kwargs):
        super().__init__(block, layers, **kwargs)
        self.weight_path = weight_path
        if type(freeze_layers) == bool and freeze_layers:
            self.freeze_layers = 4
        else:
            self.freeze_layers = freeze_layers
        if self.weight_path is not None:
            self.load_state_dict(torch.load(self.weight_path))
        del self.fc
        if bool(self.freeze_layers):
            self._freeze_layers()

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x) if self.freeze_layers != 1 else self.layer1(x).detach()
        x = self.layer2(x) if self.freeze_layers != 2 else self.layer2(x).detach()
        x = self.layer3(x) if self.freeze_layers != 3 else self.layer3(x).detach()
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        if self.freeze_layers != 4:
            return x
        else:
            return x.detach()

    def _freeze_layers(self):
        _bool = False
        for name, module in self.named_children():
            if name == f'layer{self.freeze_layers+1}' or _bool:
                _bool = True
            for p in module.parameters():
                p.requires_grad = _bool


class Classifier(nn.Module):
    def __init__(self, dim_in, dim_out, dim_hidden, n_layers, task='ml', use_batch_norm=True, dropout_rate=0.01):
        super().__init__()
        dims = [dim_in] + [dim_hidden]*(n_layers-1) + [dim_out]
        self.task = task
        self.use_batch_norm = use_batch_norm
        self.dropout_rate = dropout_rate
        self.classifier = nn.Sequential(collections.OrderedDict(
            [('Layer {}'.format(i), nn.Sequential(
                nn.Linear(n_in, n_out),
                nn.BatchNorm1d(n_out, momentum=.01, eps=0.001) if self.use_batch_norm else None,
                nn.ReLU() if i < len(dims)-2 else None,
                nn.Dropout(p=self.dropout_rate) if self.dropout_rate > 0 else None))
             for i, (n_in, n_out) in enumerate(zip(dims[:-1], dims[1:]))]))

    def get_logits(self, x):
        for layers in self.classifier:
            for layer in layers:
                if layer is not None:
                    x = layer(x)
        return x

    def forward(self, x):
        if self.task == 'mc':
            return F.softmax(self.get_logits(x), dim=-1)
        elif self.task == 'ml':
            return torch.sigmoid(self.get_logits(x))
        else:
            raise ValueError("The task tag must be either 'ml' (multi-label) or 'mc' (multi-class)!")

groups information

In [None]:
_label_groups0 = None
_label_groups1 = (3, 6, 3, 4, 7, 7, 11, 11, 11, 11, 11, 11, 11, 2, 4, 9, 9, 5, 3, 11, 11, 11, 11, 11, 11, 2, 12, 13, 7, 10, 12, 
                  5, 3, 0, 10, 8, 5, 2, 7, 1, 4, 10, 13, 11, 12, 12, 12, 5, 11, 4, 10, 3, 9, 0, 4, 10, 10, 13, 14, 5, 1, 6, 9, 
                  12, 12, 8, 3, 5, 12, 8, 1, 9, 10, 3, 3, 11, 3, 3, 9, 11, 6, 11, 11, 11, 11, 11, 9, 12, 9, 5, 8, 7, 8, 10, 6, 7, 
                  4, 11, 13, 3, 5, 14, 5, 10, 10, 13, 13, 13, 13, 11, 9, 3, 9, 9, 11, 7, 6, 7, 11, 11, 11, 11, 8, 8, 0, 8, 6, 6, 
                  8, 14, 13, 13, 2, 13, 12, 5, 4, 13, 14, 14, 4, 3, 10, 4, 4, 5, 10, 5, 5, 1, 12, 9, 9, 13, 2, 12, 9, 11, 11, 13, 
                  13, 13, 13, 13, 13, 14, 8, 7, 9, 9, 2, 7, 4, 4, 10, 10, 10, 11, 11, 11, 13, 8, 11, 3, 3, 5, 11, 3, 3, 13, 11, 
                  11, 11, 6, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 5, 14, 4, 4, 12, 11, 6, 11, 11, 11, 11, 11, 11, 11, 
                  8, 10, 2, 12, 3, 5, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 13, 10, 10, 10, 10, 9, 14, 14, 14, 8, 8, 0, 8, 9, 
                  9, 3, 13, 10, 4, 5, 6, 9, 10, 2, 11, 10, 0, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 6, 6, 6, 7, 0, 0, 7, 6, 
                  11, 6, 10, 7, 4, 12, 11, 6, 6, 6, 6, 11, 11, 11, 11, 11, 11, 11, 11, 13, 11, 10, 7, 6, 4, 6, 13, 7, 2, 6, 7, 
                  14, 6, 9, 10, 10, 3, 3, 3, 12, 12, 9, 3, 9, 11, 9, 9, 9, 4, 9, 8, 5, 11, 6, 8, 8, 2, 11, 7, 8, 11, 4, 11, 4, 
                  14, 14, 3, 4, 4, 6, 4, 7, 6, 2, 5, 12, 2, 3, 12, 4, 1, 3, 3, 3, 9, 8, 11, 6, 6, 3, 14, 10, 0, 7, 6, 12, 3, 11, 
                  13, 6, 13, 13, 13, 13, 13, 13, 13, 9, 13, 13, 13, 13, 13, 5, 12, 12, 13, 0, 13, 13, 0, 10, 10, 12, 3, 7, 6, 6, 
                  3, 5, 12, 4, 10, 9, 5, 14, 14, 14, 14, 14, 0, 0, 0, 3, 7, 0, 4, 9, 11, 4, 5, 12, 12, 12, 6, 13, 4, 13, 0, 14, 
                  1, 8, 14, 10, 5, 3, 11, 3, 11, 11, 11, 12, 14, 4, 11, 7, 3, 8, 5, 4, 10, 3, 11, 9, 0, 3, 3, 8, 9, 12, 13, 13, 
                  13, 13, 13, 6, 13, 11, 13, 1, 13, 13, 13, 11, 13, 13, 14, 11, 11, 11, 4, 7, 9, 9, 12, 9, 3, 12, 2, 8, 8, 0, 9, 
                  5, 11, 11, 12, 5, 5, 11, 11, 1, 2, 4, 7, 6, 7, 1, 1, 1, 12, 4, 8, 6, 7, 6, 11, 0, 6, 2, 3, 4, 5, 3, 3, 1, 3, 
                  13, 13, 3, 3, 8, 14, 3, 3, 3, 6, 11, 11, 11, 11, 14, 3, 3, 12, 7, 6, 8, 3, 1, 4, 13, 13, 13, 9, 9, 2, 12, 15, 
                  15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 
                  15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 
                  15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 
                  15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15)

_label_groups3_40 = (27, 19, 17, 17, 7, 23, 8, 26, 23, 30, 18, 18, 18, 16, 16, 27, 16, 27, 34, 14, 14, 27, 2, 24, 27,
                     8, 31, 0, 9, 9, 27, 27, 34, 27, 27, 38, 32, 1, 1, 27, 20, 27, 20, 17, 33, 29, 8, 31, 36, 9, 25,
                     32, 21, 20, 9, 19, 38, 8, 20, 20, 7, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 22, 11, 2, 2, 2, 14, 2, 22, 2,
                     2, 12, 14, 11, 14, 2, 2, 2, 2, 2, 2, 2, 20, 11, 5, 1, 1, 27, 27, 8, 37, 34, 34, 34, 34, 34, 34, 37,
                     37, 37, 37, 37, 37, 37, 34, 34, 34, 34, 37, 24, 36, 34, 24, 30, 30, 24, 24, 11, 24, 30, 27, 6, 6,
                     9, 35, 35, 30, 30, 35, 35, 35, 19, 27, 27, 9, 38, 23, 34, 28, 38, 38, 38, 38, 38, 31, 5, 28, 32,
                     28, 32, 32, 38, 38, 24, 27, 28, 28, 28, 28, 24, 11, 11, 24, 11, 2, 12, 11, 14, 24, 11, 2, 11, 11,
                     11, 12, 11, 11, 11, 11, 11, 12, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 11, 11, 12, 11, 11, 11,
                     14, 11, 11, 11, 11, 6, 6, 6, 5, 5, 5, 9, 28, 1, 28, 5, 3, 19, 31, 23, 31, 25, 0, 28, 14, 28, 14,
                     27, 14, 27, 3, 23, 9, 9, 16, 13, 13, 8, 32, 8, 26, 5, 28, 28, 28, 20, 20, 33, 28, 33, 27, 20, 2,
                     20, 24, 33, 28, 20, 33, 28, 32, 38, 27, 27, 38, 38, 38, 26, 26, 20, 26, 36, 36, 30, 8, 8, 8, 22, 2,
                     36, 36, 22, 19, 20, 31, 31, 20, 27, 27, 27, 0, 31, 31, 31, 27, 14, 5, 8, 27, 26, 23, 30, 30, 30,
                     39, 36, 36, 10, 32, 3, 3, 24, 16, 23, 16, 14, 16, 12, 16, 24, 23, 16, 14, 11, 7, 23, 34, 34, 7, 7,
                     16, 16, 16, 19, 16, 37, 20, 38, 23, 23, 23, 38, 38, 38, 23, 23, 38, 38, 32, 38, 28, 23, 23, 23, 23,
                     38, 16, 17, 17, 35, 37, 4, 25, 25, 10, 13, 21, 4, 4, 10, 4, 4, 4, 4, 15, 5, 15, 7, 14, 14, 12, 25,
                     34, 23, 38, 38, 27, 27, 21, 34, 34, 14, 27, 14, 14, 27, 16, 1, 27, 24, 24, 12, 24, 24, 19, 23, 5,
                     32, 32, 32, 32, 32, 27, 27, 27, 27, 31, 16, 10, 33, 0, 27, 27, 27, 1, 1, 14, 15, 36, 39, 36, 30,
                     30, 30, 30, 30, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 20, 33, 2, 26, 24, 35, 38, 27, 5, 28,
                     21, 21, 21, 21, 21, 21, 8, 8, 36, 38, 38, 36, 36, 28, 38, 34, 32, 36, 28, 34, 27, 1, 1, 1, 1, 27,
                     1, 14, 1, 19, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 1, 1, 1, 1, 14, 14, 1, 27, 27, 1, 1, 1, 1, 14,
                     14, 38, 1, 1, 1, 1, 27, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 21, 39, 30, 36, 18, 34, 15, 15, 4, 25, 28,
                     19, 19, 0, 23, 23, 3, 31, 21, 27, 13, 18, 4, 25, 12, 12, 12, 34, 34, 27, 34, 34, 37, 27, 34, 26,
                     27, 25, 31, 8, 8, 34, 8, 15, 3, 32, 27, 8, 8, 8, 7, 28, 38, 28, 28, 28, 28, 28, 28, 28, 28, 28, 20,
                     20, 28, 33, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 34, 23, 23, 34, 34, 34, 34,
                     34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 39, 30, 36, 23, 36, 36, 36, 36, 36, 36, 7, 32, 23,
                     23, 23, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 28, 32, 32, 32, 32, 32, 38, 14, 14, 32, 14, 32, 32,
                     32, 13, 32, 32, 16, 16, 14, 14, 3, 26, 31, 29, 22, 22, 12, 11, 22, 14, 14, 14, 14, 14, 14, 14, 14,
                     14, 14, 22, 2, 14, 14, 14, 14, 14, 14, 22, 3, 24, 12, 11, 31, 24, 11, 14, 24, 24, 34, 36, 24, 34,
                     24, 30, 30, 24, 24, 31, 24, 24, 24, 31, 11, 31, 6, 31, 3, 12, 0, 0, 26, 26, 0, 32, 2, 16, 2, 27,
                     27, 27, 1, 27, 27, 27, 27, 39, 39, 39, 13, 13, 30, 30, 23, 30, 27, 27, 2, 23, 2, 2, 2, 2, 8, 8, 4,
                     20, 0, 2, 2, 2, 23, 32, 38, 14, 12, 3, 17, 23, 5, 19, 16, 16, 1, 1, 16, 33, 36, 36, 13, 19, 12, 29,
                     29, 4, 31, 12, 12, 12, 12, 14, 14, 12, 11, 14, 12, 12, 12, 12, 14, 12, 14, 12, 14, 14, 28, 2, 28,
                     35, 33, 28, 32, 23, 35, 1, 38, 38, 38, 38, 38, 38, 32, 17, 23, 17, 38, 23, 23, 23, 14, 17, 35, 38,
                     17, 8, 23, 23, 32, 8, 8, 12, 32, 35, 0, 9, 6, 36, 36, 35, 35, 6, 7, 14, 2, 9, 33, 33, 32, 34, 34,
                     34, 34, 19, 27, 27, 27, 8, 8, 19, 36, 22, 30, 23, 30, 28, 34, 34, 22, 13, 23, 13, 16, 24, 24, 24,
                     31, 12, 22, 31, 27, 31, 10, 10, 13, 21, 10, 10, 10, 21, 21, 21, 10, 10, 10, 10, 10, 10, 16, 12, 14,
                     1, 16, 14, 16, 16, 1, 27, 21, 30, 30, 39, 27, 5, 31, 9, 9, 20, 30, 8, 12, 20, 26, 26, 26, 33, 35,
                     16, 32, 23, 12, 38, 31, 38, 38, 38, 18, 14, 38, 38, 13, 13, 13, 13, 13, 13, 13, 38, 38, 12, 13, 21,
                     13, 13, 25, 5, 1, 1, 2, 1, 1, 7, 16, 6, 34, 34, 1, 19, 14, 37, 37, 37, 16, 15, 26, 22, 8, 27, 30,
                     32, 27, 27, 0, 26, 3, 15, 35, 2, 0, 27, 6, 21, 8, 23, 38, 32, 32, 27, 34, 27, 11, 20, 20, 20, 33,
                     20, 27, 20, 8, 27, 14, 14, 38, 34, 14, 20, 27, 14, 14, 33, 11, 14, 24, 24, 14, 14, 14, 14, 14, 14,
                     14, 14, 38, 14, 14, 14, 27, 14, 21, 14, 14, 27, 19, 16, 19, 8, 32, 23, 34, 34, 34, 34, 7, 16, 14,
                     14, 14, 24, 27, 27, 6, 24, 36, 27, 33, 2, 22, 2, 2, 23, 23, 2, 22, 22, 22, 19, 22, 19, 19, 2, 2,
                     22, 22, 22, 22, 22, 19, 22, 11, 23, 11, 33, 28, 28, 28, 28, 33, 27, 36, 38, 32, 28, 36, 26, 26, 30,
                     20, 23, 33, 20, 27, 14, 27, 14, 27, 14, 14, 14, 21, 7, 7, 16, 12, 12, 28, 28, 14, 23, 14, 39, 39,
                     27, 28, 28, 36, 28, 38, 36, 32, 28, 28, 27, 7, 7, 16, 27, 14, 28, 26, 37, 37, 15, 32, 32, 9, 34,
                     12, 12, 2, 12, 12, 2, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
                     12, 12, 12, 12, 12, 12, 12, 14, 12, 12, 12, 12, 12, 12, 11, 12, 28, 27, 27, 27, 27, 16, 14, 31, 28,
                     7, 7, 7, 5, 23, 20, 33, 20, 3, 24, 11, 23, 35, 26, 26, 26, 33, 37, 38, 1, 32, 32, 14, 19, 19, 19,
                     19, 19, 38, 39, 39, 36, 33, 16, 16, 23, 16, 28, 28, 26, 26, 26, 33, 33, 20, 28, 36, 30, 36, 36, 36,
                     36, 30, 30, 30, 36, 30, 30, 30, 36, 30, 30, 30, 30, 30, 30, 30, 36, 30, 23, 23, 23, 30, 30, 36, 26,
                     15, 31, 19, 24, 38, 27, 38, 34, 34, 34, 5, 39, 23, 23, 23, 12, 30, 12, 12, 32, 34, 27, 34, 34, 34,
                     34, 3, 34, 34, 34, 24, 34, 34, 16, 16, 16, 27, 12, 27, 12, 27, 27, 19, 19, 38, 39, 19, 19, 23, 6,
                     6, 6, 6, 6, 25, 9, 8, 31, 5, 5, 9, 5, 0, 23, 24, 2, 2, 32, 20, 2, 23, 2, 30, 34, 12, 24, 24, 20,
                     20, 24, 24, 30, 11, 39, 22, 3, 23, 38, 38, 5, 37, 15, 23, 21, 28, 27, 32, 32, 34, 34, 27, 27, 4,
                     13, 32, 32, 11, 9, 9, 22, 18, 8, 8, 19, 19, 27, 30, 14, 14, 14, 20, 31, 22, 31, 8, 5, 32, 6, 13,
                     29, 31, 0, 19, 19, 19, 32, 36, 30, 32, 8, 6, 29, 29, 31, 27, 27, 0, 5, 38, 32, 17, 23, 17, 17, 7,
                     17, 36, 36, 6, 13, 25, 6, 21, 6, 21, 21, 21, 32, 13, 21, 21, 10, 13, 13, 13, 13, 13, 13, 13, 13,
                     13, 13, 21, 13, 13, 21, 21, 6, 21, 21, 21, 6, 21, 21, 21, 21, 21, 21, 21, 21, 6, 6, 6, 23, 21, 6,
                     6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 32, 23, 38, 2, 32, 32, 32, 32, 32, 32, 28, 28, 32,
                     32, 32, 38, 14, 38, 32, 32, 32, 32, 13, 32, 32, 28, 28, 32, 32, 32, 32, 28, 23, 28, 38, 38, 0, 26,
                     22, 13, 37, 5, 9, 27, 26, 27, 20, 30, 30, 27, 27, 5, 5, 24, 20, 2, 5, 34, 34, 19, 19, 23, 38, 38,
                     35, 1, 1, 1, 5, 26, 26, 26, 26, 8, 39, 23, 30, 30, 30, 36, 30, 30, 23, 30, 39, 27, 3, 3, 33, 18,
                     19, 28, 26, 23, 23, 23, 26, 26, 26, 19, 5, 35, 7, 0, 0, 3, 5, 25, 23, 21, 9, 0, 7, 14, 33, 14, 14,
                     33, 7, 10, 26, 26, 26, 5, 25, 23, 23, 27, 19, 19, 20, 13, 19, 19, 26, 38, 36, 30, 23, 30, 30, 23,
                     30, 38, 38, 15, 1, 27, 35, 23, 5, 27, 8, 17, 5, 22, 22, 22, 22, 22, 22, 19, 16, 27, 16, 27, 33, 27,
                     36, 7, 14, 7, 14, 23, 16, 3, 20, 8, 17, 19, 27, 30, 10, 39, 27, 19, 39, 26, 18, 32, 27, 7, 9, 3,
                     16, 14, 6, 13, 9, 5, 24, 24, 7, 11, 20, 20, 9, 17, 11, 11, 14, 11, 16, 16, 14, 14, 27, 14, 14, 14,
                     22, 22, 22, 16, 7, 14, 26, 25, 25, 25, 35, 17, 5, 2, 24, 34, 2, 2, 2, 20, 2, 30, 34, 2, 2, 2, 2, 2,
                     2, 20, 20, 30, 30, 2, 24, 24, 24, 15, 38, 33, 23, 23, 23, 23, 23, 23, 23, 33, 33, 33, 33, 33, 33,
                     33, 12, 33, 33, 33, 33, 33, 33, 33, 16, 16, 16, 23, 16, 16, 23, 16, 1, 16, 16, 33, 25, 23, 4, 10,
                     21, 21, 25, 21, 4, 10, 4, 25, 25, 25, 25, 25, 25, 25, 25, 25, 21, 21, 32, 38, 38, 24, 24, 24, 24,
                     20, 20, 24, 11, 20, 38, 34, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40,
                     40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40,
                     40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40)

_label_groups3 = (36, 37, 21, 21, 8, 28, 9, 32, 28, 25, 23, 23, 23, 26, 26, 22, 26, 22, 26, 22, 22, 26, 10, 42, 33, 9,
                  34, 5, 19, 4, 33, 33, 42, 33, 33, 36, 41, 1, 1, 18, 29, 27, 29, 21, 29, 21, 9, 34, 7, 29, 3, 41, 15,
                  29, 4, 20, 36, 9, 29, 37, 8, 10, 11, 10, 10, 10, 10, 10, 32, 10, 10, 10, 10, 10, 12, 12, 12, 12, 10,
                  10, 10, 11, 12, 11, 12, 10, 37, 25, 10, 10, 10, 10, 10, 5, 4, 6, 6, 37, 37, 9, 42, 42, 42, 42, 42,
                  42, 42, 42, 42, 42, 42, 42, 42, 37, 42, 42, 42, 42, 42, 10, 39, 42, 39, 39, 39, 10, 30, 5, 16, 37,
                  37, 13, 13, 7, 43, 43, 45, 45, 43, 43, 43, 27, 23, 23, 4, 47, 28, 42, 24, 47, 47, 47, 47, 47, 34,
                  4, 48, 41, 24, 41, 41, 47, 47, 48, 27, 33, 24, 48, 48, 10, 5, 5, 10, 5, 10, 11, 11, 12, 10, 5, 5, 5,
                  5, 5, 11, 5, 5, 5, 5, 5, 11, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 11, 5, 5, 11, 5, 11, 5, 12, 5, 5, 5, 5, 13,
                  13, 13, 4, 1, 1, 29, 48, 1, 48, 1, 5, 27, 34, 28, 32, 3, 5, 48, 22, 37, 14, 37, 14, 37, 3, 28, 25, 4,
                  16, 14, 14, 9, 41, 9, 32, 0, 19, 25, 25, 25, 25, 25, 24, 25, 25, 25, 10, 19, 19, 25, 19, 19, 19, 19,
                  41, 23, 37, 37, 47, 48, 43, 38, 45, 29, 38, 45, 45, 39, 9, 9, 9, 16, 33, 48, 48, 16, 27, 29, 34, 34,
                  29, 33, 33, 33, 5, 34, 34, 34, 17, 12, 1, 9, 30, 45, 28, 31, 31, 7, 49, 48, 48, 15, 41, 3, 3, 10, 26,
                  28, 26, 12, 26, 18, 26, 10, 28, 26, 12, 5, 8, 42, 44, 42, 8, 8, 26, 26, 26, 27, 26, 18, 29, 47, 28,
                  28, 28, 47, 47, 47, 28, 28, 47, 47, 41, 47, 24, 28, 28, 28, 28, 47, 30, 21, 21, 40, 29, 15, 15, 28,
                  15, 14, 15, 15, 15, 15, 15, 15, 15, 15, 4, 4, 13, 16, 12, 12, 11, 13, 42, 28, 47, 36, 17, 25, 13, 42,
                  42, 22, 22, 22, 22, 37, 1, 1, 26, 10, 10, 10, 10, 10, 27, 28, 4, 41, 41, 9, 41, 41, 17, 17, 17, 17,
                  34, 26, 12, 25, 5, 22, 12, 22, 1, 1, 12, 15, 35, 49, 39, 39, 39, 45, 45, 45, 49, 49, 49, 49, 49, 49,
                  49, 49, 49, 49, 49, 29, 29, 25, 32, 10, 35, 47, 33, 34, 48, 13, 13, 13, 13, 13, 13, 9, 9, 48, 47, 47,
                  48, 48, 24, 48, 44, 48, 48, 6, 44, 22, 6, 6, 6, 6, 18, 1, 12, 6, 20, 6, 6, 6, 6, 6, 6, 6, 1, 6, 6, 18,
                  18, 6, 6, 18, 6, 18, 18, 6, 18, 18, 6, 6, 6, 1, 12, 12, 18, 6, 6, 18, 6, 18, 18, 6, 18, 7, 6, 6, 6, 6,
                  6, 1, 15, 49, 39, 45, 23, 44, 21, 21, 15, 3, 48, 20, 20, 5, 28, 28, 3, 32, 13, 23, 23, 23, 15, 15, 11,
                  11, 11, 42, 37, 42, 42, 37, 42, 37, 42, 32, 46, 5, 34, 9, 9, 42, 9, 48, 48, 26, 42, 9, 9, 9, 0, 24,
                  47, 25, 25, 48, 48, 47, 24, 0, 25, 25, 29, 29, 41, 25, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
                  24, 24, 24, 44, 28, 28, 42, 44, 44, 42, 42, 42, 44, 42, 37, 44, 44, 44, 44, 44, 44, 44, 49, 39, 39,
                  28, 39, 39, 39, 39, 39, 39, 49, 41, 28, 28, 28, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 24, 41, 41,
                  41, 41, 41, 47, 22, 22, 41, 12, 41, 41, 41, 41, 41, 41, 16, 16, 12, 12, 3, 32, 34, 8, 16, 16, 11, 5,
                  16, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 16, 5, 10, 11, 11, 34, 10,
                  5, 12, 10, 10, 42, 39, 34, 42, 39, 39, 39, 10, 10, 34, 10, 10, 10, 34, 5, 34, 13, 34, 2, 2, 17, 17,
                  32, 32, 5, 41, 10, 26, 2, 37, 18, 18, 18, 18, 18, 26, 37, 0, 0, 0, 23, 23, 31, 31, 31, 31, 37, 17,
                  10, 28, 10, 16, 12, 10, 9, 9, 3, 29, 5, 10, 10, 10, 37, 41, 36, 22, 11, 29, 21, 28, 4, 27, 30, 18, 18,
                  18, 18, 18, 35, 35, 35, 27, 11, 19, 19, 15, 32, 11, 11, 11, 11, 22, 22, 11, 11, 12, 11, 11, 11, 11,
                  12, 11, 12, 11, 12, 12, 19, 25, 25, 25, 25, 25, 41, 28, 43, 1, 43, 47, 43, 43, 43, 43, 41, 21, 28, 21,
                  36, 28, 28, 28, 22, 21, 43, 36, 21, 9, 28, 28, 41, 9, 9, 11, 41, 35, 27, 4, 13, 35, 35, 35, 35, 13,
                  25, 12, 10, 25, 29, 25, 41, 42, 42, 42, 42, 27, 27, 12, 27, 9, 9, 27, 39, 16, 39, 28, 39, 48, 44, 44,
                  16, 14, 28, 14, 26, 10, 10, 10, 32, 11, 33, 32, 33, 34, 15, 15, 14, 15, 15, 15, 40, 15, 13, 15, 15,
                  14, 15, 15, 15, 15, 26, 11, 12, 1, 26, 12, 26, 26, 6, 37, 13, 31, 31, 49, 17, 4, 34, 29, 29, 29, 39,
                  9, 11, 29, 32, 32, 32, 33, 35, 26, 41, 28, 5, 43, 34, 36, 36, 36, 36, 22, 36, 36, 36, 14, 14, 14, 36,
                  36, 36, 36, 36, 11, 14, 14, 14, 14, 14, 7, 1, 1, 10, 1, 1, 8, 26, 13, 44, 44, 18, 33, 22, 36, 33, 36,
                  30, 25, 38, 33, 9, 37, 39, 41, 23, 23, 5, 38, 29, 8, 35, 37, 5, 33, 13, 13, 9, 28, 47, 41, 41, 37, 42,
                  42, 5, 29, 29, 29, 25, 29, 33, 19, 9, 22, 22, 22, 22, 22, 22, 29, 22, 22, 22, 22, 22, 22, 10, 10, 12,
                  12, 12, 12, 12, 12, 12, 12, 47, 12, 12, 12, 25, 12, 13, 12, 12, 23, 27, 26, 20, 9, 41, 28, 42, 44, 44,
                  42, 42, 16, 22, 12, 12, 10, 37, 37, 13, 10, 35, 33, 27, 10, 33, 10, 10, 28, 28, 12, 10, 16, 16, 27,
                  16, 27, 27, 10, 10, 33, 33, 33, 33, 33, 33, 33, 38, 28, 38, 25, 25, 25, 25, 25, 25, 23, 45, 47, 41,
                  24, 45, 32, 45, 45, 25, 28, 25, 25, 23, 22, 11, 22, 22, 22, 0, 22, 14, 0, 0, 0, 11, 11, 24, 24, 0, 28,
                  22, 49, 49, 23, 24, 25, 39, 25, 47, 48, 48, 25, 24, 27, 31, 9, 16, 16, 12, 48, 38, 36, 42, 21, 41, 9,
                  4, 44, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,
                  11, 11, 11, 11, 11, 11, 11, 11, 12, 11, 11, 11, 11, 11, 11, 11, 16, 48, 22, 22, 22, 22, 26, 12, 34,
                  25, 8, 8, 8, 4, 28, 29, 33, 29, 3, 10, 5, 28, 43, 38, 38, 38, 25, 36, 36, 18, 41, 30, 12, 27, 27, 27,
                  27, 27, 47, 49, 49, 28, 45, 26, 28, 28, 30, 25, 24, 49, 45, 45, 25, 25, 29, 24, 45, 45, 45, 45, 45,
                  45, 45, 42, 31, 39, 31, 39, 39, 45, 31, 39, 31, 31, 31, 31, 45, 45, 45, 28, 28, 28, 39, 31, 45, 45,
                  21, 34, 27, 33, 43, 37, 36, 42, 42, 42, 4, 45, 28, 28, 28, 11, 39, 45, 45, 41, 42, 37, 42, 42, 42, 42,
                  3, 42, 42, 42, 42, 42, 42, 30, 30, 30, 23, 11, 30, 11, 23, 23, 27, 27, 47, 49, 20, 33, 28, 7, 7, 7, 7,
                  7, 13, 4, 9, 34, 4, 4, 29, 4, 5, 28, 10, 10, 10, 41, 29, 10, 28, 12, 39, 42, 11, 10, 10, 29, 29, 32,
                  32, 39, 5, 49, 12, 29, 28, 43, 43, 4, 8, 15, 28, 13, 24, 25, 48, 48, 42, 42, 30, 30, 30, 40, 41, 41,
                  5, 29, 29, 17, 23, 9, 9, 20, 20, 33, 39, 22, 22, 22, 29, 32, 33, 32, 9, 34, 41, 13, 14, 13, 34, 5, 20,
                  20, 20, 3, 45, 31, 37, 9, 13, 19, 19, 34, 23, 30, 5, 4, 36, 41, 21, 28, 21, 21, 8, 21, 45, 35, 13, 14,
                  28, 13, 13, 13, 13, 15, 15, 41, 14, 15, 15, 15, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
                  15, 15, 13, 13, 13, 13, 13, 15, 15, 14, 15, 15, 13, 15, 15, 13, 13, 13, 28, 14, 13, 13, 13, 13, 13,
                  13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 41, 28, 47, 41, 41, 41, 41, 41, 41, 41, 24, 24, 41, 41,
                  41, 36, 22, 41, 41, 41, 41, 37, 41, 41, 41, 24, 24, 41, 41, 41, 41, 48, 28, 25, 47, 47, 5, 32, 33, 14,
                  14, 4, 4, 37, 32, 18, 29, 31, 31, 33, 33, 4, 4, 37, 29, 33, 4, 25, 44, 20, 20, 28, 47, 47, 40, 6, 6,
                  6, 4, 19, 19, 19, 19, 19, 49, 28, 39, 39, 39, 39, 31, 31, 28, 31, 49, 37, 3, 3, 29, 36, 27, 25, 38,
                  28, 28, 28, 38, 38, 38, 27, 1, 35, 8, 5, 5, 2, 34, 8, 37, 13, 29, 5, 16, 22, 22, 16, 12, 12, 16, 23,
                  49, 49, 49, 34, 15, 28, 28, 27, 20, 20, 29, 14, 27, 27, 38, 47, 45, 31, 28, 31, 39, 28, 31, 48, 47, 4,
                  1, 26, 40, 37, 4, 12, 9, 21, 4, 33, 33, 33, 33, 33, 33, 27, 26, 18, 26, 30, 33, 23, 48, 8, 12, 39, 22,
                  28, 26, 29, 29, 9, 21, 27, 23, 31, 15, 49, 37, 20, 49, 45, 26, 41, 33, 8, 25, 3, 0, 12, 13, 14, 25, 7,
                  10, 10, 38, 42, 29, 29, 29, 21, 5, 5, 12, 5, 16, 2, 12, 12, 25, 12, 12, 12, 16, 16, 33, 16, 8, 22, 38,
                  13, 14, 13, 43, 21, 4, 10, 10, 42, 10, 22, 10, 29, 10, 39, 42, 16, 10, 10, 10, 10, 36, 29, 29, 45, 10,
                  10, 10, 10, 10, 38, 36, 29, 28, 28, 28, 28, 28, 28, 28, 33, 29, 18, 18, 18, 29, 25, 11, 25, 29, 29,
                  29, 29, 29, 29, 1, 30, 30, 28, 30, 1, 28, 26, 1, 26, 1, 25, 15, 28, 15, 15, 15, 15, 15, 15, 15, 15,
                  15, 14, 15, 15, 15, 15, 15, 15, 15, 15, 13, 30, 41, 19, 43, 10, 10, 10, 10, 29, 29, 32, 5, 29, 47, 44,
                  50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
                  50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
                  50, 50, 50, 50, 50)

_label_groups4 = (27, 4, 18, 12, 0, 8, 14, 14, 14, 4, 12, 0, 9, 9, 3, 9, 6, 15, 7, 22, 22, 12, 15, 1, 2, 9, 11, 1, 23,
                  6, 0, 6, 0, 6, 8, 3, 4, 27, 24, 18, 0, 17, 9, 22, 15, 15, 0, 14, 9, 21, 0, 5, 16, 15, 0, 17, 26, 25,
                  14, 13, 26, 6, 15, 11, 28, 18, 27, 13, 17, 1, 10, 22, 27, 23, 29, 9, 7, 25, 25, 10, 25, 4, 2, 4, 28,
                  25, 10, 19, 7, 28, 20, 12, 9, 18, 17, 13, 9, 29, 25, 14, 5, 23, 27, 23, 3, 3, 7, 24, 11, 19, 19, 10,
                  25, 18, 3, 7, 23, 24, 27, 26, 26, 29, 21, 27, 14, 28, 12, 28, 12, 27, 27, 11, 29, 27, 25, 12, 15, 16,
                  0, 25, 29, 23, 19, 12, 18, 8, 18, 25, 11, 5, 12, 4, 2, 4, 24, 14, 24, 0, 25, 12, 16, 9, 29, 19, 15,
                  15, 14, 6, 13, 13, 3, 13, 8, 9, 19, 18, 13, 18, 17, 25, 5, 27, 11, 10, 20, 28, 10, 7, 9, 9, 23, 6, 15,
                  2, 13, 25, 17, 9, 8, 14, 14, 14, 1, 4, 16, 22, 11, 15, 9, 12, 9, 27, 25, 2, 12, 2, 2, 25, 25, 0, 13,
                  28, 5, 18, 11, 17, 7, 4, 11, 25, 25, 10, 10, 10, 18, 17, 13, 13, 13, 13, 26, 6, 10, 1, 10, 23, 13, 13,
                  8, 13, 22, 7, 15, 4, 19, 19, 2, 2, 4, 0, 16, 16, 29, 22, 9, 9, 24, 10, 10, 5, 9, 11, 11, 17, 5, 12, 9,
                  10, 9, 5, 16, 22, 25, 16, 16, 13, 28, 15, 16, 8, 26, 12, 13, 17, 27, 27, 27, 7, 10, 13, 19, 15, 25,
                  18, 18, 8, 27, 26, 28, 10, 22, 14, 18, 4, 7, 5, 26, 17, 11, 29, 9, 2, 15, 12, 8, 22, 8, 3, 12, 12, 26,
                  17, 28, 9, 27, 26, 8, 1, 17, 10, 9, 4, 9, 17, 0, 0, 12, 15, 27, 12, 3, 7, 18, 9, 26, 11, 11, 0, 24,
                  16, 24, 12, 24, 18, 11, 7, 16, 5, 26, 27, 18, 5, 24, 10, 17, 25, 1, 4, 0, 17, 1, 3, 29, 4, 4, 23, 14,
                  2, 4, 29, 0, 0, 0, 3, 29, 18, 19, 25, 19, 17, 1, 28, 29, 27, 11, 6, 25, 27, 26, 9, 16, 8, 1, 9, 24, 8,
                  7, 24, 25, 8, 7, 18, 15, 10, 10, 23, 12, 4, 4, 9, 5, 26, 12, 2, 5, 12, 14, 1, 5, 26, 18, 4, 24, 4, 2,
                  4, 17, 22, 18, 1, 23, 0, 0, 5, 24, 13, 10, 11, 22, 0, 25, 18, 7, 19, 10, 27, 16, 4, 24, 5, 27, 15, 12,
                  24, 14, 14, 14, 10, 15, 4, 4, 6, 17, 23, 0, 2, 12, 16, 17, 4, 10, 12, 27, 10, 0, 8, 14, 8, 12, 23, 0,
                  10, 11, 18, 27, 8, 7, 27, 10, 16, 19, 27, 8, 10, 8, 0, 10, 23, 6, 8, 18, 0, 0, 3, 10, 4, 26, 10, 11,
                  8, 20, 28, 29, 8, 9, 18, 0, 17, 14, 14, 22, 22, 8, 27, 19, 5, 0, 19, 10, 22, 24, 19, 18, 5, 9, 22, 12,
                  15, 2, 5, 11, 24, 27, 16, 10, 7, 11, 18, 12, 6, 23, 15, 27, 27, 17, 10, 23, 8, 28, 27, 9, 9, 28, 4, 4,
                  4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 19, 15, 14, 12, 29, 18, 12, 17, 24, 18, 10, 25, 15, 15,
                  28, 27, 27, 5, 10, 19, 4, 11, 13, 11, 17, 28, 1, 17, 12, 1, 14, 14, 15, 15, 16, 22, 16, 12, 5, 10, 10,
                  16, 22, 2, 9, 0, 10, 20, 18, 16, 10, 18, 25, 14, 9, 25, 16, 8, 27, 9, 24, 18, 15, 9, 16, 16, 21, 8,
                  10, 27, 17, 24, 6, 25, 13, 24, 13, 13, 13, 13, 18, 29, 21, 15, 25, 14, 7, 19, 18, 18, 27, 24, 17, 28,
                  25, 8, 0, 22, 23, 26, 8, 27, 29, 11, 10, 25, 10, 29, 15, 29, 13, 0, 19, 9, 27, 8, 26, 2, 1, 19, 21,
                  21, 28, 9, 24, 22, 13, 17, 8, 27, 12, 16, 22, 18, 27, 7, 16, 27, 25, 13, 16, 5, 18, 6, 26, 18, 18, 7,
                  0, 3, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30)

model

In [None]:
import torch
import torch.nn as nn
import torchvision.models.resnet as resnet
import collections
import torch.nn.functional as F


BLOCK = {'18': 'BasicBlock',
         '34': 'BasicBlock',
         '50': 'Bottleneck',
         '101': 'Bottleneck',
         '152': 'Bottleneck'}


LAYERS = {'18': [2, 2, 2, 2],
          '34': [3, 4, 6, 3],
          '50': [3, 4, 6, 3],
          '101': [3, 4, 23, 3],
          '152': [3, 8, 36, 3]}


class ArtCV(nn.Module):
    def __init__(self, tag='18', num_labels=(100, 681, 6, 1920, 768),
                 classifier_layers=(1, 1, 1, 1, 1), classifier_hidden=(2048, 2048, 2048, 2048, 2048),
                 task=('ml', 'ml', 'mc', 'ml', 'ml'), weights=(1, 1, 1, 1, 1),
                 use_batch_norm=True, dropout_rate=0.1,
                 weight_path=None, freeze_cnn=False,
                 focal_loss=False, focal_loss_mc=False, alpha_t_mc=True,
                 alpha=(0.25, 0.25, 0.25, 0.25), alpha_mc=(0.25, 0.75, 0.75, 0.75, 0.75, 0.75),
                 gamma_mc=2, gamma=(2, 2, 2, 2), alpha_t=True, alpha_group=(1, 1), gamma_group=2,
                 hierarchical=False, label_groups=(0, 1, 1, 1, 1, 1), group_classifier_kwargs=dict(),
                 weight_group=None,
                 hierarchical_ml=(False, False, False, False),
                 label_groups0=None, label_groups1=None, label_groups3=None, label_groups4=None,
                 group_classifier_kwargs0=dict(), group_classifier_kwargs1=dict(), group_classifier_kwargs3=dict(),
                 group_classifier_kwargs4=dict(), weight_group_ml=(None, None, None, None)):
        super().__init__()
        self.tag = tag
        self.num_labels = num_labels
        self.classifier_layers = classifier_layers
        self.classifier_hidden = classifier_hidden
        self.task = task
        self.weights = weights
        self.use_batch_norm = use_batch_norm
        self.dropout_rate = dropout_rate
        self.weight_path = weight_path
        self.freeze_cnn = freeze_cnn
        self.focal_loss = focal_loss
        self.focal_loss_mc = focal_loss_mc
        self.hierarchical = hierarchical
        self.hierarchical_ml = hierarchical_ml

        if self.focal_loss:
            self.alpha = alpha
            self.gamma = gamma
            self.alpha_t = alpha_t
        if self.focal_loss_mc:
            self.alpha_mc =alpha_mc
            self.gamma_mc = gamma_mc
            self.alpha_t_mc = alpha_t_mc
            if self.hierarchical:
                self.alpha_group = alpha_group
                self.gamma_group = gamma_group

        self.cnn = ResNet_CNN(getattr(resnet, BLOCK[tag]), LAYERS[tag],
                              weight_path=self.weight_path, freeze_layers=self.freeze_cnn)

        self.classifiers = nn.ModuleDict(
            collections.OrderedDict(
                [('classifier{}'.format(i), Classifier(dim_in=512 * getattr(resnet, BLOCK[tag]).expansion,
                                                       dim_out=self.num_labels[i],
                                                       dim_hidden=self.classifier_hidden[i],
                                                       n_layers=self.classifier_layers[i],
                                                       task=self.task[i],
                                                       use_batch_norm=self.use_batch_norm,
                                                       dropout_rate=self.dropout_rate))
                 for i in range(len(self.num_labels))]))

        if self.hierarchical:
            self.label_groups = np.array(label_groups)
            self.num_groups = len(np.unique(self.label_groups))
            self.group_classifier_kwargs = {'dim_hidden': self.classifier_hidden[2],
                                            'n_layers': self.classifier_layers[2],
                                            'task': self.task[2],
                                            'use_batch_norm': self.use_batch_norm,
                                            'dropout_rate': self.dropout_rate}
            self.group_classifier_kwargs.update(group_classifier_kwargs)
            self.group_classifier = Classifier(dim_in=512 * getattr(resnet, BLOCK[tag]).expansion,
                                               dim_out=self.num_groups, **self.group_classifier_kwargs)
            self.weight_group = weight_group if weight_group is not None else self.weights[2]
            self.group_idx_list = torch.nn.ParameterList([torch.nn.Parameter(
                torch.tensor((self.label_groups == i).astype(np.uint8), dtype=torch.bool), requires_grad=False)
                for i in range(self.num_groups)])

        if self.hierarchical_ml[0]:
            self.label_groups0 = np.array(label_groups0) if label_groups0 is not None else np.array(_label_groups0)
            self.num_groups0 = len(np.unique(self.label_groups0))
            self.group_classifier_kwargs0 = {'dim_hidden': self.classifier_hidden[0],
                                             'n_layers': self.classifier_layers[0],
                                             'task': self.task[0],
                                             'use_batch_norm': self.use_batch_norm,
                                             'dropout_rate': self.dropout_rate}
            self.group_classifier_kwargs0.update(group_classifier_kwargs0)
            self.group_classifier0 = Classifier(dim_in=512 * getattr(resnet, BLOCK[tag]).expansion,
                                                dim_out=self.num_groups0, **self.group_classifier_kwargs0)
            self.weight_group0 = weight_group_ml[0] if weight_group_ml[0] is not None else self.weights[0]
            self.group_idx_list0 = torch.nn.ParameterList([torch.nn.Parameter(
                torch.tensor((self.label_groups0 == i).astype(np.uint8), dtype=torch.bool), requires_grad=False)
                for i in range(self.num_groups0)])

        if self.hierarchical_ml[1]:
            self.label_groups1 = np.array(label_groups1) if label_groups1 is not None else np.array(_label_groups1)
            self.num_groups1 = len(np.unique(self.label_groups1))
            self.group_classifier_kwargs1 = {'dim_hidden': self.classifier_hidden[1],
                                             'n_layers': self.classifier_layers[1],
                                             'task': self.task[1],
                                             'use_batch_norm': self.use_batch_norm,
                                             'dropout_rate': self.dropout_rate}
            self.group_classifier_kwargs1.update(group_classifier_kwargs1)
            self.group_classifier1 = Classifier(dim_in=512 * getattr(resnet, BLOCK[tag]).expansion,
                                                dim_out=self.num_groups1, **self.group_classifier_kwargs1)
            self.weight_group1 = weight_group_ml[1] if weight_group_ml[1] is not None else self.weights[1]
            self.group_idx_list1 = torch.nn.ParameterList([torch.nn.Parameter(
                torch.tensor((self.label_groups1 == i).astype(np.uint8), dtype=torch.bool), requires_grad=False)
                for i in range(self.num_groups1)])

        if self.hierarchical_ml[2]:
            self.label_groups3 = np.array(label_groups3) if label_groups3 is not None else np.array(_label_groups3)
            self.num_groups3 = len(np.unique(self.label_groups3))
            self.group_classifier_kwargs3 = {'dim_hidden': self.classifier_hidden[3],
                                             'n_layers': self.classifier_layers[3],
                                             'task': self.task[3],
                                             'use_batch_norm': self.use_batch_norm,
                                             'dropout_rate': self.dropout_rate}
            self.group_classifier_kwargs3.update(group_classifier_kwargs3)
            self.group_classifier3 = Classifier(dim_in=512 * getattr(resnet, BLOCK[tag]).expansion,
                                                dim_out=self.num_groups3, **self.group_classifier_kwargs3)
            self.weight_group3 = weight_group_ml[2] if weight_group_ml[2] is not None else self.weights[3]
            self.group_idx_list3 = torch.nn.ParameterList([torch.nn.Parameter(
                torch.tensor((self.label_groups3 == i).astype(np.uint8), dtype=torch.bool), requires_grad=False)
                for i in range(self.num_groups3)])

        if self.hierarchical_ml[3]:
            self.label_groups4 = np.array(label_groups4) if label_groups4 is not None else np.array(_label_groups4)
            self.num_groups4 = len(np.unique(self.label_groups4))
            self.group_classifier_kwargs4 = {'dim_hidden': self.classifier_hidden[4],
                                             'n_layers': self.classifier_layers[4],
                                             'task': self.task[4],
                                             'use_batch_norm': self.use_batch_norm,
                                             'dropout_rate': self.dropout_rate}
            self.group_classifier_kwargs4.update(group_classifier_kwargs4)
            self.group_classifier4 = Classifier(dim_in=512 * getattr(resnet, BLOCK[tag]).expansion,
                                                dim_out=self.num_groups4, **self.group_classifier_kwargs4)
            self.weight_group4 = weight_group_ml[3] if weight_group_ml[3] is not None else self.weights[4]
            self.group_idx_list4 = torch.nn.ParameterList([torch.nn.Parameter(
                torch.tensor((self.label_groups4 == i).astype(np.uint8), dtype=torch.bool), requires_grad=False)
                for i in range(self.num_groups4)])

    def inference(self, x):
        return self.cnn(x)

    def get_probs_mc(self, x_features):
        y_pred2 = self.classifiers['classifier2'](x_features)

        if self.hierarchical:
            y_group2 = self.group_classifier(x_features)
            y_weighted2 = torch.zeros_like(y_pred2)
            for i, group_idx in enumerate(self.group_idx_list):
                y_weighted2[:, group_idx] = y_pred2[:, group_idx] / \
                                            (y_pred2[:, group_idx].sum(dim=-1, keepdim=True) + 1e-8) * \
                                            y_group2[:, [i]]
            return y_weighted2.view(-1, self.num_labels[2]), y_group2.view(-1, self.num_groups)
        else:
            return y_pred2.view(-1, self.num_labels[2])

    def get_probs(self, x):
        x_features = self.inference(x)
        y_pred0 = self.classifiers['classifier0'](x_features)
        y_pred1 = self.classifiers['classifier1'](x_features)
        y_pred3 = self.classifiers['classifier3'](x_features)
        y_pred4 = self.classifiers['classifier4'](x_features)
        if self.hierarchical_ml[0]:
            y_group0 = self.group_classifier0(x_features)
            y_probs0 = torch.zeros_like(y_pred0)
            for i, group_idx in enumerate(self.group_idx_list0):
                y_probs0[:, group_idx] = y_pred0[:, group_idx] / \
                                            (y_pred0[:, group_idx].sum(dim=-1, keepdim=True) + 1e-8) * \
                                            y_group0[:, [i]]
        else:
            y_probs0 = y_pred0.view(-1, self.num_labels[0])
            y_group0 = None

        if self.hierarchical_ml[1]:
            y_group1 = self.group_classifier1(x_features)
            y_probs1 = torch.zeros_like(y_pred1)
            for i, group_idx in enumerate(self.group_idx_list1):
                y_probs1[:, group_idx] = y_pred1[:, group_idx] / \
                                            (y_pred1[:, group_idx].sum(dim=-1, keepdim=True) + 1e-8) * \
                                            y_group1[:, [i]]
        else:
            y_probs1 = y_pred1.view(-1, self.num_labels[1])
            y_group1 = None

        if self.hierarchical_ml[2]:
            y_group3 = self.group_classifier3(x_features)
            y_probs3 = torch.zeros_like(y_pred3)
            for i, group_idx in enumerate(self.group_idx_list3):
                y_probs3[:, group_idx] = y_pred3[:, group_idx] / \
                                            (y_pred3[:, group_idx].sum(dim=-1, keepdim=True) + 1e-8) * \
                                            y_group3[:, [i]]
        else:
            y_probs3 = y_pred3.view(-1, self.num_labels[3])
            y_group3 = None

        if self.hierarchical_ml[3]:
            y_group4 = self.group_classifier4(x_features)
            y_probs4 = torch.zeros_like(y_pred4)
            for i, group_idx in enumerate(self.group_idx_list4):
                y_probs4[:, group_idx] = y_pred4[:, group_idx] / \
                                            (y_pred4[:, group_idx].sum(dim=-1, keepdim=True) + 1e-8) * \
                                            y_group4[:, [i]]
        else:
            y_probs4 = y_pred4.view(-1, self.num_labels[4])
            y_group4 = None

        if self.hierarchical:
            y_probs2, y_group2 = self.get_probs_mc(x_features)
        else:
            y_probs2 = self.get_probs_mc(x_features)
            y_group2 = None

        return y_probs0, y_probs1, y_probs2, y_probs3, y_probs4, (y_group0, y_group1, y_group2, y_group3, y_group4)

    def get_loss_mc(self, x, y2, reduction='sum'):
        if self.hierarchical:
            y_pred2, y_group2 = self.get_probs_mc(self.inference(x))
            if self.focal_loss_mc:
                loss_group = focal_loss_mc(y_group2,
                                           torch.tensor(list(map(lambda x: self.label_groups[x],
                                                                 y2.view(-1)))).view(-1),
                                           num_classes=self.num_groups, alpha=self.alpha_group,
                                           gamma=self.gamma_group, alpha_t=self.alpha_t_mc)
            else:
                loss_group = F.cross_entropy(y_group2,
                                             torch.tensor(list(map(lambda x: self.label_groups[x],
                                                                   y2.view(-1)))).view(-1), reduction='none')
        else:
            y_pred2 = self.get_probs_mc(self.inference(x))
        if self.focal_loss_mc:
            loss2 = focal_loss_mc(y_pred2, y2.view(-1), num_classes=self.num_labels[2],
                                  alpha=self.alpha_mc, gamma=self.gamma_mc, alpha_t=self.alpha_t_mc)
        else:
            loss2 = F.cross_entropy(y_pred2, y2.view(-1), reduction='none')

        if self.hierarchical:
            if reduction == 'sum':
                return loss2 * self.weights[2] + loss_group * self.weight_group
            elif reduction == 'none':
                return loss2 * self.weights[2], loss_group * self.weight_group
        else:
            return loss2 * self.weights[2]

    def get_loss(self, x, y0, y1, y2, y3, y4):
        y_pred0, y_pred1, y_pred2, y_pred3, y_pred4, y_groups_tuples = self.get_probs(x)
        y_group0, y_group1, y_group2, y_group3, y_group4 = y_groups_tuples
        if self.focal_loss_mc:
            if self.hierarchical:
                loss_group2 = focal_loss_mc(y_group2,
                                           torch.tensor(list(map(lambda x: self.label_groups[x],
                                                                 y2.view(-1)))).view(-1),
                                           num_classes=self.num_groups, alpha=self.alpha_group,
                                           gamma=self.gamma_group, alpha_t=self.alpha_t_mc)
            else:
                loss_group2 = None
            loss2 = focal_loss_mc(y_pred2, y2.view(-1), num_classes=self.num_labels[2],
                                  alpha=self.alpha_mc, gamma=self.gamma_mc, alpha_t=self.alpha_t_mc)
        else:
            if self.hierarchical:
                loss_group2 = F.cross_entropy(y_group2,
                                             torch.tensor(list(map(lambda x: self.label_groups[x],
                                                                   y2.view(-1)))).view(-1), reduction='none')
            else:
                loss_group2 = None
            loss2 = F.cross_entropy(y_pred2, y2.view(-1), reduction='none')

        if self.focal_loss:
            loss0 = torch.mean(focal_loss_ml(y_pred0, y0, alpha=self.alpha[0], gamma=self.gamma[0],
                                             alpha_t=self.alpha_t), dim=1)
            if self.hierarchical_ml[0]:
                loss_group0 = torch.mean(focal_loss_ml(y_group0,
                                                       torch.cat(list(map(lambda i:
                                                                          torch.sum(
                                                                              F.one_hot(
                                                                                  i, num_classes=self.num_groups0),
                                                                              dim=0).view(1, -1),
                                                                          list(map(lambda x: torch.unique((x * (
                                                                                  torch.tensor(
                                                                                      self.label_groups0)+1)-1),
                                                                                               dim=-1)[1:],
                                                                                   y0.long())))),
                                                                 dim=0).float(),
                                                       alpha=self.alpha[0], gamma=self.gamma[0],
                                                       alpha_t=self.alpha_t), dim=1)
            else:
                loss_group0 = None

            loss1 = torch.mean(focal_loss_ml(y_pred1, y1, alpha=self.alpha[1], gamma=self.gamma[1],
                                             alpha_t=self.alpha_t), dim=1)
            if self.hierarchical_ml[1]:
                loss_group1 = torch.mean(focal_loss_ml(y_group1,
                                                       torch.cat(list(map(lambda i:
                                                                          torch.sum(
                                                                              F.one_hot(
                                                                                  i, num_classes=self.num_groups1),
                                                                              dim=0).view(1, -1),
                                                                          list(map(lambda x: torch.unique((x * (
                                                                                  torch.tensor(
                                                                                      self.label_groups1)+1)-1),
                                                                                               dim=-1)[1:],
                                                                                   y1.long())))),
                                                                 dim=0).float(),
                                                       alpha=self.alpha[1], gamma=self.gamma[1],
                                                       alpha_t=self.alpha_t), dim=1)
            else:
                loss_group1 = None

            loss3 = torch.mean(focal_loss_ml(y_pred3, y3, alpha=self.alpha[2], gamma=self.gamma[2],
                                             alpha_t=self.alpha_t), dim=1)
            if self.hierarchical_ml[2]:
                loss_group3 = torch.mean(focal_loss_ml(y_group3,
                                                       torch.cat(list(map(lambda i:
                                                                          torch.sum(
                                                                              F.one_hot(
                                                                                  i, num_classes=self.num_groups3),
                                                                              dim=0).view(1, -1),
                                                                          list(map(lambda x: torch.unique((x * (
                                                                                  torch.tensor(
                                                                                      self.label_groups3)+1)-1),
                                                                                               dim=-1)[1:],
                                                                                   y3.long())))),
                                                                 dim=0).float(),
                                                       alpha=self.alpha[2], gamma=self.gamma[2],
                                                       alpha_t=self.alpha_t), dim=1)
            else:
                loss_group3 = None

            loss4 = torch.mean(focal_loss_ml(y_pred4, y4, alpha=self.alpha[3], gamma=self.gamma[3],
                                             alpha_t=self.alpha_t), dim=1)
            if self.hierarchical_ml[3]:
                loss_group4 = torch.mean(focal_loss_ml(y_group4,
                                                       torch.cat(list(map(lambda i:
                                                                          torch.sum(
                                                                              F.one_hot(
                                                                                  i, num_classes=self.num_groups4),
                                                                              dim=0).view(1, -1),
                                                                          list(map(lambda x: torch.unique((x * (
                                                                                  torch.tensor(
                                                                                      self.label_groups4)+1)-1),
                                                                                               dim=-1)[1:],
                                                                                   y4.long())))),
                                                                 dim=0).float(),
                                                       alpha=self.alpha[3], gamma=self.gamma[3],
                                                       alpha_t=self.alpha_t), dim=1)
            else:
                loss_group4 = None

        else:
            loss0 = torch.mean(F.binary_cross_entropy(y_pred0, y0, reduction='none'), dim=1)
            if self.hierarchical_ml[0]:
                loss_group0 = torch.mean(F.binary_cross_entropy(y_group0,
                                                                torch.cat(list(map(lambda i:
                                                                                   torch.sum(
                                                                                      F.one_hot(i,
                                                                                                num_classes=
                                                                                                self.num_groups0),
                                                                                      dim=0).view(1, -1),
                                                                                   list(map(lambda x: torch.unique((
                                                                                           x * (
                                                                                               torch.tensor(
                                                                                                   self.label_groups0)
                                                                                               + 1) - 1),
                                                                                                       dim=-1)[1:],
                                                                                            y0.long()))
                                                                                   )), dim=0).float(),
                                                                reduction='none'), dim=1)
            else:
                loss_group0 = None

            loss1 = torch.mean(F.binary_cross_entropy(y_pred1, y1, reduction='none'), dim=1)
            if self.hierarchical_ml[1]:
                loss_group1 = torch.mean(F.binary_cross_entropy(y_group1,
                                                                torch.cat(list(map(lambda i:
                                                                                   torch.sum(
                                                                                      F.one_hot(i,
                                                                                                num_classes=
                                                                                                self.num_groups1),
                                                                                      dim=0).view(1, -1),
                                                                                   list(map(lambda x: torch.unique((
                                                                                           x * (
                                                                                               torch.tensor(
                                                                                                   self.label_groups1)
                                                                                               + 1) - 1),
                                                                                                       dim=-1)[1:],
                                                                                            y1.long()))
                                                                                   )), dim=0).float(),
                                                                reduction='none'), dim=1)
            else:
                loss_group1 = None

            loss3 = torch.mean(F.binary_cross_entropy(y_pred3, y3, reduction='none'), dim=1)
            if self.hierarchical_ml[2]:
                loss_group3 = torch.mean(F.binary_cross_entropy(y_group3,
                                                                torch.cat(list(map(lambda i:
                                                                                   torch.sum(
                                                                                      F.one_hot(i,
                                                                                                num_classes=
                                                                                                self.num_groups3),
                                                                                      dim=0).view(1, -1),
                                                                                   list(map(lambda x: torch.unique((
                                                                                           x * (
                                                                                               torch.tensor(
                                                                                                   self.label_groups3)
                                                                                               + 1) - 1),
                                                                                                       dim=-1)[1:],
                                                                                            y3.long()))
                                                                                   )), dim=0).float(),
                                                                reduction='none'), dim=1)
            else:
                loss_group3 = None

            loss4 = torch.mean(F.binary_cross_entropy(y_pred4, y4, reduction='none'), dim=1)
            if self.hierarchical_ml[3]:
                loss_group4 = torch.mean(F.binary_cross_entropy(y_group4,
                                                                torch.cat(list(map(lambda i:
                                                                                   torch.sum(
                                                                                      F.one_hot(i,
                                                                                                num_classes=
                                                                                                self.num_groups4),
                                                                                      dim=0).view(1, -1),
                                                                                   list(map(lambda x: torch.unique((
                                                                                           x * (
                                                                                               torch.tensor(
                                                                                                   self.label_groups4)
                                                                                               + 1) - 1),
                                                                                                       dim=-1)[1:],
                                                                                            y4.long()))
                                                                                   )), dim=0).float(),
                                                                reduction='none'), dim=1)
            else:
                loss_group4 = None

        return loss0, loss1, loss2, loss3, loss4, (loss_group0, loss_group1, loss_group2, loss_group3, loss_group4)

    def forward(self, x, y0, y1, y2, y3, y4, reduction='sum'):
        loss0, loss1, loss2, loss3, loss4, loss_groups_tuple = self.get_loss(x, y0, y1, y2, y3, y4)
        loss_group0, loss_group1, loss_group2, loss_group3, loss_group4 = loss_groups_tuple

        loss_groups = 0
        if loss_group0 is not None:
            loss_groups += loss_group0 * self.weight_group0

        if loss_group1 is not None:
            loss_groups += loss_group1 * self.weight_group1

        if loss_group3 is not None:
            loss_groups += loss_group3 * self.weight_group3

        if loss_group4 is not None:
            loss_groups += loss_group4 * self.weight_group4

        if loss_group2 is not None:
            loss_groups += loss_group2 * self.weight_group

        if reduction == 'sum':
            return loss0 * self.weights[0] + loss1 * self.weights[1] + loss2 * self.weights[2] + \
                   loss3 * self.weights[3] + loss4 * self.weights[4] + loss_groups
        elif reduction == 'none':
            return loss0 * self.weights[0], loss1 * self.weights[1], loss2 * self.weights[2], \
                   loss3 * self.weights[3], loss4 * self.weights[4], loss_groups

    def get_concat_probs(self, x, return_hier_pred=False):
        if return_hier_pred:
            y_pred0, y_pred1, y_pred2, y_pred3, y_pred4, y_groups_tuples = self.get_probs(x)
            _, _, y_group2, _, _ = y_groups_tuples
            return torch.cat((y_pred0, y_pred1,
                              F.one_hot(y_pred2.argmax(axis=-1), num_classes=self.num_labels[2]).float()[:, 1:],
                              y_pred3, y_pred4), dim=1), \
                   F.one_hot(y_group2.argmax(axis=-1), num_classes=self.num_groups)
        else:
            y_pred0, y_pred1, y_pred2, y_pred3, y_pred4, _ = self.get_probs(x)

        return torch.cat((y_pred0, y_pred1,
                          F.one_hot(y_pred2.argmax(axis=-1), num_classes=self.num_labels[2]).float()[:, 1:],
                          y_pred3, y_pred4), dim=1)


def focal_loss_ml(inputs, targets, alpha=0.25, gamma=2, alpha_t=False):
    BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
    pt = torch.exp(-BCE_loss)
    if alpha_t:
        return (-alpha * (targets * 2 -1) + targets) * (1-pt)**gamma * BCE_loss
    else:
        return alpha * (1-pt)**gamma * BCE_loss


def focal_loss_mc(inputs, targets,
                  num_classes, alpha=(0.25, 0.75, 0.75, 0.75, 0.75, 0.75), gamma=2, alpha_t=False):
    targets_one_hot = F.one_hot(targets, num_classes=num_classes)
    pt = inputs * targets_one_hot
    one_sub_pt = 1 - pt
    log_pt = targets_one_hot * torch.log(inputs + 1e-6)
    if alpha_t:
        return torch.sum((-torch.tensor(alpha))*one_sub_pt**gamma*log_pt, dim=-1)
    else:
        return torch.sum((-1)*one_sub_pt**gamma*log_pt, dim=-1)


def get_weight_mat(y, ratio=10, base=10):
    return (torch.ones(y.shape) - y) * (torch.sum(y, dim=-1).view(-1, 1) + base) \
           * ratio / (y.shape[1] - torch.sum(y, dim=-1).view(-1, 1)).detach()

trainer

In [None]:
import torch
import sys
import copy
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import fbeta_score

import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from tqdm import trange


class Trainer:
    def __init__(self, model, dataset, use_cuda=True,
                 shuffle=True, epochs=100, extra_epochs_mc=1, train_mc_step=1, 
                 batch_size_train=64, batch_size_val=64, batch_size_all=64, batch_size_test=64,
                 head_log=True,
                 monitor_frequency=False, compute_acc=True, printout=False,
                 dataloader_train_kwargs=dict(), dataloader_val_kwargs=dict(),
                 dataloader_all_kwargs=dict(),
                 dataloader_test_kwargs=dict()):
        self.model = model
        self.dataset = dataset
        self.train_val = bool(self.dataset.train_val_split)
        self.use_cuda = use_cuda
        if self.use_cuda and torch.cuda.is_available():
            torch.set_default_tensor_type(torch.cuda.FloatTensor)
        self.epochs = epochs
        self.extra_epochs_mc = extra_epochs_mc
        self.train_mc_step = train_mc_step
        self.head_log = head_log
        self.running_loss = []
        if self.head_log:
            self.loss_log = dict()
            self.loss_log['Head 0'] = []
            self.loss_log['Head 1'] = []
            self.loss_log['Head 2'] = []
            self.loss_log['Head 3'] = []
            self.loss_log['Head 4'] = []
            if self.model.hierarchical:
                self.loss_log['Head group classifiers'] = []

        if self.train_val:
            if shuffle:
                train_sampler = RandomSampler(dataset.train)
                val_sampler = RandomSampler(dataset.val)
            else:
                train_sampler = SequentialSampler(dataset.train)
                val_sampler = SequentialSampler(dataset.val)
            self.dataloader_train_kwargs = copy.deepcopy(dataloader_train_kwargs)
            self.dataloader_train_kwargs.update({'batch_size': batch_size_train, 'sampler': train_sampler})
            self.dataloader_train = DataLoader(self.dataset.train, **self.dataloader_train_kwargs)
            self.dataloader_val_kwargs = copy.deepcopy(dataloader_val_kwargs)
            self.dataloader_val_kwargs.update({'batch_size': batch_size_val, 'sampler': val_sampler})
            self.dataloader_val = DataLoader(self.dataset.val, **self.dataloader_val_kwargs)
            self.loss_history_train = []
            self.loss_history_val = []
            self.accuracy_history_train = []
            self.accuracy_history_val = []

        else:
            sampler = RandomSampler(dataset.all)
            self.dataloader_train_kwargs = copy.deepcopy(dataloader_train_kwargs)
            self.dataloader_train_kwargs.update({'batch_size': batch_size_train, 'sampler': sampler})
            self.dataloader_train = DataLoader(self.dataset.all, **self.dataloader_train_kwargs)
            self.loss_history_train = []
            self.accuracy_history_train = []

        all_sampler = SequentialSampler(dataset.all)
        self.dataloader_all_kwargs = copy.deepcopy(dataloader_all_kwargs)
        self.dataloader_all_kwargs.update({'batch_size': batch_size_all, 'sampler': all_sampler})
        self.dataloader_all = DataLoader(self.dataset.all, **self.dataloader_all_kwargs)

        if self.dataset.test_path is not None:
            test_sampler = SequentialSampler(dataset.test)
            self.dataloader_test_kwargs = copy.deepcopy(dataloader_test_kwargs)
            self.dataloader_test_kwargs.update({'batch_size': batch_size_test, 'sampler': test_sampler})
            self.dataloader_test = DataLoader(self.dataset.test, **self.dataloader_test_kwargs)

        self.use_cuda = use_cuda and torch.cuda.is_available()
        if self.use_cuda:
            self.model.cuda()
        self.monitor_frequency = monitor_frequency
        self.compute_acc = compute_acc
        self.printout = printout

    def before_iter(self):
        pass

    def after_iter(self):
        pass

    def train_mc(self, lr, parameters=None, betas=(0.9, 0.999), eps=1e-8, weight_decay=0,
              grad_clip=False, max_norm=1e-5):
        epochs = self.extra_epochs_mc
        self.model.classifiers['classifier2'].train()
        if self.model.hierarchical:
            self.model.group_classifier.train()
        params = filter(lambda x: x.requires_grad, self.model.parameters()) \
            if parameters is None else parameters
        optim = torch.optim.Adam(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        for epoch in range(epochs):
            for x, _, _, y2, _, _ in self.dataloader_train:
                if self.use_cuda and torch.cuda.is_available():
                    x = x.cuda()
                    y2 = y2.cuda()
                loss = torch.mean(self.model.get_loss_mc(x, y2))
                optim.zero_grad()
                loss.backward()
                if grad_clip:
                    torch.nn.utils.clip_grad_norm_(params, max_norm=max_norm)
                optim.step()
                
    def train(self, parameters=None, lr=1e-1, mc_lr=1e-1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0,
              reduce_lr=False, step=5, gamma=0.8,
              reduce_lr_mc=False, step_mc=5, gamma_mc=0.8,
              grad_clip=False, max_norm=1e-5,
              train_mc_kwargs=dict()):
        epochs = self.epochs
        self.model.train()
        params = filter(lambda x: x.requires_grad, self.model.parameters()) \
            if parameters is None else parameters
        optim = torch.optim.Adam(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        if bool(self.extra_epochs_mc):
            epoch_count_mc = 0
        with trange(epochs, desc='Training progress: ', file=sys.stdout) as progbar:
            for epoch_idx in progbar:
                self.before_iter()
                progbar.update(1)
                running_loss = 0
                if self.head_log:
                    head_loss0 = 0
                    head_loss1 = 0
                    head_loss2 = 0
                    head_loss3 = 0
                    head_loss4 = 0
                    if self.model.hierarchical:
                        head_loss_group = 0
                for data_tensors in self.dataloader_train:
                    data_tensor_tuples = [data_tensors]
                    if self.head_log:
                        if self.model.hierarchical:
                            loss0, loss1, loss2, loss3, loss4, loss_group = self.loss(*data_tensor_tuples,
                                                                                      head_log=self.head_log)
                            loss = torch.mean(loss0 + loss1 + loss2 + loss3 + loss4 + loss_group)
                            head_loss_group += torch.mean(loss_group).item()
                        else:
                            loss0, loss1, loss2, loss3, loss4, _ = self.loss(*data_tensor_tuples,
                                                                             head_log=self.head_log)
                            loss = torch.mean(loss0 + loss1 + loss2 + loss3 + loss4)
                        head_loss0 += torch.mean(loss0).item()
                        head_loss1 += torch.mean(loss1).item()
                        head_loss2 += torch.mean(loss2).item()
                        head_loss3 += torch.mean(loss3).item()
                        head_loss4 += torch.mean(loss4).item()
                    else:
                        loss = self.loss(*data_tensor_tuples)
                    running_loss += loss.item()
                    optim.zero_grad()
                    loss.backward()
                    if grad_clip:
                        torch.nn.utils.clip_grad_norm_(params, max_norm=max_norm)
                    optim.step()
                self.running_loss.append(running_loss/len(self.dataloader_train))
                if self.head_log:
                    self.loss_log['Head 0'].append(head_loss0/len(self.dataloader_train))
                    self.loss_log['Head 1'].append(head_loss1/len(self.dataloader_train))
                    self.loss_log['Head 2'].append(head_loss2/len(self.dataloader_train))
                    self.loss_log['Head 3'].append(head_loss3/len(self.dataloader_train))
                    self.loss_log['Head 4'].append(head_loss4/len(self.dataloader_train))
                    if self.model.hierarchical:
                        self.loss_log['Head group classifiers'].append(head_loss_group / len(self.dataloader_train))

                if (epoch_idx+1) % step == 0 & reduce_lr:
                    for p in optim.param_groups:
                        p['lr'] *= gamma

                if bool(self.monitor_frequency):
                    if (epoch_idx + 1) % self.monitor_frequency == 0:
                        current_loss_train = self.compute_loss(tag='train')
                        self.loss_history_train.append(current_loss_train)
                        if self.compute_acc:
                            current_accuracy_train = self.compute_accuracy(tag='train')
                            self.accuracy_history_train.append(current_accuracy_train)
                        if self.train_val:
                            current_loss_val = self.compute_loss(tag='val')
                            self.loss_history_val.append(current_loss_val)
                            if self.compute_acc:
                                current_accuracy_val = self.compute_accuracy(tag='val')
                                self.accuracy_history_val.append(current_accuracy_val)
                        if self.printout:
                            print("After %i epochs, loss is %f and prediction accuracy is %f."
                                  % (epoch_idx, current_loss_train, current_accuracy_train))
                if bool(self.extra_epochs_mc) & (epoch_idx + 1) % self.train_mc_step == 0:
                    self.train_mc(lr=mc_lr, **train_mc_kwargs)
                    epoch_count_mc += 1
                    if epoch_count_mc % step_mc == 0 & reduce_lr_mc:
                        mc_lr *= gamma_mc
                                  
                self.after_iter()

    def loss(self, data_tensors, head_log=False):
        x, y0, y1, y2, y3, y4 = data_tensors
        if self.use_cuda and torch.cuda.is_available():
            x = x.cuda()
            y0 = y0.cuda()
            y1 = y1.cuda()
            y2 = y2.cuda()
            y3 = y3.cuda()
            y4 = y4.cuda()
        if head_log:
            loss0, loss1, loss2, loss3, loss4, loss_group = self.model(x, y0, y1, y2, y3, y4, reduction='none')
            return loss0, loss1, loss2, loss3, loss4, loss_group

        else:
            loss = torch.mean(self.model(x, y0, y1, y2, y3, y4, reduction='sum'))
            return loss

    @torch.no_grad()
    def plot_running_loss(self, epochs_override=None):
        len_ticks = len(self.running_loss)
        if epochs_override is None:
            x_axis = np.linspace(1, self.epochs, len_ticks)
        else:
            x_axis = np.linspace(1, epochs_override, len_ticks)
        plt.figure()
        plt.plot(x_axis, self.running_loss)
        plt.xlabel('Number of epochs')
        plt.ylabel('Estimated loss')
        plt.show()

    @torch.no_grad()
    def plot_head_loss(self, epochs_override=None):
        if not self.head_log:
            pass
        else:
            len_ticks = len(self.running_loss)
            if epochs_override is None:
                x_axis = np.linspace(1, self.epochs, len_ticks)
            else:
                x_axis = np.linspace(1, epochs_override, len_ticks)
            plt.figure()
            for key in self.loss_log.keys():
                log_array = np.array(self.loss_log[key])
                plt.plot(x_axis, log_array/log_array[0], label=key)
            plt.legend(loc='lower left', bbox_to_anchor=(1, 0.25, 0.5, 0.5))
            plt.xlabel('Number of epochs')
            plt.ylabel('Estimated loss')
            plt.show()

    @torch.no_grad()
    def compute_loss(self, tag):
        self.model.eval()
        loss_sum = 0
        if tag == 'train':
            _dataloader = self.dataloader_train
        elif tag == 'val':
            _dataloader = self.dataloader_val
        elif tag == 'all':
            _dataloader = self.dataloader_all
        else:
            raise ValueError('Invalid tag!')
        _dataset = _dataloader.dataset
        for data_tensors in _dataloader:
            x, y0, y1, y2, y3, y4 = data_tensors
            if self.use_cuda and torch.cuda.is_available():
                x = x.cuda()
                y0 = y0.cuda()
                y1 = y1.cuda()
                y2 = y2.cuda()
                y3 = y3.cuda()
                y4 = y4.cuda()
            loss = self.model(x, y0, y1, y2, y3, y4)
            loss_sum += torch.sum(loss).item()
        loss_mean = loss_sum / len(_dataset)
        self.model.train()
        return loss_mean

    @torch.no_grad()
    def loss_history_plot(self, epochs_override=None):
        len_ticks = len(self.loss_history_train)
        if epochs_override is None:
            x_axis = np.linspace(1, self.epochs, len_ticks)
        else:
            x_axis = np.linspace(1, epochs_override, len_ticks)
        plt.figure()
        if self.train_val:
            assert (len(self.loss_history_train) == len(self.loss_history_val))
            plt.plot(x_axis, self.loss_history_train, label='Training set')
            plt.plot(x_axis, self.loss_history_val, label='Validation set')
            plt.legend()
        else:
            plt.plot(x_axis, self.loss_history_train)

        plt.xlabel('Number of epochs')
        plt.ylabel('Loss')
        plt.show()

    @torch.no_grad()
    def get_probs(self, tag, return_hier_pred=False, return_probs_only=False):
        self.model.eval()
        predictions_tem = []
        if return_hier_pred:
            hier_pred_tem = []
        if tag == 'test':
            _dataloader = self.dataloader_test
            for data_tensors in _dataloader:
                x = data_tensors
                if self.use_cuda and torch.cuda.is_available():
                    x = x.cuda()
                y_concat_prob = self.model.get_concat_probs(x)
                predictions_tem += [y_concat_prob]
            predictions_array = torch.cat(predictions_tem).detach().cpu().numpy()
            return predictions_array
        else:
            ground_truth = []
            if tag == 'train':
                _dataloader = self.dataloader_train
            elif tag == 'val':
                _dataloader = self.dataloader_val
            elif tag == 'all':
                _dataloader = self.dataloader_all
            else:
                raise ValueError('Invalid tag!')
            for data_tensors in _dataloader:
                if not return_probs_only:
                    x, y0, y1, y2, y3, y4 = data_tensors
                    if self.use_cuda and torch.cuda.is_available():
                        x = x.cuda()
                        y0 = y0.cuda()
                        y1 = y1.cuda()
                        y2 = y2.cuda()
                        y3 = y3.cuda()
                        y4 = y4.cuda()
                    ground_truth += [torch.cat((y0.long(),
                                                y1.long(),
                                                F.one_hot(y2, num_classes=6).squeeze()[:, 1:].long(),
                                                y3.long(),
                                                y4.long()), dim=1)]
                else:
                    x = data_tensors
                    if self.use_cuda and torch.cuda.is_available():
                        x = x.cuda()
                if return_hier_pred:
                    y_concat_prob, y_group_pred = self.model.get_concat_probs(x, return_hier_pred=return_hier_pred)
                    hier_pred_tem += [y_group_pred]
                else:
                    y_concat_prob = self.model.get_concat_probs(x, return_hier_pred=return_hier_pred)
                predictions_tem += [y_concat_prob]
            predictions_array = torch.cat(predictions_tem).detach().cpu().numpy()
            if return_hier_pred:
                hier_pred = torch.cat(hier_pred_tem).detach().cpu().numpy()
            self.model.train()
            if not return_probs_only:
                if return_hier_pred:
                    return torch.cat(ground_truth).detach().cpu().numpy(), predictions_array, hier_pred
                else:
                    return torch.cat(ground_truth).detach().cpu().numpy(), predictions_array
            else:
                if return_hier_pred:
                    return predictions_array, hier_pred
                else:
                    return predictions_array

    @torch.no_grad()
    def make_predictions(self, tag, return_pred_only=False,
                         thre=(0.08, 0.08, 0.08, 0.08), upper_bound=(3, 4, 17, 18), lower_bound=3,
                         boundary=([0, 100], [100, 781], [786, 2706], [2706, 3474])):
        self.model.eval()
        predictions_tem = []
        if tag == 'test':
            _dataloader = self.dataloader_test
            for data_tensors in _dataloader:
                x = data_tensors
                if self.use_cuda and torch.cuda.is_available():
                    x = x.cuda()
                y_concat_pred = regularized_pred(self.model.get_concat_probs(x).detach().cpu().numpy(),
                                                 thre=thre,
                                                 upper_bound=upper_bound, lower_bound=lower_bound, boundary=boundary)
                predictions_tem += [y_concat_pred]
            predictions_array = np.concatenate(predictions_tem)
            return predictions_array
        else:
            ground_truth = []
            if tag == 'train':
                _dataloader = self.dataloader_train
            elif tag == 'val':
                _dataloader = self.dataloader_val
            elif tag == 'all':
                _dataloader = self.dataloader_all
            else:
                raise ValueError('Invalid tag!')
            for data_tensors in _dataloader:
                if not return_pred_only:
                    x, y0, y1, y2, y3, y4 = data_tensors
                    if self.use_cuda and torch.cuda.is_available():
                        x = x.cuda()
                        y0 = y0.cuda()
                        y1 = y1.cuda()
                        y2 = y2.cuda()
                        y3 = y3.cuda()
                        y4 = y4.cuda()
                    ground_truth += [torch.cat((y0.long(),
                                                y1.long(),
                                                F.one_hot(y2, num_classes=6).squeeze()[:, 1:].long(),
                                                y3.long(),
                                                y4.long()), dim=1)]
                else:
                    x = data_tensors
                    if self.use_cuda and torch.cuda.is_available():
                        x = x.cuda()
                y_concat_pred = regularized_pred(self.model.get_concat_probs(x).detach().cpu().numpy(),
                                                 thre=thre,
                                                 upper_bound=upper_bound, lower_bound=lower_bound, boundary=boundary)
                predictions_tem += [y_concat_pred]
            predictions_array = np.concatenate(predictions_tem)
            self.model.train()
            if not return_pred_only:
                return torch.cat(ground_truth).detach().cpu().numpy(), predictions_array
            else:
                return predictions_array

    @torch.no_grad()
    def compute_hier_acc(self, tag):
        ground_truth = []
        hier_pred_tem = []
        if tag == 'train':
            _dataloader = self.dataloader_train
        elif tag == 'val':
            _dataloader = self.dataloader_val
        elif tag == 'all':
            _dataloader = self.dataloader_all
        else:
            raise ValueError('Invalid tag!')
        for data_tensors in _dataloader:
            x, _, _, y2, _, _ = data_tensors
            if self.use_cuda and torch.cuda.is_available():
                x = x.cuda()
                y2 = y2.cuda()
            gt_oh = F.one_hot(y2, num_classes=6).squeeze()[:, 1:].long()
            ground_truth += [torch.sum(gt_oh, dim=-1)]
            if self.model.hierarchical:
                _, y_group_pred = self.model.get_concat_probs(x, return_hier_pred=True)
                hier_pred_tem += [y_group_pred]
            else:
                y_concat_prob = self.model.get_concat_probs(x, return_hier_pred=False)
                hier_pred_tem += [torch.sum(y_concat_prob[:, 781:786], dim=-1)]
        if self.model.hierarchical:
            return np.mean(torch.cat(ground_truth).detach().cpu().numpy() ==
                           np.where(torch.cat(hier_pred_tem).detach().cpu().numpy() == 1)[1])
        else:
            return np.mean(torch.cat(ground_truth).detach().cpu().numpy() ==
                           torch.cat(hier_pred_tem).detach().cpu().numpy())

    @torch.no_grad()
    def compute_accuracy(self, tag,
                         thre=(0.08, 0.08, 0.08, 0.08), upper_bound=(3, 4, 17, 18), lower_bound=3,
                         boundary=([0, 100], [100, 781], [786, 2706], [2706, 3474])):
        y_true, y_pred = self.make_predictions(tag=tag, thre=thre,
                                               upper_bound=upper_bound, lower_bound=lower_bound, boundary=boundary)
        f_beta = [fbeta_score(y_true[i, :], y_pred[i, :], beta=2) for i in range(y_true.shape[0])]
        return sum(f_beta) / len(f_beta)

    @torch.no_grad()
    def accuracy_history_plot(self, epochs_override=None):
        len_ticks = len(self.accuracy_history_train)
        if epochs_override is None:
            x_axis = np.linspace(1, self.epochs, len_ticks)
        else:
            x_axis = np.linspace(1, epochs_override, len_ticks)
        plt.figure()
        if self.train_val:
            assert (len(self.accuracy_history_train) == len(self.accuracy_history_val))
            plt.plot(x_axis, self.accuracy_history_train, label='Training set')
            plt.plot(x_axis, self.accuracy_history_val, label='Validation set')
            plt.legend()
        else:
            plt.plot(x_axis, self.accuracy_history_train)

        plt.xlabel('Number of epochs')
        plt.ylabel('Accuracy')
        plt.show()

The inference section will use trained model to predict labels in test set.

In [None]:
torch.cuda.is_available()

Let's define the dataset, model and trainer.

In [None]:
_root_path = '/kaggle/input/imet-2020-fgvc7'
path = f'{_root_path}/train'
test_path = f'{_root_path}/test'
data_info_path = f'{_root_path}/train.csv'
labels_info_path = f'{_root_path}/labels.csv'
test_csv_path = f'{_root_path}/sample_submission.csv'

dataset = TrainValSet(train_val_split=False, train_transform='val',
                      path=path, data_info_path=data_info_path, labels_info_path=labels_info_path,
                      test_path=test_path, test_csv_path=test_csv_path)

tag='50'
weight_path = f'/kaggle/input/resnet{tag}/resnet{tag}.pth'
model = ArtCV(tag=tag, weight_path=weight_path, freeze_cnn=True, dropout_rate=0, weights=(1, 1, 10, 1, 1),
              classifier_layers=(2,2,2,2,2), focal_loss=False, 
              alpha=(0.25, 0.25, 0.25, 0.25), alpha_mc=(0.25, 0.75, 0.75, 0.75, 0.75, 0.75),
              gamma_mc=2, gamma=(2, 2, 2, 2), alpha_t=True,
              hierarchical=True, label_groups=(0, 1, 2, 1, 1, 1), alpha_group=(1, 1), gamma_group=2, weight_group=5,
              hierarchical_ml=(False, True, True, True), label_groups3=_label_groups3_40)

trainer = Trainer(model, dataset, batch_size_train=256, batch_size_val=256, batch_size_all=64,
                  epochs=15, extra_epochs_mc=0, head_log=True, compute_acc=False,
                  monitor_frequency=False,
                  dataloader_train_kwargs={'num_workers':2}, dataloader_val_kwargs={'num_workers':2},
                  dataloader_all_kwargs={'num_workers':2})

file_name = 'best0526/frozen_hier_all_no_crop_resnet50_2layer_12_epochs_reduced_lr.model.pkl'
save_path = f'/kaggle/input/{file_name}'
model.load_state_dict(torch.load(save_path))

Let's evaluate the results.

In [None]:
predictions_array=trainer.make_predictions(tag='test', thre=(0.1, 0.1, 0.1, 0.1), upper_bound=(3, 4, 10, 10), lower_bound=3)

In [None]:
submission = pd.read_csv(f'{_root_path}/sample_submission.csv')
for i, one_hot in enumerate(predictions_array):
    ids = np.where(one_hot)[0]
    submission.iloc[i].attribute_ids = ' '.join([str(x) for x in ids])

submission.head()

In [None]:
submission.to_csv('submission.csv', index=False)

## Acknowledge
<br>If you find this notebook helpful, please upvote.
<br>Team member: 
<br>[yunchenyang](https://www.kaggle.com/yunchenyang), email: yunchenyang@hotmail.com; 
<br>[ytisserant](https://www.kaggle.com/ytisserant), email: ytisserant@gmail.com