# Training
- Make inference using our finetuned models
- Our 1st place solution used **ensembling of many models** by taking the **average of the prediction probabilities for each word** and **the entire dataset was used for training with no validation**

## Steps:
1. [Preprocessing](https://www.kaggle.com/nguyncaoduy/1-place-scl-ds-2021-voidandtwotsts-preprocess)
2. [Training](https://www.kaggle.com/nguyncaoduy/1-place-scl-ds-2021-voidandtwotsts-train)
3. [Ensembling](https://www.kaggle.com/nguyncaoduy/1-place-scl-ds-2021-voidandtwotsts-ensemble) - This Notebook

In [None]:
!pip install ohmeow-blurr==0.0.22 datasets==1.3.0 fsspec==0.8.5 -qq

In [None]:
# turn off multithreading to avoid deadlock
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
from transformers import *
from fastai.text.all import *

from blurr.data.all import *
from blurr.modeling.all import *

SEED = 42
set_seed(SEED, True)

In [None]:
import json

with open('../input/sclds2021preprocess/wordlist.json', 'r') as f:
    wordlist = json.load(f)

In [None]:
import ast
df_converters = {'tokens': ast.literal_eval, 'labels': ast.literal_eval}

valid_df = pd.read_csv('../input/sclds2021preprocess/valid.csv', converters=df_converters)

In [None]:
len(valid_df)

In [None]:
model_path = Path('../input/sclds2021train')

In [None]:
# Re-define certain things for 'load_learner' to work

@delegates()
class TokenCrossEntropyLossFlat(BaseLoss):
    "Same as `CrossEntropyLossFlat`, but for mutiple tokens output"
    y_int = True
    @use_kwargs_dict(keep=True, weight=None, ignore_index=-100, reduction='mean')
    def __init__(self, *args, axis=-1, **kwargs): super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
    def decodes(self, x):    return L([ i.argmax(dim=self.axis) for i in x ])
    def activation(self, x): return L([ F.softmax(i, dim=self.axis) for i in x ])

def get_y(inp): return [(label, len(hf_tokenizer.tokenize(str(entity)))) for entity, label in zip(inp.tokens, inp.labels)]

# Evaluation
- This is only relevant during model selection and testing
- For the final training, full dataset is used so the accuracy below doesn't really reflect the power of the model.

In [None]:
@patch
def blurr_predict(self:Learner, items, rm_type_tfms=None):
    hf_before_batch_tfm = get_blurr_tfm(self.dls.before_batch)
    is_split_str = hf_before_batch_tfm.is_split_into_words and isinstance(items[0], str)
    is_df = isinstance(items, pd.DataFrame)
    if (not is_df and (is_split_str or not is_listy(items))): items = [items]
    dl = self.dls.test_dl(items, rm_type_tfms=rm_type_tfms, num_workers=0)
    with self.no_bar(): probs, _, decoded_preds = self.get_preds(dl=dl, with_input=False, with_decoded=True)
    trg_tfms = self.dls.tfms[self.dls.n_inp:]
    outs = []
    probs, decoded_preds = L(probs), L(decoded_preds)
    for i in range(len(items)):
        item_probs = [probs[i]]
        item_dec_preds = [decoded_preds[i]]
        item_dec_labels = tuplify([tfm.decode(item_dec_preds[tfm_idx]) for tfm_idx, tfm in enumerate(trg_tfms)])
        outs.append((item_dec_labels, item_dec_preds, item_probs))
    return outs

In [None]:
from string import punctuation

def reconstruct(num, pred, raw_tokens, raw_address):
    def complete_word(x):
        y = x.strip().strip(punctuation)
        if y != '' and y in wordlist:
            x = x.replace(y, wordlist[y])
        return x
    
    def normalize_bracket(x):
        if '(' in x and ')' not in x:
            x = x + ')'
        elif ')' in x and '(' not in x:
            x = '(' + x
        return x
    
    ans = ['/'] * num
    for idx in range(num):
        res = pred[idx]
        start_poi, end_poi = -1, -1
        start_str, end_str = -1, -1
        for i in range(len(res[0])):
            if 'POI' in res[1][i]:
                if start_poi == -1: start_poi = i
                end_poi = i
            if 'STR' in res[1][i]:
                if start_str == -1: start_str = i
                end_str = i
        
        if start_poi != -1:
            txt1 = raw_address[idx]
            for i in range(start_poi):
                txt1 = txt1[len(raw_tokens[idx][i]):].strip()
            for i in range(len(raw_tokens[idx]) - 1, end_poi, -1):
                txt1 = txt1[:-len(raw_tokens[idx][i])].strip()
            
            txt1_check = ''.join(raw_tokens[idx][start_poi:end_poi + 1]).replace(' ', '')
            assert txt1.replace(' ', '') == txt1_check
            
            last = len(txt1)
            for i in range(end_poi, start_poi - 1, -1):
                while last > 0 and txt1[last - 1] == ' ':
                    last -= 1
                assert last >= len(raw_tokens[idx][i])
                last -= len(raw_tokens[idx][i])
                if 'SHORT' in res[1][i]:
                    txt1 = txt1[:last] + complete_word(raw_tokens[idx][i]) + txt1[last + len(raw_tokens[idx][i]):]
        else:
            txt1 = ''
        
        if start_str != -1:
            txt2 = raw_address[idx]
            for i in range(start_str):
                txt2 = txt2[len(raw_tokens[idx][i]):].strip()
            for i in range(len(raw_tokens[idx]) - 1, end_str, -1):
                txt2 = txt2[:-len(raw_tokens[idx][i])].strip()
            
            txt2_check = ''.join(raw_tokens[idx][start_str:end_str + 1]).replace(' ', '')
            assert txt2.replace(' ', '') == txt2_check
            
            last = len(txt2)
            for i in range(end_str, start_str - 1, -1):
                while last > 0 and txt2[last - 1] == ' ':
                    last -= 1
                assert last >= len(raw_tokens[idx][i])
                last -= len(raw_tokens[idx][i])
                if 'SHORT' in res[1][i]:
                    txt2 = txt2[:last] + complete_word(raw_tokens[idx][i]) + txt2[last + len(raw_tokens[idx][i]):]
        else:
            txt2 = ''
        
        txt1 = txt1.strip(punctuation)
        txt2 = txt2.strip(punctuation)
        txt1 = normalize_bracket(txt1)
        txt2 = normalize_bracket(txt2)
        
        ans[idx] = (txt1 + '/' + txt2)
    
    return ans

def reconstruct_ensemble(num, pred, raw_tokens, raw_address):
    def complete_word(x):
        y = x.strip().strip(punctuation)
        if y != '' and y in wordlist:
            x = x.replace(y, wordlist[y])
        return x
    
    def normalize_bracket(x):
        if '(' in x and ')' not in x:
            x = x + ')'
        elif ')' in x and '(' not in x:
            x = '(' + x
        return x
    
    ans = ['/'] * num
    for idx in range(num):
        res = pred[idx]
        start_poi, end_poi = -1, -1
        start_str, end_str = -1, -1
        for i in range(len(res)):
            if 'POI' in res[i]:
                if start_poi == -1: start_poi = i
                end_poi = i
            if 'STR' in res[i]:
                if start_str == -1: start_str = i
                end_str = i
        
        if start_poi != -1:
            txt1 = raw_address[idx]
            for i in range(start_poi):
                txt1 = txt1[len(raw_tokens[idx][i]):].strip()
            for i in range(len(raw_tokens[idx]) - 1, end_poi, -1):
                txt1 = txt1[:-len(raw_tokens[idx][i])].strip()
            
            txt1_check = ''.join(raw_tokens[idx][start_poi:end_poi + 1]).replace(' ', '')
            assert txt1.replace(' ', '') == txt1_check
            
            last = len(txt1)
            for i in range(end_poi, start_poi - 1, -1):
                while last > 0 and txt1[last - 1] == ' ':
                    last -= 1
                assert last >= len(raw_tokens[idx][i])
                last -= len(raw_tokens[idx][i])
                if 'SHORT' in res[i]:
                    txt1 = txt1[:last] + complete_word(raw_tokens[idx][i]) + txt1[last + len(raw_tokens[idx][i]):]
        else:
            txt1 = ''
        
        if start_str != -1:
            txt2 = raw_address[idx]
            for i in range(start_str):
                txt2 = txt2[len(raw_tokens[idx][i]):].strip()
            for i in range(len(raw_tokens[idx]) - 1, end_str, -1):
                txt2 = txt2[:-len(raw_tokens[idx][i])].strip()
            
            txt2_check = ''.join(raw_tokens[idx][start_str:end_str + 1]).replace(' ', '')
            assert txt2.replace(' ', '') == txt2_check
            
            last = len(txt2)
            for i in range(end_str, start_str - 1, -1):
                while last > 0 and txt2[last - 1] == ' ':
                    last -= 1
                assert last >= len(raw_tokens[idx][i])
                last -= len(raw_tokens[idx][i])
                if 'SHORT' in res[i]:
                    txt2 = txt2[:last] + complete_word(raw_tokens[idx][i]) + txt2[last + len(raw_tokens[idx][i]):]
        else:
            txt2 = ''
        
        txt1 = txt1.strip(punctuation)
        txt2 = txt2.strip(punctuation)
        txt1 = normalize_bracket(txt1)
        txt2 = normalize_bracket(txt2)
        
        ans[idx] = (txt1 + '/' + txt2)
    
    return ans

In [None]:
def calc_acc(df):
    return df.loc[valid_df['pred'] == df['POI/street'], 'id'].count() / len(df)

In [None]:
raw_tokens = list(valid_df['tokens'])
raw_address = list(valid_df['raw_address'])

In [None]:
raw_avg_pred = []

for model in model_path.ls():
    learn = load_learner(model)
    raw_pred = learn.blurr_predict_tokens(raw_tokens)
    raw_avg_pred.append([raw[3] for raw in raw_pred])
    pred = reconstruct(len(valid_df), raw_pred, raw_tokens, raw_address)
    valid_df['pred'] = pred
    score = calc_acc(valid_df)
    print(f'{model.name} - {score:.5f}')

In [None]:
raw_ensemble_pred = [(sum(col))/len(col) for col in zip(*raw_avg_pred)]
raw_ensemble_pred = [pred.argmax(-1) for pred in raw_ensemble_pred]
raw_ensemble_pred = learn.dls.vocab.map_ids(raw_ensemble_pred)
pred = reconstruct_ensemble(len(valid_df), raw_ensemble_pred, raw_tokens, raw_address)
valid_df['pred'] = pred
score = calc_acc(valid_df)
print(f'Ensemble - {score:.5f}')

# Submission

In [None]:
import re

def clean(s):
    res = re.sub(r'(\w)(\()(\w)', '\g<1> \g<2>\g<3>', s)
    res = re.sub(r'(\w)([),.:;]+)(\w)', '\g<1>\g<2> \g<3>', res)
    res = re.sub(r'(\w)(\.\()(\w)', '\g<1>. (\g<3>', res)
    res = re.sub(r'\s+', ' ', res)
    res = res.strip()
    return res

In [None]:
test_df = pd.read_csv('../input/scl-2021-ds/test.csv')
test_df['raw_address'] = test_df['raw_address'].apply(lambda x: x.strip())
test_df['tokens'] = test_df['raw_address'].apply(clean).str.split()
test_df.head()

In [None]:
raw_tokens = list(test_df['tokens'])
raw_address = list(test_df['raw_address'])

In [None]:
raw_avg_pred = []

for model in model_path.ls():
    learn = load_learner(model)
    raw_pred = learn.blurr_predict_tokens(raw_tokens)
    raw_avg_pred.append([raw[3] for raw in raw_pred])

In [None]:
raw_ensemble_pred = [(sum(col))/len(col) for col in zip(*raw_avg_pred)]
raw_ensemble_pred = [pred.argmax(-1) for pred in raw_ensemble_pred]
raw_ensemble_pred = learn.dls.vocab.map_ids(raw_ensemble_pred)
pred = reconstruct_ensemble(len(test_df), raw_ensemble_pred, raw_tokens, raw_address)
test_df['POI/street'] = pred

In [None]:
test_df.drop(columns=['raw_address', 'tokens'], inplace=True)
test_df.head()

In [None]:
test_df.to_csv('submission.csv', index=False)