In [1]:
from fastai.vision.all import *
from transformers import (BertTokenizer, AutoModel)
from train_utils import *

In [2]:
PATH = Path('/mnt/dysk25/data/shopee/')
BERT_PATH = './bert_indonesian'
OUTPUT_CLASSES = 11014
train_df = add_splits(pd.read_csv(PATH/'train.csv'))




In [3]:
class ArcFaceClassifier(nn.Module):
    def __init__(self, in_features, output_classes):
        super().__init__()
        self.initial_layers=nn.Sequential(
            nn.BatchNorm1d(in_features),
            nn.Dropout(.25))
        self.W = nn.Parameter(torch.Tensor(in_features, output_classes))
        nn.init.kaiming_uniform_(self.W)
    def forward(self, x):
        x = self.initial_layers(x)
        x_norm = F.normalize(x)
        W_norm = F.normalize(self.W, dim=0)
        return x_norm @ W_norm
    
class BertArcFace(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert_model = AutoModel.from_pretrained(BERT_PATH)
        self.classifier = ArcFaceClassifier(768, OUTPUT_CLASSES)
        self.outputEmbs = False
    def forward(self, x):
        output = self.bert_model(*x)
        embeddings =output.last_hidden_state[:,0,:]
        if self.outputEmbs:
            return embeddings
        return self.classifier(embeddings)


In [4]:
#Taken from https://www.kaggle.com/c/shopee-product-matching/discussion/233605#1278984
def string_escape(s, encoding='utf-8'):
    return s.encode('latin1').decode('unicode-escape').encode('latin1').decode(encoding)

class TitleTransform(Transform):
    def __init__(self):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained(BERT_PATH)
               
    def encodes(self, row):
        text = row.title
        text=string_escape(text)
        encodings = self.tokenizer(text, padding = 'max_length', max_length=100, truncation=True,return_tensors='pt')
        keys =['input_ids', 'attention_mask']
        return tuple(encodings[key].squeeze() for key in keys)

def get_text_dls():
    tfm = TitleTransform()

    data_block = DataBlock(
        blocks = (TransformBlock(type_tfms=tfm), 
                  CategoryBlock(vocab=train_df.label_group.to_list())),
        splitter=ColSplitter(),
        get_y=ColReader('label_group'),
        )
    return  data_block.dataloaders(train_df, bs=64)

In [5]:
def split_2way(model):
    return L(params(model.bert_model), params(model.classifier))

In [6]:
learn = Learner(get_text_dls(), BertArcFace(), splitter=split_2way, loss_func=arcface_loss,
               cbs = [F1FromEmbs],metrics=FakeMetric()).to_fp16()

In [7]:
%%time
learn.fine_tune(7, 1e-2)

epoch,train_loss,valid_loss,F1 embeddings,time
0,19.505682,,0.770121,00:54


epoch,train_loss,valid_loss,F1 embeddings,time
0,16.139585,,0.797005,01:12
1,14.134651,,0.810589,01:13
2,11.903905,,0.81672,01:12
3,9.619471,,0.819458,01:13
4,7.618684,,0.820368,01:12
5,6.240385,,0.819968,01:13
6,5.451341,,0.820975,01:12


CPU times: user 9min 14s, sys: 8.11 s, total: 9min 22s
Wall time: 9min 24s
