In [None]:
# The following is necessary if you want to use the fast tokenizer for deberta v2 or v3
import shutil
from pathlib import Path

transformers_path = Path("/opt/conda/lib/python3.7/site-packages/transformers")

input_dir = Path("../input/deberta-v2-3-fast-tokenizer")

convert_file = input_dir / "convert_slow_tokenizer.py"
conversion_path = transformers_path/convert_file.name

if conversion_path.exists():
    conversion_path.unlink()

shutil.copy(convert_file, transformers_path)
deberta_v2_path = transformers_path / "models" / "deberta_v2"

for filename in ['tokenization_deberta_v2.py', 'tokenization_deberta_v2_fast.py', "deberta__init__.py"]:
    if str(filename).startswith("deberta"):
        filepath = deberta_v2_path/str(filename).replace("deberta", "")
    else:
        filepath = deberta_v2_path/filename
    if filepath.exists():
        filepath.unlink()

    shutil.copy(input_dir/filename, filepath)

In [None]:
import os

OUTPUT_DIR = './'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

# CFG

In [None]:
class CFG:
    model="microsoft/deberta-v3-large"
    epochs=3
    encoder_lr=2e-5
    decoder_lr=2e-5
#     epochs=4
#     encoder_lr=3e-5
#     decoder_lr=3e-5
    batch_size=4
    valid_batch_size=32
    max_len=512
    trn_fold=[4]

    apex=True
    print_freq=100
    num_workers=2
    scheduler='cosine' # ['linear', 'cosine']
    batch_scheduler=True
    num_cycles=0.5
    num_warmup_steps=0
    eps=1e-8
    betas=(0.9, 0.999)
    fc_dropout=0.2
    weight_decay=0.0001
    gradient_accumulation_steps=1
    max_grad_norm=1000
    seed=1119
    n_fold=5

In [None]:
import os
import gc
import re
import ast
import sys
import copy
import json
import time
import math
import string
import pickle
import random
import joblib
import itertools
import warnings
# warnings.filterwarnings("ignore")

import scipy as sp
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
from tqdm.auto import tqdm
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import DataLoader, Dataset

import transformers
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
%env TOKENIZERS_PARALLELISM=true

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# From https://www.kaggle.com/theoviel/evaluation-metric-folds-baseline
def spans_to_binary(spans, length=None):
    """
    Converts spans to a binary array indicating whether each character is in the span.
    EX: print(spans_to_binary([[0, 5], [10, 15]], length=30))
    """
    length = np.max(spans) if length is None else length
    binary = np.zeros(length)
    for start, end in spans:
        binary[start:end] = 1
    return binary

def micro_f1(preds, truths):
    """
    EX:
    preds = [[0, 0, 1], [0, 0, 0]]
    truths = [[0, 0, 1], [1, 0, 0]]
    micro_f1(preds, truths)
    """
    # Micro : aggregating over all instances
    preds = np.concatenate(preds)
    truths = np.concatenate(truths)
    return f1_score(truths, preds)


# 以整体为单位，内部用for一个句子一个句子提取
def span_micro_f1(preds, truths):
    """
    EX: 
    pred = [[[1, 2]], [[3, 4]]]
    truth = [[[1, 2]], [[3, 6]]]
    span_micro_f1(pred, truth)  # 每个句子是二维
    """
    bin_preds = []
    bin_truths = []
    count_diff = 0
    for i, (pred, truth) in enumerate(zip(preds, truths)):
        if not len(pred) and not len(truth):
            continue

        length = max(np.max(pred) if len(pred) else 0, np.max(truth) if len(truth) else 0)
        
#         print(f"pred: {pred}")
#         print(f"truth: {truth}")
#         print('_' * 50)
#         time.sleep(1)
        
        bin_pred = spans_to_binary(pred, length)
        bin_truth = spans_to_binary(truth, length)
        if (bin_pred != bin_truth).any():
            count_diff += 1
#             print(i)
#             print(f"pred: {pred}")
#             print(f"truth: {truth}")
#             print('_' * 50)
#             time.sleep(2)
        bin_preds.append(bin_pred)
        bin_truths.append(bin_truth)
    print(f"count_diff: {count_diff}")
    return micro_f1(bin_preds, bin_truths)

In [None]:
def get_char_probs(texts, predictions, tokenizer):
    """
    将预测的token的标签还原成字符，因为原数据给的是字符级别的
    """
    char_probs = [np.zeros(len(text)) for text in texts]
    for i, (text, prediction) in enumerate(zip(texts, predictions)):
        inputs = tokenizer(text, return_offsets_mapping=True)
        for k, (offset_mapping, pred) in enumerate(zip(inputs['offset_mapping'], prediction)):
            start = offset_mapping[0]
            end = offset_mapping[1]
            char_probs[i][start: end] = pred  # 将token对应的所有字符都打上和token一样的标签
        
        # 如果当前位置是空格，且上一个字符和下一个字符的预测概率都大于阈值，则连起来
        for j, char_prob in enumerate(char_probs[i]):
            if j > 0 and j < len(char_probs[i]) - 1:
                if texts[i][j] == ' ':
                    if char_probs[i][j-1] >= 0.5 and char_probs[i][j+1] >= 0.5:
                        char_probs[i][j] = 1.0
                    elif char_probs[i][j-1] < 0.5 or char_probs[i][j+1] < 0.5:
                        char_probs[i][j] = 0.0
    return char_probs


def get_results(char_probs, th=0.5):
    """
    将char_prob通过threshold转换成 location 的形式（即span, 例：'70 91;176 183' 字符串的形式）
    分组：
    [list(g) for _, g in itertools.groupby([0, 2, 3, 8, 9], key=lambda n, c=itertools.count(): n - next(c))]    
    输出：[[0], [2, 3], [8, 9]]
    """
    results = []
    for char_prob in char_probs: 
        label_idx = np.where(char_prob >= th)[0]  # 是1的索引
        span_list = [list(g) for _, g in itertools.groupby(label_idx, key=lambda n, c=itertools.count(): n - next(c))] 
        result = [f"{span[0]} {span[-1] + 1}" for span in span_list]  # 因为原location是左闭右开，所以右边界要 + 1
        result = ";".join(result)
        results.append(result)
    return results


def get_predictions(results):
    """
    将rusults转换成 [[70, 91], [176, 183]] 列表的形式
    """
    predictions = []
    for result in results:
        prediction = []
        if result != "":
            for loc in [s.split() for s in result.split(';')]:
                start, end = int(loc[0]), int(loc[1])
                prediction.append([start, end])
        predictions.append(prediction)
    return predictions

def create_labels_for_scoring(df):
    """
    将ground_truth的location转换成 [[70, 91], [176, 183]] 列表的形式
    """
    df['location_for_create_labels'] = [ast.literal_eval(f'[]')] * len(df)
    for i in range(len(df)):
        lst = df.loc[i, 'location']
        if lst:
            new_lst = ';'.join(lst)
            df.loc[i, 'location_for_create_labels'] = ast.literal_eval(f'[["{new_lst}"]]')
    # create labels
    truths = []
    for location_list in df['location_for_create_labels'].values:
        truth = []
#         if len(location_list) > 0:
        if len(location_list[0]) > 0:
            location = location_list[0]
            for loc in [s.split() for s in location.split(';')]:
                start, end = int(loc[0]), int(loc[1])
                truth.append([start, end])
        truths.append(truth)
    return truths

# Utils

In [None]:
def get_score(y_true, y_pred):
    score = span_micro_f1(y_true, y_pred)
    return score

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed(CFG.seed)

# Data Loading

In [None]:
# train = pd.read_csv('../input/nbme-score-clinical-patient-notes/train.csv')
train = pd.read_csv('../input/nbme-train-new/train_new.csv', usecols=['id', 'case_num', 'pn_num', 'feature_num', 'annotation', 'location'])
features= pd.read_csv('../input/nbme-score-clinical-patient-notes/features.csv')
patient_notes = pd.read_csv('../input/nbme-score-clinical-patient-notes/patient_notes.csv')

train['annotation'] = train['annotation'].apply(ast.literal_eval)
train['location'] = train['location'].apply(ast.literal_eval)

In [None]:
features.loc[27, 'feature_text'] = "Last-Pap-smear-1-year-ago"

train = train.merge(features, on=['feature_num', 'case_num'], how='left')
train = train.merge(patient_notes, on=['pn_num', 'case_num'], how='left')

In [None]:
# incorrect annotation
train.loc[338, 'annotation'] = ast.literal_eval('[["father heart attack"]]')
train.loc[338, 'location'] = ast.literal_eval('[["764 783"]]')

train.loc[621, 'annotation'] = ast.literal_eval('[["for the last 2-3 months"]]')
train.loc[621, 'location'] = ast.literal_eval('[["77 100"]]')

train.loc[655, 'annotation'] = ast.literal_eval('[["no heat intolerance"], ["no cold intolerance"]]')
train.loc[655, 'location'] = ast.literal_eval('[["285 292;301 312"], ["285 287;296 312"]]')

train.loc[1262, 'annotation'] = ast.literal_eval('[["mother thyroid problem"]]')
train.loc[1262, 'location'] = ast.literal_eval('[["551 557;565 580"]]')

train.loc[1265, 'annotation'] = ast.literal_eval('[[\'felt like he was going to "pass out"\']]')
train.loc[1265, 'location'] = ast.literal_eval('[["131 135;181 212"]]')

train.loc[1396, 'annotation'] = ast.literal_eval('[["stool , with no blood"]]')
train.loc[1396, 'location'] = ast.literal_eval('[["259 280"]]')

train.loc[1591, 'annotation'] = ast.literal_eval('[["diarrhoe non blooody"]]')
train.loc[1591, 'location'] = ast.literal_eval('[["176 184;201 212"]]')

train.loc[1615, 'annotation'] = ast.literal_eval('[["diarrhea for last 2-3 days"]]')
train.loc[1615, 'location'] = ast.literal_eval('[["249 257;271 288"]]')

train.loc[1664, 'annotation'] = ast.literal_eval('[["no vaginal discharge"]]')
train.loc[1664, 'location'] = ast.literal_eval('[["822 824;907 924"]]')

train.loc[1714, 'annotation'] = ast.literal_eval('[["started about 8-10 hours ago"]]')
train.loc[1714, 'location'] = ast.literal_eval('[["101 129"]]')

train.loc[1929, 'annotation'] = ast.literal_eval('[["no blood in the stool"]]')
train.loc[1929, 'location'] = ast.literal_eval('[["531 539;549 561"]]')

train.loc[2134, 'annotation'] = ast.literal_eval('[["last sexually active 9 months ago"]]')
train.loc[2134, 'location'] = ast.literal_eval('[["540 560;581 593"]]')

train.loc[2191, 'annotation'] = ast.literal_eval('[["right lower quadrant pain"]]')
train.loc[2191, 'location'] = ast.literal_eval('[["32 57"]]')

train.loc[2553, 'annotation'] = ast.literal_eval('[["diarrhoea no blood"]]')
train.loc[2553, 'location'] = ast.literal_eval('[["308 317;376 384"]]')

train.loc[3124, 'annotation'] = ast.literal_eval('[["sweating"]]')
train.loc[3124, 'location'] = ast.literal_eval('[["549 557"]]')

train.loc[3858, 'annotation'] = ast.literal_eval('[["previously as regular"], ["previously eveyr 28-29 days"], ["previously lasting 5 days"], ["previously regular flow"]]')
train.loc[3858, 'location'] = ast.literal_eval('[["102 123"], ["102 112;125 141"], ["102 112;143 157"], ["102 112;159 171"]]')

train.loc[4373, 'annotation'] = ast.literal_eval('[["for 2 months"]]')
train.loc[4373, 'location'] = ast.literal_eval('[["33 45"]]')

train.loc[4763, 'annotation'] = ast.literal_eval('[["35 year old"]]')
train.loc[4763, 'location'] = ast.literal_eval('[["5 16"]]')

train.loc[4782, 'annotation'] = ast.literal_eval('[["darker brown stools"]]')
train.loc[4782, 'location'] = ast.literal_eval('[["175 194"]]')

train.loc[4908, 'annotation'] = ast.literal_eval('[["uncle with peptic ulcer"]]')
train.loc[4908, 'location'] = ast.literal_eval('[["700 723"]]')

train.loc[6016, 'annotation'] = ast.literal_eval('[["difficulty falling asleep"]]')
train.loc[6016, 'location'] = ast.literal_eval('[["225 250"]]')

train.loc[6192, 'annotation'] = ast.literal_eval('[["helps to take care of aging mother and in-laws"]]')
train.loc[6192, 'location'] = ast.literal_eval('[["197 218;236 260"]]')

train.loc[6380, 'annotation'] = ast.literal_eval('[["No hair changes"], ["No skin changes"], ["No GI changes"], ["No palpitations"], ["No excessive sweating"]]')
train.loc[6380, 'location'] = ast.literal_eval('[["480 482;507 519"], ["480 482;499 503;512 519"], ["480 482;521 531"], ["480 482;533 545"], ["480 482;564 582"]]')

train.loc[6562, 'annotation'] = ast.literal_eval('[["stressed due to taking care of her mother"], ["stressed due to taking care of husbands parents"]]')
train.loc[6562, 'location'] = ast.literal_eval('[["290 320;327 337"], ["290 320;342 358"]]')

train.loc[6862, 'annotation'] = ast.literal_eval('[["stressor taking care of many sick family members"]]')
train.loc[6862, 'location'] = ast.literal_eval('[["288 296;324 363"]]')

train.loc[7022, 'annotation'] = ast.literal_eval('[["heart started racing and felt numbness for the 1st time in her finger tips"]]')
train.loc[7022, 'location'] = ast.literal_eval('[["108 182"]]')

train.loc[7422, 'annotation'] = ast.literal_eval('[["first started 5 yrs"]]')
train.loc[7422, 'location'] = ast.literal_eval('[["102 121"]]')

train.loc[8876, 'annotation'] = ast.literal_eval('[["No shortness of breath"]]')
train.loc[8876, 'location'] = ast.literal_eval('[["481 483;533 552"]]')

train.loc[9027, 'annotation'] = ast.literal_eval('[["recent URI"], ["nasal stuffines, rhinorrhea, for 3-4 days"]]')
train.loc[9027, 'location'] = ast.literal_eval('[["92 102"], ["123 164"]]')

train.loc[9938, 'annotation'] = ast.literal_eval('[["irregularity with her cycles"], ["heavier bleeding"], ["changes her pad every couple hours"]]')
train.loc[9938, 'location'] = ast.literal_eval('[["89 117"], ["122 138"], ["368 402"]]')

train.loc[9973, 'annotation'] = ast.literal_eval('[["gaining 10-15 lbs"]]')
train.loc[9973, 'location'] = ast.literal_eval('[["344 361"]]')

train.loc[10513, 'annotation'] = ast.literal_eval('[["weight gain"], ["gain of 10-16lbs"]]')
train.loc[10513, 'location'] = ast.literal_eval('[["600 611"], ["607 623"]]')

train.loc[11551, 'annotation'] = ast.literal_eval('[["seeing her son knows are not real"]]')
train.loc[11551, 'location'] = ast.literal_eval('[["386 400;443 461"]]')

train.loc[11677, 'annotation'] = ast.literal_eval('[["saw him once in the kitchen after he died"]]')
train.loc[11677, 'location'] = ast.literal_eval('[["160 201"]]')

train.loc[12124, 'annotation'] = ast.literal_eval('[["tried Ambien but it didnt work"]]')
train.loc[12124, 'location'] = ast.literal_eval('[["325 337;349 366"]]')

train.loc[12279, 'annotation'] = ast.literal_eval('[["heard what she described as a party later than evening these things did not actually happen"]]')
train.loc[12279, 'location'] = ast.literal_eval('[["405 459;488 524"]]')

train.loc[12289, 'annotation'] = ast.literal_eval('[["experienced seeing her son at the kitchen table these things did not actually happen"]]')
train.loc[12289, 'location'] = ast.literal_eval('[["353 400;488 524"]]')

train.loc[13238, 'annotation'] = ast.literal_eval('[["SCRACHY THROAT"], ["RUNNY NOSE"]]')
train.loc[13238, 'location'] = ast.literal_eval('[["293 307"], ["321 331"]]')

train.loc[13297, 'annotation'] = ast.literal_eval('[["without improvement when taking tylenol"], ["without improvement when taking ibuprofen"]]')
train.loc[13297, 'location'] = ast.literal_eval('[["182 221"], ["182 213;225 234"]]')

train.loc[13299, 'annotation'] = ast.literal_eval('[["yesterday"], ["yesterday"]]')
train.loc[13299, 'location'] = ast.literal_eval('[["79 88"], ["409 418"]]')

train.loc[13845, 'annotation'] = ast.literal_eval('[["headache global"], ["headache throughout her head"]]')
train.loc[13845, 'location'] = ast.literal_eval('[["86 94;230 236"], ["86 94;237 256"]]')

train.loc[14083, 'annotation'] = ast.literal_eval('[["headache generalized in her head"]]')
train.loc[14083, 'location'] = ast.literal_eval('[["56 64;156 179"]]')

In [None]:
train['annotation_length'] = train['annotation'].apply(len)
display(train['annotation_length'].value_counts())

In [None]:
# train = train[train['annotation_length'] != 0].reset_index(drop=True)

# CV split

In [None]:
Fold = GroupKFold(n_splits=CFG.n_fold)
groups = train['pn_num'].values
for n, (train_index, val_index) in enumerate(Fold.split(train, train['location'], groups)):
    train.loc[val_index, 'fold'] = int(n)
train['fold'] = train['fold'].astype(int)
display(train.groupby('fold').size())

# tokenizer

In [None]:
from transformers.models.deberta_v2 import DebertaV2TokenizerFast
tokenizer = DebertaV2TokenizerFast.from_pretrained(CFG.model)
tokenizer.save_pretrained(OUTPUT_DIR + 'tokenizer/')
CFG.tokenizer = tokenizer

# Dataset

In [None]:
# for text_col in ['pn_history']:
#     pn_history_lengths = []
#     tk0 = tqdm(patient_notes[text_col].fillna("").values, total=len(patient_notes))
#     for text in tk0:
#         length = len(tokenizer(text, add_special_tokens=False)['input_ids'])
#         pn_history_lengths.append(length)
#     print(f'{text_col} max(lengths): {max(pn_history_lengths)}')

# for text_col in ['feature_text']:
#     features_lengths = []
#     tk0 = tqdm(features[text_col].fillna("").values, total=len(features))
#     for text in tk0:
#         length = len(tokenizer(text, add_special_tokens=False)['input_ids'])
#         features_lengths.append(length)
#     print(f'{text_col} max(lengths): {max(features_lengths)}')

# CFG.max_len = max(pn_history_lengths) + max(features_lengths) + 3 # cls & sep & sep
# print(f"max_len: {CFG.max_len}")
CFG.max_len = 354

In [None]:
def prepare_input(cfg, text, feature_text):
    inputs = cfg.tokenizer(text, feature_text, 
                           add_special_tokens=True,
                           max_length=CFG.max_len,
                           padding="max_length",
                           return_offsets_mapping=False,
                           truncation='only_first')
    for k, v in inputs.items():
        inputs[k] = torch.tensor(v, dtype=torch.long)
    return inputs

def create_label(cfg, text, annotation_length, location_list):
    inputs = tokenizer(text, max_length=CFG.max_len, padding="max_length", return_offsets_mapping=True, truncation='only_first')
    
    offset_mapping = inputs['offset_mapping']
    label = np.zeros(len(offset_mapping))
    not_text = np.where(np.array(inputs.sequence_ids()) != 0)[0]
    is_text = np.where(np.array(inputs.sequence_ids()) == 0)[0]
    label[not_text] = -1
    
#     if not location_list:
    if not location_list or not location_list[0]:
        return torch.tensor(label, dtype=torch.float)
    
    location_new = []
    for loc in location_list:
        if len(loc.split(';')) > 1:  # 对于同一个特征不是连续的词，用;隔开的情况，如['1 2;5 6']
            location_new.extend([lo.split() for lo in loc.split(';')])
        else:
            location_new.append(loc.split())
            
    for i in is_text:
        span = offset_mapping[i]
        for loc in location_new:
            loc_left = int(loc[0])
            loc_right = int(loc[1])
            # 分词后没有包含空格时：如 bert:
            if span[0] >= loc_left and span[1] <= loc_right:  # 1. span再location里时
                label[i] = 1
                break
            elif span[0] <= loc_left < span[1] or span[0] < loc_right <= span[1]:  # 2. span在location外时
                label[i] = 1
                break
            
    return torch.tensor(label, dtype=torch.float)

In [None]:
class TrainDataset(Dataset):
    def __init__(self, cfg, df):
        self.cfg = cfg
        self.feature_texts = df['feature_text'].values
        self.pn_historys = df['pn_history'].values
        self.annotation_lengths = df['annotation_length'].values
        self.locations = df['location'].values

    def __len__(self):
        return len(self.feature_texts)

    def __getitem__(self, item):
        inputs = prepare_input(self.cfg, 
                               self.pn_historys[item], 
                               self.feature_texts[item])
        label = create_label(self.cfg, 
                             self.pn_historys[item], 
                             self.annotation_lengths[item], 
                             self.locations[item])
        return inputs, label

# Model

In [None]:
class CustomModel(nn.Module):
    def __init__(self, cfg, config_path=None, pretrained=False):
        super().__init__()
        self.cfg = cfg
        if config_path is None:
            self.config = AutoConfig.from_pretrained(cfg.model, output_hidden_states=True)
        else:
            self.config = torch.load(config_path)
        if pretrained:
            self.model = AutoModel.from_pretrained(cfg.model, config=self.config)
        else:
            self.model = AutoModel.from_config(self.config)
#         self.model.resize_token_embeddings(len(tokenizer))
        
#         self.layer_norm = nn.LayerNorm(self.config.hidden_size)
        self.fc_dropout = nn.Dropout(cfg.fc_dropout)
        self.fc = nn.Linear(self.config.hidden_size, 1)
        self._init_weights(self.fc)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        
    def forward(self, inputs):
        last_hidden_states = self.model(**inputs)[0]
#         out = self.layer_norm(last_hidden_states)
        out = self.fc(self.fc_dropout(last_hidden_states))
        return out

# Helpler functions

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

In [None]:
def train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device):
    model.train()
    losses = AverageMeter()
    start = end = time.time()
    global_step = 0
    for step, (inputs, labels) in enumerate(train_loader):
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        y_preds = model(inputs)
        loss = criterion(y_preds.view(-1, 1), labels.view(-1, 1))
        loss = torch.masked_select(loss, labels.view(-1, 1) != -1).mean()
        
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        losses.update(loss.item(), batch_size)
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            if CFG.batch_scheduler:
                scheduler.step()
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print(f"Epoch:[{epoch+1}/{CFG.epochs}], batch:[{step}/{len(train_loader)}], maxlen:{CFG.max_len}, Elapsed:{timeSince(start, float(step+1) / len(train_loader)):s}, Loss: {losses.avg:.5f}, Grad: {grad_norm:.4f}, LR: {optimizer.param_groups[0]['lr']:.8f}")
    return losses.avg


def valid_fn(valid_loader, model, criterion, device):
    losses = AverageMeter()
    model.eval()
    preds = []
    start = end = time.time()
    for step, (inputs, labels) in enumerate(valid_loader):
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        with torch.no_grad():
            y_preds = model(inputs)
        loss = criterion(y_preds.view(-1, 1), labels.view(-1, 1))
        loss = torch.masked_select(loss, labels.view(-1, 1) != -1).mean()
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        losses.update(loss.item(), batch_size)
        preds.append(y_preds.sigmoid().to('cpu').numpy())
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print(f"EVAL: [{step}/{len(valid_loader)}], Elapsed {timeSince(start, float(step+1)/len(valid_loader)):s}, Loss: {losses.avg:.5f}")
 
    predictions = np.concatenate(preds)
    return losses.avg, predictions


def inference_fn(test_loader, model, device):
    preds = []
    model.eval()
    model.to(device)
    tk0 = tqdm(test_loader, total=len(test_loader))
    for inputs in tk0:
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        with torch.no_grad():
            y_preds = model(inputs)
        preds.append(y_preds.sigmoid().to('cpu').numpy())
    predictions = np.concatenate(preds)
    return predictions

In [None]:
def train_loop(folds, fold):
    
    print(f"========== fold: {fold} training ==========")

    train_folds = folds[folds['fold'] != fold].reset_index(drop=True)
    valid_folds = folds[folds['fold'] == fold].reset_index(drop=True)
    valid_texts = valid_folds['pn_history'].values
    valid_labels = create_labels_for_scoring(valid_folds)
    
    train_dataset = TrainDataset(CFG, train_folds)
    valid_dataset = TrainDataset(CFG, valid_folds)

    train_loader = DataLoader(train_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=True,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=True)  # false
    valid_loader = DataLoader(valid_dataset,
                              batch_size=CFG.valid_batch_size,
                              shuffle=False,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)

    model = CustomModel(CFG, config_path=None, pretrained=True)    
    torch.save(model.config, OUTPUT_DIR+'config.pth')  # 从model保存 model_config
    model.to(device)
    
    def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
        param_optimizer = list(model.named_parameters())
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_parameters = [
            {'params': [p for n, p in model.model.named_parameters() if not any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': weight_decay},
            {'params': [p for n, p in model.model.named_parameters() if any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': 0.0},
            {'params': [p for n, p in model.named_parameters() if "model" not in n],
             'lr': decoder_lr, 'weight_decay': 0.0}
        ]
        return optimizer_parameters

    optimizer_parameters = get_optimizer_params(model,
                                                encoder_lr=CFG.encoder_lr, 
                                                decoder_lr=CFG.decoder_lr,
                                                weight_decay=CFG.weight_decay)
    optimizer = AdamW(optimizer_parameters, lr=CFG.encoder_lr, eps=CFG.eps, betas=CFG.betas)
    
    def get_scheduler(cfg, optimizer, num_train_steps):
        if cfg.scheduler=='linear':
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps
            )
        elif cfg.scheduler=='cosine':
            scheduler = get_cosine_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps, num_cycles=cfg.num_cycles
            )
        return scheduler
    
    num_train_steps = int(len(train_folds) / CFG.batch_size * CFG.epochs)
    scheduler = get_scheduler(CFG, optimizer, num_train_steps)

    criterion = nn.BCEWithLogitsLoss(reduction="none")
    
    best_score = 0.

    for epoch in range(CFG.epochs):

        start_time = time.time()

        # train
        avg_loss = train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device)
        
        # eval
        avg_val_loss, predictions = valid_fn(valid_loader, model, criterion, device)
        predictions = predictions.reshape((len(valid_folds), CFG.max_len))
        
        # scoring from valid set
        char_probs = get_char_probs(valid_texts, predictions, CFG.tokenizer)  
        results = get_results(char_probs, th=0.5)
        preds = get_predictions(results)
        score = get_score(valid_labels, preds)

        elapsed = time.time() - start_time

        print(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.5f}  avg_val_loss: {avg_val_loss:.5f}  time: {elapsed:.0f}s')
        print(f'Epoch {epoch+1} - Score: {score:.5f}')
        
        if best_score < score:
            best_score = score
            print(f'Epoch {epoch+1} - Save Best Score: {best_score:.5f} Model')
            torch.save({'model': model.state_dict(),
                        'predictions': predictions},
                        OUTPUT_DIR + f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth")

    predictions = torch.load(OUTPUT_DIR + f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth", map_location=torch.device('cpu'))['predictions']
    valid_folds[[i for i in range(CFG.max_len)]] = predictions

    torch.cuda.empty_cache()
    gc.collect()
    
    return valid_folds

In [None]:
if __name__ == '__main__':
    
    def get_result(oof_df):
        labels = create_labels_for_scoring(oof_df)
        predictions = oof_df[[i for i in range(CFG.max_len)]].values
        char_probs = get_char_probs(oof_df['pn_history'].values, predictions, CFG.tokenizer)
        results = get_results(char_probs, th=0.5)
        preds = get_predictions(results)
        score = get_score(labels, preds)
        print(f'Score: {score:<.5f}')
    
    oof_df = pd.DataFrame()
    for fold in range(CFG.n_fold):
        if fold in CFG.trn_fold:
            _oof_df = train_loop(train, fold)
            oof_df = pd.concat([oof_df, _oof_df])
            print(f"========== fold: {fold} result ==========")
            get_result(_oof_df)
    oof_df = oof_df.reset_index(drop=True)
    print(f"========== CV ==========")
    get_result(oof_df)
    oof_df.to_pickle(OUTPUT_DIR + 'oof_df.pkl')

In [None]:
from IPython.display import FileLink 
FileLink(r'./microsoft-deberta-v3-large_fold4_best.pth')

In [None]:
"""
2e-5, epoch 3, fold 0:
count_diff: 670
Epoch 3 - avg_train_loss: 0.00624  avg_val_loss: 0.01063  time: 2410s
Epoch 3 - Score: 0.88539
"""

In [None]:
"""
3e-5, epoch 4
fold 0:
Epoch 4 - avg_train_loss: 0.00447  avg_val_loss: 0.01256
Epoch 4 - Score: 0.88294


fold 1:
count_diff: 634
Epoch 3 - avg_train_loss: 0.00701  avg_val_loss: 0.01207
Epoch 3 - Score: 0.88259

fold 2:
count_diff: 696
Epoch 4 - avg_train_loss: 0.00491  avg_val_loss: 0.01241
Epoch 4 - Score: 0.88047
Epoch 4 - Save Best Score: 0.88047 Model

fold 3:
count_diff: 664
Epoch 4 - avg_train_loss: 0.00424  avg_val_loss: 0.01338
Epoch 4 - Score: 0.87419


fold4: 
weight_decay: 0.01
count_diff: 679
Epoch 4 - avg_train_loss: 0.00406  avg_val_loss: 0.01271
Epoch 4 - Score: 0.88178

weight_decay: 0.0001
Epoch 4 - avg_train_loss: 0.00408  avg_val_loss: 0.01286
Epoch 4 - Score: 0.88392
"""