In [None]:
!pip install ../input/pytorchlightning160/pytorch_lightning-1.6.0-py3-none-any.whl
# !pip install pytorch-crf

In [None]:
# 因为 transformers 里面的 Deberta 不支持 FastTokenizer，
# 而我们需要用到offset_mappping，这个只有 FastTokenizer 才有，所以需要导入 FastTokenizer
# The following is necessary if you want to use the fast tokenizer for deberta v2 or v3
import shutil
from pathlib import Path

transformers_path = Path("/opt/conda/lib/python3.7/site-packages/transformers")

input_dir = Path("../input/deberta-v2-3-fast-tokenizer")

convert_file = input_dir / "convert_slow_tokenizer.py"
conversion_path = transformers_path/convert_file.name

if conversion_path.exists():
    conversion_path.unlink()

shutil.copy(convert_file, transformers_path)
deberta_v2_path = transformers_path / "models" / "deberta_v2"

for filename in ['tokenization_deberta_v2.py', 'tokenization_deberta_v2_fast.py', "deberta__init__.py"]:
    if str(filename).startswith("deberta"):
        filepath = deberta_v2_path/str(filename).replace("deberta", "")
    else:
        filepath = deberta_v2_path/filename
    if filepath.exists():
        filepath.unlink()

    shutil.copy(input_dir/filename, filepath)

## Library

In [None]:
import os
import gc
import re
import ast
import sys
import copy
import json
import time
import math
import string
import pickle
import random
import joblib
import itertools
import warnings
warnings.filterwarnings("ignore")

from pathlib import Path

import scipy as sp
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
from tqdm.auto import tqdm
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import DataLoader, Dataset
# from torchcrf import CRF
import pytorch_lightning as pl
from pathlib import Path

# os.system('pip uninstall -y transformers')
# os.system('python -m pip install --no-index --find-links=../input/nbme-pip-wheels transformers')
import tokenizers
import transformers
print(f"tokenizers.__version__: {tokenizers.__version__}")
print(f"transformers.__version__: {transformers.__version__}")
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
%env TOKENIZERS_PARALLELISM=true

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

**数据解析**

本次比赛提供了5份数据分别是 train, test, features, patient_notes, submission， 其中test, submission为提交答案时用

重点是如下3个文件

train 文件标记了每个病例中，不同症状的相关描述

features 中给出了所有病症的名称和id

patient_notes 中给出了每份病例的详细描述

总体来说我们希望，通过对病例和症状的分析和挖掘，自动的找出不同病症在病例中的相关描述

## CFG

In [None]:
class CFG:
    ## 常规设置
#     data_dir = '/home/tzj/data/nbme-score-clinical-patient-notes'
#     output_dir = '../output'
    data_dir = '../input/nbme-score-clinical-patient-notes'
    output_dir = './'
    debug = False
    debug_size = 0
    train = True
#     seed = 42
    seed = 6001
    n_fold = 5
    trn_fold = [0]
    print_freq = 100
#     print_freq = 1

    ## 数据设置
    num_workers = 4
    batch_size = 4
    max_len = 512
    pin_memory = True

    ## 模型设置
#     model = "/home/tzj/pretrained_models/en-deberta-v3-large"
    model = "../input/deberta-v3-large/deberta-v3-large"
    fc_dropout = 0.2
    fgm = False
    label_smooth = False
    smoothing = 0.05


    ## 优化器设置
    scheduler = 'cosine'  # ['linear', 'cosine']
#     batch_scheduler = True
    num_cycles = 0.5
    warmup_steps = 0.1
    encoder_lr = 2e-5
    decoder_lr = 2e-5
#     min_lr = 1e-6
    eps = 1e-6
    betas = (0.9, 0.999)
    weight_decay = 0.01

    ## Trainer设置
    apex = False
    apex_level = 'O1'
    max_epochs = 5
    gradient_accumulation_steps = 4
    precision = 32
    max_grad_norm = 1
    fast_dev_run = 0 # 快速检验，取 n 个train, val, test batches
    num_sanity_val_steps = 0 # 在开始前取 n 个val batches
    val_check_interval = 0.5


if CFG.debug:
    CFG.debug_size = 100
    CFG.epochs = 2
    CFG.trn_fold = [0]
    CFG.fast_dev_run = 2
    CFG.num_sanity_val_steps = 0
    CFG.val_check_interval = 0.5


## Helper functions for scoring

In [None]:
# From https://www.kaggle.com/theoviel/evaluation-metric-folds-baseline

def micro_f1(preds, truths):
    """
    Micro f1 on binary arrays.

    Args:
        preds (list of lists of ints): Predictions.
        truths (list of lists of ints): Ground truths.

    Returns:
        float: f1 score.
    """
    # Micro : aggregating over all instances
    preds = np.concatenate(preds)
    truths = np.concatenate(truths)
    return f1_score(truths, preds)


def spans_to_binary(spans, length=None):
    """
    Converts spans to a binary array indicating whether each character is in the span.

    Args:
        spans (list of lists of two ints): Spans.

    Returns:
        np array [length]: Binarized spans.
    """
    length = np.max(spans) if length is None else length
    binary = np.zeros(length)
    for start, end in spans:
        binary[start:end] = 1
    return binary


def span_micro_f1(truths, preds):
    """
    Micro f1 on spans.

    Args:
        preds (list of lists of two ints): Prediction spans.
        truths (list of lists of two ints): Ground truth spans.

    Returns:
        float: f1 score.
        
    将 preds 和 truths 转换为 0，1 编码， 1 表示是annotation
    然后进行 f1_score(binary)
    """
    bin_preds = []
    bin_truths = []
    for pred, truth in zip(preds, truths):
        if not len(pred) and not len(truth):
            continue
        length = max(np.max(pred) if len(pred) else 0, np.max(truth) if len(truth) else 0)
        bin_preds.append(spans_to_binary(pred, length))
        bin_truths.append(spans_to_binary(truth, length))
    return micro_f1(bin_preds, bin_truths)

### Label Smoothing

In [None]:
class LabelSmoothLoss(nn.Module):
    def __init__(self, smoothing=0.0, loss_func=nn.BCEWithLogitsLoss(reduction='sum')):
        super(LabelSmoothLoss, self).__init__()
        self.smoothing = smoothing
        self.loss_func = loss_func

    def forward(self, inputs, target):
        # inputs为未经过激活的logits
        #target为数值时才使用scatter_， 此处target为one-hot
        '''
        log_prob = F.log_softmax(inputs, dim=-1)
        weight = inputs.new_ones(inputs.size()) * self.smoothing / (inputs.size(-1) - 1.)
        weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing))
        loss = (-weight * log_prob).sum(dim=-1).mean()'''
        '''
        log_prob = F.log_softmax(inputs, dim=-1)
        # 由于将多标签看为多个二分类，因此不用除以类别数
        weight = inputs.new_ones(inputs.size()) * self.smoothing
        weight[target==1] = 1. - self.smoothing
        loss = (-weight * log_prob).sum(dim=-1).mean()'''

        weight = inputs.new_ones(inputs.size()) * self.smoothing
        weight[target == 1] = 1. - self.smoothing
        loss = self.loss_func(inputs, weight)
        return loss

In [None]:
def create_labels_for_scoring(df):
    # 整理原数据集中的 location ，作为打分的标签
    truths = []
    for location_list in df['location']:
        # 有些标注中带有 ";"
        location_list = [loc for location in location_list for loc in location.split(';')]
        truth = []
        if len(location_list) > 0:
            for loc in location_list:
                start, end = loc.split()
                truth.append([int(start), int(end)])
        truths.append(truth)
        '''
        输入形式如下：
        [[[696, 724]],
         [[668, 693]],
         [[203, 217]],
         [[70, 91], [176, 183]],
         [[222, 258]],
         [],
         [[321, 329], [404, 413], [652, 661]]]
        '''
    return truths

def get_char_probs(features, texts, predictions, tokenizer):  
    # 获取每个字符所属类别的概率
    results = [np.zeros(len(t)) for t in texts]
    for i, (feature, text, prediction) in enumerate(zip(features, texts, predictions)):
        encoded = tokenizer(feature, text, add_special_tokens=True, return_offsets_mapping=True)
        offset_mapping = encoded['offset_mapping']
        sequence_ids = encoded.sequence_ids()
        # 这里 offset_mapping 和 prediction 的长度可能不一致，因为 predictions 带有填充，但是 zip 自动丢弃了多余的部分
        for j, (offset, pred) in enumerate(zip(offset_mapping, prediction)):
            if sequence_ids[j] != 1:
                continue
            start = offset[0]
            end = offset[1]
            # 属于同一个 token 的 char 的概率统一为该 token 的概率
            results[i][start:end] = pred
    return results

def get_results(char_probs, th=0.5):
    # 获得预测结果大于 th 的 char index， 并获得其起止位置， 用 “；”隔开每一对
    results = []
    for char_prob in char_probs:
        result = np.where(char_prob >= th)[0] + 1
#         result = np.where(char_prob >= th)[0]
        # itertools.count()： 计数器，默认从 0 开始
        # itertools.groupby(res, key)，将 res 中所有 key 相同的元素进行分组
        result = [list(g) for _, g in itertools.groupby(result, key=lambda n, c=itertools.count(): n - next(c))]
        result = [f"{min(r)} {max(r)}" for r in result]
#         result = [f"{min(r)} {max(r) + 1}" for r in result]
        result = ";".join(result)
        results.append(result)
        '''
        返回形式如下：
        ['2 3;11 14;30 32;44 46']
        '''
    return results


def get_predictions(results):
    # 将 str 类型的预测的 char index 转为 int 型
    predictions = []
    for result in results:
        prediction = []
        if result != "":
            for loc in [s.split() for s in result.split(';')]:
                start, end = int(loc[0]), int(loc[1])
                prediction.append([start, end])
        predictions.append(prediction)
    '''
    返回形式如下：
    [[2,3], [11, 14], [30, 32], [44, 46]]
    '''
    return predictions


def get_score(y_true, y_pred):
    score = span_micro_f1(y_true, y_pred)
    return score

In [None]:
# def create_labels_for_scoring_n(df):
#     # example: ['0 1', '3 4'] -> ['0 1; 3 4']
#     df['location_for_create_labels'] = [ast.literal_eval(f'[]')] * len(df)
#     for i in range(len(df)):
#         lst = df.loc[i, 'location']
#         if lst:
#             new_lst = ';'.join(lst)
#             df.loc[i, 'location_for_create_labels'] = ast.literal_eval(f'[["{new_lst}"]]')
#     # create labels
#     truths = []
#     for location_list in df['location_for_create_labels'].values:
#         truth = []
#         if len(location_list) > 0:
#             location = location_list[0]
#             for loc in [s.split() for s in location.split(';')]:
#                 start, end = int(loc[0]), int(loc[1])
#                 truth.append([start, end])
#         truths.append(truth)
#     return truths



# def get_char_probs_n(texts, predictions, tokenizer):
#     # 获取每个字符所属类别的概率
#     results = [np.zeros(len(t)) for t in texts]

#     for i, (text, prediction) in enumerate(zip(texts, predictions)):
#         encoded = tokenizer(text,
#                             add_special_tokens=True,
#                             return_offsets_mapping=True)
#         # 这里 offset_mapping 和 prediction 的长度可能不一致，因为 predictions 带有填充，但是 zip 自动丢弃了多余的部分
#         for idx, (offset_mapping, pred) in enumerate(zip(encoded['offset_mapping'], prediction)):
#             start = offset_mapping[0]
#             end = offset_mapping[1]
#             # 属于同一个 token 的 char 的概率统一为该 token 的概率
#             results[i][start:end] = pred
#     return results

# def get_results(char_probs, th=0.5):
#     # 获得预测结果大于 th 的 char index， 并获得其起止位置， 用 “；”隔开每一对
#     results = []
#     for char_prob in char_probs:
#         result = np.where(char_prob >= th)[0] + 1
# #         result = np.where(char_prob >= th)[0]
#         # itertools.count()： 计数器，默认从 0 开始
#         # itertools.groupby(res, key)，将 res 中所有 key 相同的元素进行分组
#         result = [list(g) for _, g in itertools.groupby(result, key=lambda n, c=itertools.count(): n - next(c))]
#         result = [f"{min(r)} {max(r)}" for r in result]
# #         result = [f"{min(r)} {max(r) + 1}" for r in result]
#         result = ";".join(result)
#         results.append(result)
#         '''
#         返回形式如下：
#         ['2 3;11 14;30 32;44 46']
#         '''
#     return results


# def get_predictions(results):
#     # 将 str 类型的预测的 char index 转为 int 型
#     predictions = []
#     for result in results:
#         prediction = []
#         if result != "":
#             for loc in [s.split() for s in result.split(';')]:
#                 start, end = int(loc[0]), int(loc[1])
#                 prediction.append([start, end])
#         predictions.append(prediction)
#     '''
#     返回形式如下：
#     [[2,3], [11, 14], [30, 32], [44, 46]]
#     '''
#     return predictions


# def get_score(y_true, y_pred):
#     score = span_micro_f1(y_true, y_pred)
#     return score

## Utils

### incorrect annotation

In [None]:
def process_feature_text(text):
    text = re.sub('I-year', '1-year', text)
    text = re.sub('-OR-', " or ", text)
    text = re.sub('-', ' ', text)
    return text

# incorrect annotation
def correcting(train, features, patient_notes):
    train.loc[338, 'annotation'] = ast.literal_eval('[["father heart attack"]]')
    train.loc[338, 'location'] = ast.literal_eval('[["764 783"]]')

    train.loc[621, 'annotation'] = ast.literal_eval('[["for the last 2-3 months"]]')
    train.loc[621, 'location'] = ast.literal_eval('[["77 100"]]')

    train.loc[655, 'annotation'] = ast.literal_eval('[["no heat intolerance"], ["no cold intolerance"]]')
    train.loc[655, 'location'] = ast.literal_eval('[["285 292;301 312"], ["285 287;296 312"]]')

    train.loc[1262, 'annotation'] = ast.literal_eval('[["mother thyroid problem"]]')
    train.loc[1262, 'location'] = ast.literal_eval('[["551 557;565 580"]]')

    train.loc[1265, 'annotation'] = ast.literal_eval('[[\'felt like he was going to "pass out"\']]')
    train.loc[1265, 'location'] = ast.literal_eval('[["131 135;181 212"]]')

    train.loc[1396, 'annotation'] = ast.literal_eval('[["stool , with no blood"]]')
    train.loc[1396, 'location'] = ast.literal_eval('[["259 280"]]')

    train.loc[1591, 'annotation'] = ast.literal_eval('[["diarrhoe non blooody"]]')
    train.loc[1591, 'location'] = ast.literal_eval('[["176 184;201 212"]]')

    train.loc[1615, 'annotation'] = ast.literal_eval('[["diarrhea for last 2-3 days"]]')
    train.loc[1615, 'location'] = ast.literal_eval('[["249 257;271 288"]]')

    train.loc[1664, 'annotation'] = ast.literal_eval('[["no vaginal discharge"]]')
    train.loc[1664, 'location'] = ast.literal_eval('[["822 824;907 924"]]')

    train.loc[1714, 'annotation'] = ast.literal_eval('[["started about 8-10 hours ago"]]')
    train.loc[1714, 'location'] = ast.literal_eval('[["101 129"]]')

    train.loc[1929, 'annotation'] = ast.literal_eval('[["no blood in the stool"]]')
    train.loc[1929, 'location'] = ast.literal_eval('[["531 539;549 561"]]')

    train.loc[2134, 'annotation'] = ast.literal_eval('[["last sexually active 9 months ago"]]')
    train.loc[2134, 'location'] = ast.literal_eval('[["540 560;581 593"]]')

    train.loc[2191, 'annotation'] = ast.literal_eval('[["right lower quadrant pain"]]')
    train.loc[2191, 'location'] = ast.literal_eval('[["32 57"]]')

    train.loc[2553, 'annotation'] = ast.literal_eval('[["diarrhoea no blood"]]')
    train.loc[2553, 'location'] = ast.literal_eval('[["308 317;376 384"]]')

    train.loc[3124, 'annotation'] = ast.literal_eval('[["sweating"]]')
    train.loc[3124, 'location'] = ast.literal_eval('[["549 557"]]')

    train.loc[3858, 'annotation'] = ast.literal_eval(
        '[["previously as regular"], ["previously eveyr 28-29 days"], ["previously lasting 5 days"], ["previously regular flow"]]')
    train.loc[3858, 'location'] = ast.literal_eval(
        '[["102 123"], ["102 112;125 141"], ["102 112;143 157"], ["102 112;159 171"]]')

    train.loc[4373, 'annotation'] = ast.literal_eval('[["for 2 months"]]')
    train.loc[4373, 'location'] = ast.literal_eval('[["33 45"]]')

    train.loc[4763, 'annotation'] = ast.literal_eval('[["35 year old"]]')
    train.loc[4763, 'location'] = ast.literal_eval('[["5 16"]]')

    train.loc[4782, 'annotation'] = ast.literal_eval('[["darker brown stools"]]')
    train.loc[4782, 'location'] = ast.literal_eval('[["175 194"]]')

    train.loc[4908, 'annotation'] = ast.literal_eval('[["uncle with peptic ulcer"]]')
    train.loc[4908, 'location'] = ast.literal_eval('[["700 723"]]')

    train.loc[6016, 'annotation'] = ast.literal_eval('[["difficulty falling asleep"]]')
    train.loc[6016, 'location'] = ast.literal_eval('[["225 250"]]')

    train.loc[6192, 'annotation'] = ast.literal_eval('[["helps to take care of aging mother and in-laws"]]')
    train.loc[6192, 'location'] = ast.literal_eval('[["197 218;236 260"]]')

    train.loc[6380, 'annotation'] = ast.literal_eval(
        '[["No hair changes"], ["No skin changes"], ["No GI changes"], ["No palpitations"], ["No excessive sweating"]]')
    train.loc[6380, 'location'] = ast.literal_eval(
        '[["480 482;507 519"], ["480 482;499 503;512 519"], ["480 482;521 531"], ["480 482;533 545"], ["480 482;564 582"]]')

    train.loc[6562, 'annotation'] = ast.literal_eval(
        '[["stressed due to taking care of her mother"], ["stressed due to taking care of husbands parents"]]')
    train.loc[6562, 'location'] = ast.literal_eval('[["290 320;327 337"], ["290 320;342 358"]]')

    train.loc[6862, 'annotation'] = ast.literal_eval('[["stressor taking care of many sick family members"]]')
    train.loc[6862, 'location'] = ast.literal_eval('[["288 296;324 363"]]')

    train.loc[7022, 'annotation'] = ast.literal_eval(
        '[["heart started racing and felt numbness for the 1st time in her finger tips"]]')
    train.loc[7022, 'location'] = ast.literal_eval('[["108 182"]]')

    train.loc[7422, 'annotation'] = ast.literal_eval('[["first started 5 yrs"]]')
    train.loc[7422, 'location'] = ast.literal_eval('[["102 121"]]')

    train.loc[8876, 'annotation'] = ast.literal_eval('[["No shortness of breath"]]')
    train.loc[8876, 'location'] = ast.literal_eval('[["481 483;533 552"]]')

    train.loc[9027, 'annotation'] = ast.literal_eval('[["recent URI"], ["nasal stuffines, rhinorrhea, for 3-4 days"]]')
    train.loc[9027, 'location'] = ast.literal_eval('[["92 102"], ["123 164"]]')

    train.loc[9938, 'annotation'] = ast.literal_eval(
        '[["irregularity with her cycles"], ["heavier bleeding"], ["changes her pad every couple hours"]]')
    train.loc[9938, 'location'] = ast.literal_eval('[["89 117"], ["122 138"], ["368 402"]]')

    train.loc[9973, 'annotation'] = ast.literal_eval('[["gaining 10-15 lbs"]]')
    train.loc[9973, 'location'] = ast.literal_eval('[["344 361"]]')

    train.loc[10513, 'annotation'] = ast.literal_eval('[["weight gain"], ["gain of 10-16lbs"]]')
    train.loc[10513, 'location'] = ast.literal_eval('[["600 611"], ["607 623"]]')

    train.loc[11551, 'annotation'] = ast.literal_eval('[["seeing her son knows are not real"]]')
    train.loc[11551, 'location'] = ast.literal_eval('[["386 400;443 461"]]')

    train.loc[11677, 'annotation'] = ast.literal_eval('[["saw him once in the kitchen after he died"]]')
    train.loc[11677, 'location'] = ast.literal_eval('[["160 201"]]')

    train.loc[12124, 'annotation'] = ast.literal_eval('[["tried Ambien but it didnt work"]]')
    train.loc[12124, 'location'] = ast.literal_eval('[["325 337;349 366"]]')

    train.loc[12279, 'annotation'] = ast.literal_eval(
        '[["heard what she described as a party later than evening these things did not actually happen"]]')
    train.loc[12279, 'location'] = ast.literal_eval('[["405 459;488 524"]]')

    train.loc[12289, 'annotation'] = ast.literal_eval(
        '[["experienced seeing her son at the kitchen table these things did not actually happen"]]')
    train.loc[12289, 'location'] = ast.literal_eval('[["353 400;488 524"]]')

    train.loc[13238, 'annotation'] = ast.literal_eval('[["SCRACHY THROAT"], ["RUNNY NOSE"]]')
    train.loc[13238, 'location'] = ast.literal_eval('[["293 307"], ["321 331"]]')

    train.loc[13297, 'annotation'] = ast.literal_eval(
        '[["without improvement when taking tylenol"], ["without improvement when taking ibuprofen"]]')
    train.loc[13297, 'location'] = ast.literal_eval('[["182 221"], ["182 213;225 234"]]')

    train.loc[13299, 'annotation'] = ast.literal_eval('[["yesterday"], ["yesterday"]]')
    train.loc[13299, 'location'] = ast.literal_eval('[["79 88"], ["409 418"]]')

    train.loc[13845, 'annotation'] = ast.literal_eval('[["headache global"], ["headache throughout her head"]]')
    train.loc[13845, 'location'] = ast.literal_eval('[["86 94;230 236"], ["86 94;237 256"]]')

    train.loc[14083, 'annotation'] = ast.literal_eval('[["headache generalized in her head"]]')
    train.loc[14083, 'location'] = ast.literal_eval('[["56 64;156 179"]]')

    features.loc[27, 'feature_text'] = "Last-Pap-smear-1-year-ago"

    patient_notes['pn_history'] = patient_notes['pn_history'].apply(
        lambda x: x.replace('dad with recent heart attcak', 'dad with recent heart attack'))

    features['feature_text'] = features['feature_text'].apply(process_feature_text)

    return train, features, patient_notes

### CV split

In [None]:
def CV_group_split(dataset, n_splits=5, debug=False, debug_size=1000):
# 使用 GroupKFold 是因为训练数据中，每一条 patient_note 中有多个标记，因此需要从 patient_note 层面对数据进行切分
# 即将每条 patient_note 的不同标记划分到训练集和验证集中
    Fold = GroupKFold(n_splits=n_splits)
    groups = dataset['pn_num'].values
    for n, (train_index, val_index) in enumerate(Fold.split(dataset, dataset['location'], groups)):
        dataset.loc[val_index, 'fold'] = int(n)
    dataset['fold'] = dataset['fold'].astype(int)
    if debug:
        dataset = dataset.sample(n=debug_size, random_state=0).reset_index(drop=True)
    return dataset

### Tokenizer

In [None]:
def get_tokenizer(tokenizer_path):
    if 'deberta' in tokenizer_path:
        from transformers.models.deberta_v2 import DebertaV2TokenizerFast
        tokenizer = DebertaV2TokenizerFast.from_pretrained(tokenizer_path)
    else:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    return tokenizer


## DataModule

In [None]:
class NBMEDataModule(pl.LightningDataModule):
    def __init__(self, config, prepare_train=True, prepare_test=True):
        super().__init__()
        self.prepare_data_per_node = False
        self.debug = config.debug
        self.debug_size = 0 if self.debug == False else config.debug_size
        self.shuffle = (self.debug == False)
        self.batch_size = config.batch_size
        self.pin_memory = config.pin_memory
        self.num_workers = config.num_workers
        self.max_len = config.max_len
        self.data_dir = config.data_dir
        self.n_fold = config.n_fold

        self.tokenizer = get_tokenizer(config.model)
        self.prepare_train = prepare_train
        self.prepare_test = prepare_test

    def set_trn_fold(self, trn_fold):
        self.trn_fold = trn_fold

    def load_train(self):
        train = pd.read_csv(Path(self.data_dir) / 'train.csv')
        features = pd.read_csv(Path(self.data_dir) / 'features.csv')
        patient_notes = pd.read_csv(Path(self.data_dir) / 'patient_notes.csv')
        train['annotation'] = train['annotation'].apply(ast.literal_eval)
        train['location'] = train['location'].apply(ast.literal_eval)

        train, features, patient_notes = correcting(train, features, patient_notes)
        train = train.merge(features, on=['feature_num', 'case_num'], how='left')
        train = train.merge(patient_notes, on=['pn_num', 'case_num'], how='left')
        train['annotation_length'] = train['annotation'].apply(len)
        return train

    def load_test(self):
        test = pd.read_csv(Path(self.data_dir) / 'test.csv')
        features = pd.read_csv(Path(self.data_dir) / 'features.csv')
        patient_notes = pd.read_csv(Path(self.data_dir) / 'patient_notes.csv')
        submission = pd.read_csv(Path(self.data_dir) / 'sample_submission.csv')

        features.loc[27, 'feature_text'] = "Last-Pap-smear-1-year-ago"

        test = test.merge(features, on=['feature_num', 'case_num'], how='left')
        test = test.merge(patient_notes, on=['pn_num', 'case_num'], how='left')
        return test, submission

    def calculate_max_len(self, dataset):
        dataset['pn_history'].fillna('')
        dataset['feature_text'].fillna('')
        tqdm.pandas(desc="pn_history_lens")
        pn_history_lens = dataset['pn_history'].progress_apply(
            lambda x: len(self.tokenizer(x, add_special_tokens=False)['input_ids']))
        tqdm.pandas(desc="pn_history_lens")
        feature_text_lens = dataset['feature_text'].progress_apply(
            lambda x: len(self.tokenizer(x, add_special_tokens=False)['input_ids']))
        max_len_feat = feature_text_lens.max()
        max_len_pn = pn_history_lens.max()
        return (feature_text_lens, pn_history_lens, max_len_feat + max_len_pn + 3)  # cls & sep & sep

    def prepare_data(self):
        if self.prepare_train == True:
            train = self.load_train()
            # 将数据切分成 n 折
            train = CV_group_split(train, self.n_fold, self.debug, self.debug_size)
            self.train_max_len = self.calculate_max_len(train)[2]
            self.train = train
            self.prepare_train = False
            print('Train data prepared!')

        if self.prepare_test == True:
            self.test, self.submission = self.load_test()
            self.prepare_test = False
            print('Test data prepared!')

    def setup(self, stage='fit'):
        if stage == 'fit':
            self.build_fit_dataset(trn_fold=self.trn_fold)

        elif stage == 'test':
            self.build_test_dataset()

        elif stage == 'predict':
            self.build_predict_dataset()

    def build_fit_dataset(self, trn_fold=None):
        df = self.train
        if trn_fold != None:
            self.train_df = df[df['fold'] != trn_fold].reset_index(drop=True)
            self.val_df = df[df['fold'] == trn_fold].reset_index(drop=True)
            self.train_dataset = NBMEDataset(self.train_df, self.tokenizer, self.train_max_len)
            self.val_dataset = NBMEDataset(self.val_df, self.tokenizer, self.train_max_len)

    def build_test_dataset(self):
        self.test_dataset = NBMEInferDataset(self.test, self.tokenizer, self.max_len)

    def build_predict_dataset(self):
        self.predict_dataset = NBMEInferDataset(self.test, self.tokenizer, self.max_len)

    def train_dataloader(self):
        loader = DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers,
                            pin_memory=self.pin_memory, shuffle=self.shuffle)
        return loader

    def val_dataloader(self):
        loader = DataLoader(self.val_dataset, batch_size=self.batch_size * 4, num_workers=self.num_workers,
                            shuffle=False)
        return loader

    def test_dataloader(self):
        loader = DataLoader(self.test_dataset, batch_size=self.batch_size * 4, num_workers=self.num_workers,
                            shuffle=False)
        return loader

    def predict_dataloader(self):
        loader = DataLoader(self.predict_dataset, batch_size=self.batch_size * 4, num_workers=self.num_workers,
                            shuffle=False)
        return loader


### Dataset

In [None]:
def tokenize(tokenizer, singel_data, max_len, return_offsets_mapping=False):
    features, text = singel_data[['feature_text', 'pn_history']]
    inputs = tokenizer(
        features,  # question
        text,  # paragraph
        add_special_tokens=True,  # cls, sep
        max_length=max_len,
        padding='max_length',
        return_offsets_mapping=return_offsets_mapping)
    for k, v in inputs.items():
        inputs[k] = torch.tensor(v, dtype=torch.long)
    return inputs

def add_labels(tokenizer, singel_data, max_len, return_offsets_mapping=False):
    features, text, location_list, annotation_list = singel_data[
        ['feature_text', 'pn_history', 'location', 'annotation']]
    inputs = tokenizer(
        features,  # question
        text,  # paragraph
        add_special_tokens=True,  # cls, sep
        max_length=max_len,
        padding='max_length',
        return_offsets_mapping=True,
        return_tensors='pt')
    for k, v in inputs.items():
        inputs[k].squeeze_()

    offset_mapping = inputs.pop('offset_mapping')
    sequence_ids = np.array(inputs.sequence_ids())

    label = np.where(sequence_ids != 1, -1, 0)
    # 找到 patient note 的开始和结束下标
    token_start_idx = np.where(label == 0)[0][0]
    token_end_idx = np.where(label == 0)[0][-1]

    # 每个 feature 可能在同一条 patient note 出现多次
    for i, location in enumerate(location_list):
        # 注意，可能有些location 存在";"，如第 8478 条，表示对应的 annotation 存在跳跃
        # ['79 94;100 116']  -> [['79 94'], ['100' '116']]
        location = [s for s in location.split(';')]
        for loc in location:
            char_start, char_end = map(int, loc.split())
            cur_token_start, cur_token_end = token_start_idx, token_end_idx

            ######## offset_mapping 里面， 每一组(start, end)可能会包含前置的空格
            # token_start_index 不能超过界限， 并且其对应单词的首个 char 的位置不能大于 answer 的 start_char
            while cur_token_start <= cur_token_end and offset_mapping[cur_token_start][0] <= char_start:
                cur_token_start += 1
            cur_token_start -= 1

            while cur_token_start <= cur_token_end and offset_mapping[cur_token_end][1] >= char_end:
                cur_token_end -= 1
            cur_token_end += 1

            label[cur_token_start: cur_token_end + 1] = 1.0
    label = torch.tensor(label, dtype=torch.float32)
    return label

In [None]:
class NBMEDataset(Dataset):
    def __init__(self, df, tokenizer, max_len):
        super().__init__()
        self.df = df
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        single_data = self.df.iloc[index]
        inputs = tokenize(self.tokenizer, single_data, self.max_len)
        label = add_labels(self.tokenizer, single_data, self.max_len, return_offsets_mapping=False)
        return inputs, label


### FGM

In [None]:
class FGM():
    """
    定义对抗训练方法FGM,对模型embedding参数进行扰动
    """
    def __init__(self, model, epsilon=0.25):
        self.model = model
        self.epsilon = epsilon
        self.backup = {}

    def attack(self, embed_name='word_embeddings'):
        """
        得到对抗样本
        :param emb_name:模型中embedding的参数名
        :return:
        """
        for name, param in self.model.named_parameters():
            if param.requires_grad and embed_name in name:
                self.backup[name] = param.data.clone()
                norm = torch.norm(param.grad)

                if norm != 0 and not torch.isnan(norm):
                    r_at = self.epsilon * param.grad / norm
                    param.data.add_(r_at)

    def restore(self, embed_name='word_embeddings'):
        for name, param in self.model.named_parameters():
            if param.requires_grad and embed_name in name:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}




## Model

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class NBMEModel(pl.LightningModule):
    def __init__(self, config, model_config_path=None, pretrained=False, weight_path=None):
        super().__init__()
        self.save_hyperparameters('config')

        if model_config_path:
            self.model_config = torch.load(model_config_path)
        else:
            self.model_config = AutoConfig.from_pretrained(config.model, output_hidden_states=True)
        if pretrained:
            self.model = AutoModel.from_pretrained(config.model, config=self.model_config)
        else:
            self.model = AutoModel.from_config(self.model_config)

        self.fc = nn.Linear(self.model_config.hidden_size, 1)

        # TODO multi_dropout / layer norm
        self.dropout_0 = nn.Dropout(config.fc_dropout / 2.)
        self.dropout_1 = nn.Dropout(config.fc_dropout / 1.5)
        self.dropout_2 = nn.Dropout(config.fc_dropout)
        self.dropout_3 = nn.Dropout(config.fc_dropout * 1.5)
        self.dropout_4 = nn.Dropout(config.fc_dropout * 2.)     
        self.__init_weight(self.fc)
        self.__set_metrics()
 
        if config.label_smooth:
            self.criterion = LabelSmoothLoss(smoothing=config.smoothing, loss_func=nn.BCEWithLogitsLoss(reduction="none"))
        else:
            self.criterion = nn.BCEWithLogitsLoss(reduction="none")

        if hasattr(self.hparams.config, 'fgm') and self.hparams.config.fgm:
            self.automatic_optimization = False
            self.fgm = FGM(self)
            
        if weight_path != None:
            weight = torch.load(weight_path, map_location='cpu')
            if 'state_dict' in weight.keys():
                weight = weight['state_dict']
            self.load_state_dict(weight)


    def __set_metrics(self):
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.val_acc = AverageMeter()

        self.train_losses.reset()
        self.val_losses.reset()
        self.val_acc.reset()

    def __init_weight(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.model_config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.model_config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, inputs):
        outputs = self.model(**inputs)
        last_hidden_states, pooler_output = outputs[0], outputs[1]
        output_0 = self.fc(self.dropout_0(last_hidden_states))
        output_1 = self.fc(self.dropout_1(last_hidden_states))
        output_2 = self.fc(self.dropout_2(last_hidden_states))
        output_3 = self.fc(self.dropout_3(last_hidden_states))
        output_4 = self.fc(self.dropout_4(last_hidden_states))
        return (output_0 + output_1 + output_2 + output_3 + output_4) / 5
#         return output_2

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        y_preds = self.forward(inputs)
        loss = self.criterion(y_preds.view(-1, 1), labels.view(-1, 1))
        loss = torch.masked_select(loss, labels.view(-1, 1) != -1).mean()
        self.train_losses.update(loss.item(), len(labels))
        self.log('train/avg_loss', self.train_losses.avg)
        # 因为 optimizer 有 3 组参数，所有 get_last_lr() 会返回含有 3 个元素的列表
        en_lr = self.trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[0]
        de_lr = self.trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[-1]
        self.log('train/en_lr', en_lr, prog_bar=True)
        self.log('train/de_lr', de_lr, prog_bar=True)

        if (self.trainer.global_step) % self.hparams.config.print_freq == 0:
            # if (self.trainer.global_step + 1) % self.hparams.config.print_freq == 0:
            self.print('Global step:{global_step}.'
                       'Train Loss: {loss.val:.4f}(avg: {loss.avg:.4f}) '
                       'Encoder LR: {en_lr:.8f}, Decoder LR: {de_lr:.8f}'
                       .format(global_step=self.trainer.global_step,
                           loss=self.train_losses,
                               en_lr=en_lr,
                               de_lr=de_lr))
        # 如果没有FGM，在这里就可以返回loss
        # 为了使用FGM，这里要手动进行求导和优化器更新
        if self.hparams.config.fgm:
            # loss regularization， 但是不加效果要更好一些
            # if self.hparams.config.gradient_accumulation_steps > 1:
            #     loss = loss / self.hparams.config.gradient_accumulation_steps
            self.manual_backward(loss)
            torch.nn.utils.clip_grad_norm(self.parameters(), self.hparams.config.max_grad_norm)
            # 这里不能用 global_step ，否则因为关闭了自动优化，global_step 只能在 step 之后才会更新，会陷入死循环
            if (batch_idx + 1) % self.hparams.config.gradient_accumulation_steps == 0:
            # if (self.trainer.global_step + 1) % self.hparams.config.gradient_accumulation_steps == 0:
                self.fgm.attack()
                y_preds_adv = self.forward(inputs)
                loss_adv = self.criterion(y_preds_adv.view(-1, 1), labels.view(-1, 1))
                loss_adv = torch.masked_select(loss_adv, labels.view(-1, 1) != -1).mean()
                self.manual_backward(loss_adv)
                self.fgm.restore()

                opt = self.optimizers()
                opt.step()
                opt.zero_grad()
                sch = self.lr_schedulers()
                sch.step()

        return loss

    def training_epoch_end(self, outs):
        torch.cuda.empty_cache()

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        y_preds = self.forward(inputs)
        loss = self.criterion(y_preds.view(-1, 1), labels.view(-1, 1))
        loss = torch.masked_select(loss, labels.view(-1, 1) != -1).mean()
        self.val_losses.update(loss.item(), len(labels))
        self.log('val/avg_loss', self.val_losses.avg)
        return loss, y_preds.sigmoid().cpu().numpy()

    def validation_epoch_end(self, outs):
        val_df = self.trainer.datamodule.val_df
        val_labels = create_labels_for_scoring(val_df)
        valid_features, valid_texts = val_df['feature_text'], val_df['pn_history']
        preds = np.concatenate([item[1] for item in outs])
        val_loss_avg = self.val_losses.avg
        #  ======================== scoring ============================
        char_probs = get_char_probs(valid_features, valid_texts, preds, self.trainer.datamodule.tokenizer)
#         char_probs = get_char_probs_n(valid_texts, preds, self.trainer.datamodule.tokenizer)
        results = get_results(char_probs)
        predictions = get_predictions(results)
        score = get_score(val_labels, predictions)
        self.log(f'val/loss_avg', val_loss_avg)
        self.log(f'val/score', score)
        self.print(f'Global step:{self.trainer.global_step}.\n Val loss avg: {val_loss_avg}, score: {score}')

        self.val_losses.reset()
        self.val_acc.reset()

    def predict_step(self, batch, batch_idx, dataloader_idx= None):
        inputs = batch
        y_preds = self.forward(inputs)
        return y_preds.sigmoid().cpu().numpy()

    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        encoder_lr = self.hparams.config.encoder_lr
        decoder_lr = self.hparams.config.decoder_lr
        num_cycles = self.hparams.config.num_cycles
        # end_lr = self.hparams.config.min_lr
        weight_decay = self.hparams.config.weight_decay
        eps = self.hparams.config.eps
        betas = self.hparams.config.betas
        optimizer_parameters = [
            {'params': [p for n, p in self.model.named_parameters()
                        if not any(nd in n for nd in no_decay)],
             'lr':encoder_lr , 'weight_decay': weight_decay,
             },
            {'params': [p for n, p in self.model.named_parameters()
                        if any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': 0.0,
             },
            {'params': [p for n, p in self.named_parameters() if "model" not in n],
             'lr': decoder_lr, 'weight_decay': 0.0,
             }
        ]
        optimizer = AdamW(optimizer_parameters,
                          lr=encoder_lr, eps=eps, betas=betas)

        if self.trainer.max_steps == None or self.trainer.max_epochs != None:
            # 注意，因为使用FGM需要关闭自动优化，传入 trainer 的 accumulate_grad_batches 是None
            # 因此这里计算不能使用 trainer 的参数，要使用 config 里的参数
            # max_steps = (
            #         len(self.trainer.datamodule.train_dataloader()) * self.trainer.max_epochs
            #         // self.trainer.accumulate_grad_batches
            # )
            max_steps = (
                    len(self.trainer.datamodule.train_dataloader()) * self.trainer.max_epochs
                    // self.hparams.config.gradient_accumulation_steps
            )
        else:
            max_steps = self.trainer.max_steps

        warmup_steps = self.hparams.config.warmup_steps
        if isinstance(warmup_steps, float):
            warmup_steps = int(warmup_steps * max_steps)

        print(f'====== Max steps: {max_steps},\t Warm up steps: {warmup_steps} =========')

        if self.hparams.config.scheduler == 'linear':
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps,
            )
        elif self.hparams.config.scheduler == 'cosine':
            scheduler = get_cosine_schedule_with_warmup(
                optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps,
                                num_cycles=num_cycles
            )
        else:
            scheduler = None
        sched = {
            'scheduler': scheduler, 'interval': 'step'
        }
        return ([optimizer], [sched])



## Run

In [None]:
pl.seed_everything(CFG.seed)
dm = NBMEDataModule(CFG, prepare_test=False)
dm.prepare_data()

In [None]:
CFG.trn_fold = [4]
fgm_p = 'fgm_' if CFG.fgm else ''
ls = 'ls_' if CFG.label_smooth else ''
print(f"FGM:{CFG.fgm}, \t label_smooth:{CFG.label_smooth}_{CFG.smoothing}")
print(f"decoder_lr:{CFG.decoder_lr}, \t batch_size:{CFG.batch_size} * {CFG.gradient_accumulation_steps}")
print(f"precision:{CFG.precision}, grad_norm:{CFG.max_grad_norm}, \t apex:{CFG.apex}_{CFG.apex_level}")  
for train_fold in CFG.trn_fold:
    prefix = f'{fgm_p}{ls}fold{train_fold}'
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
            filename=prefix+'step{step}-val_loss{val/loss_avg:.4f}-val_score{val/score:.4f}',
            auto_insert_metric_name=False,
            save_top_k=1, monitor='val/score', mode='max', save_last=False, verbose=True, save_weights_only=True,
        )
    callbacks = [checkpoint_callback]
    dm.set_trn_fold(train_fold)

    model = NBMEModel(CFG, model_config_path=None, pretrained=True)
    trainer = pl.Trainer(
        gpus=[0],
        default_root_dir=f"seed{CFG.seed}_fold{train_fold}_from_{CFG.model.split('/')[-1]}",
        log_every_n_steps=10,
        amp_backend="apex" if CFG.apex else "native",
        amp_level=CFG.apex_level if CFG.apex else None,
        precision=16 if CFG.apex else CFG.precision,
        max_epochs=CFG.max_epochs,
        callbacks=callbacks,
        gradient_clip_val=None if CFG.fgm else CFG.max_grad_norm,
        accumulate_grad_batches=None if CFG.fgm else CFG.gradient_accumulation_steps,
        fast_dev_run=CFG.fast_dev_run,
        num_sanity_val_steps=CFG.num_sanity_val_steps,
        val_check_interval=CFG.val_check_interval,
    )
    trainer.fit(model, datamodule=dm)
    


In [None]:
torch.cuda.empty_cache()