# Training
- Fine tune a pretrained model from `HuggingFace` using `fastai` and `blurr`
- Our 1st place solution used **ensembling of many models** by taking the **average of the prediction probabilities for each word** and **the entire dataset was used for training with no validation**

## Steps:
1. [Preprocessing](https://www.kaggle.com/nguyncaoduy/1-place-scl-ds-2021-voidandtwotsts-preprocess)
2. [Training](https://www.kaggle.com/nguyncaoduy/1-place-scl-ds-2021-voidandtwotsts-train) - This Notebook
3. [Ensembling](https://www.kaggle.com/nguyncaoduy/1-place-scl-ds-2021-voidandtwotsts-ensemble)

In [None]:
!pip install ohmeow-blurr==0.0.22 datasets==1.3.0 fsspec==0.8.5 -qq

In [None]:
# Change this 2 lines to use to another pretrained model
pretrained_model_name = 'xlm-roberta-large'
model_name = 'xlm-roberta-large'

In [None]:
# turn off multithreading to avoid deadlock
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
from transformers import *
from fastai.text.all import *

from blurr.data.all import *
from blurr.modeling.all import *

SEED = 42
set_seed(SEED, True)

In [None]:
import json

with open('../input/sclds2021preprocess/wordlist.json', 'r') as f:
    wordlist = json.load(f)

In [None]:
import ast
df_converters = {'tokens': ast.literal_eval, 'labels': ast.literal_eval}

train_df = pd.read_csv('../input/sclds2021preprocess/train.csv', converters=df_converters)
valid_df = pd.read_csv('../input/sclds2021preprocess/valid.csv', converters=df_converters)

In [None]:
len(train_df), len(valid_df)

In [None]:
labels = sorted(list(set([lbls for sublist in train_df.labels.tolist() for lbls in sublist])))
print(labels)

In [None]:
task = HF_TASKS_AUTO.TokenClassification
config = AutoConfig.from_pretrained(pretrained_model_name)
config.num_labels = len(labels)

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, 
                                                                               task=task, 
                                                                               config=config)
hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)

In [None]:
before_batch_tfm = HF_TokenClassBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
                                                     is_split_into_words=True, 
                                                     tok_kwargs={'return_special_tokens_mask': True})

blocks = (
    HF_TextBlock(before_batch_tfm=before_batch_tfm, input_return_type=HF_TokenClassInput), 
    HF_TokenCategoryBlock(vocab=labels)
)

def get_y(inp): return [(label, len(hf_tokenizer.tokenize(str(entity)))) for entity, label in zip(inp.tokens, inp.labels)]

In [None]:
db = DataBlock(
    blocks=blocks, 
    splitter=RandomSplitter(valid_pct=0.1, seed=SEED),
    get_x=ColReader('tokens'),
    get_y=get_y,
)

In [None]:
dls = db.dataloaders(train_df, bs=32)
dls.show_batch(dataloaders=dls)

In [None]:
@delegates()
class TokenCrossEntropyLossFlat(BaseLoss):
    "Same as `CrossEntropyLossFlat`, but for mutiple tokens output"
    y_int = True
    @use_kwargs_dict(keep=True, weight=None, ignore_index=-100, reduction='mean')
    def __init__(self, *args, axis=-1, **kwargs): super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
    def decodes(self, x):    return L([ i.argmax(dim=self.axis) for i in x ])
    def activation(self, x): return L([ F.softmax(i, dim=self.axis) for i in x ])

In [None]:
model = HF_BaseModelWrapper(hf_model)
loss_func = TokenCrossEntropyLossFlat()
opt_func = partial(Adam)
learn_cbs = [HF_BaseModelCallback]
fit_cbs = [HF_TokenClassMetricsCallback()]
splitter = hf_splitter

In [None]:
learn = Learner(dls, model, loss_func=loss_func, opt_func=opt_func, splitter=splitter, cbs=learn_cbs).to_fp16()

In [None]:
learn.unfreeze()

In [None]:
learn.fit_one_cycle(5, 1e-4, moms=(0.8, 0.7, 0.8), cbs=fit_cbs)

In [None]:
learn.recorder.plot_loss()

In [None]:
print(learn.token_classification_report)

In [None]:
learn.export(f'{model_name}.pkl')

# Evaluation
- This is only relevant during model selection and testing
- For the final training, full dataset is used so the accuracy below doesn't really reflect the power of the model.

In [None]:
@patch
def blurr_predict(self:Learner, items, rm_type_tfms=None):
    hf_before_batch_tfm = get_blurr_tfm(self.dls.before_batch)
    is_split_str = hf_before_batch_tfm.is_split_into_words and isinstance(items[0], str)
    is_df = isinstance(items, pd.DataFrame)
    if (not is_df and (is_split_str or not is_listy(items))): items = [items]
    dl = self.dls.test_dl(items, rm_type_tfms=rm_type_tfms, num_workers=0)
    with self.no_bar(): probs, _, decoded_preds = self.get_preds(dl=dl, with_input=False, with_decoded=True)
    trg_tfms = self.dls.tfms[self.dls.n_inp:]
    outs = []
    probs, decoded_preds = L(probs), L(decoded_preds)
    for i in range(len(items)):
        item_probs = [probs[i]]
        item_dec_preds = [decoded_preds[i]]
        item_dec_labels = tuplify([tfm.decode(item_dec_preds[tfm_idx]) for tfm_idx, tfm in enumerate(trg_tfms)])
        outs.append((item_dec_labels, item_dec_preds, item_probs))
    return outs

In [None]:
from string import punctuation

def reconstruct(num, pred, raw_tokens, raw_address):
    def complete_word(x):
        y = x.strip().strip(punctuation)
        if y != '' and y in wordlist:
            x = x.replace(y, wordlist[y])
        return x
    
    def normalize_bracket(x):
        if '(' in x and ')' not in x:
            x = x + ')'
        elif ')' in x and '(' not in x:
            x = '(' + x
        return x
    
    ans = ['/'] * num
    for idx in range(num):
        res = pred[idx]
        start_poi, end_poi = -1, -1
        start_str, end_str = -1, -1
        for i in range(len(res[0])):
            if 'POI' in res[1][i]:
                if start_poi == -1: start_poi = i
                end_poi = i
            if 'STR' in res[1][i]:
                if start_str == -1: start_str = i
                end_str = i
        
        if start_poi != -1:
            txt1 = raw_address[idx]
            for i in range(start_poi):
                txt1 = txt1[len(raw_tokens[idx][i]):].strip()
            for i in range(len(raw_tokens[idx]) - 1, end_poi, -1):
                txt1 = txt1[:-len(raw_tokens[idx][i])].strip()
            
            txt1_check = ''.join(raw_tokens[idx][start_poi:end_poi + 1]).replace(' ', '')
            assert txt1.replace(' ', '') == txt1_check
            
            last = len(txt1)
            for i in range(end_poi, start_poi - 1, -1):
                while last > 0 and txt1[last - 1] == ' ':
                    last -= 1
                assert last >= len(raw_tokens[idx][i])
                last -= len(raw_tokens[idx][i])
                if 'SHORT' in res[1][i]:
                    txt1 = txt1[:last] + complete_word(raw_tokens[idx][i]) + txt1[last + len(raw_tokens[idx][i]):]
        else:
            txt1 = ''
        
        if start_str != -1:
            txt2 = raw_address[idx]
            for i in range(start_str):
                txt2 = txt2[len(raw_tokens[idx][i]):].strip()
            for i in range(len(raw_tokens[idx]) - 1, end_str, -1):
                txt2 = txt2[:-len(raw_tokens[idx][i])].strip()
            
            txt2_check = ''.join(raw_tokens[idx][start_str:end_str + 1]).replace(' ', '')
            assert txt2.replace(' ', '') == txt2_check
            
            last = len(txt2)
            for i in range(end_str, start_str - 1, -1):
                while last > 0 and txt2[last - 1] == ' ':
                    last -= 1
                assert last >= len(raw_tokens[idx][i])
                last -= len(raw_tokens[idx][i])
                if 'SHORT' in res[1][i]:
                    txt2 = txt2[:last] + complete_word(raw_tokens[idx][i]) + txt2[last + len(raw_tokens[idx][i]):]
        else:
            txt2 = ''
        
        txt1 = txt1.strip(punctuation)
        txt2 = txt2.strip(punctuation)
        txt1 = normalize_bracket(txt1)
        txt2 = normalize_bracket(txt2)
        
        ans[idx] = (txt1 + '/' + txt2)
    
    return ans

In [None]:
def show_diff(df):
    MAX_ROWS = 50
    CNT = 0
    for idx in range(len(df)):
        if CNT == MAX_ROWS: break
        row = df.iloc[idx]
        if row['POI/street'] != row['pred']:
            CNT += 1
            print(idx, row['id'], row['POI/street'], 'vs', row['pred'])

In [None]:
def calc_acc(df):
    return df.loc[valid_df['pred'] == df['POI/street'], 'id'].count() / len(df)

In [None]:
raw_tokens = list(valid_df['tokens'])
raw_address = list(valid_df['raw_address'])

In [None]:
raw_pred = learn.blurr_predict_tokens(raw_tokens)

In [None]:
pred = reconstruct(len(valid_df), raw_pred, raw_tokens, raw_address)

In [None]:
valid_df['pred'] = pred
valid_df.head()

In [None]:
# Final evaluation with the same metric used for the competition
calc_acc(valid_df)

In [None]:
show_diff(valid_df)