In [35]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [36]:
from fastai.vision.all import *
import sklearn.metrics as skm
from tqdm.notebook import tqdm
import sklearn.feature_extraction.text
from transformers import (BertTokenizer, BertModel,
                          DistilBertTokenizer, DistilBertModel)

In [37]:
from shopee_utils import *
from train_utils import *

In [38]:
PATH = Path('../input/shopee-product-matching')
model_file  = '../input/resnet-model/bert814.pth'
if not PATH.is_dir():
    PATH = Path('/home/slex/data/shopee')
    model_file ='models/bert814.pth'
BERT_PATH = './bert_indonesian'
BERT_PATH='cahya/distilbert-base-indonesian'

In [39]:
train_df = pd.read_csv(PATH/'train_split.csv')
train_df['is_valid'] = train_df.split==0

In [40]:
def get_img_file(row):
    img =row.image
    fn  = PATH/'train_images'/img
    if not fn.is_file():
        fn = PATH/'test_images'/img
    return fn

In [41]:
class TitleTransform(Transform):
    def __init__(self):
        super().__init__()
        self.tokenizer = DistilBertTokenizer.from_pretrained(BERT_PATH)
        
        
    def encodes(self, row):
        text = row.title
        encodings = self.tokenizer(text, padding = 'max_length', max_length=50, truncation=True,return_tensors='pt')
        keys =['input_ids', 'attention_mask']#, 'token_type_ids'] 
        return tuple(encodings[key].squeeze() for key in keys)

In [42]:
text_tfm = TitleTransform()

data_block = DataBlock(blocks = (ImageBlock(), TransformBlock(type_tfms=text_tfm), 
                                 CategoryBlock(vocab=train_df.label_group.to_list())),
                 splitter=ColSplitter(),
                 #splitter=RandomSplitter(),
                 get_y=ColReader('label_group'),
                 get_x=[get_img_file,lambda x:x],
                 item_tfms=Resize(460),
                 batch_tfms=aug_transforms(size=224, min_scale=0.75),
                 )
dls = data_block.dataloaders(train_df, bs=64,num_workers=16)

In [43]:
b_im,b_txt,by=dls.one_batch()

In [44]:
b_im.shape

torch.Size([64, 3, 224, 224])

In [45]:
class ArcFaceClassifier(nn.Module):
    def __init__(self, in_features, output_classes):
        super().__init__()
        self.W = nn.Parameter(torch.Tensor(in_features, output_classes))
        nn.init.kaiming_uniform_(self.W)
    def forward(self, x):
        x_norm = F.normalize(x)
        W_norm = F.normalize(self.W, dim=0)
        return x_norm @ W_norm

In [61]:
class MultiModalModel(nn.Module):
    def __init__(self, img_stem, text_stem):
        super().__init__()
        emb_dim = 1024+768
        self.img_stem = img_stem
        self.text_stem = text_stem
        self.regularizers = nn.Sequential(
            nn.BatchNorm1d(emb_dim),
            nn.Dropout()
        )
        self.classifier=ArcFaceClassifier(emb_dim, dls.c)
        #self.classifier=nn.Linear(emb_dim, dls.c)
        self.outputEmbs = False
    def forward(self, img_x, text_x):
        img_out = self.img_stem(img_x)
        text_out = self.text_stem(*text_x)
        text_out = text_out.last_hidden_state[:,0,:]
        embs = torch.cat([img_out, text_out],dim=1)
        embs = self.regularizers(embs)
        if self.outputEmbs:
            return embs
        return self.classifier(embs)

In [22]:
def new_model():
    img_stem = nn.Sequential(create_body(resnet34,cut=-2), AdaptiveConcatPool2d(), Flatten())
    bert_model = DistilBertModel.from_pretrained(BERT_PATH)
    return MultiModalModel(img_stem, bert_model).cuda()

In [32]:
def split_2way(model):
    return L(params(model.img_stem)+params(model.text_stem),
            params(model.classifier)+params(model.regularizers))

In [59]:
learn = Learner(dls,new_model(), loss_func=arcface_loss, splitter=split_2way,cbs = F1FromEmbs, metrics=FakeMetric())


In [60]:
learn.fine_tune(20,1e-3,freeze_epochs=2)

epoch,train_loss,valid_loss,F1 embeddings,time
0,8.07724,,0.754148,01:19
1,4.304684,,0.787597,01:19


epoch,train_loss,valid_loss,F1 embeddings,time
0,2.869407,,0.807468,01:38
1,2.27852,,0.816945,01:40
2,1.53528,,0.826359,01:40
3,0.86095,,0.833761,01:41
4,0.336392,,0.837497,01:39
5,0.128855,,0.841752,01:41
6,0.060897,,0.842654,01:41
7,0.038779,,0.841713,01:38
8,0.023473,,0.842289,01:38
9,0.0139,,0.841026,01:38
