## **Flat NER model for roles and entities**
### **Description:**

*   Performing named entity recognition for entities and roles at token - level.


*   Entities and Role which are not considered due to insufficient data for training:
  1.  discarded_enities=['EnvironmentalExposure','SexualHistory','InfectiousDiseases','PhysicalActivity','Residence','LivingSituation','MaritalStatus','Occupation']

  2.   discarded_roles=['LivingStatus','Other','MedicalCondition','Extent','History']

*   Labelling is done using BIO-scheme.
*   Labels are generated using the mapping between label offsets and Bert token offsets.
*   In entities and roles, observations indicate that certain documents exhibit overlapping spans among them, as well as between roles and entities.
*   To handle overlapping between roles and entities we are classifying it with seperate heads.

*   By analysis, it was observed by seperating status, method from the roles we can avoid overlapping among them, so even here classifying with seperate heads.

*   Overlapping between the Entities where between 'Family','Residence','LivingSituation','MaritalStatus'. Except 'Family' all other entities have been dropped.

### **Experimentations:**
*  By eliminating overlapping entities, does the performance of the model in recognizing family entities improve?
*   Does the model benefits by training entities (discarding overlapping) and roles together?




  ###   **Model**
*   ***Tokenizer:*** BertTokenizerFast
*   ***pre-trained Bert model:*** 'emilyalsentzer/Bio_ClinicalBERT'

*   ***Hyperparameters:***
  * eps=1e-8
  * learning_rate=7e-5
  * weight_decay=0
  * num_train_epochs=15
  * patience=3
  * batch_size=16
  * max_len_token=512
*   Fine-tuned a pre-trained Bert model with freezing 6 layers and multiple classification heads where each classifies status, methods, entities, and  roles respectively.


*   Since we are training multiple tasks simultaneously, we aggregate the losses and send them for backpropagation.

*   Sequence evaluation is used as NER metrics.

*   ****
### **Run the notebook::**
* Provide model_type and ver and project directory in input section and run all cells.
 datasets and their paths:
*   train_data_set path: project_directory+'/data/trainset.json'

*   test_data_set path: project_directory+'/data/testset.json'
* Model stored at project_directory+'/models/final_models/'+model_name+'_'+ver+'.pth'.
* Evaluation dataset predictions are stored at 'tst_'+model_type+'_'+str(ver)+'.json
### **Metrics:**
Entity_metrics:

{'Alcohol': {'precision': 0.8490566037735849, 'recall': 0.8181818181818182, 'f1': 0.8333333333333334, 'number': 55}, 'Drug': {'precision': 0.9032258064516129, 'recall': 0.9333333333333333, 'f1': 0.9180327868852459, 'number': 30}, 'Family': {'precision': 0.5952380952380952, 'recall': 0.6756756756756757, 'f1': 0.6329113924050633, 'number': 37}, 'Tobacco': {'precision': 0.9056603773584906, 'recall': 0.9056603773584906, 'f1': 0.9056603773584906, 'number': 53}, 'overall_precision': 0.8156424581005587, 'overall_recall': 0.8342857142857143, 'overall_f1': 0.824858757062147, 'overall_accuracy': 0.981357413341101}


Role_metrics:

{'Amount': {'precision': 0.7966101694915254, 'recall': 0.8545454545454545, 'f1': 0.8245614035087718, 'number': 55}, 'ExposureHistory': {'precision': 0.6, 'recall': 0.6, 'f1': 0.6, 'number': 10}, 'Frequency': {'precision': 0.8461538461538461, 'recall': 0.7857142857142857, 'f1': 0.8148148148148148, 'number': 28}, 'Location': {'precision': 0.25806451612903225, 'recall': 0.5, 'f1': 0.3404255319148936, 'number': 16}, 'QuitHistory': {'precision': 0.6923076923076923, 'recall': 0.6, 'f1': 0.6428571428571429, 'number': 15}, 'Temporal': {'precision': 0.35714285714285715, 'recall': 0.45454545454545453, 'f1': 0.4, 'number': 11}, 'Type': {'precision': 0.6643835616438356, 'recall': 0.7293233082706767, 'f1': 0.6953405017921147, 'number': 133}, 'overall_precision': 0.6488294314381271, 'overall_recall': 0.7238805970149254, 'overall_f1': 0.6843033509700176, 'overall_accuracy': 0.9312554616953103}



Status_metrics:

{'Status': {'precision': 0.6632653061224489, 'recall': 0.6701030927835051, 'f1': 0.6666666666666666, 'number': 194}, 'overall_precision': 0.6632653061224489, 'overall_recall': 0.6701030927835051, 'overall_f1': 0.6666666666666666, 'overall_accuracy': 0.9568890183512962}





Method_metrics:

{'Method': {'precision': 0.4090909090909091, 'recall': 0.5454545454545454, 'f1': 0.4675324675324675, 'number': 33}, 'overall_precision': 0.4090909090909091, 'overall_recall': 0.5454545454545454, 'overall_f1': 0.4675324675324675, 'overall_accuracy': 0.9784445091756481}




In [1]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.0.0 (from evaluate)
  Downloading datasets-2.17.0-py3-none-any.whl (536 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m536.6/536.6 kB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m
Collecting dill (from evaluate)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0m
Collecting responses<0.19 (from evaluate)
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting pyarrow>=12.0.0 (from datasets>=2.0.0

In [2]:
import json
import pandas as pd
import evaluate
from transformers import BertTokenizerFast, BertModel, AdamW, get_linear_schedule_with_warmup, DataCollatorForTokenClassification
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm, trange

In [3]:
!pip install seqeval

Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: seqeval
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
  Created wheel for seqeval: filename=seqeval-1.2.2-py3-none-any.whl size=16162 sha256=652c72f37ffd83179a24f8a6fb01365adaa8153cdc44a369ef3e8317ead04b12
  Stored in directory: /root/.cache/pip/wheels/1a/67/4a/ad4082dd7dfc30f2abfe4d80a2ed5926a506eb8a972b4767fa
Successfully built seqeval
Installing collected packages: seqeval
Successfully installed seqeval-1.2.2


In [4]:
#Input section
model_type='Flt_ent_role_model'
ver=5
project_directory='/content/drive/MyDrive/PHD_assessment_gmu/'

metric=evaluate.load("seqeval")
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizerFast.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')
num_freeze_layers=4

Downloading builder script:   0%|          | 0.00/6.34k [00:00<?, ?B/s]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

## Parameter Initialization

Hypertuning parameters

In [5]:
eps=1e-8
learning_rate=7e-5
weight_decay=0
num_train_epochs=15
patience=3
save_model_path=project_directory+'/models/final_models/'+
batch_size=16
max_len=512


Dataset, Model paths and discarding classes


In [6]:
discarded_enities=['EnvironmentalExposure','SexualHistory','InfectiousDiseases','PhysicalActivity','Residence','LivingSituation','MaritalStatus','Occupation']
discarded_roles=['LivingStatus','Other','MedicalCondition','Extent','History']


raw_dataset_path=project_directory+'data/'+'SocialHistoryMTSamples.json'
train_dataset_path=project_directory+'data/'+'trainset.json'
test_dataset_path=project_directory+'data/'+'testset.json'
bert_model_name='emilyalsentzer/Bio_ClinicalBERT'

Mapping the classes to ids.

In [7]:
id_label_status={0:'O',1:'B-Status',2:'I-Status'}
id_label_method={0:'O',1:'B-Method',2:'I-Method'}
id_label_role={0:'O',1:'B-Type',2:'I-Type',3:'B-Amount',4:'I-Amount',5:'B-Temporal',6:'I-Temporal',7:'B-Frequency',8:'I-Frequency',9:'B-QuitHistory',10:'I-QuitHistory',11:'B-ExposureHistory',12:'I-ExposureHistory',13:'B-Location',14:'I-Location'}
id_label_event={0:'No Relation',1:'Relation'}
label_id_status = {v: k for k, v in id_label_status.items()}
label_id_method = {v: k for k, v in id_label_method.items()}
label_id_role = {v: k for k, v in id_label_role.items()}
label_id_ent = {'B-Alcohol':1,
 'B-Drug':3,
 'B-Family':5,
 'B-Tobacco':7,
 'I-Alcohol':2,
 'I-Drug':4,
 'I-Family':6,
 'I-Tobacco':8,
 'O':0}
id_label_ent = {v: k for k, v in label_id_ent.items()}
label_id_event = {v: k for k, v in id_label_event.items()}


## Data Loading and Input processing


### Generate Entity and Role labels for a doc

In [8]:

class GenerateLabel:

  @staticmethod
  def generate_enity_labels(entity_list, token_len, token_offsets):
    '''
    Generates entity labels for each token in the sentence. Performs by mapping the entity's offset positions to the tokens offset.

    Inputs:
    entity_list: List of entity dicts, each entity dict has the information such as offset positions, id , category and text.
    token_len: Initializes the labels size
    token_offsets: list of tuples of offset mappings of bert tokens.
    file_name: Name of the file from which the entity list is generated.

    Outputs:
    entity_labels: List of BIO labels for each token.
    '''
    # Initialize a list to store the BIO labels for each token

    entity_labels = [label_id_ent['O']] * token_len
    for entity in entity_list:
        category = entity['entity_category']
        if category not in discarded_enities:
          entity_start_pos = int(entity['entity_strt_pos'])
          entity_end_pos = int(entity['entity_end_pos'])-1

          # Find tokens that correspond to the entity's position
          entity_start_token = None
          entity_end_token = None

          for i, (start_offset, end_offset) in enumerate(token_offsets):
              if entity_start_token is None and start_offset >= entity_start_pos:
                  entity_start_token = i
              if end_offset > entity_end_pos:
                  entity_end_token = i
                  break

          # Assign BIO labels to the tokens
          if entity_start_token is not None:
            entity_labels[entity_start_token] = label_id_ent['B-' + category]
            if entity_end_token is not None:
              entity_labels[entity_start_token + 1:entity_end_token + 1] = [label_id_ent['I-' + category]] * (entity_end_token - entity_start_token)

    return entity_labels

  @staticmethod
  def generate_role_labels(role_list, token_len, token_offsets):
    '''
    Generates role labels for each token in the sentence. Performs by mapping the entity's offset positions to the tokens offset.
    Also handles overlapping entities by seperating status and method from the role and creating seperate labels for it.

    Inputs:
    role_list: List of role dicts, each entity dict has the information such as offset positions, id , category and text.
    token_len: Initializes the labels size
    token_offsets: list of tuples of offset mappings of bert tokens.
    file_name: Name of the file from which the entity list is generated.

    Outputs:
    role_labels: List of BIO labels for each token.
    status_labels: List of BIO labels for each token.
    method_labels: List of BIO labels for each token.
    '''
    # Initialize a list to store the BIO labels for each token
    role_labels = [label_id_role['O']] * token_len
    status_labels = [label_id_status['O']] * token_len
    method_labels = [label_id_method['O']] * token_len

    for role in role_list:
        category = role['entity_category']
        entity_start_pos = int(role['entity_strt_pos'])
        entity_end_pos = int(role['entity_end_pos'])-1
        if category in discarded_roles:
          continue
        # Find tokens that correspond to the entity's position
        entity_start_token = None
        entity_end_token = None

        for i, (start_offset, end_offset) in enumerate(token_offsets):
            if entity_start_token is None and start_offset >= entity_start_pos:
                entity_start_token = i
            if end_offset > entity_end_pos:
                entity_end_token = i
                break

        # Assign BIO labels to the tokens
        if category == 'Status':
          if entity_start_token is not None:
            status_labels[entity_start_token] = label_id_status['B-' + category]
            if entity_end_token is not None:
                status_labels[entity_start_token + 1:entity_end_token + 1] = [label_id_status['I-' + category]] * (entity_end_token - entity_start_token)

        elif category == 'Method':
          if entity_start_token is not None:
            method_labels[entity_start_token] = label_id_method['B-' + category]
            if entity_end_token is not None:
                method_labels[entity_start_token + 1:entity_end_token + 1] = [label_id_method['I-' + category]] * (entity_end_token - entity_start_token)

        else:
          if entity_start_token is not None:
            role_labels[entity_start_token] = label_id_role['B-' + category]
            if entity_end_token is not None:
                role_labels[entity_start_token + 1:entity_end_token + 1] = [label_id_role['I-' + category]] * (entity_end_token - entity_start_token)


    return role_labels, status_labels, method_labels

  @staticmethod
  def generate_relation_labels( token_offsets, data):
    '''
    Generates relation mapping with token indices for each event present in a document.
    Reads events and maps them to their related entities and roles with their token indices.


    Inputs:
    token_offsets: list of tuples of offset mappings of bert tokens.
    data: Data dict which consists of entity_list, role_list and events_list.

    Outputs:
    relation_labels: The relation mapping is a dictionary where key is a tuple of entity token indices and value is a list of tuple of role token indices.

    '''
    entity_token_indices = {}
    role_token_indices = {}

    entity_list = data.get('entity_list', [])
    role_list = data.get('role_list', [])
    events_list = data.get('events_list', [])

    # Create a dictionary to map entity IDs to their token indices
    for entity in entity_list:
        entity_id = entity['entity_id']
        entity_start_pos = int(entity['entity_strt_pos'])
        entity_end_pos = int(entity['entity_end_pos'])-1

        entity_start_token = None
        entity_end_token = None

        for i, (start_offset, end_offset) in enumerate(token_offsets):
            if entity_start_token is None and start_offset >= entity_start_pos:
                entity_start_token = i
            if end_offset > entity_end_pos:
                entity_end_token = i
                break

        if entity_start_token is not None:
            entity_token_indices[entity_id] = (entity_start_token, entity_end_token)

    # Create a dictionary to map event-related role IDs to their token indices
    for event in events_list:
        entity_id = event['entity_id']
        related_roles = event['Related_roles']

        entity_indices = entity_token_indices.get(entity_id, None)

        if entity_indices:
            role_indices = []

            for role_id in related_roles:
                role = next((role for role in role_list if role['role_id'] == role_id), None)

                if role:
                    role_start_pos = int(role['entity_strt_pos'])
                    role_end_pos = int(role['entity_end_pos'])-1

                    role_start_token = None
                    role_end_token = None

                    for i, (start_offset, end_offset) in enumerate(token_offsets):
                        if role_start_token is None and start_offset >= role_start_pos:
                            role_start_token = i
                        if end_offset > role_end_pos:
                            role_end_token = i
                            break

                    if role_start_token is not None:
                        role_indices.append((role_start_token, role_end_token))

            if role_indices:
                role_token_indices[(entity_indices[0],entity_indices[1])] = role_indices

    return role_token_indices

In [9]:
def collate_fn_entity_role(batch):
  '''
  Padding inputs to the batch size.
  '''
  input_ids = [torch.tensor(x['input_ids']) for x in batch]
  attention_mask = [torch.tensor(x['attention_mask']) for x in batch]
  #token_type_ids = [torch.tensor(x['token_type_ids']) for x in batch]
  tokens=[x['tokens'] for x in batch]
  file_name=[x['file_name'] for x in batch]

  entity_labels = [torch.tensor(x['entity_labels']) for x in batch]
  role_labels = [torch.tensor(x['role_labels']) for x in batch]
  status_labels = [torch.tensor(x['status_labels']) for x in batch]
  method_labels = [torch.tensor(x['method_labels']) for x in batch]
  input_ids = pad_sequence(input_ids, batch_first=True,padding_value=tokenizer.pad_token_id)
  attention_mask = pad_sequence(attention_mask, batch_first=True,padding_value=0)
  #token_type_ids = pad_sequence(token_type_ids, batch_first=True,padding_value=0)
  entity_labels = pad_sequence(entity_labels, batch_first=True,padding_value=-100)
  role_labels = pad_sequence(role_labels, batch_first=True,padding_value=-100)
  status_labels = pad_sequence(status_labels, batch_first=True,padding_value=-100)
  method_labels = pad_sequence(method_labels, batch_first=True,padding_value=-100)
  return {
    'input_ids':input_ids,
    'attention_mask':attention_mask,
    'entity_labels':entity_labels,
    'role_labels':role_labels,
    'status_labels':status_labels,
    'method_labels':method_labels,
    'tokens':tokens,
    'file_name':file_name,

  }


### Dataset and Dataloader

In [10]:
class ERDataset(Dataset):
  def __init__(self, data, tokenizer, max_len):
    self.data = data
    self.tokenizer = tokenizer
    self.max_len = max_len
  def __len__(self):
    return len(self.data)
  def __getitem__(self, idx):
    item=self.data[idx]
    text = item['text']
    inputs = self.tokenizer(text,max_length=max_len,truncation=True,return_offsets_mapping=True)
    tokens_len=len(inputs['input_ids'])
    offset_mapping_list=inputs['offset_mapping']
    entity_labels=GenerateLabel.generate_enity_labels(item['entity_list'],tokens_len,offset_mapping_list)
    role_labels, status_labels, method_labels=GenerateLabel.generate_role_labels(item['role_list'],tokens_len,offset_mapping_list)
    relation_labels=GenerateLabel.generate_relation_labels(offset_mapping_list,item)
    return {
      'input_ids':inputs['input_ids'],
      'attention_mask':inputs['attention_mask'],
      'entity_labels':entity_labels,
      'role_labels':role_labels,
      'status_labels':status_labels,
      'method_labels':method_labels,
      'relation_labels':relation_labels,
      'text':text,
      'file_name':item['file_name'],
      'offset_mapping':offset_mapping_list,
      'tokens':inputs.tokens()
    }

In [11]:
with open(train_dataset_path,'r') as f:
  tr_dataset=json.load(f)
with open(test_dataset_path,'r') as f:
  tst_dataset=json.load(f)

In [12]:
train_dataset=ERDataset(tr_dataset,tokenizer,max_len)
test_dataset=ERDataset(tst_dataset,tokenizer,max_len)

In [13]:
#train_dataset=torch.load(train_dataset_path)
#test_dataset=torch.load(test_dataset_path)


In [14]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn_entity_role)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size,collate_fn=collate_fn_entity_role)

## Ner model

In [15]:
class EntityBertModel(nn.Module):
  def __init__(self, model_name, num_freeze_layers,num_status_labels,num_method_labels,num_role_labels,num_entity_labels, dropout=0.1):
    super(EntityBertModel, self).__init__()
    self.bertmodel = BertModel.from_pretrained(model_name)
    #Performing freezing bert layers
    for layer in self.bertmodel.encoder.layer[:num_freeze_layers]:
      for param in layer.parameters():
          param.requires_grad = False
    self.dropout = nn.Dropout(dropout)
    self.status_classifier = nn.Linear(self.bertmodel.config.hidden_size, num_status_labels)
    self.method_classifier = nn.Linear(self.bertmodel.config.hidden_size, num_method_labels)
    self.role_classifier = nn.Linear(self.bertmodel.config.hidden_size, num_role_labels)
    self.entity_classifier = nn.Linear(self.bertmodel.config.hidden_size, num_entity_labels)
  def forward(self, input_ids, attention_mask):
    bert_output = self.bertmodel(input_ids=input_ids, attention_mask=attention_mask)
    sequence_output = self.dropout(bert_output[0])
    status_logits = self.status_classifier(sequence_output)
    method_logits = self.method_classifier(sequence_output)
    role_logits = self.role_classifier(sequence_output)
    entity_logits = self.entity_classifier(sequence_output)

    return status_logits, method_logits, role_logits, entity_logits

In [16]:
model= EntityBertModel(model_name=bert_model_name,num_freeze_layers=num_freeze_layers,num_status_labels=len(id_label_status),num_method_labels=len(id_label_method),num_role_labels=len(id_label_role),num_entity_labels=len(id_label_ent))

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

In [17]:
no_decay=['bias','LayerNorm.weight']
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": weight_decay,
    }
]


In [18]:
optimizer = AdamW(model.parameters(), lr=learning_rate,eps=eps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)*num_train_epochs)



In [19]:
training_args={
    'output_dir':save_model_path,
    'num_train_epochs':num_train_epochs,
    'optimizer':optimizer,
    'scheduler':scheduler,
    'patience':patience,
    'run_name':model_type+'_'+str(ver)
}

### Metric

In [20]:
def compute_ner_metric(preds,labels,id_label):
    '''
    Converts id to its repective label and removes padded labels.
    Computes sequence evaluation in strict mode
    '''
    ground_truths=[[id_label[l] for l in lab if l!=-100] for lab in labels]
    prediction_labels=[[id_label[p] for p,l in zip(predict,lab) if l!=-100] for predict,lab in zip(preds,labels)]
    metric_res=metric.compute(predictions=prediction_labels,references=ground_truths)
    return metric_res


### Evaluation

In [21]:
def evaluate_entity_model(model,val_loader,loss_fn,id_label_ent,id_label_role,id_label_status,id_label_method,extract_predictions=False,dt_set=''):
  '''
  Performs evaluation for a given dataset and returns the computed loss and metrics for Entity, Role, Status, Method.
  Saves Token level predictions of entity, role, method and status in data/tr_{runname and version}.json
  '''
  model.eval()
  with torch.no_grad():

     #Initializing Validation losses for each of categories
    val_ent_loss=0
    val_status_loss=0
    val_method_loss=0
    val_role_loss=0
    val_average_loss=0

    #Initializing labels and prediction list, in-turn used for evaluating metrics.
    all_entity_labels=[]
    all_entity_predictions=[]
    all_role_labels=[]
    all_role_predictions=[]
    all_method_labels=[]
    all_method_predictions=[]
    all_status_labels=[]
    all_status_predictions=[]

    all_res_entity={}
    all_res_role={}
    all_res_method={}
    all_res_status={}
    all_res={}

    for step,batch in enumerate(val_loader):
      #Extracting inputs from the batch.
      inputs={'input_ids':batch['input_ids'].to(device),'attention_mask':batch['attention_mask'].to(device)}
      entity_labels=batch['entity_labels'].to(device)
      role_labels=batch['role_labels'].to(device)
      method_labels=batch['method_labels'].to(device)
      status_labels=batch['status_labels'].to(device)
      tokens=batch['tokens']

      #Model computation
      status_logits, method_logits, role_logits, entity_logits=model(**inputs)

      #Loss computation
      entity_loss=loss_fn(entity_logits.view(-1,len(id_label_ent)),entity_labels.view(-1))
      role_loss=loss_fn(role_logits.view(-1,len(id_label_role)),role_labels.view(-1))
      status_loss=loss_fn(status_logits.view(-1,len(id_label_status)),status_labels.view(-1))
      method_loss=loss_fn(method_logits.view(-1,len(id_label_method)),method_labels.view(-1))
      avg_loss_step=(entity_loss+role_loss+status_loss+method_loss)/4
      val_ent_loss+=entity_loss
      val_role_loss += role_loss
      val_method_loss += method_loss
      val_status_loss += status_loss
      val_average_loss += avg_loss_step

      #Calculating probabilities and predictions
      entity_probabilities = torch.softmax(entity_logits, dim=-1)
      entity_predictions = torch.argmax(entity_probabilities, dim=-1)
      role_probabilities  = torch.softmax(role_logits,dim=-1)
      role_predictions = torch.argmax(role_probabilities, dim =-1)
      status_probabilities=torch.softmax(status_logits,dim=-1)
      status_predictions=torch.argmax(status_probabilities,dim=-1)

      method_probabilities=torch.softmax(method_logits,dim=-1)
      method_predictions=torch.argmax(method_probabilities,dim=-1)

      #Collecting all labels and predictions
      all_status_predictions.extend(status_predictions.tolist())
      all_status_labels.extend(batch['status_labels'].tolist())
      all_method_labels.extend(batch['method_labels'].tolist())
      all_method_predictions.extend(method_predictions.tolist())

      # Mapping Token,Prediction,label and saving to a json file

      status_labels=batch['status_labels'].tolist()
      status_predictions=status_predictions.tolist()
      for idx in range(0,len(batch['file_name'])):
        all_res_status[batch['file_name'][idx]]=list(zip(batch['tokens'][idx],status_labels[idx],status_predictions[idx]))


      method_labels=batch['method_labels'].tolist()
      method_predictions=method_predictions.tolist()
      for idx in range(0,len(batch['file_name'])):
        all_res_method[batch['file_name'][idx]]=list(zip(batch['tokens'][idx],method_labels[idx],method_predictions[idx]))


      all_role_labels.extend(batch['role_labels'].tolist())
      all_role_predictions.extend(role_predictions.tolist())
      role_labels=batch['role_labels'].tolist()
      role_predictions=role_predictions.tolist()
      for idx in range(0,len(batch['file_name'])):
        all_res_role[batch['file_name'][idx]]=list(zip(batch['tokens'][idx],role_labels[idx],role_predictions[idx]))

      all_entity_labels.extend(entity_labels.tolist())
      all_entity_predictions.extend(entity_predictions.tolist())
      entity_predictions=entity_predictions.tolist()
      entity_labels=entity_labels.tolist()
      for idx in range(0,len(batch['file_name'])):
        all_res_entity[batch['file_name'][idx]]=list(zip(batch['tokens'][idx],entity_labels[idx],entity_predictions[idx]))

    all_res={'entity':all_res_entity,'role':all_res_role,'status':all_res_status,'method':all_res_method}

    #Computing NER metrics
    entity_metrics=compute_ner_metric(all_entity_predictions,all_entity_labels,id_label_ent)
    role_metrics=compute_ner_metric(all_role_predictions,all_role_labels,id_label_role)
    status_metrics=compute_ner_metric(all_status_predictions,all_status_labels,id_label_status)
    method_metrics=compute_ner_metric(all_method_predictions,all_method_labels,id_label_method)

    #Saving Results
    if extract_predictions:
      with open('/content/drive/MyDrive/PHD_assessment_gmu/data/'+dt_set+'_'+model_type+'_'+str(ver)+'.json','w') as f:
        json.dump(all_res,f)
  return {'avg_loss':val_average_loss/len(val_loader),'entity_evaluation':(val_ent_loss/len(val_loader),entity_metrics),'role_evaluation':(val_role_loss/len(val_loader),role_metrics),
            'method_evaluation':(val_method_loss/len(val_loader),method_metrics),'status_evaluation':(val_status_loss/len(val_loader),status_metrics)}




### Training

In [22]:
def train_entity_bert_model(model,tr_dataloader,tst_dataloader,id_label_ent,id_label_role,id_label_status,id_label_method,bert_model_name,training_args):
  '''
  Performs training with early stopping.
  Average loss has been calculated for the multiple tasks and sent for backward propogation.

  '''
  if torch.cuda.is_available():
    model.cuda()
  train_cycles=trange(training_args['num_train_epochs'],desc='Epoch',disable=0)
  loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
  #Intializing optimizers and schedulers
  optimizer=training_args['optimizer']
  scheduler=training_args['scheduler']
  patience=training_args['patience']

  #Intializing training and validation loss lists for accumulating around the epochs
  tr_ent_loss_lst=[]
  tr_status_loss_lst=[]
  tr_method_loss_lst=[]
  tr_role_loss_lst=[]
  tr_avg_loss_lst=[]

  val_ent_loss_list=[]
  val_status_loss_list=[]
  val_method_loss_list=[]
  val_role_loss_list=[]
  val_avg_loss_list=[]

  early_stopping_count=0
  early_stopping_count_ent=0
  early_stopping_count_status=0
  early_stopping_count_method=0
  early_stopping_count_role=0

  for cycle in train_cycles:
    epoch_cycles=tqdm(tr_dataloader,desc='Iteration',disable=-1)
    model.train()
    #Intializing epoch loss
    tr_ent_loss=0
    tr_status_loss=0
    tr_method_loss=0
    tr_role_loss=0
    average_loss=0

    for step,batch in enumerate(epoch_cycles):
      #Model Computation
      optimizer.zero_grad()
      inputs={'input_ids':batch['input_ids'].to(device),'attention_mask':batch['attention_mask'].to(device)}

      entity_labels=batch['entity_labels'].to(device)
      status_logits, method_logits, role_logits, entity_logits=model(**inputs)
      role_labels=batch['role_labels'].to(device)
      method_labels=batch['method_labels'].to(device)
      status_labels=batch['status_labels'].to(device)

      #Loss computation for a batch
      entity_loss=loss_fn(entity_logits.view(-1,len(id_label_ent)),entity_labels.view(-1))
      role_loss=loss_fn(role_logits.view(-1,len(id_label_role)),role_labels.view(-1))
      status_loss=loss_fn(status_logits.view(-1,len(id_label_status)),status_labels.view(-1))
      method_loss=loss_fn(method_logits.view(-1,len(id_label_method)),method_labels.view(-1))
      avg_loss_step=(entity_loss+role_loss+status_loss+method_loss)/4

      avg_loss_step.backward()
      optimizer.step()
      scheduler.step()

      #Loss computation across an epoch
      tr_ent_loss+=entity_loss.item()
      tr_status_loss+=status_loss.item()
      tr_method_loss+=method_loss.item()
      tr_role_loss+=role_loss.item()
      average_loss+=avg_loss_step.item()

    #Accumulation losses for all epochs.
    tr_ent_loss_lst.append(tr_ent_loss/len(tr_dataloader))
    tr_status_loss_lst.append(tr_status_loss/len(tr_dataloader))
    tr_method_loss_lst.append(tr_method_loss/len(tr_dataloader))
    tr_role_loss_lst.append(tr_role_loss/len(tr_dataloader))
    tr_avg_loss_lst.append(average_loss/len(tr_dataloader))

    print('Epoch: {}  Train Entity Loss: {}'.format(cycle,tr_ent_loss_lst[-1]))
    print('Epoch: {}  Train Status Loss: {}'.format(cycle,tr_status_loss_lst[-1]))
    print('Epoch: {}  Train Method Loss: {}'.format(cycle,tr_method_loss_lst[-1]))
    print('Epoch: {}  Train Role Loss: {}'.format(cycle,tr_role_loss_lst[-1]))
    print('Epoch: {}  Train Average Loss: {}'.format(cycle,tr_avg_loss_lst[-1]))

    #Computing metrics and evaluating models for validataion data.
    #Saves Results on last epoch or when it's reaches patience
    if cycle == training_args['num_train_epochs']-1 or early_stopping_count_ent >= patience-1 or early_stopping_count_status >= patience-1  or early_stopping_count_method >= patience-1 or early_stopping_count_role >= patience-1 :
      val_results=evaluate_entity_model(model,tst_dataloader,loss_fn,id_label_ent,id_label_role,id_label_status,id_label_method,True,'tst')
    else:
      val_results=evaluate_entity_model(model,tst_dataloader,loss_fn,id_label_ent,id_label_role,id_label_status,id_label_method)

    val_ent_loss_list.append(val_results['entity_evaluation'][0])
    val_status_loss_list.append(val_results['status_evaluation'][0])
    val_method_loss_list.append(val_results['method_evaluation'][0])
    val_role_loss_list.append(val_results['role_evaluation'][0])
    val_avg_loss_list.append(val_results['avg_loss'])

    print('Epoch: {}  Val Entity Loss: {}'.format(cycle,val_ent_loss_list[-1]))
    print('Epoch: {}  Val Status Loss: {}'.format(cycle,val_status_loss_list[-1]))
    print('Epoch: {}  Val Method Loss: {}'.format(cycle,val_method_loss_list[-1]))
    print('Epoch: {}  Val Role Loss: {}'.format(cycle,val_role_loss_list[-1]))
    print('Epoch: {}  Avg Val Loss: {}'.format(cycle, val_avg_loss_list[-1]))
    print('Epoch: {}  Metrics below: \n')

    #Computing metrics and evaluating models for training data.
    if cycle == training_args['num_train_epochs']-1 or early_stopping_count_ent >= patience-1 or early_stopping_count_status >= patience-1  or early_stopping_count_method >= patience-1 or early_stopping_count_role >= patience-1 :
      tr_results=evaluate_entity_model(model,tr_dataloader,loss_fn,id_label_ent,id_label_role,id_label_status,id_label_method,True,'tr')
    else:
      tr_results=evaluate_entity_model(model,tr_dataloader,loss_fn,id_label_ent,id_label_role,id_label_status,id_label_method)


    print('Entity_metrics: \n')
    print(val_results['entity_evaluation'][1])
    print('\n')
    print(tr_results['entity_evaluation'][1])
    print('\n')
    print('Role_metrics: \n')
    print(val_results['role_evaluation'][1])
    print('\n')
    print(tr_results['role_evaluation'][1])
    print('\n')
    print('Status_metrics: \n')
    print(val_results['status_evaluation'][1])
    print('\n')
    print(tr_results['status_evaluation'][1])
    print('\n')
    print('Method_metrics: \n')
    print(val_results['method_evaluation'][1])
    print('\n')
    print(tr_results['method_evaluation'][1])
    print('\n')

    #Early stopping is implemented by checking for a continuous increase in validation loss over a specified number of epochs.
    if cycle==0:
      min_ent_val_loss=val_ent_loss_list[-1]
      min_status_val_loss=val_status_loss_list[-1]
      min_method_val_loss=val_method_loss_list[-1]
      min_role_val_loss=val_role_loss_list[-1]
      min_val_loss=val_results['avg_loss']
      early_stopping_count_ent=0
      early_stopping_count_status=0
      early_stopping_count_method=0
      early_stopping_count_role=0
    else:
      if val_results['avg_loss']<min_val_loss:
        min_val_loss=val_results['avg_loss']
        early_stopping_count=0
      else:
        early_stopping_count+=1
      if val_ent_loss_list[-1]<min_ent_val_loss:
        min_ent_val_loss=val_ent_loss_list[-1]
        early_stopping_count_ent=0
      else:
        early_stopping_count_ent+=1
      if val_status_loss_list[-1]<min_status_val_loss:
        min_status_val_loss=val_status_loss_list[-1]
        early_stopping_count_status=0
      else:
        early_stopping_count_status+=1
      if val_method_loss_list[-1]<min_method_val_loss:
        min_method_val_loss=val_method_loss_list[-1]
        early_stopping_count_method=0
      else:
        early_stopping_count_method+=1
      if val_role_loss_list[-1]<min_role_val_loss:
        min_role_val_loss=val_role_loss_list[-1]
        early_stopping_count_role=0
      else:
        early_stopping_count_role+=1

      if early_stopping_count>=patience or early_stopping_count_ent>=patience or early_stopping_count_status>=patience or early_stopping_count_method>=patience or early_stopping_count_role>=patience:
        print('Early stopping counter for entity model : {}'.format(early_stopping_count_ent))
        print('Early stopping counter for status model : {}'.format(early_stopping_count_status))
        print('Early stopping counter for method model : {}'.format(early_stopping_count_method))
        print('Early stopping counter for role model : {}'.format(early_stopping_count_role))
        print('Early stopping counter for overall model : {}'.format(early_stopping_count))
        print('Early stopping at epoch: {}'.format(cycle))
        break

  #Model save:
  torch.save(model.state_dict(),training_args['output_dir']+training_args['run_name']+'.pth')
  return model, {'train_loss':{'tr_ent_loss_lst':tr_ent_loss_lst,
                 'tr_method_loss_lst':tr_method_loss_lst,'tr_status_loss_lst':tr_status_loss_lst,'tr_role_loss_lst':tr_role_loss_lst,'tr_avg_loss_lst':tr_avg_loss_lst},
                 'val_loss':{'val_ent_loss_list':val_ent_loss_list,'val_status_loss_list':val_status_loss_list,'val_method_loss_list':val_method_loss_list,'val_avg_loss_list':val_avg_loss_list}}







In [23]:
trained_model,loss=train_entity_bert_model(model,train_dataloader,test_dataloader,id_label_ent,id_label_role,id_label_status,id_label_method,bert_model_name,training_args)

Epoch:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch: 0  Train Entity Loss: 0.8092200370395884
Epoch: 0  Train Status Loss: 0.48840374631040234
Epoch: 0  Train Method Loss: 0.31141672677853527
Epoch: 0  Train Role Loss: 1.3964523567872889
Epoch: 0  Train Average Loss: 0.7513732208925135


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 0  Val Entity Loss: 0.47225427627563477
Epoch: 0  Val Status Loss: 0.39268621802330017
Epoch: 0  Val Method Loss: 0.2066621333360672
Epoch: 0  Val Role Loss: 0.887981116771698
Epoch: 0  Avg Val Loss: 0.48989588022232056
Epoch: {}  Metrics below: 



Epoch:   7%|▋         | 1/15 [02:43<38:09, 163.55s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 55}, 'Drug': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 30}, 'Family': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 37}, 'Tobacco': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 53}, 'overall_precision': 0.0, 'overall_recall': 0.0, 'overall_f1': 0.0, 'overall_accuracy': 0.9236819108651325}


{'Alcohol': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 199}, 'Drug': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 124}, 'Family': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 146}, 'Tobacco': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 225}, 'overall_precision': 0.0, 'overall_recall': 0.0, 'overall_f1': 0.0, 'overall_accuracy': 0.9171234557800867}


Role_metrics: 

{'Amount': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 55}, 'ExposureHistory': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 10}, 'Frequency': {'precision': 0.0

Epoch:  13%|█▎        | 2/15 [05:13<33:42, 155.61s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 55}, 'Drug': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 30}, 'Family': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 37}, 'Tobacco': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 53}, 'overall_precision': 0.0, 'overall_recall': 0.0, 'overall_f1': 0.0, 'overall_accuracy': 0.9236819108651325}


{'Alcohol': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 199}, 'Drug': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 124}, 'Family': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 146}, 'Tobacco': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 225}, 'overall_precision': 0.0, 'overall_recall': 0.0, 'overall_f1': 0.0, 'overall_accuracy': 0.9171234557800867}


Role_metrics: 

{'Amount': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 55}, 'ExposureHistory': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 10}, 'Frequency': {'precision': 0.0

Epoch:  20%|██        | 3/15 [07:47<30:55, 154.65s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 55}, 'Drug': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 30}, 'Family': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 37}, 'Tobacco': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 53}, 'overall_precision': 0.0, 'overall_recall': 0.0, 'overall_f1': 0.0, 'overall_accuracy': 0.9236819108651325}


{'Alcohol': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 199}, 'Drug': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 124}, 'Family': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 146}, 'Tobacco': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 225}, 'overall_precision': 0.0, 'overall_recall': 0.0, 'overall_f1': 0.0, 'overall_accuracy': 0.9171234557800867}


Role_metrics: 

{'Amount': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 55}, 'ExposureHistory': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 10}, 'Frequency': {'precision': 0.0

Epoch:  27%|██▋       | 4/15 [10:15<27:53, 152.13s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.675, 'recall': 0.4909090909090909, 'f1': 0.5684210526315789, 'number': 55}, 'Drug': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 30}, 'Family': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 37}, 'Tobacco': {'precision': 0.4444444444444444, 'recall': 0.37735849056603776, 'f1': 0.40816326530612246, 'number': 53}, 'overall_precision': 0.5529411764705883, 'overall_recall': 0.26857142857142857, 'overall_f1': 0.36153846153846153, 'overall_accuracy': 0.940868045441305}


{'Alcohol': {'precision': 0.5104166666666666, 'recall': 0.49246231155778897, 'f1': 0.5012787723785166, 'number': 199}, 'Drug': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 124}, 'Family': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 146}, 'Tobacco': {'precision': 0.5418994413407822, 'recall': 0.4311111111111111, 'f1': 0.4801980198019802, 'number': 225}, 'overall_precision': 0.5256064690026954, 'overall_recall': 0.28097982708933716, 'overall_f

Epoch:  33%|███▎      | 5/15 [12:46<25:17, 151.75s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.5735294117647058, 'recall': 0.7090909090909091, 'f1': 0.6341463414634145, 'number': 55}, 'Drug': {'precision': 0.9375, 'recall': 0.5, 'f1': 0.6521739130434783, 'number': 30}, 'Family': {'precision': 0.38461538461538464, 'recall': 0.40540540540540543, 'f1': 0.39473684210526316, 'number': 37}, 'Tobacco': {'precision': 0.6515151515151515, 'recall': 0.8113207547169812, 'f1': 0.7226890756302522, 'number': 53}, 'overall_precision': 0.5925925925925926, 'overall_recall': 0.64, 'overall_f1': 0.6153846153846153, 'overall_accuracy': 0.958636760850568}


{'Alcohol': {'precision': 0.4701492537313433, 'recall': 0.6331658291457286, 'f1': 0.5396145610278373, 'number': 199}, 'Drug': {'precision': 0.7634408602150538, 'recall': 0.5725806451612904, 'f1': 0.6543778801843319, 'number': 124}, 'Family': {'precision': 0.3560606060606061, 'recall': 0.3219178082191781, 'f1': 0.3381294964028777, 'number': 146}, 'Tobacco': {'precision': 0.6513409961685823, 'recall': 0.

Epoch:  40%|████      | 6/15 [15:16<22:40, 151.14s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.6307692307692307, 'recall': 0.7454545454545455, 'f1': 0.6833333333333332, 'number': 55}, 'Drug': {'precision': 0.7575757575757576, 'recall': 0.8333333333333334, 'f1': 0.7936507936507938, 'number': 30}, 'Family': {'precision': 0.41304347826086957, 'recall': 0.5135135135135135, 'f1': 0.4578313253012048, 'number': 37}, 'Tobacco': {'precision': 0.7619047619047619, 'recall': 0.9056603773584906, 'f1': 0.8275862068965516, 'number': 53}, 'overall_precision': 0.642512077294686, 'overall_recall': 0.76, 'overall_f1': 0.6963350785340314, 'overall_accuracy': 0.968540635013108}


{'Alcohol': {'precision': 0.5057034220532319, 'recall': 0.6683417085427136, 'f1': 0.5757575757575757, 'number': 199}, 'Drug': {'precision': 0.6, 'recall': 0.7016129032258065, 'f1': 0.6468401486988847, 'number': 124}, 'Family': {'precision': 0.4601226993865031, 'recall': 0.5136986301369864, 'f1': 0.48543689320388356, 'number': 146}, 'Tobacco': {'precision': 0.7722007722007722, 'r

Epoch:  47%|████▋     | 7/15 [17:50<20:16, 152.00s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.6451612903225806, 'recall': 0.7272727272727273, 'f1': 0.6837606837606838, 'number': 55}, 'Drug': {'precision': 0.7027027027027027, 'recall': 0.8666666666666667, 'f1': 0.7761194029850748, 'number': 30}, 'Family': {'precision': 0.5, 'recall': 0.5945945945945946, 'f1': 0.5432098765432098, 'number': 37}, 'Tobacco': {'precision': 0.7619047619047619, 'recall': 0.9056603773584906, 'f1': 0.8275862068965516, 'number': 53}, 'overall_precision': 0.6601941747572816, 'overall_recall': 0.7771428571428571, 'overall_f1': 0.7139107611548556, 'overall_accuracy': 0.9720361200116516}


{'Alcohol': {'precision': 0.5153846153846153, 'recall': 0.6733668341708543, 'f1': 0.5838779956427015, 'number': 199}, 'Drug': {'precision': 0.5827814569536424, 'recall': 0.7096774193548387, 'f1': 0.64, 'number': 124}, 'Family': {'precision': 0.56, 'recall': 0.5753424657534246, 'f1': 0.5675675675675674, 'number': 146}, 'Tobacco': {'precision': 0.7790697674418605, 'recall': 0.8933

Epoch:  53%|█████▎    | 8/15 [20:18<17:36, 150.96s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.6557377049180327, 'recall': 0.7272727272727273, 'f1': 0.689655172413793, 'number': 55}, 'Drug': {'precision': 0.7352941176470589, 'recall': 0.8333333333333334, 'f1': 0.78125, 'number': 30}, 'Family': {'precision': 0.54, 'recall': 0.7297297297297297, 'f1': 0.6206896551724138, 'number': 37}, 'Tobacco': {'precision': 0.7868852459016393, 'recall': 0.9056603773584906, 'f1': 0.8421052631578947, 'number': 53}, 'overall_precision': 0.6796116504854369, 'overall_recall': 0.8, 'overall_f1': 0.7349081364829397, 'overall_accuracy': 0.9740751529274687}


{'Alcohol': {'precision': 0.5942622950819673, 'recall': 0.7286432160804021, 'f1': 0.654627539503386, 'number': 199}, 'Drug': {'precision': 0.6766917293233082, 'recall': 0.7258064516129032, 'f1': 0.7003891050583658, 'number': 124}, 'Family': {'precision': 0.6032608695652174, 'recall': 0.7602739726027398, 'f1': 0.6727272727272727, 'number': 146}, 'Tobacco': {'precision': 0.8112449799196787, 'recall': 0.897

Epoch:  60%|██████    | 9/15 [22:53<15:12, 152.12s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.6666666666666666, 'recall': 0.7272727272727273, 'f1': 0.6956521739130435, 'number': 55}, 'Drug': {'precision': 0.7428571428571429, 'recall': 0.8666666666666667, 'f1': 0.8, 'number': 30}, 'Family': {'precision': 0.5, 'recall': 0.7297297297297297, 'f1': 0.5934065934065933, 'number': 37}, 'Tobacco': {'precision': 0.7833333333333333, 'recall': 0.8867924528301887, 'f1': 0.8318584070796461, 'number': 53}, 'overall_precision': 0.6698564593301436, 'overall_recall': 0.8, 'overall_f1': 0.7291666666666666, 'overall_accuracy': 0.974366443344014}


{'Alcohol': {'precision': 0.606694560669456, 'recall': 0.7286432160804021, 'f1': 0.6621004566210045, 'number': 199}, 'Drug': {'precision': 0.7286821705426356, 'recall': 0.7580645161290323, 'f1': 0.7430830039525692, 'number': 124}, 'Family': {'precision': 0.5902439024390244, 'recall': 0.8287671232876712, 'f1': 0.6894586894586894, 'number': 146}, 'Tobacco': {'precision': 0.8278688524590164, 'recall': 0.89777777

Epoch:  67%|██████▋   | 10/15 [25:26<12:41, 152.36s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.6833333333333333, 'recall': 0.7454545454545455, 'f1': 0.7130434782608696, 'number': 55}, 'Drug': {'precision': 0.8181818181818182, 'recall': 0.9, 'f1': 0.8571428571428572, 'number': 30}, 'Family': {'precision': 0.5813953488372093, 'recall': 0.6756756756756757, 'f1': 0.6250000000000001, 'number': 37}, 'Tobacco': {'precision': 0.7619047619047619, 'recall': 0.9056603773584906, 'f1': 0.8275862068965516, 'number': 53}, 'overall_precision': 0.7085427135678392, 'overall_recall': 0.8057142857142857, 'overall_f1': 0.7540106951871658, 'overall_accuracy': 0.9749490241771046}


{'Alcohol': {'precision': 0.6115702479338843, 'recall': 0.7437185929648241, 'f1': 0.671201814058957, 'number': 199}, 'Drug': {'precision': 0.7734375, 'recall': 0.7983870967741935, 'f1': 0.7857142857142857, 'number': 124}, 'Family': {'precision': 0.75, 'recall': 0.8424657534246576, 'f1': 0.7935483870967742, 'number': 146}, 'Tobacco': {'precision': 0.812, 'recall': 0.9022222222222

Epoch:  67%|██████▋   | 10/15 [27:58<13:59, 167.85s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.711864406779661, 'recall': 0.7636363636363637, 'f1': 0.736842105263158, 'number': 55}, 'Drug': {'precision': 0.875, 'recall': 0.9333333333333333, 'f1': 0.9032258064516129, 'number': 30}, 'Family': {'precision': 0.5471698113207547, 'recall': 0.7837837837837838, 'f1': 0.6444444444444444, 'number': 37}, 'Tobacco': {'precision': 0.7868852459016393, 'recall': 0.9056603773584906, 'f1': 0.8421052631578947, 'number': 53}, 'overall_precision': 0.7170731707317073, 'overall_recall': 0.84, 'overall_f1': 0.7736842105263158, 'overall_accuracy': 0.9764054762598311}


{'Alcohol': {'precision': 0.6398305084745762, 'recall': 0.7587939698492462, 'f1': 0.6942528735632183, 'number': 199}, 'Drug': {'precision': 0.7716535433070866, 'recall': 0.7903225806451613, 'f1': 0.7808764940239042, 'number': 124}, 'Family': {'precision': 0.6528497409326425, 'recall': 0.863013698630137, 'f1': 0.743362831858407, 'number': 146}, 'Tobacco': {'precision': 0.8285714285714286, 'rec


