In [1]:
import os
from datetime import datetime
from typing import Dict, Tuple, Any
from tqdm import tqdm
import pickle

import math
import numpy as np
import pandas as pd

import cv2
import albumentations
from torch.utils.data import Dataset

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.autograd import Variable
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import OneCycleLR, ReduceLROnPlateau

import timm

In [2]:
# ---------------------------------------
# parameters

MODEL_DIR = './model_checkpoints/'
DATA_DIR = '../input/'
LOG_DIR = './logs/'
DEVICE = 'cuda:0'
MODEL_NAME = 'rexnet_200'

TRAIN_STEP = 0
FOLD = 0

IMAGE_SIZE = 256
BATCH_SIZE = 64
NUM_EPOCHS = 10
NUM_WORKERS = 4
LR = 1e-4
USE_AMP = True


In [3]:
load = torch.load('./model_checkpoints/effnetb3_600_fold1_epoch1.pth')
model_only_weight = {k[7:] if k.startswith('module.') else k: v for k, v in load['model_state_dict'].items()}

In [4]:
model = EffnetB3_Landmark(out_dim=81313).cuda()
model.load_state_dict(model_only_weight)
model = nn.DataParallel(model)

In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=.001)

In [7]:
optimizer.load_state_dict(load['optimizer_state_dict'])

In [2]:
class Swish(torch.autograd.Function):

    @staticmethod
    def forward(ctx, i):
        result = i * torch.sigmoid(i)
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_variables[0]
        sigmoid_i = torch.sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))


class Swish_module(nn.Module):
    @autocast()
    def forward(self, x):
        return Swish.apply(x)


class DenseCrossEntropy(nn.Module):
    @autocast()
    def forward(self, x, target):
        x = x.float()
        target = target.float()
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)

        loss = -logprobs * target
        loss = loss.sum(-1)
        return loss.mean()


class ArcMarginProduct_subcenter(nn.Module):
    def __init__(self, in_features, out_features, k=3):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(out_features*k, in_features))
        self.reset_parameters()
        self.k = k
        self.out_features = out_features
        
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
    
    @autocast()
    def forward(self, features):
        cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
        cosine_all = cosine_all.view(-1, self.out_features, self.k)
        cosine, _ = torch.max(cosine_all, dim=2)
        return cosine   


class ArcFaceLossAdaptiveMargin(nn.modules.Module):
    def __init__(self, margins, s=30.0):
        super().__init__()
        self.crit = DenseCrossEntropy()
        self.s = s
        self.margins = margins
            
    @autocast()
    def forward(self, logits, labels, out_dim):
        ms = []
        ms = self.margins[labels.cpu().numpy()]
        cos_m = torch.from_numpy(np.cos(ms)).float().cuda()
        sin_m = torch.from_numpy(np.sin(ms)).float().cuda()
        th = torch.from_numpy(np.cos(math.pi - ms)).float().cuda()
        mm = torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda()
        labels = F.one_hot(labels, out_dim).float()
        logits = logits.float()
        cosine = logits
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * cos_m.view(-1, 1) - sine * sin_m.view(-1, 1)
        phi = torch.where(cosine > th.view(-1, 1), phi, cosine - mm.view(-1, 1))
        output = (labels * phi) + ((1.0 - labels) * cosine)
        output *= self.s
        loss = self.crit(output, labels)
        return loss


def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6, p_trainable=True):
        super(GeM,self).__init__()
        if p_trainable:
            self.p = Parameter(torch.ones(1)*p)
        else:
            self.p = p
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps)
    
    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'


class EffnetB3_Landmark(nn.Module):

    def __init__(self, out_dim, load_pretrained=True):
        super().__init__()

        self.backbone = timm.create_model('tf_efficientnet_b3_ns', pretrained=True)
        self.feat = nn.Sequential(
            nn.Linear(self.backbone.num_features, 512, bias=True),
            nn.BatchNorm1d(512),
            Swish_module()
        )
        self.backbone.global_pool = GeM()
        self.backbone.classifier = nn.Identity()
        
        # self.swish = Swish_module()
        self.metric_classify = ArcMarginProduct_subcenter(512, out_dim)


    def extract(self, x):
        return self.backbone(x)[:, :, 0, 0]

    @autocast()
    def forward(self, x):
        x = self.extract(x)
        logits_m = self.metric_classify(self.feat(x))
        return logits_m