This notebook has two parts:
1. construct model
2. construct dataset

for details please refer to 
https://www.kaggle.com/competitions/nbme-score-clinical-patient-notes/discussion/315707


implementation could be buggy. please report any problems. thanks!

In [None]:
# set up fast tokenizer from 
# https://www.kaggle.com/code/thanhns/deberta-v3-large-0-883-lb
# https://www.kaggle.com/datasets/thanhns/deberta-tokenizer

# The following is necessary if you want to use the fast tokenizer for deberta v2 or v3
# This must be done before importing transformers
import shutil
from pathlib import Path

transformers_path = Path("/opt/conda/lib/python3.7/site-packages/transformers")
input_dir = Path("../input/deberta-v2-3-fast-tokenizer")
convert_file = input_dir / "convert_slow_tokenizer.py"
conversion_path = transformers_path/convert_file.name

if conversion_path.exists():
    conversion_path.unlink()

shutil.copy(convert_file, transformers_path)
deberta_v2_path = transformers_path / "models" / "deberta_v2"
for filename in ['tokenization_deberta_v2.py', 'tokenization_deberta_v2_fast.py']:
    filepath = deberta_v2_path/filename
    if filepath.exists():
        filepath.unlink()

    shutil.copy(input_dir/filename, filepath)

In [None]:
import numpy as np
import pandas as pd
import ast
import itertools



import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler

import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
    
from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForTokenClassification
from transformers.models.deberta_v2.tokenization_deberta_v2_fast import DebertaV2TokenizerFast as DebertaV3TokenizerFast
from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2Model as DebertaV3Model


In [None]:
#tokenizer

arch = 'microsoft/deberta-v3-small'
len_tokenizer =128001  

def get_tokenizer():
    #tokenizer = AutoTokenizer.from_pretrained(arch, use_fast=True)#, add_prefix_space=True)
    tokenizer = DebertaV3TokenizerFast.from_pretrained(arch)
    print('len(tokenizer)', len(tokenizer))  
    assert(len_tokenizer==len(tokenizer))
    return tokenizer


In [None]:
#model

class RNNCharHead(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.lstm = nn.LSTM(in_dim, hidden_size=512, num_layers=2, batch_first=True, bidirectional=True)
        self.out  = nn.Linear(512*2,out_dim)
        
    def forward(self, x, x_length):
        batch_size,max_length,dim=x.shape
        x_pack = torch.nn.utils.rnn.pack_padded_sequence(x, x_length, batch_first=True, enforce_sorted = False)
        x_pack, hidden = self.lstm(x_pack)
        x, _ = torch.nn.utils.rnn.pad_packed_sequence(x_pack, batch_first=True)
    
        x = self.out(x)
        x = F.pad(x,(0,0,0,max_length-max(x_length)),mode='constant',value=0)
        return x
    
    
def token_to_char(x, token_to_char_index):
    batch_size, L, dim = x.shape
    x = x.reshape(-1,dim)
    
    i = token_to_char_index + (torch.arange(batch_size)*L).reshape(-1,1).to(x.device)
    i = i.reshape(-1)
 
    c = x[i]
    c[i==0] = 0
    c = c.reshape(batch_size,-1,dim)
    return c
    
    
class Net(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.output_type = ['probability', 'loss']
        
        config = AutoConfig.from_pretrained(arch)
        config.update(
            {
                'output_hidden_states': True,
                'hidden_dropout_prob': 0.1,
                'layer_norm_eps':  1e-7,
                'add_pooling_layer': False,
                'num_labels': 1,
            }
        )
        self.transformer = AutoModel.from_pretrained(arch, config=config)
        self.transformer.resize_token_embeddings(len_tokenizer) #len(tokenizer))
 
        #token
        self.token_label = nn.Linear(config.hidden_size, 1)
        self.rnn = RNNCharHead(config.hidden_size, 1)
     
        
 

    def forward(self, batch ):
    
        tx = self.transformer(
            input_ids      = batch['token_id'],
            attention_mask = batch['token_in_mask'],
            token_type_ids = batch['token_type_id'],
        )
        last = tx.last_hidden_state
        token_label = self.token_label(last)
        
        last_char  = token_to_char(last, batch['token_to_char_index'])
        char_label = self.rnn(last_char, batch['char_length'])
        
        
        output = {}
        if 'loss' in self.output_type:
            
            if self.training==True:
                # Multi-Sample Dropout https://github.com/abhishekkrthakur/long-text-token-classification/issues/3
                token_loss = 0
                for dropout_rate in [0.1, 0.2, 0.3, 0.4, 0.5]:
                    t = self.token_label( F.dropout(last, dropout_rate, training=self.training) )
                    token_loss += compute_token_label_loss(t, batch['token_label'], batch['token_out_mask'])/5
                output['token_loss'] = token_loss
          
            if self.training==False:
                output['token_loss']  = compute_token_label_loss(token_label, batch['token_label'], batch['token_out_mask'])
            #---
            output['char_loss']  = compute_char_label_loss(char_label, batch['char_label'], batch['char_mask'])
         
        if 'probability' in self.output_type:
            output['token_label'] = torch.sigmoid(token_label).squeeze(-1)
            output['char_label' ] = torch.sigmoid(char_label).squeeze(-1)
            
        return output

#loss function
def compute_token_label_loss(logit, target, mask):
    batch_size, max_len, _ = logit.shape

    keep   = (mask.reshape(-1) == 1) & (target.reshape(-1) >= 0)
    logit  = logit.reshape(-1)
    target = target.reshape(-1)
    logit  = logit[keep]
    target = target[keep]
    
    loss = F.binary_cross_entropy_with_logits(logit, target)
    return loss


def compute_char_label_loss(logit, target, mask):
    
    keep   = (mask.reshape(-1) == 1)
    logit  = logit.reshape(-1)
    target = target.reshape(-1)
    logit  = logit[keep]
    target = target[keep]
    
    loss = F.binary_cross_entropy_with_logits(logit, target)
    return loss



In [None]:
#dataset

def d_to_token(d, tokenizer, max_length, max_char_length):
   
    e = tokenizer.encode_plus(
        d.pn_history,
        d.feature_text,
        add_special_tokens=True,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_offsets_mapping=True,
        return_token_type_ids=True,
    )
    
    token_in_mask  = e['attention_mask']
    token_type_id  = e['token_type_ids']
    token_id       = e['input_ids']
    token_offset   = e['offset_mapping']
    token_out_mask = (np.array(token_in_mask)==1)&(np.array(token_type_id)==0).astype(np.int8).tolist()
    token_label    = np.zeros(max_length, np.int8)


    token_to_char_index = np.zeros(max_char_length, np.int32)
    char_label  = np.zeros(max_char_length, np.int8)
    char_length = len(d.pn_history)
    char_mask   = np.zeros(max_char_length, np.int8)
    char_mask[:char_length]=1
    
    for (start, end) in d.span:
        char_label[start:end]=1
        
    for i,(start,end) in enumerate(token_offset):
        if start!=end:
            token_label[i] =  max(char_label[start:end]) 
            if token_type_id[i]==0 : token_to_char_index[start:end]=i
         
    ignore = np.where(np.array(e.sequence_ids())!= 0)[0]
    token_label[ignore] = -1
    
    #--- 
    if 0: # debug : print encoding
        print('pn_history:\n',d.pn_history)
        print('feature_text:\n',d.feature_text)
        print('')
        print('sum(token_out_mask):',sum(token_out_mask))
        token = tokenizer.convert_ids_to_tokens(token_id)
        for i,(start,end) in enumerate(token_offset):
            print(
                '%4d'%i,
                '%12s'%str(token_offset[i]),
                '%d'%token_in_mask[i],
                '%d'%token_out_mask[i],
                '%d'%token_type_id[i],
                '%6d'%token_id[i],
                '%2d'%token_label[i],
                '%20s'%token[i],
                '%20s'%repr(d.pn_history[start:end]) if token_type_id[i]==0 else '%20s'%repr(d.feature_text[start:end]),
                '%+40s'%str(char_label[start:end]) if token_type_id[i]==0 else '',
                '%+40s'%str(token_to_char_index[start:end]) if token_type_id[i]==0 else '',
            )
        input('wait for key press')
    
    r = {}
    r['pn_history'    ] = d.pn_history
    r['feature_text'  ] = d.feature_text
    r['token_offset'  ] = token_offset
    r['token_out_mask'] = torch.tensor(token_out_mask, dtype=torch.long)
    r['token_in_mask' ] = torch.tensor(token_in_mask , dtype=torch.long)
    r['token_type_id' ] = torch.tensor(token_type_id , dtype=torch.long)
    r['token_id'      ] = torch.tensor(token_id      , dtype=torch.long)
    r['token_label'   ] = torch.tensor(token_label   , dtype=torch.float)

    r['char_length'   ] = char_length
    r['char_label'    ] = torch.tensor(char_label   , dtype=torch.float)
    r['char_mask'     ] = torch.tensor(char_mask    , dtype=torch.long)
    r['token_to_char_index' ] = torch.tensor(token_to_char_index   , dtype=torch.long)
    return r


tensor_list = [
    'token_in_mask', 'token_out_mask', 'token_type_id', 'token_id',
    'token_label',
    'char_label', 'char_mask', 'token_to_char_index',
]


class NBMEDataset(Dataset):
    def __init__(self, df, tokenizer, max_length, max_char_length):
        
        self.df = df
        self.max_length = max_length
        self.max_char_length = max_char_length
        self.tokenizer = tokenizer
        self.length = len(self.df)
    
    def __str__(self):
        string = ''
        string += '\tlen = %d\n' % len(self)
        return string
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, index):
        d = self.df.iloc[index]
        r = d_to_token(d, self.tokenizer, self.max_length, self.max_char_length)
        
        r['index']= index
        r['id'] = d.id
        return r


def null_collate_fn(batch):
    d = {}
    key = batch[0].keys()
    for k in key:
        v = [b[k] for b in batch]
        if k in tensor_list:
            v = torch.stack(v)
        d[k] = v
    return d

In [None]:
def location_to_array(location, pn_history, format='truth'):
    
    if format=='truth':
        location = location.replace('"', "'").replace(';', "','")
        location = ast.literal_eval(location)
    if format=='predict':
        if location is not '':
            location = location.replace(';', ',')
            location = location.split(',')
            #print(location)
        
    array = np.zeros(len(pn_history))
    for loc in location:
        start, end = loc.split()
        start, end = int(start), int(end)
        array[start:end]=1
    return array

def array_to_span(array, format='string'):
    span = [list(g) for _, g in itertools.groupby(np.where(array==1)[0], key=lambda n, c=itertools.count(): n - next(c))]
    
    if format=='string':
        span = ['%d %d'%(min(r), max(r)+1) for r in span]
    if format=='list':
        span = [[min(r), max(r)+1] for r in span]
    
    #location = ';'.join(span)
    return span

def location_to_span(location, pn_history):
    array = location_to_array(location, pn_history, format='truth')
    span = array_to_span(array, format='list')
    return span

In [None]:
#example

train_df   = pd.read_csv('../input/nbme-score-clinical-patient-notes/train.csv')
feature_df = pd.read_csv('../input/nbme-score-clinical-patient-notes/features.csv')
patient_note_df = pd.read_csv('../input/nbme-score-clinical-patient-notes/patient_notes.csv')

all_df = train_df.merge(feature_df, on=['feature_num', 'case_num'], how='left')
all_df = all_df.merge(patient_note_df, on=['pn_num', 'case_num'], how='left')
all_df['span'] = all_df.apply(lambda d: location_to_span(d.location, d.pn_history), axis=1)
 
tokenizer = get_tokenizer()

dataset = NBMEDataset(all_df, tokenizer, max_length=256, max_char_length=1024)
loader = DataLoader(
            dataset,
            sampler = SequentialSampler(dataset),
            batch_size  = 8,
            drop_last   = True,
            num_workers = 2,
            pin_memory  = False,
            worker_init_fn = lambda id: np.random.seed(torch.initial_seed() // 2 ** 32 + id),
            collate_fn = null_collate_fn,
        )
print(dataset)
batch = next(iter(loader)) 

#---

net = Net()
#print(net)
net.train()

with torch.no_grad(): 
    output = net(batch)

print('batch')
for k,v in batch.items():
    if k in tensor_list: 
        print('%32s :'%k, v.shape)

print('output')
for k,v in output.items():
    if 'loss' not in k:
        print('%32s :'%k, v.shape)
for k,v in output.items():
    if 'loss' in k:
        print('%32s :'%k, v.item())

