# Nested Ner Model
### **Description:**


*   Solving nested entity problems for type entities.
*   Family and Living situation are major entities which has nested structure.


*   Generated 2 levels of labels to handle nested entities. Identifies independent and nested token positions and labels them in following way:
    1.  Level1: Labels comprise of independent and outer most entities of nested structure.
    2.  Level2: Labels comprise of inner layer of nested structure and if a nested depth is greater than 2 then the priority order will decide.
*   Created BERT-CRF architecture with 2 classification heads (linear+CRF)
### **Experimentation:**
*  **Approach 1**: Two model architectures are developed, with Model 1 designed to fine-tune both Level 1 and Level 2 using sentence-wide labels as described previously. A potential concern arises regarding Level 2's ability to accurately learn nested entities, given that the input information encompasses not just nested entities but also independent and outermost entities, possibly introducing complexity or ambiguity in distinguishing nested relationships.

*   **Approach 2 (Nested NER Span model):** operates by taking the spans of the outermost entities as inputs for Level 2. The underlying ideology posits that by focusing on the spans of the outermost entities at level 1, the model can more effectively discern and learn the innermost entities at level 2, thereby reducing potential confusion. Furthermore, this model is designed to aggregate spans pertaining to nested entities and accurately map them back to their respective positions.

### **Model**

1.  Tokenizer: BertTokenizerFast
2.  pre-trained Bert model: 'emilyalsentzer/Bio_ClinicalBERT'
3.  Hyperparameters:
    *   eps=1e-8
    * learning_rate=7e-5
    * weight_decay=0
    * num_train_epochs=15
    * patience=3
    * batch_size=8
    * max_len_token=512


4.  Fine-tuned a pre-trained Bert model with CRF and multiple classification heads where each (Independent, outermost) and nested entities.
5.  Since we are training multiple tasks simultaneously, we sum the losses and send them for backpropagation.
6.  Sequence evaluation is used as NER metrics.
7.  Model stored at project_directory+'/models/final_models/Nested_Entity_CRFModel_v1.pth'.
8.

### **Inputs:**
train_data_set path: 'PHD_assessment_gmu/data/train_dataset.pth'
test_data_set path: 'PHD_assessment_gmu/data/test_dataset.pth'

**Run:**
change the project directory and press run all.
###**Metrics:**

{'Alcohol': {'precision': 0.7368421052631579, 'recall': 0.7636363636363637, 'f1': 0.7499999999999999, 'number': 55}, 'Drug': {'precision': 0.8, 'recall': 0.9333333333333333, 'f1': 0.8615384615384616, 'number': 30}, 'Family': {'precision': 0.4827586206896552, 'recall': 0.7, 'f1': 0.5714285714285714, 'number': 20}, 'LivingSituation': {'precision': 0.48, 'recall': 0.631578947368421, 'f1': 0.5454545454545454, 'number': 19}, 'MaritalStatus': {'precision': 0.4722222222222222, 'recall': 0.5666666666666667, 'f1': 0.5151515151515152, 'number': 30}, 'Occupation': {'precision': 0.29411764705882354, 'recall': 0.3448275862068966, 'f1': 0.31746031746031744, 'number': 29}, 'Tobacco': {'precision': 0.8070175438596491, 'recall': 0.8679245283018868, 'f1': 0.8363636363636363, 'number': 53}, 'overall_precision': 0.6190476190476191, 'overall_recall': 0.7161016949152542, 'overall_f1': 0.6640471512770137, 'overall_accuracy': 0.9268861054471308}


{'Family': {'precision': 0.7142857142857143, 'recall': 0.8823529411764706, 'f1': 0.7894736842105262, 'number': 17}, 'LivingSituation': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'overall_precision': 0.6818181818181818, 'overall_recall': 0.8823529411764706, 'overall_f1': 0.7692307692307693, 'overall_accuracy': 0.9976696766676376}








### Library installation

In [None]:
!pip install evaluate

In [None]:
!pip install seqeval

In [None]:
!pip install pytorch-crf



### Importing

In [None]:
import json
import pandas as pd
import evaluate
from transformers import BertTokenizerFast, BertModel, AdamW, get_linear_schedule_with_warmup, DataCollatorForTokenClassification
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm, trange
from torchcrf import CRF

In [None]:
metric=evaluate.load("seqeval")
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_freeze_layers=6

In [None]:
project_directory='/content/drive/MyDrive/PHD_assessment_gmu/'

### Hyperparameters and data paths

In [None]:
eps=1e-8
learning_rate=7e-5
weight_decay=0.01
num_train_epochs=8
patience=5
batch_size=8
max_len=512
min_label_size=10

discarded_enities=['EnvironmentalExposure','SexualHistory','InfectiousDiseases','PhysicalActivity','Residence']
discarded_roles=['LivingStatus','Other','MedicalCondition','Extent','History']


raw_dataset_path=project_directory+'data/'+'SocialHistoryMTSamples.json'
train_dataset_path=project_directory+'data/'+'train_dataset.pth'
test_dataset_path=project_directory+'data/'+'test_dataset.pth'
bert_model_name='emilyalsentzer/Bio_ClinicalBERT'
priority_order = ['LivingSituation', 'Family', 'MaritalStatus']
tokenizer = BertTokenizerFast.from_pretrained(bert_model_name)
save_model_path=project_directory+'/models/'

###  Data Loading

In [None]:
label_id_l1={'B-Alcohol':1,
 'B-Drug':3,
 'B-Family':5,
 'B-LivingSituation':7,
 'B-MaritalStatus':9,
 'B-Occupation':11,

 'B-Tobacco':13,
 'I-Alcohol':2,
 'I-Drug':4,
 'I-Family':6,
 'I-LivingSituation':8,
 'I-MaritalStatus':10,
 'I-Occupation':12,

 'I-Tobacco':14,
 'O':0}
id_label_l1 = {v: k for k, v in label_id_l1.items()}
label_id_l2={'B-Family':1,
 'B-LivingSituation':3,

 'I-Family':2,
 'I-LivingSituation':4,

 'O':0}
id_label_l2 = {v: k for k, v in label_id_l2.items()}

In [None]:
l2_entities=['Family','LivingSituation']

In [None]:
with open(project_directory+'data/'+'trainset.json', 'r', encoding='utf-8') as file:
  train_data=json.load(file)
with open(project_directory+'data/'+'testset.json', 'r', encoding='utf-8') as file:
  test_data=json.load(file)

In [None]:

with open('/content/drive/MyDrive/PHD_assessment_gmu/data/nested_ent_filenames.json', 'r', encoding='utf-8') as file:
  nested_ent_filenames=json.load(file)
nested_lst_tr=[]
nested_lst_tst=[]
for ele in train_data:
  if ele['file_name'] in nested_ent_filenames['train_files']:
    nested_lst_tr.append(ele)
for ele in test_data:
  if ele['file_name'] in nested_ent_filenames['test_files']:
    nested_lst_tst.append(ele)

### Generate nested entity Labels

In [None]:
def generate_nested_bio_labels_with_prioritization_and_positions(text, entity_list, priority_order,tokenizer):
    '''
   The function organizes entities into a hierarchical two-layer structure, discerning nested entities where the outer entities are
   labeled in the first layer and the inner entities in the second.
   Single-level entities are exclusively labeled in the first layer.
   The function is designed to handle up to two layers of nesting; for more complex nesting beyond two layers, a priority order is employed
   to determine the labeling.
    '''
    # Tokenize the text
    tokenized_output = tokenizer(text,max_length=max_len,truncation=True,return_offsets_mapping=True)
    tokens = tokenized_output.tokens()
    layer1_labels = ['O'] * len(tokens)  # Layer 1
    layer2_labels = ['O'] * len(tokens)  # Layer 2
    tokenized_output['tokens']=tokens

    def update_labels(labels, token_idx, label, is_begin, priority):
        '''
        Update labels when the new label has a higher priority.
        '''
        current_label = labels[token_idx].split('-')[-1] if labels[token_idx] != 'O' else None
        current_priority = priority_order.index(current_label) if current_label in priority_order else len(priority_order)
        if current_label is None or priority < current_priority:
            prefix = 'B-' if is_begin else 'I-'  # 'B-' for Begin, 'I-' for Inside
            labels[token_idx] = prefix + label

    def is_nested(entity, other_entities):
        '''
        Determines whether an entity is nested
        '''
        entity_start = int(entity['entity_strt_pos'])
        entity_end = int(entity['entity_end_pos'])
        for other in other_entities:
            if other == entity:
                continue
            other_start = int(other['entity_strt_pos'])
            other_end = int(other['entity_end_pos'])
            if other_start <= entity_start and other_end >= entity_end:
                return True
        return False

    def has_nested_entities(entity, all_entities):
      '''
      Identifies whether an entity has nested entities or not.
      '''
      return any(is_nested(other, [entity]) for other in all_entities if other != entity)
    # To store start and end positions of encompassing entities
    encompassing_positions = []
    # Assign labels to the layers and track positions
    for entity in entity_list:
        entity_category = entity['entity_category']
        entity_start_char = int(entity['entity_strt_pos'])
        entity_end_char = int(entity['entity_end_pos'])
        start_token_idx = tokenized_output.char_to_token(entity_start_char)
        end_token_idx = tokenized_output.char_to_token(entity_end_char - 1)
        entity_priority = priority_order.index(entity_category) if entity_category in priority_order else len(priority_order)

        if start_token_idx is not None and end_token_idx is not None:
            for token_idx in range(start_token_idx, end_token_idx + 1):
                is_begin = token_idx == start_token_idx
                if not is_nested(entity, entity_list):
                    # Non-nested or encompassing entities in Layer 1
                    update_labels(layer1_labels, token_idx, entity_category, is_begin, entity_priority)
                    if is_begin and has_nested_entities(entity, entity_list):
                        encompassing_positions.append((start_token_idx, end_token_idx+1))
                else:
                    # Nested entities in Layer 2, with priority
                    update_labels(layer2_labels, token_idx, entity_category, is_begin, entity_priority)

    layer1_labels_updated=[label_id_l1['O'] if ele.split('-',1)[-1] in discarded_enities else label_id_l1[ele] for ele in layer1_labels]
    layer2_labels_updated=[label_id_l2[ele] if ele.split('-',1)[-1] in l2_entities else label_id_l2['O'] for ele in layer2_labels]

    return layer1_labels_updated, layer2_labels_updated, encompassing_positions, tokenized_output




### Dataset and Padding

In [None]:
def pad_span_positions(span_positions):
  '''
  Padding the span positions with (-1,-1) for the batch
  '''
  max_length = max(len(spans) for spans in span_positions)

# Pad the span positions
  padded_spans = []
  span_masks = []
  for spans in span_positions:
    # Pad with (-1, -1)
    if spans==[]:
      padded = [(-1, -1)] * max_length
    else:
      padded = spans + [(-1, -1)] * (max_length - len(spans))
    padded_spans.append(padded)
  return torch.tensor(padded_spans)
def collate_fn_entity_role(batch):
  '''
  Collate function for the entity role dataset
  '''
  input_ids = [torch.tensor(x['input_ids']) for x in batch]
  attention_mask = [torch.tensor(x['attention_mask']) for x in batch]
  l_1_labels = [torch.tensor(x['l_1_labels']) for x in batch]
  l_2_labels = [torch.tensor(x['l_2_labels']) for x in batch]
  tokens=[x['tokens'] for x in batch]
  outermost_entities_pos=[x['outermost_entities_pos'] for x in batch]

  input_ids = pad_sequence(input_ids, batch_first=True,padding_value=tokenizer.pad_token_id)
  attention_mask = pad_sequence(attention_mask, batch_first=True,padding_value=0)
  l_1_labels = pad_sequence(l_1_labels, batch_first=True,padding_value=-100)
  l_2_labels = pad_sequence(l_2_labels, batch_first=True,padding_value=-100)
  outermost_entities_pos=pad_span_positions(outermost_entities_pos)
  return {
    'input_ids':input_ids,
    'attention_mask':attention_mask,
    'l_1_labels':l_1_labels,
    'l_2_labels':l_2_labels,
    'tokens':tokens,
    'outermost_entities_pos':outermost_entities_pos
  }


In [None]:
class NestedEntityDataset(Dataset):
  def __init__(self, data, tokenizer, max_len):
    self.data = data
    self.tokenizer = tokenizer
    self.max_len = max_len
  def __len__(self):
    return len(self.data)
  def __getitem__(self, idx):
    item=self.data[idx]
    text = item['text']
    level_1_lab,level_2_lab,outermost_entities_pos,inputs=generate_nested_bio_labels_with_prioritization_and_positions(text, item['entity_list'], priority_order,self.tokenizer)
    #inputs = self.tokenizer(text,max_length=max_len,truncation=True,return_offsets_mapping=True)
    tokens_len=len(inputs['input_ids'])
    offset_mapping_list=inputs['offset_mapping']
    #entity_labels=GenerateLabel.generate_enity_labels(item['entity_list'],tokens_len,offset_mapping_list)
    #role_labels, status_labels, method_labels=GenerateLabel.generate_role_labels(item['role_list'],tokens_len,offset_mapping_list)
    #relation_labels=GenerateLabel.generate_relation_labels(offset_mapping_list,item)
    return {
      'input_ids':inputs['input_ids'],
      'attention_mask':inputs['attention_mask'],
      'l_1_labels':level_1_lab,
      'l_2_labels':level_2_lab,
      'outermost_entities_pos':outermost_entities_pos,
      'text':text,
      'file_name':item['file_name'],
      'offset_mapping':offset_mapping_list,
      'tokens':inputs['tokens']
    }

In [None]:
train_dataset=NestedEntityDataset(train_data,tokenizer,max_len)
test_dataset=NestedEntityDataset(test_data,tokenizer,max_len)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,collate_fn=collate_fn_entity_role)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size,collate_fn=collate_fn_entity_role)

### NER model with 2 layers processing whole sentence

In [None]:
class NestedNERModel(nn.Module):
    '''
    Created a BERT-CRF architecture model.
    Have 2 classification heads each one processes the whole sentence and tuned for its layered labels.
    level1: classifies independent entities and nested entities with the outermost class
    level2: nested entities with 2nd priority inner level if it has more than 2 or the 2nd nested entity
    '''
    def __init__(self, bert_model_name, num_labels_level1, num_labels_level2):
        super(NestedNERModel, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)

        self.level1_classifier = nn.Linear(self.bert.config.hidden_size, num_labels_level1)
        self.level2_classifier = nn.Linear(self.bert.config.hidden_size, num_labels_level2)
        self.dropout = nn.Dropout(0.1)
        self.crf_level1 = CRF(num_labels_level1, batch_first=True)
        self.crf_level2 = CRF(num_labels_level2, batch_first=True)

    def forward(self, input_ids, attention_mask, labels_level1=None, labels_level2=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        level1_feats = self.level1_classifier(sequence_output)
        level2_feats = self.level2_classifier(sequence_output)
        mask = attention_mask.type(torch.uint8)
        if labels_level1 is not None and labels_level2 is not None:
          labels_level1 = torch.where(mask == 1, labels_level1, torch.tensor(-1).to(labels_level1.device))
          labels_level2 = torch.where(mask == 1, labels_level2, torch.tensor(-1).to(labels_level2.device))

          loss_level1 = -self.crf_level1(level1_feats, labels_level1, mask=mask,reduction='mean')
          loss_level2 = -self.crf_level2(level2_feats, labels_level2, mask=mask,reduction='mean')
          return loss_level1 + loss_level2
        else:
          prediction_level1 = self.crf_level1.decode(level1_feats, mask=attention_mask.byte())
          prediction_level2 = self.crf_level2.decode(level2_feats, mask=attention_mask.byte())
          return prediction_level1, prediction_level2


### Nested Ner Span model.

In [None]:

def update_prediction_layer2(span_type_id,batch_size,predictions,labels):
  '''
  This function processes a batch of span predictions, mapping them back to the document level in a batch.
  It generates single-layer prediction labels for each token, updated comprehensively by the predictions from all spans relevant to a document.
  '''
  updated_pred_batch={}
  for sp_id,pred in zip(span_type_id,predictions):
    if sp_id[0] not in updated_pred_batch.keys():
      updated_pred_batch[sp_id[0]]=[0]*len(labels[sp_id[0]])
    updated_pred_batch[sp_id[0]][sp_id[1][0]:sp_id[1][1]]=pred
  batch_indexes=set(range(0,batch_size))
  missing_indexes=batch_indexes-set(list(updated_pred_batch.keys()))
  for ind in missing_indexes:
    updated_pred_batch[ind]=[0]*len(labels[ind])
  values_list = [updated_pred_batch[i] for i in range(0, len(updated_pred_batch))]
  return values_list

In [None]:

class NestedNERSpanModel(nn.Module):
    '''
    Model Architecture:
    Created a BERT-CRF architecture model.
    We have 2 classification heads with CRF layers where level1 classifies the whole sentence and level2 classifies the span of the outermost entity.
    level1: classifies independent entities and nested entities with the outermost class
    level2: nested entities with 2nd priority inner level if it has more than 2 or the 2nd nested entity
    '''
    def __init__(self, bert_model_name, num_labels_level1, num_labels_level2):
        super(NestedNERSpanModel, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.level1_classifier = nn.Linear(self.bert.config.hidden_size, num_labels_level1)
        self.level2_classifier = nn.Linear(self.bert.config.hidden_size, num_labels_level2)
        self.dropout = nn.Dropout(0.1)
        self.crf_level1 = CRF(num_labels_level1, batch_first=True)
        self.crf_level2 = CRF(num_labels_level2, batch_first=True)

    def forward(self, input_ids, attention_mask, span_positions, labels_level1=None, labels_level2=None):
        '''
        Perfomance sentence prediction for level1 and span prediction for level2
        Input for level2 is the bert embeddings calculated for the span of the outermost entity.
        Input for level1 is the bert embeddings calculated for the whole sentence.
        Span padding and masking for a batch is done here for level2 inputs.
        Also checks if a particular batch has a span or not. If not, it is skipped.
        '''


        outputs = self.bert(input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        drp_output = self.dropout(sequence_output)

        #Predicting level1 labels for independent and outer-most entities.
        level1_feats = self.level1_classifier(drp_output)

        batch_labels_2=[]
        span_emb_lst=[]
        span_mask_list=[]
        max_span_len=0
        span_type_id=[]

        #Creation of level2 features
        for batch_ind in range(input_ids.size(0)):
          sen_span_pos=0
          for strt, end in span_positions[batch_ind]:
            #Ignores padded span positions.
            if strt == -1 and end == -1:
              continue
            span_len=end-strt

            #creation of span embeddings
            max_span_len=max(max_span_len,span_len)
            span_embedding = sequence_output[batch_ind,strt:end,:]
            span_emb_lst.append(span_embedding)

            #Extracting labels for level2 feats
            if labels_level2 is not None:
              batch_labels_2.append(labels_level2[batch_ind][strt:end])
            #creation of span masks for level2 features
            span_mask = torch.ones(span_len, dtype=torch.uint8, device=input_ids.device)
            span_mask_list.append(span_mask)
            #Collects spans offset position and maps the position of the sentence in a batch
            span_type_id.append((batch_ind,(strt,end)))
            sen_span_pos+=1

        #Span padding and masking for a batch is done here for level2 inputs.
        if span_emb_lst != []:
          padded_span_embs = pad_sequence(span_emb_lst, batch_first=True, padding_value=0.0)
          padded_span_masks = pad_sequence(span_mask_list, batch_first=True, padding_value=0)
          padded_span_msk=padded_span_masks.type(torch.uint8)
          if labels_level2 is not None:
              padded_labels_level2 = pad_sequence(batch_labels_2, batch_first=True, padding_value=-100)
          level2_feats = self.level2_classifier(padded_span_embs)
        mask = attention_mask.type(torch.uint8)

        #Prediction of level1 and level2
        if labels_level1 is not None and labels_level2 is not None:
          labels_level1 = torch.where(mask == 1, labels_level1, torch.tensor(-1).to(labels_level1.device))
          loss_level1 = -self.crf_level1(level1_feats, labels_level1, mask=mask,reduction='mean')
          if span_emb_lst != []:
            padded_labels_level2 = torch.where(padded_span_masks == 1, padded_labels_level2, torch.tensor(-1).to(padded_labels_level2.device))
            loss_level2 = -self.crf_level2(level2_feats, padded_labels_level2, mask=padded_span_msk,reduction='mean')
          else:
            #If a batch encounters no nested spans then loss is 0 for level2.
            loss_level2=0
          return loss_level1 + loss_level2
        else:
          prediction_level1 = self.crf_level1.decode(level1_feats, mask=mask)
          if span_emb_lst != []:
            prediction_level2 = self.crf_level2.decode(level2_feats, mask=padded_span_msk)

            #Updates span predictions to the token positions in a sentence
            prediction_level2=update_prediction_layer2(span_type_id,input_ids.size(0),prediction_level2,prediction_level1)

          else:
            prediction_level2 = [[0 for _ in sublist] for sublist in prediction_level1]

          return prediction_level1, prediction_level2


In [None]:
model=NestedNERSpanModel(bert_model_name, len(id_label_l1), len(id_label_l2))

### Model Training and evaluation

In [None]:
optimizer = AdamW(model.parameters(), lr=learning_rate,eps=eps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)*num_train_epochs)



In [None]:
training_args={
    'output_dir':save_model_path,
    'num_train_epochs':num_train_epochs,
    'optimizer':optimizer,
    'scheduler':scheduler,
    'patience':patience,
    'run_name':'Nested_Entity_CRFModel_v1'
}

In [None]:
def compute_ner_metric(preds,labels,id_label):
    ground_truths=[[id_label[l] for l in lab if l!=-100] for lab in labels]
    prediction_labels=[[id_label[p] for p,l in zip(predict,lab) if l!=-100] for predict,lab in zip(preds,labels)]
    metric_res=metric.compute(predictions=prediction_labels,references=ground_truths)
    return metric_res


In [None]:
def evaluate_entity_model(model,val_loader,id_label_l1,id_label_l2,extract_predictions=False,dt_set=''):
  model.eval()
  with torch.no_grad():
    all_res_l1=[]
    all_res_l2=[]
    l1_predlst=[]
    l2_predlst=[]
    l2_lab=[]
    l1_lab=[]
    val_loss=0

    for step,batch in enumerate(val_loader):
      l_1_labels=batch['l_1_labels'].to(device)
      l_2_labels=batch['l_2_labels'].to(device)
      span_pos=batch['outermost_entities_pos'].to(device)

      inputs={'input_ids':batch['input_ids'].to(device),'attention_mask':batch['attention_mask'].to(device),'span_positions':span_pos}
      #Model Computation
      batch_loss=model(input_ids=batch['input_ids'].to(device),attention_mask=batch['attention_mask'].to(device),span_positions=span_pos,labels_level1=l_1_labels, labels_level2=l_2_labels)
      #Compute loss
      val_loss+=batch_loss
      pred_l1,pred_l2=model(**inputs)
      l1_predlst.extend(pred_l1)
      l2_predlst.extend(pred_l2)
      l1_lab.extend(l_1_labels.tolist())
      l2_lab.extend(l_2_labels.tolist())
      #Extract predictions and map to token and label
      for tok_lst,lab_lst,pred_lst in zip(batch['tokens'],l_1_labels.tolist(),pred_l1):
        all_res_l1.append(list(zip(tok_lst,lab_lst,pred_lst)))
      for tok_lst,lab_lst,pred_lst in zip(batch['tokens'],l_2_labels.tolist(),pred_l2):
        all_res_l2.append(list(zip(tok_lst,lab_lst,pred_lst)))
    #Dump results
    all_res={'level_1':all_res_l1,'level_2':all_res_l2}
    if extract_predictions:
      with open('/content/drive/MyDrive/PHD_assessment_gmu/data/'+'nested_ner_predictions_'+dt_set+'.json','w') as f:
        json.dump(all_res,f)
    #Compute metrics
    l1_metrics=compute_ner_metric(l1_predlst,l1_lab,id_label_l1)
    l2_metrics=compute_ner_metric(l2_predlst,l2_lab,id_label_l2)

  return {'val_loss':val_loss,'l1_metrics':l1_metrics,'l2_metrics':l2_metrics}




In [None]:
def train_entity_bert_model(model,tr_dataloader,tst_dataloader,id_label_l1,id_label_l2,bert_model_name,training_args):
  if torch.cuda.is_available():
    model.cuda()
  train_cycles=trange(training_args['num_train_epochs'],desc='Epoch',disable=0)

  optimizer=training_args['optimizer']
  scheduler=training_args['scheduler']
  patience=training_args['patience']

  tr_ent_loss_lst=[]
  val_ent_loss_list=[]

  early_stopping_count=0
  min_ent_val_loss=0

  for cycle in train_cycles:
    epoch_cycles=tqdm(tr_dataloader,desc='Iteration',disable=-1)
    model.train()
    tr_ent_loss=0

    for step,batch in enumerate(epoch_cycles):
      optimizer.zero_grad()
      l_1_labels=batch['l_1_labels'].to(device)
      l_2_labels=batch['l_2_labels'].to(device)
      span_pos=batch['outermost_entities_pos'].to(device)
      #Model Computation
      batch_loss=model(input_ids=batch['input_ids'].to(device),attention_mask=batch['attention_mask'].to(device),span_positions=span_pos,labels_level1=l_1_labels, labels_level2=l_2_labels)
      #Loss
      batch_loss.backward()
      optimizer.step()
      scheduler.step()
      tr_ent_loss+=batch_loss
    #Evaluation
    tr_ent_loss_lst.append(tr_ent_loss/len(tr_dataloader))
    if early_stopping_count == patience or early_stopping_count == patience-1:
      tr_results=evaluate_entity_model(model,tr_dataloader,id_label_l1,id_label_l2,True,'tr')
      val_results=evaluate_entity_model(model,tst_dataloader,id_label_l1,id_label_l2,True,'tst')
    else:
      val_results=evaluate_entity_model(model,tst_dataloader,id_label_l1,id_label_l2)
      tr_results=evaluate_entity_model(model,tr_dataloader,id_label_l1,id_label_l2)

    val_ent_loss_list.append(val_results['val_loss'])
    print('Epoch: {}  Train Method Loss: {}'.format(cycle,tr_ent_loss_lst[-1]))

    print('Epoch: {}  Val Method Loss: {}'.format(cycle,val_ent_loss_list[-1]))
    print('Method_metrics: \n')
    print(val_results['l1_metrics'])
    print('\n')
    print(val_results['l2_metrics'])
    print('\n')
    #Early stopping
    if cycle==0:
      min_ent_val_loss=val_ent_loss_list[-1]
      early_stopping_count=0
    else:
      if val_ent_loss_list[-1]<min_ent_val_loss:
        min_ent_val_loss=val_ent_loss_list[-1]
        early_stopping_count=0
      else:
        early_stopping_count+=1
      if early_stopping_count>=patience:
        print('Early stopping counter for method model : {}'.format(early_stopping_count))
        break
  #Save model
  torch.save(model.state_dict(),training_args['output_dir']+training_args['run_name']+'.pth')
  return model , {'train_loss':{'tr_method_loss_lst':tr_ent_loss_lst}
                  ,'val_loss':{'val_method_loss_list':val_ent_loss_list}}







In [None]:
trained_model,loss=train_entity_bert_model(model,train_dataloader,test_dataloader,id_label_l1,id_label_l2,bert_model_name,training_args)

  _warn_prf(average, modifier, msg_start, len(result))
Epoch:  12%|█▎        | 1/8 [00:11<01:17, 11.07s/it]

Epoch: 0  Train Method Loss: 44.31918716430664
Epoch: 0  Val Method Loss: 278.312255859375
Method_metrics: 

{'Alcohol': {'precision': 0.45054945054945056, 'recall': 0.7454545454545455, 'f1': 0.5616438356164384, 'number': 55}, 'Drug': {'precision': 1.0, 'recall': 0.03333333333333333, 'f1': 0.06451612903225806, 'number': 30}, 'Family': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 20}, 'LivingSituation': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 19}, 'MaritalStatus': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 30}, 'Occupation': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 29}, 'Tobacco': {'precision': 0.3111111111111111, 'recall': 0.5283018867924528, 'f1': 0.3916083916083916, 'number': 53}, 'overall_precision': 0.38461538461538464, 'overall_recall': 0.2966101694915254, 'overall_f1': 0.33492822966507174, 'overall_accuracy': 0.8732886688027964}


{'Family': {'precision': 0.7142857142857143, 'recall': 0.8823529411764706, 'f1': 0.7894736842105262,

  _warn_prf(average, modifier, msg_start, len(result))
Epoch:  25%|██▌       | 2/8 [00:22<01:06, 11.10s/it]

Epoch: 1  Train Method Loss: 23.1181583404541
Epoch: 1  Val Method Loss: 155.7523193359375
Method_metrics: 

{'Alcohol': {'precision': 0.6176470588235294, 'recall': 0.7636363636363637, 'f1': 0.6829268292682927, 'number': 55}, 'Drug': {'precision': 0.8333333333333334, 'recall': 0.8333333333333334, 'f1': 0.8333333333333334, 'number': 30}, 'Family': {'precision': 0.2857142857142857, 'recall': 0.3, 'f1': 0.2926829268292683, 'number': 20}, 'LivingSituation': {'precision': 0.20930232558139536, 'recall': 0.47368421052631576, 'f1': 0.2903225806451613, 'number': 19}, 'MaritalStatus': {'precision': 0.3488372093023256, 'recall': 0.5, 'f1': 0.4109589041095891, 'number': 30}, 'Occupation': {'precision': 0.2222222222222222, 'recall': 0.27586206896551724, 'f1': 0.24615384615384614, 'number': 29}, 'Tobacco': {'precision': 0.75, 'recall': 0.9056603773584906, 'f1': 0.8205128205128206, 'number': 53}, 'overall_precision': 0.5016393442622951, 'overall_recall': 0.6483050847457628, 'overall_f1': 0.5656192236

Epoch:  38%|███▊      | 3/8 [00:32<00:54, 10.92s/it]

Epoch: 2  Train Method Loss: 13.854171752929688
Epoch: 2  Val Method Loss: 140.9980010986328
Method_metrics: 

{'Alcohol': {'precision': 0.8035714285714286, 'recall': 0.8181818181818182, 'f1': 0.8108108108108109, 'number': 55}, 'Drug': {'precision': 0.7647058823529411, 'recall': 0.8666666666666667, 'f1': 0.8125, 'number': 30}, 'Family': {'precision': 0.39285714285714285, 'recall': 0.55, 'f1': 0.45833333333333337, 'number': 20}, 'LivingSituation': {'precision': 0.27586206896551724, 'recall': 0.42105263157894735, 'f1': 0.3333333333333333, 'number': 19}, 'MaritalStatus': {'precision': 0.5, 'recall': 0.5, 'f1': 0.5, 'number': 30}, 'Occupation': {'precision': 0.35714285714285715, 'recall': 0.3448275862068966, 'f1': 0.3508771929824561, 'number': 29}, 'Tobacco': {'precision': 0.8727272727272727, 'recall': 0.9056603773584906, 'f1': 0.8888888888888888, 'number': 53}, 'overall_precision': 0.6269230769230769, 'overall_recall': 0.690677966101695, 'overall_f1': 0.657258064516129, 'overall_accuracy'

Epoch:  50%|█████     | 4/8 [00:43<00:43, 10.98s/it]

Epoch: 3  Train Method Loss: 9.635627746582031
Epoch: 3  Val Method Loss: 132.5026397705078
Method_metrics: 

{'Alcohol': {'precision': 0.7796610169491526, 'recall': 0.8363636363636363, 'f1': 0.8070175438596492, 'number': 55}, 'Drug': {'precision': 0.875, 'recall': 0.9333333333333333, 'f1': 0.9032258064516129, 'number': 30}, 'Family': {'precision': 0.5, 'recall': 0.7, 'f1': 0.5833333333333334, 'number': 20}, 'LivingSituation': {'precision': 0.3793103448275862, 'recall': 0.5789473684210527, 'f1': 0.45833333333333337, 'number': 19}, 'MaritalStatus': {'precision': 0.375, 'recall': 0.5, 'f1': 0.42857142857142855, 'number': 30}, 'Occupation': {'precision': 0.28125, 'recall': 0.3103448275862069, 'f1': 0.2950819672131148, 'number': 29}, 'Tobacco': {'precision': 0.8545454545454545, 'recall': 0.8867924528301887, 'f1': 0.8703703703703703, 'number': 53}, 'overall_precision': 0.6181818181818182, 'overall_recall': 0.7203389830508474, 'overall_f1': 0.665362035225049, 'overall_accuracy': 0.9265948150

Epoch:  62%|██████▎   | 5/8 [00:54<00:32, 10.98s/it]

Epoch: 4  Train Method Loss: 6.908662796020508
Epoch: 4  Val Method Loss: 142.71470642089844
Method_metrics: 

{'Alcohol': {'precision': 0.711864406779661, 'recall': 0.7636363636363637, 'f1': 0.736842105263158, 'number': 55}, 'Drug': {'precision': 0.8, 'recall': 0.9333333333333333, 'f1': 0.8615384615384616, 'number': 30}, 'Family': {'precision': 0.46153846153846156, 'recall': 0.6, 'f1': 0.5217391304347826, 'number': 20}, 'LivingSituation': {'precision': 0.4074074074074074, 'recall': 0.5789473684210527, 'f1': 0.47826086956521735, 'number': 19}, 'MaritalStatus': {'precision': 0.5625, 'recall': 0.6, 'f1': 0.5806451612903225, 'number': 30}, 'Occupation': {'precision': 0.35135135135135137, 'recall': 0.4482758620689655, 'f1': 0.393939393939394, 'number': 29}, 'Tobacco': {'precision': 0.8070175438596491, 'recall': 0.8679245283018868, 'f1': 0.8363636363636363, 'number': 53}, 'overall_precision': 0.6227106227106227, 'overall_recall': 0.7203389830508474, 'overall_f1': 0.6679764243614932, 'overal

Epoch:  75%|███████▌  | 6/8 [01:05<00:21, 10.92s/it]

Epoch: 5  Train Method Loss: 5.189341068267822
Epoch: 5  Val Method Loss: 141.7391357421875
Method_metrics: 

{'Alcohol': {'precision': 0.8333333333333334, 'recall': 0.8181818181818182, 'f1': 0.8256880733944955, 'number': 55}, 'Drug': {'precision': 0.8484848484848485, 'recall': 0.9333333333333333, 'f1': 0.888888888888889, 'number': 30}, 'Family': {'precision': 0.48, 'recall': 0.6, 'f1': 0.5333333333333332, 'number': 20}, 'LivingSituation': {'precision': 0.4074074074074074, 'recall': 0.5789473684210527, 'f1': 0.47826086956521735, 'number': 19}, 'MaritalStatus': {'precision': 0.5151515151515151, 'recall': 0.5666666666666667, 'f1': 0.5396825396825397, 'number': 30}, 'Occupation': {'precision': 0.42857142857142855, 'recall': 0.41379310344827586, 'f1': 0.42105263157894735, 'number': 29}, 'Tobacco': {'precision': 0.8867924528301887, 'recall': 0.8867924528301887, 'f1': 0.8867924528301887, 'number': 53}, 'overall_precision': 0.6798418972332015, 'overall_recall': 0.7288135593220338, 'overall_f1

Epoch:  88%|████████▊ | 7/8 [01:16<00:10, 10.99s/it]

Epoch: 6  Train Method Loss: 4.184082984924316
Epoch: 6  Val Method Loss: 144.53701782226562
Method_metrics: 

{'Alcohol': {'precision': 0.7543859649122807, 'recall': 0.7818181818181819, 'f1': 0.7678571428571429, 'number': 55}, 'Drug': {'precision': 0.8, 'recall': 0.9333333333333333, 'f1': 0.8615384615384616, 'number': 30}, 'Family': {'precision': 0.4666666666666667, 'recall': 0.7, 'f1': 0.56, 'number': 20}, 'LivingSituation': {'precision': 0.46153846153846156, 'recall': 0.631578947368421, 'f1': 0.5333333333333333, 'number': 19}, 'MaritalStatus': {'precision': 0.4594594594594595, 'recall': 0.5666666666666667, 'f1': 0.5074626865671642, 'number': 30}, 'Occupation': {'precision': 0.30303030303030304, 'recall': 0.3448275862068966, 'f1': 0.32258064516129037, 'number': 29}, 'Tobacco': {'precision': 0.8392857142857143, 'recall': 0.8867924528301887, 'f1': 0.8623853211009174, 'number': 53}, 'overall_precision': 0.6240875912408759, 'overall_recall': 0.7245762711864406, 'overall_f1': 0.6705882352

Epoch: 100%|██████████| 8/8 [01:27<00:00, 10.98s/it]

Epoch: 7  Train Method Loss: 3.6720736026763916
Epoch: 7  Val Method Loss: 145.64718627929688
Method_metrics: 

{'Alcohol': {'precision': 0.7368421052631579, 'recall': 0.7636363636363637, 'f1': 0.7499999999999999, 'number': 55}, 'Drug': {'precision': 0.8, 'recall': 0.9333333333333333, 'f1': 0.8615384615384616, 'number': 30}, 'Family': {'precision': 0.4827586206896552, 'recall': 0.7, 'f1': 0.5714285714285714, 'number': 20}, 'LivingSituation': {'precision': 0.48, 'recall': 0.631578947368421, 'f1': 0.5454545454545454, 'number': 19}, 'MaritalStatus': {'precision': 0.4722222222222222, 'recall': 0.5666666666666667, 'f1': 0.5151515151515152, 'number': 30}, 'Occupation': {'precision': 0.29411764705882354, 'recall': 0.3448275862068966, 'f1': 0.31746031746031744, 'number': 29}, 'Tobacco': {'precision': 0.8070175438596491, 'recall': 0.8679245283018868, 'f1': 0.8363636363636363, 'number': 53}, 'overall_precision': 0.6190476190476191, 'overall_recall': 0.7161016949152542, 'overall_f1': 0.6640471512


