In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from pathlib import Path

transformers_path = Path("/opt/conda/lib/python3.7/site-packages/transformers")

import shutil

input_dir = Path("../input/deberta-v23-fix")

convert_file = input_dir / "convert_slow_tokenizer.py"
conversion_path = transformers_path/convert_file.name

if conversion_path.exists():
    conversion_path.unlink()

shutil.copy(convert_file, transformers_path)
deberta_v2_path = transformers_path / "models" / "deberta_v2"

for filename in ['tokenization_deberta_v2.py', 'tokenization_deberta_v2_fast.py']:
    filepath = deberta_v2_path/filename
    if filepath.exists():
        filepath.unlink()

    shutil.copy(input_dir/filename, filepath)

In [None]:
def get_deberta_tokenizer(model_name, use_trim_offsets=False):        
    from transformers.models.deberta_v2.tokenization_deberta_v2_fast import DebertaV2TokenizerFast
    return DebertaV2TokenizerFast.from_pretrained(model_name, trim_offsets=use_trim_offsets)

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, T5EncoderModel, AutoConfig
from tqdm import tqdm
import gc
import itertools

## Model

In [None]:
class ModelConfig(object):
    def __init__(self, args):
        self.pretrain_path = args.pretrain_path
        self.device = args.device
        self.dropout = 0.1


class Swish(nn.Module):
    def __init__(self, inplace=True):
        super(Swish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        if self.inplace:
            x.mul_(torch.sigmoid(x))
            return x
        else:
            return x * torch.sigmoid(x)


class Activation(nn.Module):
    def __init__(self, name="swish"):
        super(Activation, self).__init__()
        if name not in ["swish", "relu", "gelu"]:
            raise
        if name == "swish":
            self.net = Swish()
        elif name == "relu":
            self.net = nn.ReLU()
        elif name == "gelu":
            self.net = nn.GELU()
    
    def forward(self, x):
        return self.net(x)


class Dence(nn.Module):
    def __init__(self, i_dim, o_dim, activation="swish"):
        super(Dence, self).__init__()
        self.dence = nn.Sequential(
            nn.Linear(i_dim, o_dim),
            # nn.ReLU(),
            Activation(activation),
        )

    def forward(self, x):
        return self.dence(x)


class NBMEModel(nn.Module):
    def __init__(self, args):
        super(NBMEModel, self).__init__()
        config = AutoConfig.from_pretrained(args.pretrain_path)
        if "t5" in args.pretrain_path:
            self.transformer = T5EncoderModel.from_pretrained(args.pretrain_path)
        else:
            self.transformer = AutoModel.from_pretrained(args.pretrain_path)
        self.dropout = nn.Dropout(args.dropout)
        self.output = nn.Linear(config.hidden_size, 1)

    def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
        if token_type_ids:
            transformer_out = self.transformer(input_ids, attention_mask, token_type_ids)
        else:
            transformer_out = self.transformer(input_ids, attention_mask)
        sequence_output = transformer_out.last_hidden_state
        sequence_output = self.dropout(sequence_output)

        logits = self.output(sequence_output)
        logits_out = torch.sigmoid(logits)
        loss = 0
        if labels is not None:
            loss = self.loss(logits, labels)

        return logits_out, loss
    
    def loss(self, logits, labels):
        loss_fct = nn.BCEWithLogitsLoss()
        logits = logits.squeeze(-1)
        labels_mask = (labels != -1)
        # logging.debug(f"labels shape: {labels.shape}")
        # logging.debug(f"logits shape: {logits.shape}")
        # logging.debug(f"labels_mask shape: {labels_mask.shape}")
        active_logits = torch.masked_select(logits, labels_mask)
        true_labels = torch.masked_select(labels, labels_mask)
        loss = loss_fct(active_logits, true_labels)
        return loss

## Datasets

In [None]:
class NBMEDataset:
    def __init__(self, samples, tokenizer, mask_prob=0.0, mask_ratio=0.0):
        self.samples = samples
        self.length = len(samples)
        self.tokenizer = tokenizer
        self.mask_prob = mask_prob
        self.mask_ratio = mask_ratio

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        input_ids = self.samples[idx]["input_ids"]
        attention_mask = self.samples[idx]["attention_mask"]
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
        }

    
class Collate:
    def __init__(self, tokenizer, fix_length=-1, fixed=False):
        self.tokenizer = tokenizer
        self.fix_length = fix_length
        self.fixed = fixed

    def __call__(self, batch):
        output = dict()
        output["input_ids"] = [sample["input_ids"] for sample in batch]
        output["attention_mask"] = [sample["attention_mask"] for sample in batch]

        # calculate max token length of this batch
        batch_max = max([len(ids) for ids in output["input_ids"]])
        if self.fix_length != -1:
            batch_max = min(batch_max, self.fix_length)
            if self.fixed:
                batch_max = self.fix_length

        # add padding
        if self.tokenizer.padding_side == "right":
            output["input_ids"] = [s + (batch_max - len(s)) * [self.tokenizer.pad_token_id] for s in output["input_ids"]]
            output["attention_mask"] = [s + (batch_max - len(s)) * [0] for s in output["attention_mask"]]
        else:
            output["input_ids"] = [(batch_max - len(s)) * [self.tokenizer.pad_token_id] + s for s in output["input_ids"]]
            output["attention_mask"] = [(batch_max - len(s)) * [0] + s for s in output["attention_mask"]]

        # convert to tensors
        output["input_ids"] = torch.tensor(output["input_ids"], dtype=torch.long)
        output["attention_mask"] = torch.tensor(output["attention_mask"], dtype=torch.long)

        return output

## Predicter

In [None]:
class TrainerConfig(object):
    def __init__(self, args):
        self.model_load = args.model_load
        self.valid_batch_size = args.valid_batch_size
        self.fix_length = args.fix_length


class Predicter(object):
    def __init__(self, args):
        self.trainer_config = TrainerConfig(args)
        self.model_config = ModelConfig(args)
        self.device = args.device
    
    def build_model(self):
        self.model = NBMEModel(self.model_config)
    
    def model_init(self):
        self.build_model()
        self.model.to(self.device)
        self.model.eval()
    
    def model_load(self, path):
        self.model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
        self.model.to(self.device)
        self.model.eval()
    
    def get_logits(self, batch, return_loss=False):
        input_ids = batch["input_ids"].to(self.device)
        attention_mask = batch["attention_mask"].to(self.device)
        # logging.debug(f"attention_mask: {attention_mask}")
        if return_loss:
            labels = batch["labels"].to(self.device)
            # logging.debug(f"labels: {labels.tolist()}")
            # logging.debug(f"ids: {input_ids.tolist()}")
            logits, loss = self.model(input_ids, attention_mask, labels=labels)
            return logits, loss
        else:
            logits, _ = self.model(input_ids, attention_mask)
            return logits

    @torch.no_grad()
    def predict(self, valid_datasets, valid_collate):
        self.model.eval()
        valid_iter = torch.utils.data.DataLoader(valid_datasets, batch_size=self.trainer_config.valid_batch_size, collate_fn=valid_collate)
        preds = []
        PAD = torch.tensor([0], dtype=torch.float).unsqueeze(0)
        for batch in tqdm(valid_iter):
            logits = self.get_logits(batch).cpu().squeeze(-1)
            bs, length = logits.shape
            batch_pad = torch.cat([PAD] * bs, dim=0)
            logits = torch.cat([logits] + [batch_pad] * (512 - length), dim=1)
            preds.append(logits)
        preds = torch.cat(preds, dim=0)
        return preds

## Utils

In [None]:
def prepare_test_data(df, tokenizer, max_len):
    test_sample = []
    for _, row in df.iterrows():
        # if row["id"] == "10075_100":
        #     debug = True
        # else:
        #     debug = False
        pn_history = row["pn_history"]
        feature_text = row["feature_text"]
        encoder = tokenizer.encode_plus(
            pn_history, feature_text,
            add_special_tokens=True,
            return_offsets_mapping=False,
        )
        decoder = tokenizer.encode_plus(
            pn_history,
            add_special_tokens=True,
            return_offsets_mapping=True,
        )
        offset_mapping = decoder.offset_mapping
        test_sample.append({
            "input_ids": encoder.input_ids,
            "attention_mask": encoder.attention_mask,
            "offset_mapping": offset_mapping,
            "id": row["id"],
            "length": len(pn_history),
        })
    return test_sample


def micro_f1(preds, truths):
    """
    Micro f1 on binary arrays.

    Args:
        preds (list of lists of ints): Predictions.
        truths (list of lists of ints): Ground truths.

    Returns:
        float: f1 score.
    """
    # Micro : aggregating over all instances
    preds = np.concatenate(preds)
    truths = np.concatenate(truths)
    return f1_score(truths, preds)


def spans_to_binary(spans, length=None):
    """
    Converts spans to a binary array indicating whether each character is in the span.

    Args:
        spans (list of lists of two ints): Spans.

    Returns:
        np array [length]: Binarized spans.
    """
    length = np.max(spans) if length is None else length
    binary = np.zeros(length)
    for start, end in spans:
        binary[start:end] = 1
    return binary


def span_micro_f1(preds, truths):
    """
    Micro f1 on spans.

    Args:
        preds (list of lists of two ints): Prediction spans.
        truths (list of lists of two ints): Ground truth spans.

    Returns:
        float: f1 score.
    """
    bin_preds = []
    bin_truths = []
    for pred, truth in zip(preds, truths):
        if not len(pred) and not len(truth):
            continue
        length = max(np.max(pred) if len(pred) else 0, np.max(truth) if len(truth) else 0)
        bin_preds.append(spans_to_binary(pred, length))
        bin_truths.append(spans_to_binary(truth, length))
    return micro_f1(bin_preds, bin_truths)


def create_labels_for_scoring(df):
    df["location_for_create_labels"] = df["location"]
    truths = []
    for location_list in df['location_for_create_labels'].values:
        location_list = eval(location_list)
        truth = []
        for location in location_list:
            for loc in [s.split(" ") for s in location.split(";")]:
                truth.append([int(loc[0]), int(loc[1])])
        truths.append(truth)
    return truths


def get_char_probs(samples, predictions):
    results = [np.zeros(item["length"]) for item in samples]
    for i, (sample, prediction) in enumerate(zip(samples, predictions)):
        offset_mapping = sample["offset_mapping"]
        for idx, (offset_mapping, pred) in enumerate(zip(offset_mapping, prediction)):
            start = offset_mapping[0]
            end = offset_mapping[1]
            results[i][start:end] = pred
    return results


def get_results(char_probs, th=0.5):
    results = []
    for char_prob in char_probs:
        result = np.where(char_prob >= th)[0] + 1
        result = [list(g) for _, g in itertools.groupby(result, key=lambda n, c=itertools.count(): n - next(c))]
        result = [f"{min(r)} {max(r)}" for r in result]
        result = ";".join(result)
        results.append(result)
    return results


def get_results_v2(char_probs, valid_texts, th=0.5):
    results = []
    for i, char_prob in enumerate(char_probs):
        result = []
        end = 0
        while end < len(char_prob):
            if char_prob[end] < th:
                end += 1
            else:
                start = end
                while end < len(char_prob) and char_prob[end] >= th:
                    end += 1
                if valid_texts[i][start].isspace():
                    result.extend(list(range(start+1, end+1)))
                else:
                    result.extend(list(range(start, end+1)))

        result = [list(g) for _, g in itertools.groupby(result, key=lambda n, c=itertools.count(): n - next(c))]
        result = [f"{min(r)} {max(r)}" for r in result]
        result = ";".join(result)
        results.append(result)
    return results


def get_predictions(results):
    predictions = []
    for result in results:
        prediction = []
        if result != "":
            for loc in [s.split() for s in result.split(';')]:
                start, end = int(loc[0]), int(loc[1])
                prediction.append([start, end])
        predictions.append(prediction)
    return predictions


def text_to_word(text):
    word = text.split()
    word_offset = []

    start = 0
    for w in word:
        r = text[start:].find(w)

        if r == -1:
            raise NotImplementedError
        else:
            start = start + r
            end = start + len(w)
            word_offset.append((start, end))
            # print('%32s'%w, '%5d'%start, '%5d'%r, text[start:end])
        start = end

    return word, word_offset


def get_score(y_true, y_pred):
    score = span_micro_f1(y_true, y_pred)
    return score


def get_results_v3(char_probs, valid_texts, valid_id, th_mp={}, th=0.5):
    results = []
    for i, char_prob in enumerate(char_probs):
        result = []
        end = 0
        feature_num = valid_id[i].split("_")[-1]
        fth = th
        ### post 2
        fth = th_mp.get(feature_num, 0.5)
        ### /post 2
        while end < len(char_prob):
            if char_prob[end] < fth:
                end += 1
            else:
                start = end
                while end < len(char_prob) and char_prob[end] >= fth:
                    end += 1
                ### post 1
                if feature_num == "708":
                    if "5" in valid_texts[i][start:end+1] or "3" in valid_texts[i][start:end+1]:
                        continue
                ### /post 1
                if valid_texts[i][start].isspace():
                    result.extend(list(range(start+1, end+1)))
                else:
                    result.extend(list(range(start, end+1)))

        result = [list(g) for _, g in itertools.groupby(result, key=lambda n, c=itertools.count(): n - next(c))]
        result = [f"{min(r)} {max(r)}" for r in result]
        result = ";".join(result)
        results.append(result)
    return results

## Main

In [None]:
class Config:
    valid_batch_size = 32
    fix_length = 512
    device = None
    model_load = ""
    pretrain_path = ""


args = Config
args.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# th_mp = {
#     '103': 0.33,
#     '911': 0.31,
#     '003': 0.35,
#     '313': 0.35,
#     '904': 0.63,
#     '200': 0.31,
#     '207': 0.35,
#     '510': 0.69,
#     '508': 0.31,
#     "516": 0.38,
#     "708": 0.51,
#     "206": 0.32,
#     "702": 0.66,
# }
# no del
# th_mp = {'103': 0.33, '911': 0.31, '003': 0.35, '313': 0.35, '904': 0.63, '200': 0.31, '207': 0.35, '510': 0.69, '508': 0.31, '516': 0.38, '708': 0.51, '206': 0.32, '702': 0.66, '703': 0.5, '203': 0.42, '807': 0.5, '812': 0.5, '215': 0.62, '000': 0.6, '108': 0.41000000000000003, '816': 0.4, '010': 0.64, '004': 0.56, '502': 0.5, '815': 0.41000000000000003, '914': 0.35000000000000003, '112': 0.5, '403': 0.6, '900': 0.49, '802': 0.5, '915': 0.45, '111': 0.64, '214': 0.37, '208': 0.5700000000000001, '201': 0.53, '314': 0.5, '504': 0.5, '505': 0.36, '101': 0.5, '801': 0.54, '404': 0.43, '808': 0.36, '513': 0.5, '309': 0.5, '408': 0.38, '608': 0.5, '012': 0.5, '211': 0.5, '814': 0.54, '601': 0.56, '507': 0.4, '512': 0.36, '511': 0.64, '800': 0.5, '002': 0.39, '210': 0.5, '301': 0.61, '311': 0.46, '907': 0.5, '302': 0.4, '100': 0.39, '501': 0.36, '600': 0.5, '304': 0.5, '804': 0.5, '705': 0.5, '102': 0.41000000000000003, '707': 0.38, '811': 0.37, '110': 0.36, '906': 0.58, '401': 0.5, '500': 0.35000000000000003, '701': 0.59, '609': 0.64, '312': 0.63, '611': 0.49, '006': 0.6, '310': 0.5, '105': 0.39, '107': 0.35000000000000003, '603': 0.62, '809': 0.5, '916': 0.39, '913': 0.5, '606': 0.5700000000000001, '104': 0.56, '605': 0.45, '213': 0.55, '602': 0.45, '305': 0.5, '517': 0.5, '704': 0.35000000000000003, '908': 0.5, '005': 0.38, '910': 0.5, '514': 0.55, '300': 0.35000000000000003, '803': 0.56, '706': 0.58, '001': 0.64, '901': 0.5, '813': 0.45, '202': 0.56, '204': 0.61, '604': 0.56, '007': 0.5, '905': 0.36, '607': 0.43, '109': 0.5, '810': 0.62, '902': 0.4, '402': 0.5700000000000001, '405': 0.5, '205': 0.59, '303': 0.52, '506': 0.55, '610': 0.45, '806': 0.63, '308': 0.5, '307': 0.64, '315': 0.5, '306': 0.6, '209': 0.58, '407': 0.5, '515': 0.5, '212': 0.39, '903': 0.59, '216': 0.35000000000000003, '406': 0.61, '805': 0.5, '106': 0.5, '008': 0.5, '009': 0.43, '909': 0.42, '912': 0.39, '011': 0.61, '700': 0.54, '503': 0.38, '400': 0.54, '409': 0.55, '817': 0.35000000000000003, '509': 0.55}
# have del
# th_mp = {  "103": 0.33,  "911": 0.31,  "003": 0.35,  "313": 0.35,  "904": 0.63,  "200": 0.31,  "207": 0.35,  "510": 0.69,  "508": 0.31,  "516": 0.38,  "708": 0.51,  "206": 0.32,  "702": 0.66,  "703": 0.5,  "203": 0.42,  "807": 0.5,  "812": 0.5,  "215": 0.62,  "000": 0.6,  "108": 0.41000000000000003,  "816": 0.4,  "010": 0.64,  "502": 0.5,  "815": 0.41000000000000003,  "914": 0.35000000000000003,  "112": 0.5,  "403": 0.6,  "900": 0.49,  "802": 0.5,  "915": 0.45,  "111": 0.64,  "214": 0.37,  "208": 0.5700000000000001,  "201": 0.53,  "314": 0.5,  "504": 0.5,  "505": 0.36,  "101": 0.5,  "801": 0.54,  "404": 0.43,  "513": 0.5,  "309": 0.5,  "408": 0.38,  "608": 0.5,  "012": 0.5,  "211": 0.5,  "601": 0.56,  "507": 0.4,  "512": 0.36,  "511": 0.64,  "800": 0.5,  "002": 0.39,  "210": 0.5,  "301": 0.61,  "907": 0.5,  "302": 0.4,  "100": 0.39,  "501": 0.36,  "600": 0.5,  "304": 0.5,  "804": 0.5,  "705": 0.5,  "102": 0.41000000000000003,  "811": 0.37,  "110": 0.36,  "906": 0.58,  "401": 0.5,  "500": 0.35000000000000003,  "609": 0.64,  "312": 0.63,  "006": 0.6,  "310": 0.5,  "105": 0.39,  "107": 0.35000000000000003,  "603": 0.62,  "809": 0.5,  "916": 0.39,  "913": 0.5,  "606": 0.5700000000000001,  "605": 0.45,  "213": 0.55,  "305": 0.5,  "517": 0.5,  "704": 0.35000000000000003,  "908": 0.5,  "005": 0.38,  "910": 0.5,  "514": 0.55,  "300": 0.35000000000000003,  "803": 0.56,  "706": 0.58,  "001": 0.64,  "901": 0.5,  "813": 0.45,  "204": 0.61,  "604": 0.56,  "007": 0.5,  "905": 0.36,  "607": 0.43,  "109": 0.5,  "810": 0.62,  "902": 0.4,  "405": 0.5,  "205": 0.59,  "506": 0.55,  "610": 0.45,  "806": 0.63,  "308": 0.5,  "307": 0.64,  "315": 0.5,  "306": 0.6,  "209": 0.58,  "407": 0.5,  "515": 0.5,  "212": 0.39,  "903": 0.59,  "406": 0.61,  "805": 0.5,  "106": 0.5,  "008": 0.5,  "009": 0.43,  "909": 0.42,  "912": 0.39,  "700": 0.54,  "503": 0.38,  "400": 0.54,  "817": 0.35000000000000003,  "509": 0.55 }
# deep del
th_mp = { "103": 0.33, "911": 0.31, "003": 0.35, "313": 0.35, "904": 0.63, "200": 0.31, "207": 0.35, "508": 0.31, "516": 0.38, "708": 0.51, "206": 0.32, "702": 0.66, "000": 0.6, "108": 0.41, "816": 0.4, "010": 0.64, "914": 0.35, "403": 0.6, "111": 0.64, "214": 0.37, "505": 0.36, "801": 0.54, "404": 0.43, "408": 0.38, "601": 0.56, "507": 0.4, "512": 0.36, "511": 0.64, "002": 0.39, "301": 0.61, "302": 0.4, "100": 0.39, "501": 0.36, "102": 0.41, "811": 0.37, "110": 0.36, "906": 0.58, "500": 0.35, "312": 0.63, "006": 0.6, "105": 0.39, "107": 0.35, "603": 0.62, "916": 0.39, "213": 0.55, "704": 0.35, "514": 0.55, "300": 0.35, "803": 0.56, "706": 0.58, "813": 0.45, "204": 0.61, "604": 0.56, "607": 0.43, "810": 0.62, "205": 0.59, "610": 0.45, "307": 0.64, "306": 0.6, "903": 0.59, "406": 0.61, "009": 0.43, "909": 0.42, "912": 0.39, "503": 0.38, "400": 0.54, "817": 0.35, }
# 0.001
# th_mp = {'103': 0.33, '911': 0.31, '003': 0.35, '313': 0.35, '904': 0.63, '200': 0.31, '207': 0.35, '510': 0.69, '508': 0.31, '516': 0.38, '708': 0.51, '206': 0.32, '702': 0.66, '909': 0.5, '606': 0.5, '306': 0.5, '805': 0.5, '305': 0.5, '603': 0.62, '506': 0.5, '107': 0.51, '212': 0.5, '812': 0.5, '916': 0.5, '405': 0.5, '404': 0.5, '605': 0.5, '814': 0.5, '100': 0.5, '309': 0.5, '906': 0.5, '000': 0.5, '806': 0.5, '304': 0.5, '310': 0.5, '802': 0.5, '701': 0.5, '813': 0.49, '913': 0.5, '815': 0.5, '112': 0.5, '201': 0.5, '502': 0.5, '801': 0.5, '503': 0.4, '109': 0.5, '008': 0.5, '101': 0.5, '007': 0.5, '002': 0.5, '908': 0.5, '300': 0.47000000000000003, '905': 0.5, '406': 0.5, '308': 0.5, '203': 0.5, '110': 0.48, '816': 0.5, '213': 0.51, '401': 0.5, '803': 0.54, '501': 0.5, '602': 0.5, '006': 0.5, '408': 0.5, '914': 0.5, '509': 0.5, '600': 0.5, '507': 0.5, '513': 0.5, '804': 0.5, '303': 0.5, '901': 0.5, '216': 0.5, '106': 0.5, '314': 0.5, '817': 0.36, '004': 0.5, '809': 0.5, '604': 0.5, '312': 0.5, '705': 0.5, '514': 0.5, '900': 0.49, '409': 0.5, '005': 0.5, '607': 0.5, '301': 0.51, '210': 0.5, '009': 0.43, '307': 0.5, '609': 0.5, '700': 0.5, '704': 0.5, '211': 0.5, '315': 0.5, '400': 0.54, '108': 0.5, '500': 0.5, '204': 0.5, '915': 0.5, '407': 0.5, '311': 0.5, '511': 0.5, '703': 0.5, '903': 0.5, '810': 0.5, '517': 0.5, '102': 0.5, '902': 0.5, '111': 0.53, '707': 0.5, '209': 0.5, '205': 0.5, '706': 0.49, '800': 0.5, '808': 0.5, '011': 0.5, '608': 0.5, '012': 0.5, '215': 0.5, '512': 0.5, '505': 0.51, '910': 0.5, '907': 0.5, '403': 0.5, '202': 0.5, '504': 0.5, '001': 0.5, '811': 0.5, '515': 0.5, '611': 0.5, '601': 0.5, '807': 0.5, '214': 0.5, '010': 0.5, '402': 0.5, '610': 0.5, '302': 0.48, '912': 0.5, '208': 0.5, '104': 0.5, '105': 0.5}
# 0.0005
# th_mp = {'103': 0.33, '911': 0.31, '003': 0.35, '313': 0.35, '904': 0.63, '200': 0.31, '207': 0.35, '510': 0.69, '508': 0.31, '516': 0.38, '708': 0.51, '206': 0.32, '702': 0.66, '901': 0.5, '402': 0.5, '902': 0.5, '611': 0.49, '211': 0.5, '408': 0.5, '900': 0.49, '913': 0.5, '707': 0.5, '107': 0.48, '405': 0.5, '906': 0.5, '813': 0.45, '511': 0.64, '404': 0.43, '809': 0.5, '805': 0.5, '001': 0.5, '804': 0.5, '507': 0.5, '610': 0.47000000000000003, '608': 0.5, '512': 0.5, '000': 0.5, '912': 0.5, '102': 0.41000000000000003, '600': 0.5, '400': 0.54, '406': 0.5, '704': 0.37, '011': 0.5, '903': 0.54, '604': 0.56, '803': 0.54, '602': 0.5, '008': 0.5, '701': 0.5, '105': 0.41000000000000003, '910': 0.5, '802': 0.5, '007': 0.5, '308': 0.5, '505': 0.51, '806': 0.5, '010': 0.53, '514': 0.55, '311': 0.5, '703': 0.5, '407': 0.5, '814': 0.5, '108': 0.41000000000000003, '302': 0.48, '315': 0.5, '908': 0.5, '817': 0.35000000000000003, '208': 0.5, '403': 0.53, '203': 0.5, '603': 0.62, '914': 0.5, '104': 0.5, '601': 0.5, '916': 0.5, '706': 0.55, '009': 0.43, '606': 0.49, '005': 0.5, '312': 0.5, '213': 0.51, '800': 0.5, '109': 0.5, '504': 0.5, '607': 0.43, '209': 0.5, '609': 0.5, '909': 0.42, '112': 0.5, '100': 0.5, '816': 0.4, '502': 0.5, '310': 0.5, '305': 0.5, '705': 0.5, '210': 0.5, '501': 0.5, '309': 0.5, '905': 0.5, '409': 0.5, '006': 0.5, '303': 0.5, '509': 0.5, '300': 0.35000000000000003, '101': 0.5, '304': 0.5, '314': 0.5, '500': 0.52, '204': 0.54, '517': 0.5, '812': 0.5, '700': 0.52, '004': 0.5, '106': 0.5, '515': 0.5, '205': 0.59, '801': 0.54, '815': 0.5, '012': 0.5, '201': 0.5, '306': 0.5, '216': 0.5, '307': 0.5, '212': 0.5, '915': 0.5, '002': 0.5, '401': 0.5, '807': 0.5, '506': 0.5, '215': 0.5, '301': 0.61, '811': 0.38, '503': 0.4, '810': 0.61, '202': 0.5, '111': 0.53, '808': 0.5, '513': 0.5, '907': 0.5, '214': 0.5, '605': 0.5, '110': 0.48}
# pt5 th f1 <= 0.90
# th_mp = {'207': 0.37, '209': 0.61, '911': 0.5, '103': 0.35000000000000003, '003': 0.36, '215': 0.61, '603': 0.63, '510': 0.64, '809': 0.51, '509': 0.52, '206': 0.45, '513': 0.48, '402': 0.5, '508': 0.35000000000000003, '200': 0.5, '505': 0.47000000000000003, '504': 0.35000000000000003, '702': 0.62, '708': 0.39, '313': 0.47000000000000003, '516': 0.58, '008': 0.63, '503': 0.35000000000000003, '216': 0.5, '514': 0.61, '212': 0.44, '403': 0.63, '904': 0.63, '204': 0.5, '201': 0.44, '500': 0.42, '511': 0.63, '111': 0.6, '902': 0.39, '609': 0.51, '813': 0.37, '517': 0.5, '811': 0.36, '010': 0.63, '512': 0.38, '203': 0.45, '213': 0.58, '214': 0.5, '604': 0.64, '706': 0.55, '205': 0.5, '310': 0.51, '801': 0.41000000000000003, '306': 0.63, '009': 0.35000000000000003, '005': 0.5, '903': 0.61, '803': 0.37, '312': 0.36, '810': 0.5, '607': 0.51}
# pt5 th f1 <= 0.90 up > 1e-5
# th_mp = {'207': 0.37, '911': 0.5, '103': 0.35000000000000003, '003': 0.36, '215': 0.61, '603': 0.63, '510': 0.64, '509': 0.52, '206': 0.45, '513': 0.48, '402': 0.5, '508': 0.35000000000000003, '200': 0.5, '505': 0.47000000000000003, '504': 0.35000000000000003, '702': 0.62, '708': 0.39, '313': 0.47000000000000003, '503': 0.35000000000000003, '216': 0.5, '514': 0.61, '212': 0.44, '403': 0.63, '204': 0.5, '201': 0.44, '511': 0.63, '111': 0.6, '902': 0.39, '813': 0.37, '517': 0.5, '811': 0.36, '010': 0.63, '512': 0.38, '203': 0.45, '214': 0.5, '604': 0.64, '706': 0.55, '205': 0.5, '801': 0.41000000000000003, '009': 0.35000000000000003, '005': 0.5, '903': 0.61, '810': 0.5}
# punish_th = {
#     "207": 0.35,
#     "911": 0.31,
#     "103": 0.33,
#     "003": 0.35,
#     "206": 0.32,
#     "603": 0.62,
#     "508": 0.31,
#     "200": 0.31,
#     "702": 0.66,
#     "505": 0.36,
#     "313": 0.35,
#     "516": 0.38,
#     "708": 0.51,
#     "503": 0.38,
#     "403": 0.6,
#     "904": 0.63,
#     "514": 0.55,
#     "500": 0.35,
#     "204": 0.61,
#     "111": 0.64,
#     "511": 0.64,
#     "811": 0.37,
#     "512": 0.36,
#     "813": 0.45,
#     "213": 0.55,
#     "010": 0.64,
#     "214": 0.37,
#     "706": 0.58,
#     "810": 0.62,
#     "205": 0.59,
#     "604": 0.56,
#     "803": 0.56,
# }

params = [96, 79, 40, 88, 48, 39]
# params = [99, 69, 72, 15, 1, 2, 19, 75] # pt5
# params = [70, 10, 86, 39, 96, 1, 3, 94] # pt6
model_list = [
    {
        "model_path": [
            # deberta-xlarge
            ("../input/nbme-deberta-xlarge-4fold/dxl00_4fold.bin", params[2]/4),
            ("../input/nbme-deberta-xlarge-4fold/dxl10_4fold.bin", params[2]/4),
            ("../input/nbme-deberta-xlarge-4fold/dxl20_4fold.bin", params[2]/4),
            ("../input/nbme-deberta-xlarge-4fold/dxl30_4fold.bin", params[2]/4),
            # pseudo deberta-xlarge
            ("../input/nbme-deberta-xlarge-pretrain/dxl03.bin", params[3]/4),
            ("../input/nbme-deberta-xlarge-pretrain/dxl13.bin", params[3]/4),
            ("../input/nbme-deberta-xlarge-pretrain/dxl23.bin", params[3]/4),
            ("../input/nbme-deberta-xlarge-pretrain/dxl33.bin", params[3]/4),
        ],
        "pretrain_path": "../input/d/datasets/shujun717/deberta-xlarge",
    },
    {
        "model_path": [
            # deberta-v3-large
            ("../input/nbme-deberta-v3-large-4fold/dvl01_4fold.bin", params[0]/4),
            ("../input/nbme-deberta-v3-large-4fold/dvl12_4fold.bin", params[0]/4),
            ("../input/nbme-deberta-v3-large-4fold/dvl22_4fold.bin", params[0]/4),
            ("../input/nbme-deberta-v3-large-4fold/dvl30_4fold.bin", params[0]/4),
            # pseudo deberta-v3-large
            ("../input/nbme-deberta-v3-large-4fold/dvl03_4fold.bin", params[1]/4),
            ("../input/nbme-deberta-xlarge-pretrain/dvl14.bin", params[1]/4),
            ("../input/nbme-deberta-v3-large-4fold/dvl23_fold4.bin", params[1]/4),
            ("../input/nbme-deberta-v3-large-4fold/dvl33_fold4.bin", params[1]/4),
            # pseudo round 2 deberta-v3-large
#             ("../input/nbme-deberta-v3-large/dvl05.bin", params[2]/4),
#             ("../input/nbme-deberta-v3-large/dvl15.bin", params[2]/4),
#             ("../input/nbme-deberta-v3-large/dvl25.bin", params[2]/4),
#             ("../input/nbme-deberta-v3-large/dvl35.bin", params[2]/4),
        ],
        "pretrain_path": "../input/deberta-v3-large/deberta-v3-large",
    },
    {
        "model_path": [
            # pretrain roberta-large
            ("../input/nbme-roberta-large/Fold_0.bin", params[4]/4),
            ("../input/nbme-roberta-large/Fold_1.bin", params[4]/4),
            ("../input/nbme-roberta-large/Fold_2.bin", params[4]/4),
            ("../input/nbme-roberta-large/Fold_3.bin", params[4]/4),
            # pretrain + pseudo roberta-large
            ("../input/nbme-roberta-large/rl02.bin", params[5]/4),
            ("../input/nbme-roberta-large/rl12.bin", params[5]/4),
            ("../input/nbme-roberta-large/rl22.bin", params[5]/4),
            ("../input/nbme-roberta-large/rl32.bin", params[5]/4),
            # pretrain + pseudo round 2 roberta-large
#             ("../input/nbme-roberta-large/rl04.bin", params[7]/4),
#             ("../input/nbme-roberta-large/rl14.bin", params[7]/4),
#             ("../input/nbme-roberta-large/rl24.bin", params[7]/4),
#             ("../input/nbme-roberta-large/rl34.bin", params[7]/4),
        ],
        "pretrain_path": "../input/robertalarge",
    },
]
default_list = [209, 807, 911, 809, 207, 907, 215, 515, 502, 804, 910, 8, 100, 201, 206, 403, 903, 311, 310, 609, 905, 101, 510, 915, 506, 909, 5, 705, 511, 108, 505, 202, 703, 4, 404, 210, 312, 512, 305, 405, 803, 904, 608, 810, 402, 204, 106, 815, 813, 814, 514, 105, 610, 200, 704, 916, 109, 2, 508, 816, 900, 908, 7, 313, 706, 211, 306, 509, 513, 800, 812, 203, 214, 304, 307, 600, 516, 701, 806, 811]
p2r = [911, 510, 508, 200, 505, 313, 516, 201, 904, 511, 512, 813, 706, 803, 4, 608, 916, 502, 703, 2, 900, 100, 814, 910, 405]
r2p = [207, 215, 509, 809, 402, 513, 8, 514, 204, 811, 214, 203, 810, 310, 306, 905, 903, 312, 311, 515, 307, 109, 610, 202, 600, 506, 815, 305, 806, 211, 101]
o = [807, 209, 206, 403, 609, 5, 106, 915, 108, 105, 704, 701, 404, 800, 816, 909, 7, 304, 210, 812, 705, 907, 908, 804]
default_list = ["0" * (3 - len(str(item))) + str(item) for item in default_list]
p2r = ["0" * (3 - len(str(item))) + str(item) for item in p2r]
r2p = ["0" * (3 - len(str(item))) + str(item) for item in r2p]
o = ["0" * (3 - len(str(item))) + str(item) for item in o]
# for key in default_list:
#     th_mp[str(key)] = 0.419
for key in p2r:
    th_mp[str(key)] = 0.385
for key in r2p:
    th_mp[str(key)] = 0.503
for key in o:
    th_mp[str(key)] = 0.408
model_num = 0
for model_config in model_list:
    for model in model_config["model_path"]:
        model_num += model[1]
features_df = pd.read_csv("../input/nbme-score-clinical-patient-notes/features.csv")
patient_df = pd.read_csv("../input/nbme-score-clinical-patient-notes/patient_notes.csv")
valid_df = pd.read_csv("../input/nbme-score-clinical-patient-notes/test.csv")
# valid_df = pd.read_csv("../input/nbme-score-clinical-patient-notes/train.csv")
valid_df = valid_df.merge(features_df, on=["feature_num", "case_num"], how="left")
valid_df = valid_df.merge(patient_df, on=["pn_num", "case_num"], how="left")
valid_df['text_len'] = valid_df['pn_history'].apply(lambda x: len(x))
valid_df = valid_df.sort_values(by="text_len").reset_index(drop=True)
text = valid_df["pn_history"].tolist()
results = [np.zeros(len(item)) for item in text]
for model_config in model_list:
    model_path = model_config["model_path"]
    pretrain_path = model_config["pretrain_path"]
    args.pretrain_path = pretrain_path
    args.tokenizer_path = pretrain_path
    predicter = Predicter(args)
    predicter.model_init()
    if "deberta-v" in args.tokenizer_path:
        tokenizer = get_deberta_tokenizer(args.tokenizer_path)
    else:
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, trim_offsets=False)
    valid_samples = prepare_test_data(valid_df, tokenizer, args.fix_length)
    valid_datasets = NBMEDataset(valid_samples, tokenizer)
    valid_collate = Collate(tokenizer)
    for path in model_path:
        predicter.model_load(path[0])
        preds = predicter.predict(valid_datasets, valid_collate)
        preds = preds * (path[1] / model_num)
        char_probs = get_char_probs(valid_samples, preds)
        for i in range(len(results)):
            results[i] += char_probs[i]
            # logging.info(f"{len(results[i])} {len(char_probs[i])}")
        torch.cuda.empty_cache()
        del preds, char_probs
        gc.collect()
    del valid_samples, valid_datasets
    del predicter.model
    del predicter
    gc.collect()
submission = valid_df
# results = get_results_v2(results, text, th=0.50)
results = get_results_v3(results, text, valid_df["id"].values.tolist(), th_mp, th=0.5)
submission['location'] = results
display(submission.head())
submission[['id', 'location']].to_csv('submission.csv', index=False)

In [None]:
submission[["id", "location"]].head()