In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from fastai.vision.all import *
import sklearn.metrics as skm
from tqdm.notebook import tqdm
import sklearn.feature_extraction.text
from transformers import (BertTokenizer, BertModel,
                          DistilBertTokenizer, DistilBertModel)

In [None]:
from shopee_utils import *

In [None]:
PATH = Path('../input/shopee-product-matching')
model_file  = '../input/resnet-model/bert_model.pth'
if not PATH.is_dir():
    PATH = Path('/home/slex/data/shopee')
    model_file ='models/bert_model.pth'
model_name = 'cahya/bert-base-indonesian-522M'

In [None]:
TAKE_PAIRS = 5.1

In [None]:
train_df = pd.read_csv(PATH/'train.csv')
train_df = add_splits(train_df)

In [None]:
test_df = pd.read_csv(PATH/'test.csv')

In [None]:
class TitleTransform(Transform):
    def __init__(self):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained('../input/distilbert-base-indonesian')
        
        
    def encodes(self, row):
        text = row.title
        encodings = self.tokenizer(text, padding = 'max_length', max_length=50, truncation=True,return_tensors='pt')
        keys =['input_ids', 'attention_mask', 'token_type_ids'] 
        return tuple(encodings[key].squeeze() for key in keys)

In [None]:
class EmbsModel(nn.Module):
    def __init__(self, bert_model):
        super().__init__()
        self.bert_model = bert_model
    def forward(self, x):
        output = self.bert_model(*x)
        embeddings = output.last_hidden_state[:,0,:]
        return embeddings

In [None]:
tfm = TitleTransform()

data_block = DataBlock(
    blocks = (TransformBlock(type_tfms=tfm), 
              CategoryBlock(vocab=train_df.label_group.to_list())),
    splitter=ColSplitter(),
    get_y=ColReader('label_group'),
    )
dls = data_block.dataloaders(train_df, bs=128)


In [None]:
state_d = torch.load(model_file)

model = BertModel.from_pretrained('../input/distilbert-base-indonesian')

model.load_state_dict(state_d)

In [None]:
model = EmbsModel(model).cuda().eval()

## Verify on validataion set

In [None]:
# valid_embs, _ = embs_from_model(model, dls.valid)

# dists, inds = get_nearest(valid_embs, do_chunk(valid_embs))

# valid_df=train_df[train_df.is_valid==True].copy().reset_index()
# valid_df = add_target_groups(valid_df)

# pairs = sorted_pairs(dists, inds)[:len(valid_df)*10]

# _=build_from_pairs(pairs, valid_df.target.to_list())

## Run test set inference

In [None]:
# fake_test_df = train_df[['posting_id', 'image', 'image_phash', 'title', 'label_group']].copy()
# fake_test_df = pd.concat([fake_test_df, fake_test_df])
# fake_test_df = add_target_groups(fake_test_df)
# test_df = fake_test_df

In [None]:
test_dl = dls.test_dl(test_df)

In [None]:
test_embs, _ = embs_from_model(model, test_dl)

In [None]:
dists, inds = get_nearest(test_embs, do_chunk(test_embs))



pairs = sorted_pairs(dists, inds)[:int(len(test_df)*TAKE_PAIRS)]
if 'target' in test_df.columns.to_list():
    _=build_from_pairs(pairs, test_df.target.to_list())

In [None]:
groups = [[] for _ in range(len(test_df))]
for x,y,v in pairs:
    groups[x].append(y)

In [None]:
matches = [' '.join(test_df.iloc[g].posting_id.to_list()) for g in groups]
test_df['matches'] = matches

test_df[['posting_id','matches']].to_csv('submission.csv',index=False)

In [None]:
pd.read_csv('./submission.csv')