# Entity Relation model (parallel task learning):
*   Performing named entity recognition for entities and roles at token - level and forming relation pairs and predicting relations.
* Evaluating Relations and ner models.

*   Entities and Role which are not considered due to insufficient data for training:
  1.  discarded_enities=['EnvironmentalExposure','SexualHistory','InfectiousDiseases','PhysicalActivity','Residence','LivingSituation','MaritalStatus','Occupation']

  2.   discarded_roles=['LivingStatus','Other','MedicalCondition','Extent','History']

*   Labelling is done using BIO-scheme.
*   Labels are generated using the mapping between label offsets and Bert token offsets.
*   In entities and roles, observations indicate that certain documents exhibit overlapping spans among them, as well as between roles and entities.
*   To handle overlapping between roles and entities we are classifying it with seperate heads.

*   By analysis, it was observed by seperating status, method from the roles we can avoid overlapping among them, so even here classifying with seperate heads.

*   Overlapping between the Entities where between 'Family','Residence','LivingSituation','MaritalStatus'. Except 'Family' all other entities have been dropped.
*    Generated relation mapping with token indices and labels for each event present in a document.
### **Experimentations:**
*  By eliminating overlapping entities, does the performance of the model in recognizing family entities improve?
*   Does the model improves by training entities (discarding overlapping), roles and relations together?


  ###   **Model**
*   ***Tokenizer:*** BertTokenizerFast
*   ***pre-trained Bert model:*** 'emilyalsentzer/Bio_ClinicalBERT'

*   ***Hyperparameters:***
  * eps=1e-8
  * learning_rate=7e-5
  * weight_decay=0
  * num_train_epochs=15
  * patience=3
  * batch_size=16
  * max_len_token=512
*   Fine-tuned a pre-trained Bert model with  multiple classification heads where each classifies status, methods, entities, and  roles respectively and also relations using BERT embeddings and cross-span attention. Cross span attention is bi-directional to capture complex relations.
*   Since we are training multiple tasks simultaneously, we aggregate the losses and send them for backpropagation.

*   Sequence evaluation is used as NER metrics.

*   Model stored at savemodel path + model_type+ver.pth.
 ****
### **Run the notebook::**
* Provide model_type and ver and project directory in input section and run all cells.
 datasets and their paths:
*   train_data_set path: project_directory+'/data/trainset.json'

*   test_data_set path: project_directory+'/data/testset.json'
* Model stored at project_directory+'/models/'+model_name+'_'+ver+'.pth'.
* Evaluation dataset predictions are stored at 'tst_'+model_type+'_'+str(ver)+'.json
### **Metrics:**

In [None]:
!pip install evaluate

In [None]:
import json
import pandas as pd
import evaluate
from transformers import BertTokenizerFast, BertModel, AdamW, get_linear_schedule_with_warmup, DataCollatorForTokenClassification
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm, trange
import itertools

In [None]:
!pip install seqeval

In [None]:
metric=evaluate.load("seqeval")
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizerFast.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')
num_freeze_layers=6

Downloading builder script:   0%|          | 0.00/6.34k [00:00<?, ?B/s]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

## Datasets

### Parameter initialization and Dataset path

In [None]:
model_type='Relation_Ner_parallel'
ver=2

discarded_enities=['EnvironmentalExposure','SexualHistory','InfectiousDiseases','PhysicalActivity','Residence','LivingSituation','MaritalStatus','Occupation']
discarded_roles=['LivingStatus','Other','MedicalCondition','Extent','History']


project_directory='/content/drive/MyDrive/PHD_assessment_gmu/'
raw_dataset_path=project_directory+'data/'+'SocialHistoryMTSamples.json'
train_dataset_path=project_directory+'data/'+'trainset.json'
test_dataset_path=project_directory+'data/'+'testset.json'
save_model_path=project_directory+'/models/'
bert_model_name='emilyalsentzer/Bio_ClinicalBERT'

eps=1e-8
learning_rate=7e-5
weight_decay=0.01
num_train_epochs=15
patience=5
max_len=512

batch_size=16

In [None]:
id_label_status={0:'O',1:'B-Status',2:'I-Status'}
id_label_method={0:'O',1:'B-Method',2:'I-Method'}
id_label_role={0:'O',1:'B-Type',2:'I-Type',3:'B-Amount',4:'I-Amount',5:'B-Temporal',6:'I-Temporal',7:'B-Frequency',8:'I-Frequency',9:'B-QuitHistory',10:'I-QuitHistory',11:'B-ExposureHistory',12:'I-ExposureHistory',13:'B-Location',14:'I-Location'}
id_label_event={0:'No Relation',1:'Relation'}
label_id_status = {v: k for k, v in id_label_status.items()}
label_id_method = {v: k for k, v in id_label_method.items()}
label_id_role = {v: k for k, v in id_label_role.items()}
label_id_ent = {'B-Alcohol':1,
 'B-Drug':3,
 'B-Family':5,
 'B-Tobacco':7,
 'I-Alcohol':2,
 'I-Drug':4,
 'I-Family':6,
 'I-Tobacco':8,
 'O':0}

id_label_ent = {v: k for k, v in label_id_ent.items()}
label_id_event = {v: k for k, v in id_label_event.items()}


### Generate Relation label and pairs and entity, role labels


In [None]:

class GenerateLabel:

  @staticmethod
  def generate_enity_labels(entity_list, token_len, token_offsets):
    '''
    Generates entity labels for each token in the sentence. Performs by mapping the entity's offset positions to the tokens offset.

    Inputs:
    entity_list: List of entity dicts, each entity dict has the information such as offset positions, id , category and text.
    token_len: Initializes the labels size
    token_offsets: list of tuples of offset mappings of bert tokens.
    file_name: Name of the file from which the entity list is generated.

    Outputs:
    entity_labels: List of BIO labels for each token.
    '''
    # Initialize a list to store the BIO labels for each token

    entity_labels = [label_id_ent['O']] * token_len
    for entity in entity_list:
        category = entity['entity_category']
        if category not in discarded_enities:
          entity_start_pos = int(entity['entity_strt_pos'])
          entity_end_pos = int(entity['entity_end_pos'])-1

          # Find tokens that correspond to the entity's position
          entity_start_token = None
          entity_end_token = None

          for i, (start_offset, end_offset) in enumerate(token_offsets):
              if entity_start_token is None and start_offset >= entity_start_pos:
                  entity_start_token = i
              if end_offset > entity_end_pos:
                  entity_end_token = i
                  break
          if entity_start_token is not None:
            entity_labels[entity_start_token] = label_id_ent['B-' + category]
            if entity_end_token is not None:
                entity_labels[entity_start_token + 1:entity_end_token + 1] = [label_id_ent['I-' + category]] * (entity_end_token - entity_start_token)



    return entity_labels

  @staticmethod
  def generate_role_labels(role_list, token_len, token_offsets):
    '''
    Generates role labels for each token in the sentence. Performs by mapping the entity's offset positions to the tokens offset.
    Also handles overlapping entities by seperating status and method from the role and creating seperate labels for it.

    Inputs:
    role_list: List of role dicts, each entity dict has the information such as offset positions, id , category and text.
    token_len: Initializes the labels size
    token_offsets: list of tuples of offset mappings of bert tokens.
    file_name: Name of the file from which the entity list is generated.

    Outputs:
    role_labels: List of BIO labels for each token.
    status_labels: List of BIO labels for each token.
    method_labels: List of BIO labels for each token.
    '''
    # Initialize a list to store the BIO labels for each token
    role_labels = [label_id_role['O']] * token_len
    status_labels = [label_id_status['O']] * token_len
    method_labels = [label_id_method['O']] * token_len

    for role in role_list:
        category = role['entity_category']
        entity_start_pos = int(role['entity_strt_pos'])
        entity_end_pos = int(role['entity_end_pos'])-1
        if category in discarded_roles:
          continue
        # Find tokens that correspond to the entity's position
        entity_start_token = None
        entity_end_token = None

        for i, (start_offset, end_offset) in enumerate(token_offsets):
            if entity_start_token is None and start_offset >= entity_start_pos:
                entity_start_token = i
            if end_offset > entity_end_pos:
                entity_end_token = i
                break

        # Assign BIO labels to the tokens
        if category == 'Status':
          if entity_start_token is not None:
            status_labels[entity_start_token] = label_id_status['B-' + category]
            if entity_end_token is not None:
                status_labels[entity_start_token + 1:entity_end_token + 1] = [label_id_status['I-' + category]] * (entity_end_token - entity_start_token)

        elif category == 'Method':
          if entity_start_token is not None:
            method_labels[entity_start_token] = label_id_method['B-' + category]
            if entity_end_token is not None:
                method_labels[entity_start_token + 1:entity_end_token + 1] = [label_id_method['I-' + category]] * (entity_end_token - entity_start_token)

        else:
          if entity_start_token is not None:
            role_labels[entity_start_token] = label_id_role['B-' + category]
            if entity_end_token is not None:
                role_labels[entity_start_token + 1:entity_end_token + 1] = [label_id_role['I-' + category]] * (entity_end_token - entity_start_token)


    return role_labels, status_labels, method_labels

  @staticmethod
  def generate_relation_labels( token_offsets, data):
     '''
    Generates relation mapping with token indices for each event present in a document.
    Reads events and maps them to their related entities and roles with their token indices.


    Inputs:
    token_offsets: list of tuples of offset mappings of bert tokens.
    data: Data dict which consists of entity_list, role_list and events_list.

    Outputs:
    relation_labels: The relation mapping is a dictionary where key is a tuple of entity token indices and value is a list of tuple of role token indices.

    '''
    entity_token_indices = {}
    role_token_indices = {}
    event_categories = {}

    entity_list = data.get('entity_list', [])
    role_list = data.get('role_list', [])
    events_list = data.get('events_list', [])

    # Create a dictionary to map entity IDs to their token indices
    for entity in entity_list:
        entity_id = entity['entity_id']
        entity_start_pos = int(entity['entity_strt_pos'])
        entity_end_pos = int(entity['entity_end_pos'])-1
        entity_category = entity['entity_category']

        entity_start_token = None
        entity_end_token = None

        for i, (start_offset, end_offset) in enumerate(token_offsets):
            if entity_start_token is None and start_offset >= entity_start_pos:
                entity_start_token = i
            if end_offset > entity_end_pos:
                entity_end_token = i
                break

        if entity_start_token is not None:
            entity_token_indices[entity_id] = (entity_start_token, entity_end_token, entity_category)

    # Create a dictionary to map event-related role IDs to their token indices
    for event in events_list:
        entity_id = event['entity_id']
        related_roles = event['Related_roles']
        ent_info=entity_token_indices.get(entity_id, None)


        if ent_info:
            en_strt,en_end,ent_cat=ent_info
            entity_indices = (en_strt,en_end)
            role_indices = []
            role_categories = []

            for role_id in related_roles:
                role = next((role for role in role_list if role['role_id'] == role_id), None)

                if role:
                    role_start_pos = int(role['entity_strt_pos'])
                    role_end_pos = int(role['entity_end_pos'])-1

                    role_start_token = None
                    role_end_token = None

                    for i, (start_offset, end_offset) in enumerate(token_offsets):
                        if role_start_token is None and start_offset >= role_start_pos:
                            role_start_token = i
                        if end_offset > role_end_pos:
                            role_end_token = i
                            break

                    if role_start_token is not None:
                        role_indices.append((role_start_token, role_end_token))
                        role_categories.append(((role_start_token, role_end_token),role['entity_category']))

            if role_indices:
                role_token_indices[(entity_indices[0],entity_indices[1])] = role_indices
                event_categories[((entity_indices[0],entity_indices[1]),ent_cat)] = role_categories

    return role_token_indices, event_categories

  @staticmethod
  def generate_relation_pairs(relation_labels):
    '''
    Generates relation pairs from relation mapping.

    Inputs:
    ;relation_labels: The relation mapping is a dictionary where key is a tuple of entity token indices and value is a list of tuple of role token indices.
    ;return: related_pairs, related_logits, related_pairs_categories
    related_pairs: list of token indices pairs of entity and role/status/method.
    related_logits: list of relation label for each pair.
    related_pairs_categories: list of categories for each pair.
    '''
    related_pairs = []
    not_related_pairs = []
    related_pairs_categories = []
    not_related_pairs_categories = []

    # Extract all unique value pairs
    all_value_pairs = set()
    for value_list in relation_labels.values():
        for value in value_list:
            all_value_pairs.add(value)

    # Iterate through each key-value pair in the dictionary
    for key, values in relation_labels.items():
        value_set = set(values)

        # Add related pairs
        for value in value_set:
            related_pairs.append((key[0], value[0]))
            related_pairs_categories.append((key[1], value[1]))

        # Add not related pairs
        not_related_value_pairs = all_value_pairs - value_set
        for value in not_related_value_pairs:
            not_related_pairs.append((key[0], value[0]))
            not_related_pairs_categories.append((key[1], value[1]))

    related_logits = [1]*len(related_pairs)
    not_related_logits = [0]*len(not_related_pairs)
    related_pairs = related_pairs + not_related_pairs
    related_logits = related_logits + not_related_logits
    related_pairs_categories = related_pairs_categories + not_related_pairs_categories
    return related_pairs, related_logits, related_pairs_categories

### Dataset creation and loading

In [None]:
class ERDataset(Dataset):
  def __init__(self, data, tokenizer, max_len):
    self.data = data
    self.tokenizer = tokenizer
    self.max_len = max_len
  def __len__(self):
    return len(self.data)
  def __getitem__(self, idx):
    item=self.data[idx]
    text = item['text']
    inputs = self.tokenizer(text,max_length=max_len,truncation=True,return_offsets_mapping=True)
    tokens_len=len(inputs['input_ids'])
    offset_mapping_list=inputs['offset_mapping']
    entity_labels=GenerateLabel.generate_enity_labels(item['entity_list'],tokens_len,offset_mapping_list)
    role_labels, status_labels, method_labels=GenerateLabel.generate_role_labels(item['role_list'],tokens_len,offset_mapping_list)
    relation_labels,event_categories_relation=GenerateLabel.generate_relation_labels(offset_mapping_list,item)
    related_pairs, related_logits, related_pairs_categories=GenerateLabel.generate_relation_pairs(event_categories_relation)

    return {
      'input_ids':inputs['input_ids'],
      'attention_mask':inputs['attention_mask'],
      'entity_labels':entity_labels,
      'role_labels':role_labels,
      'status_labels':status_labels,
      'method_labels':method_labels,
      'relation_labels':event_categories_relation,
      'text':text,
      'file_name':item['file_name'],
      'offset_mapping':offset_mapping_list,
      'tokens':inputs.tokens(),
      'relation_pairs':related_pairs,
      'related_logits':related_logits,
      'related_pairs_categories':related_pairs_categories
    }
def pad_span_positions(span_positions):
  '''
  Pads span positions to the same length.
  '''
  max_length = max(len(spans) for spans in span_positions)
  padded_spans = []
  span_masks = []
  for spans in span_positions:
    if spans==[]:
      padded = [((-1, -1),(-1, -1))] * max_length
    else:
      padded = spans + [((-1, -1),(-1, -1))] * (max_length - len(spans))
    padded_spans.append(padded)
  return torch.tensor(padded_spans)
def collate_fn_entity_role(batch):
  input_ids = [torch.tensor(x['input_ids']) for x in batch]
  attention_mask = [torch.tensor(x['attention_mask']) for x in batch]
  #token_type_ids = [torch.tensor(x['token_type_ids']) for x in batch]
  tokens=[x['tokens'] for x in batch]
  file_name=[x['file_name'] for x in batch]
  relation_labels=[x['relation_labels'] for x in batch]
  relation_pairs_btch=[x['relation_pairs'] for x in batch]
  related_logits=[x['related_logits'] for x in batch]
  text=[x['text'] for x in batch]
  related_pairs_categories=[x['related_pairs_categories'] for x in batch]
  '''
  for x in batch:
    for idx in range(0,len(x['entity_labels'])):
      if isinstance(x['entity_labels'][idx],list):
        mutiple_labels=x['entity_labels'][idx]
        print(mutiple_labels)
        sing_mul_label=''
        for lab in mutiple_labels:
          if sing_mul_label=='':
            sing_mul_label=str(lab)
          else:
            sing_mul_label=sing_mul_label+'900'+str(lab)
        x['entity_labels'][idx] =int(sing_mul_label)
        print(x['entity_labels'][idx])
  '''
  entity_labels = [torch.tensor(x['entity_labels']) for x in batch]
  role_labels = [torch.tensor(x['role_labels']) for x in batch]
  status_labels = [torch.tensor(x['status_labels']) for x in batch]
  method_labels = [torch.tensor(x['method_labels']) for x in batch]
  input_ids = pad_sequence(input_ids, batch_first=True,padding_value=tokenizer.pad_token_id)
  attention_mask = pad_sequence(attention_mask, batch_first=True,padding_value=0)
  #token_type_ids = pad_sequence(token_type_ids, batch_first=True,padding_value=0)
  entity_labels = pad_sequence(entity_labels, batch_first=True,padding_value=-100)
  role_labels = pad_sequence(role_labels, batch_first=True,padding_value=-100)
  status_labels = pad_sequence(status_labels, batch_first=True,padding_value=-100)
  method_labels = pad_sequence(method_labels, batch_first=True,padding_value=-100)
  relation_pairs = pad_span_positions(relation_pairs_btch)
  return {
    'input_ids':input_ids,
    'attention_mask':attention_mask,
    'entity_labels':entity_labels,
    'role_labels':role_labels,
    'status_labels':status_labels,
    'method_labels':method_labels,
    'tokens':tokens,
    'file_name':file_name,
    'relation_labels':relation_labels,
    'relation_pairs':relation_pairs,
    'related_logits':related_logits,
    'text':text,
    'related_pairs_categories':related_pairs_categories,
    'relation_pairs_btch':relation_pairs_btch
  }


In [None]:
with open(train_dataset_path, 'r', encoding='utf-8') as file:
  train_data=json.load(file)
with open(test_dataset_path, 'r', encoding='utf-8') as file:
  test_data=json.load(file)
train_dataset=ERDataset(train_data,tokenizer, max_len)
test_dataset=ERDataset(test_data,tokenizer, max_len)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn_entity_role)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size,collate_fn=collate_fn_entity_role)

## Model

In [None]:
class CrossSpanAttention(nn.Module):
    '''
    Cross Span Attention Layer:
    This layer performs an attention mechanism where one set of spans (span2) attends to another set of spans (span1).
    It is designed to capture inter-span relationships and contextual dependencies within a sequence.

    '''
    def __init__(self, embed_size):
        super(CrossSpanAttention, self).__init__()
        self.embed_size = embed_size
        self.keys = nn.Linear(self.embed_size, self.embed_size)
        self.queries = nn.Linear(self.embed_size, self.embed_size)
        self.values = nn.Linear(self.embed_size, self.embed_size)

    def forward(self, span1_embeddings, span2_embeddings, span1_mask, span2_mask):
        '''
        Computes the attention-weighted output.
        - span1_embeddings: Embeddings for span1 tokens.
        - span2_embeddings: Embeddings for span2 tokens.
        - span1_mask: Mask for span1 to ignore padding tokens during attention computation.
        - span2_mask: Mask for span2 to apply attention weights only to valid tokens.

        The forward pass computes keys from span1 embeddings, queries, and values from span2 embeddings, applies a scaled dot-product attention mechanism, and returns the attention-weighted average of span2 values.


        '''
        span1_keys = self.keys(span1_embeddings)
        span2_queries = self.queries(span2_embeddings)
        span2_values = self.values(span2_embeddings)

        attention = torch.matmul(span2_queries, span1_keys.transpose(-2, -1)) / self.embed_size ** 0.5
        attention = attention.masked_fill(span1_mask.unsqueeze(1), -1e9)
        attention = torch.softmax(attention, dim=-1)

        assert not torch.isnan(attention).any(), "NaNs found in attention after softmax"
        span2_values = span2_values * span2_mask.unsqueeze(-1)
        out = torch.matmul(attention, span2_values)

        out = out.mean(dim=0,keepdim=True)
        return out

class EntityRelationClassifier(nn.Module):
    '''
    module for classifying entity relationships using BERT embeddings and cross-span attention.

    '''
    def __init__(self, model_name, num_freeze_layers,num_status_labels,num_method_labels,num_role_labels,num_entity_labels,num_relation_classes, dropout=0.1):
        super().__init__()
        self.bertmodel = BertModel.from_pretrained(model_name)
        self.cross_span_attention = CrossSpanAttention(self.bertmodel.config.hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.status_classifier = nn.Linear(self.bertmodel.config.hidden_size, num_status_labels)
        self.method_classifier = nn.Linear(self.bertmodel.config.hidden_size, num_method_labels)
        self.role_classifier = nn.Linear(self.bertmodel.config.hidden_size, num_role_labels)
        self.entity_classifier = nn.Linear(self.bertmodel.config.hidden_size, num_entity_labels)
        self.relation_classifier = nn.Linear(self.bertmodel.config.hidden_size * 2, num_relation_classes)

    def forward(self, input_ids, attention_mask, all_span_positions):
        '''
        Args:
        input_ids (torch.Tensor): Indices of input sequence tokens in the vocabulary.
        attention_mask (torch.Tensor): Mask to avoid attention on padding token indices.
        all_span_positions (List of Lists): Span positions for all potential entities and roles in each example of the batch.

        Returns:
        Tuple containing logits for status, method, role, entity predictions, and a list of all relation predictions,
        and their corresponding logits.
        '''

        # Pass inputs through BERT model and apply dropout to the sequence output
        bert_output = self.bertmodel(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output= bert_output[0]
        bert_output = self.dropout(bert_output[0])

        # Generate logits for each classification task using the respective linear layers
        status_logits = self.status_classifier(bert_output)
        method_logits = self.method_classifier(bert_output)
        role_logits = self.role_classifier(bert_output)
        entity_logits = self.entity_classifier(bert_output)

        # Initialize lists to store the logits and predictions for relationships
        all_relation_logits = []
        all_relation_preds = []

        # Process each example in the batch
        batch_size = input_ids.size(0)
        for i in range(batch_size):

            entity_pairs = all_span_positions[i]
            batch_relation_logits = []

            # Process each entity pair

            for span1, span2 in entity_pairs:
                # Skip padded spans
                if torch.equal(span1,torch.tensor([-1, -1]).to(device)) or torch.equal(span2,torch.tensor([-1, -1]).to(device)):
                    continue
                # Determine the maximum length for padding based on the spans
                max_length=max((span1[1]+1) - span1[0], (span2[1]+1) - span2[0])
                # Extract embeddings and padding for each span
                span1_embeddings = sequence_output[i, span1[0]:span1[1]+1, :]
                span1_embeddings, span1_mask = self.pad_and_create_mask(span1_embeddings, max_length)
                span2_embeddings = sequence_output[i, span2[0]:span2[1]+1, :]
                span2_embeddings, span2_mask = self.pad_and_create_mask(span2_embeddings, max_length)

                # Apply bi-directional cross-span attention for capturing the complex relationships and contextual nuances between text segments
                span1_to_span2_attention = self.cross_span_attention(span1_embeddings, span2_embeddings, span1_mask, span2_mask)
                span2_to_span1_attention = self.cross_span_attention(span2_embeddings, span1_embeddings, span2_mask, span1_mask)



                # Concatenate and classify the relation
                combined_embedding = torch.cat((span1_to_span2_attention, span2_to_span1_attention), dim=1)
                #predict relations
                relation_logits = self.relation_classifier(combined_embedding)

                all_relation_logits.append(relation_logits)

                probabilities = nn.functional.softmax(relation_logits, dim=1)

                labels = torch.argmax(probabilities, dim=1)

                batch_relation_logits.append(labels.item())
            all_relation_preds.append(batch_relation_logits)
        # Combine logits from all entities and relations
        all_relation_logits = torch.cat(all_relation_logits, dim=0)
        return status_logits,method_logits,role_logits,entity_logits, all_relation_preds,all_relation_logits
    def pad_and_create_mask(self, embeddings, max_length):
      '''
      Pad the embeddings to the maximum length
      '''
      pad_size = max_length - embeddings.size(0)
      if pad_size > 0:
          pad = torch.zeros((pad_size, embeddings.size(1)), device=embeddings.device)
          mask = torch.cat([torch.ones(embeddings.size(0), device=embeddings.device), torch.zeros(pad_size, device=embeddings.device)])
          embeddings = torch.cat([embeddings, pad], dim=0)
      else:
          mask = torch.ones(embeddings.size(0), device=embeddings.device)

      # Convert the mask to boolean
      mask = mask.bool()

      return embeddings, mask




### Model Evaluation

In [None]:
def compute_ner_metric(preds,labels,id_label):
    ground_truths=[[id_label[l] for l in lab if l!=-100] for lab in labels]
    prediction_labels=[[id_label[p] for p,l in zip(predict,lab) if l!=-100] for predict,lab in zip(preds,labels)]
    metric_res=metric.compute(predictions=prediction_labels,references=ground_truths)
    return metric_res
def compute_relation_metric(preds,labels,id_label):
    metric_res=classification_report(labels, preds, target_names=[id_label[0],id_label[1]])
    return metric_res

In [None]:
def process_relation_labels_preds(related_pairs_categories, relation_pairs, related_logits,text,relation_preds,tokens,id_label_event):
  doc_relations=[]
  for idx in range(0,len(related_pairs_categories)):
    entity_token_ind=relation_pairs[idx][0]
    role_token_ind=relation_pairs[idx][1]
    doc_relations.append({'text':text,'entity':related_pairs_categories[idx][0],
    'role':related_pairs_categories[idx][1],

    'entity_tokens':tokens[entity_token_ind[0]:entity_token_ind[1]+1],
    'role_tokens':tokens[role_token_ind[0]:role_token_ind[1]+1],
    'ground_truth':id_label_event[related_logits[idx]],
    'prediction':id_label_event[related_logits[idx]]})

  return doc_relations

In [None]:
def evaluate_entity_model(model,val_loader,loss_fn,id_label_ent,id_label_role,id_label_status,id_label_method,id_label_event,extract_predictions=False,dt_set=''):
  '''
  Performs evaluation for a given dataset and returns the computed loss and metrics for Entity, Role, Status, Method and relations.
  Saves Token level predictions of entity, role, method, status and relations in data/tr_{runname and version}.json
  '''
  model.eval()
  with torch.no_grad():
    #Initializing Validation losses for each of categories
    val_ent_loss=0
    val_status_loss=0
    val_method_loss=0
    val_role_loss=0
    val_average_loss=0
    val_event_loss=0

    #Initializing labels and prediction list, in-turn used for evaluating metrics.
    all_entity_labels=[]
    all_entity_predictions=[]
    all_role_labels=[]
    all_role_predictions=[]
    all_method_labels=[]
    all_method_predictions=[]
    all_status_labels=[]
    all_status_predictions=[]
    all_event_labels=[]
    all_event_predictions=[]

    all_res_entity={}
    all_res_role={}
    all_res_method={}
    all_res_status={}
    all_res_event=[]
    all_res={}
    for step,batch in enumerate(val_loader):
      #Extracting inputs from the batch.
      inputs={'input_ids':batch['input_ids'].to(device),'attention_mask':batch['attention_mask'].to(device),'all_span_positions':batch['relation_pairs'].to(device)}
      entity_labels=batch['entity_labels'].to(device)
      role_labels=batch['role_labels'].to(device)
      method_labels=batch['method_labels'].to(device)
      status_labels=batch['status_labels'].to(device)
      relation_labels=batch['related_logits']
      relation_labels_tensor=[torch.tensor(x) for x in batch['related_logits']]
      relation_labels_tensor=torch.cat(relation_labels_tensor,dim=0)
      tokens=batch['tokens']
      entity_role_cat_rel=batch['related_pairs_categories']

      #Forward Pass
      status_logits,method_logits,role_logits,entity_logits, relation_preds,relation_logits=model(**inputs)

      #Computing Loss
      relation_loss=loss_fn(relation_logits.view(-1,len(id_label_event)),relation_labels_tensor.view(-1).to(device))
      entity_loss=loss_fn(entity_logits.view(-1,len(id_label_ent)),entity_labels.view(-1))
      role_loss=loss_fn(role_logits.view(-1,len(id_label_role)),role_labels.view(-1))
      status_loss=loss_fn(status_logits.view(-1,len(id_label_status)),status_labels.view(-1))
      method_loss=loss_fn(method_logits.view(-1,len(id_label_method)),method_labels.view(-1))
      avg_loss_step=(entity_loss+role_loss+status_loss+method_loss)/4
      val_ent_loss+=entity_loss
      val_role_loss += role_loss
      val_method_loss += method_loss
      val_status_loss += status_loss
      val_event_loss += relation_loss
      val_average_loss += avg_loss_step
      #Calculating probabilities and predictions

      entity_probabilities = torch.softmax(entity_logits, dim=-1)
      entity_predictions = torch.argmax(entity_probabilities, dim=-1)
      role_probabilities  = torch.softmax(role_logits,dim=-1)
      role_predictions = torch.argmax(role_probabilities, dim =-1)
      status_probabilities=torch.softmax(status_logits,dim=-1)
      status_predictions=torch.argmax(status_probabilities,dim=-1)

      method_probabilities=torch.softmax(method_logits,dim=-1)
      method_predictions=torch.argmax(method_probabilities,dim=-1)
      all_status_predictions.extend(status_predictions.tolist())
      all_status_labels.extend(batch['status_labels'].tolist())
      all_method_labels.extend(batch['method_labels'].tolist())
      all_method_predictions.extend(method_predictions.tolist())
      all_event_labels.extend(batch['related_logits'])
      all_event_predictions.extend(relation_preds)

      # Mapping Token,Prediction,label and saving to a json file
      for idx in range(0,len(batch['file_name'])):
        document_relations=process_relation_labels_preds(batch['related_pairs_categories'][idx], batch['relation_pairs_btch'][idx], batch['related_logits'][idx],batch['text'][idx],relation_preds[idx],batch['tokens'][idx],id_label_event)
        all_res_event.extend(document_relations)

      status_labels=batch['status_labels'].tolist()
      status_predictions=status_predictions.tolist()
      for idx in range(0,len(batch['file_name'])):
        all_res_status[batch['file_name'][idx]]=list(zip(batch['tokens'][idx],status_labels[idx],status_predictions[idx]))


      method_labels=batch['method_labels'].tolist()
      method_predictions=method_predictions.tolist()
      for idx in range(0,len(batch['file_name'])):
        all_res_method[batch['file_name'][idx]]=list(zip(batch['tokens'][idx],method_labels[idx],method_predictions[idx]))


      all_role_labels.extend(batch['role_labels'].tolist())
      all_role_predictions.extend(role_predictions.tolist())
      role_labels=batch['role_labels'].tolist()
      role_predictions=role_predictions.tolist()
      for idx in range(0,len(batch['file_name'])):
        all_res_role[batch['file_name'][idx]]=list(zip(batch['tokens'][idx],role_labels[idx],role_predictions[idx]))

      all_entity_labels.extend(entity_labels.tolist())
      all_entity_predictions.extend(entity_predictions.tolist())
      entity_predictions=entity_predictions.tolist()
      entity_labels=entity_labels.tolist()
      #updated_entity_labels=updated_entity_labels.tolist()
      for idx in range(0,len(batch['file_name'])):
        all_res_entity[batch['file_name'][idx]]=list(zip(batch['tokens'][idx],entity_labels[idx],entity_predictions[idx]))
    #print(all_res_event)
    all_res={'entity':all_res_entity,'role':all_res_role,'status':all_res_status,'method':all_res_method,'relations':all_res_event}

    #Computing NER metrics
    entity_metrics=compute_ner_metric(all_entity_predictions,all_entity_labels,id_label_ent)
    role_metrics=compute_ner_metric(all_role_predictions,all_role_labels,id_label_role)
    status_metrics=compute_ner_metric(all_status_predictions,all_status_labels,id_label_status)
    method_metrics=compute_ner_metric(all_method_predictions,all_method_labels,id_label_method)
    all_event_predictions = list(itertools.chain.from_iterable(all_event_predictions))
    all_event_labels = list(itertools.chain.from_iterable(all_event_labels))
    relation_metrics=compute_relation_metric(all_event_predictions,all_event_labels,id_label_event)
    #Saving Results
    if extract_predictions:
      with open(project_directory+'data/'+dt_set+'_'+model_type+'_'+str(ver)+'.json','w') as f:
        json.dump(all_res,f)
  return {'avg_loss':val_average_loss/len(val_loader),'entity_evaluation':(val_ent_loss/len(val_loader),entity_metrics),'role_evaluation':(val_role_loss/len(val_loader),role_metrics),
            'method_evaluation':(val_method_loss/len(val_loader),method_metrics),'status_evaluation':(val_status_loss/len(val_loader),status_metrics),'event_evaluation':(val_event_loss/len(val_loader),relation_metrics)}




### Model Training

In [None]:
def train_entity_bert_model(model,tr_dataloader,tst_dataloader,id_label_ent,id_label_role,id_label_status,id_label_method,id_label_event,bert_model_name,training_args):
  '''
  Performs training with early stopping.
  Average loss has been calculated for the multiple tasks and sent for backward propogation.

  '''
  if torch.cuda.is_available():
    model.cuda()
  train_cycles=trange(training_args['num_train_epochs'],desc='Epoch',disable=0)
  loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
  #Intializing optimizers and schedulers
  optimizer=training_args['optimizer']
  scheduler=training_args['scheduler']
  patience=training_args['patience']

  #Intializing training and validation loss lists for accumulating around the epochs
  tr_ent_loss_lst=[]
  tr_status_loss_lst=[]
  tr_method_loss_lst=[]
  tr_role_loss_lst=[]
  tr_event_loss_lst=[]
  tr_avg_loss_lst=[]

  val_ent_loss_list=[]
  val_status_loss_list=[]
  val_method_loss_list=[]
  val_role_loss_list=[]
  val_avg_loss_list=[]
  val_event_loss_list=[]

  early_stopping_count=0
  early_stopping_count_ent=0
  early_stopping_count_status=0
  early_stopping_count_method=0
  early_stopping_count_role=0

  for cycle in train_cycles:
    epoch_cycles=tqdm(tr_dataloader,desc='Iteration',disable=-1)
    model.train()
    #Intializing epoch loss
    tr_ent_loss=0
    tr_status_loss=0
    tr_method_loss=0
    tr_role_loss=0
    tr_event_loss=0
    average_loss=0

    for step,batch in enumerate(epoch_cycles):
      optimizer.zero_grad()
      inputs={'input_ids':batch['input_ids'].to(device),'attention_mask':batch['attention_mask'].to(device),'all_span_positions':batch['relation_pairs'].to(device)}
      entity_labels=batch['entity_labels'].to(device)
      #Model Computation
      status_logits,method_logits,role_logits,entity_logits, relation_preds, relation_logits=model(**inputs)
      role_labels=batch['role_labels'].to(device)
      method_labels=batch['method_labels'].to(device)
      status_labels=batch['status_labels'].to(device)
      relation_labels=batch['related_logits']
      relation_labels_tensor=[torch.tensor(x) for x in batch['related_logits']]
      relation_labels_tensor=torch.cat(relation_labels_tensor,dim=0)

      #Loss computation
      relation_loss=loss_fn(relation_logits.view(-1,len(id_label_event)),relation_labels_tensor.view(-1).to(device))
      entity_loss=loss_fn(entity_logits.view(-1,len(id_label_ent)),entity_labels.view(-1))
      role_loss=loss_fn(role_logits.view(-1,len(id_label_role)),role_labels.view(-1))
      status_loss=loss_fn(status_logits.view(-1,len(id_label_status)),status_labels.view(-1))
      method_loss=loss_fn(method_logits.view(-1,len(id_label_method)),method_labels.view(-1))
      avg_loss_step=(entity_loss+role_loss+status_loss+method_loss+relation_loss)/5

      avg_loss_step.backward()
      optimizer.step()
      scheduler.step()

      #Loss computation across an epoch
      tr_ent_loss+=entity_loss.item()
      tr_status_loss+=status_loss.item()
      tr_method_loss+=method_loss.item()
      tr_role_loss+=role_loss.item()
      tr_event_loss+=relation_loss.item()
      average_loss+=avg_loss_step.item()

    #Accumulation losses for all epochs.
    tr_ent_loss_lst.append(tr_ent_loss/len(tr_dataloader))
    tr_status_loss_lst.append(tr_status_loss/len(tr_dataloader))
    tr_method_loss_lst.append(tr_method_loss/len(tr_dataloader))
    tr_role_loss_lst.append(tr_role_loss/len(tr_dataloader))
    tr_event_loss_lst.append(tr_event_loss/len(tr_dataloader))
    tr_avg_loss_lst.append(average_loss/len(tr_dataloader))

    print('Epoch: {}  Train Entity Loss: {}'.format(cycle,tr_ent_loss_lst[-1]))
    print('Epoch: {}  Train Status Loss: {}'.format(cycle,tr_status_loss_lst[-1]))
    print('Epoch: {}  Train Method Loss: {}'.format(cycle,tr_method_loss_lst[-1]))
    print('Epoch: {}  Train Role Loss:   {}'.format(cycle,tr_role_loss_lst[-1]))
    print('Epoch: {}  Train Relation loss:  {}'.format(cycle,tr_event_loss_lst[-1]))
    print('Epoch: {}  Train Average Loss: {}'.format(cycle,tr_avg_loss_lst[-1]))

    #Computing metrics and evaluating models for validataion data.
    #Saves Results on last epoch or when it's reaches patience
    if cycle == training_args['num_train_epochs']-1 or early_stopping_count_ent >= patience-1 or early_stopping_count_status >= patience-1  or early_stopping_count_method >= patience-1 or early_stopping_count_role >= patience-1 :
      val_results=evaluate_entity_model(model,tst_dataloader,loss_fn,id_label_ent,id_label_role,id_label_status,id_label_method,id_label_event,True,'tst')
    else:
      val_results=evaluate_entity_model(model,tst_dataloader,loss_fn,id_label_ent,id_label_role,id_label_status,id_label_method,id_label_event)

    val_ent_loss_list.append(val_results['entity_evaluation'][0])
    val_status_loss_list.append(val_results['status_evaluation'][0])
    val_method_loss_list.append(val_results['method_evaluation'][0])
    val_role_loss_list.append(val_results['role_evaluation'][0])
    val_event_loss_list.append(val_results['event_evaluation'][0])
    val_avg_loss_list.append(val_results['avg_loss'])

    print('Epoch: {}  Val Entity Loss: {}'.format(cycle,val_ent_loss_list[-1]))
    print('Epoch: {}  Val Status Loss: {}'.format(cycle,val_status_loss_list[-1]))
    print('Epoch: {}  Val Method Loss: {}'.format(cycle,val_method_loss_list[-1]))
    print('Epoch: {}  Val Role Loss: {}'.format(cycle,val_role_loss_list[-1]))
    print('Epoch: {}  Val Event Loss: {}'.format(cycle,val_event_loss_list[-1]))
    print('Epoch: {}  Avg Val Loss: {}'.format(cycle, val_avg_loss_list[-1]))
    print('Epoch: {}  Metrics below: \n'.format(cycle))

    #Computing metrics and evaluating models for training data.
    if cycle == training_args['num_train_epochs']-1 or early_stopping_count_ent >= patience-1 or early_stopping_count_status >= patience-1  or early_stopping_count_method >= patience-1 or early_stopping_count_role >= patience-1 :
      tr_results=evaluate_entity_model(model,tr_dataloader,loss_fn,id_label_ent,id_label_role,id_label_status,id_label_method,id_label_event,True,'tr')
    else:
      tr_results=evaluate_entity_model(model,tr_dataloader,loss_fn,id_label_ent,id_label_role,id_label_status,id_label_method,id_label_event)


    print('Entity_metrics: \n')
    print(val_results['entity_evaluation'][1])
    print('\n')
    print(tr_results['entity_evaluation'][1])
    print('\n')
    print('Role_metrics: \n')
    print(val_results['role_evaluation'][1])
    print('\n')
    print(tr_results['role_evaluation'][1])
    print('\n')
    print('Status_metrics: \n')
    print(val_results['status_evaluation'][1])
    print('\n')
    print(tr_results['status_evaluation'][1])
    print('\n')
    print('Method_metrics: \n')
    print(val_results['method_evaluation'][1])
    print('\n')
    print(tr_results['method_evaluation'][1])
    print('\n')
    print('Relation_metrics: \n')
    print(val_results['event_evaluation'][1])
    print('\n')
    print(tr_results['event_evaluation'][1])
    print('\n')

     #Early stopping is implemented by checking for a continuous increase in validation loss over a specified number of epochs.
    if cycle==0:
      min_ent_val_loss=val_ent_loss_list[-1]
      min_status_val_loss=val_status_loss_list[-1]
      min_method_val_loss=val_method_loss_list[-1]
      min_role_val_loss=val_role_loss_list[-1]
      min_val_loss=val_results['avg_loss']
      early_stopping_count_ent=0
      early_stopping_count_status=0
      early_stopping_count_method=0
      early_stopping_count_role=0
    else:
      if val_results['avg_loss']<min_val_loss:
        min_val_loss=val_results['avg_loss']
        early_stopping_count=0
      else:
        early_stopping_count+=1
      if val_ent_loss_list[-1]<min_ent_val_loss:
        min_ent_val_loss=val_ent_loss_list[-1]
        early_stopping_count_ent=0
      else:
        early_stopping_count_ent+=1
      if val_status_loss_list[-1]<min_status_val_loss:
        min_status_val_loss=val_status_loss_list[-1]
        early_stopping_count_status=0
      else:
        early_stopping_count_status+=1
      if val_method_loss_list[-1]<min_method_val_loss:
        min_method_val_loss=val_method_loss_list[-1]
        early_stopping_count_method=0
      else:
        early_stopping_count_method+=1
      if val_role_loss_list[-1]<min_role_val_loss:
        min_role_val_loss=val_role_loss_list[-1]
        early_stopping_count_role=0
      else:
        early_stopping_count_role+=1

      if early_stopping_count>=patience or early_stopping_count_ent>=patience or early_stopping_count_status>=patience or early_stopping_count_method>=patience or early_stopping_count_role>=patience:
        print('Early stopping counter for entity model : {}'.format(early_stopping_count_ent))
        print('Early stopping counter for status model : {}'.format(early_stopping_count_status))
        print('Early stopping counter for method model : {}'.format(early_stopping_count_method))
        print('Early stopping counter for role model : {}'.format(early_stopping_count_role))
        print('Early stopping counter for overall model : {}'.format(early_stopping_count))
        print('Early stopping at epoch: {}'.format(cycle))
        break

  #Model save:
  torch.save(model.state_dict(),training_args['output_dir']+training_args['run_name']+'.pth')
  return model, {'train_loss':{'tr_ent_loss_lst':tr_ent_loss_lst,'tr_event_loss_lst':tr_event_loss_lst,
                 'tr_method_loss_lst':tr_method_loss_lst,'tr_status_loss_lst':tr_status_loss_lst,'tr_role_loss_lst':tr_role_loss_lst,'tr_avg_loss_lst':tr_avg_loss_lst},
                 'val_loss':{'val_event_loss_list':val_event_loss_list,'val_ent_loss_list':val_ent_loss_list,'val_status_loss_list':val_status_loss_list,'val_method_loss_list':val_method_loss_list,'val_avg_loss_list':val_avg_loss_list}}







In [None]:
model= EntityRelationClassifier(model_name=bert_model_name,num_freeze_layers=num_freeze_layers,num_status_labels=len(id_label_status),num_method_labels=len(id_label_method),num_role_labels=len(id_label_role),num_entity_labels=len(id_label_ent),num_relation_classes=len(id_label_event))

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

In [None]:
optimizer = AdamW(model.parameters(), lr=learning_rate,eps=eps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)*num_train_epochs)
training_args={
    'output_dir':save_model_path,
    'num_train_epochs':num_train_epochs,
    'optimizer':optimizer,
    'scheduler':scheduler,
    'patience':patience,
    'run_name':model_type+'_'+str(ver)

}



In [None]:
trained_model,loss=train_entity_bert_model(model,train_dataloader,test_dataloader,id_label_ent,id_label_role,id_label_status,id_label_method,id_label_event,bert_model_name,training_args)

Epoch:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch: 0  Train Entity Loss: 0.07764063424923841
Epoch: 0  Train Status Loss: 0.07206000541062917
Epoch: 0  Train Method Loss: 0.02911323495209217
Epoch: 0  Train Role Loss:   0.10538096572546397
Epoch: 0  Train Relation loss:  0.4622990962337045
Epoch: 0  Train Average Loss: 0.14929878843181274
Epoch: 0  Val Entity Loss: 0.175241157412529
Epoch: 0  Val Status Loss: 0.1376015692949295
Epoch: 0  Val Method Loss: 0.06696680933237076
Epoch: 0  Val Role Loss: 0.31744521856307983
Epoch: 0  Val Event Loss: 0.5046505331993103
Epoch: 0  Avg Val Loss: 0.17431369423866272
Epoch: 0  Metrics below: 



Epoch:   7%|▋         | 1/15 [01:06<15:27, 66.28s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.8148148148148148, 'recall': 0.8, 'f1': 0.8073394495412846, 'number': 55}, 'Drug': {'precision': 0.9032258064516129, 'recall': 0.9333333333333333, 'f1': 0.9180327868852459, 'number': 30}, 'Family': {'precision': 0.5416666666666666, 'recall': 0.7027027027027027, 'f1': 0.611764705882353, 'number': 37}, 'Tobacco': {'precision': 0.8392857142857143, 'recall': 0.8867924528301887, 'f1': 0.8623853211009174, 'number': 53}, 'overall_precision': 0.7671957671957672, 'overall_recall': 0.8285714285714286, 'overall_f1': 0.7967032967032968, 'overall_accuracy': 0.9784445091756481}


{'Alcohol': {'precision': 0.8398058252427184, 'recall': 0.8693467336683417, 'f1': 0.8543209876543211, 'number': 199}, 'Drug': {'precision': 0.84375, 'recall': 0.8709677419354839, 'f1': 0.8571428571428571, 'number': 124}, 'Family': {'precision': 0.83125, 'recall': 0.910958904109589, 'f1': 0.8692810457516339, 'number': 146}, 'Tobacco': {'precision': 0.9026548672566371, 'recall': 0.

Epoch:  13%|█▎        | 2/15 [02:13<14:25, 66.57s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.8148148148148148, 'recall': 0.8, 'f1': 0.8073394495412846, 'number': 55}, 'Drug': {'precision': 0.9032258064516129, 'recall': 0.9333333333333333, 'f1': 0.9180327868852459, 'number': 30}, 'Family': {'precision': 0.5416666666666666, 'recall': 0.7027027027027027, 'f1': 0.611764705882353, 'number': 37}, 'Tobacco': {'precision': 0.8392857142857143, 'recall': 0.8867924528301887, 'f1': 0.8623853211009174, 'number': 53}, 'overall_precision': 0.7671957671957672, 'overall_recall': 0.8285714285714286, 'overall_f1': 0.7967032967032968, 'overall_accuracy': 0.9784445091756481}


{'Alcohol': {'precision': 0.8398058252427184, 'recall': 0.8693467336683417, 'f1': 0.8543209876543211, 'number': 199}, 'Drug': {'precision': 0.84375, 'recall': 0.8709677419354839, 'f1': 0.8571428571428571, 'number': 124}, 'Family': {'precision': 0.83125, 'recall': 0.910958904109589, 'f1': 0.8692810457516339, 'number': 146}, 'Tobacco': {'precision': 0.9026548672566371, 'recall': 0.

Epoch:  20%|██        | 3/15 [03:19<13:16, 66.41s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.8148148148148148, 'recall': 0.8, 'f1': 0.8073394495412846, 'number': 55}, 'Drug': {'precision': 0.9032258064516129, 'recall': 0.9333333333333333, 'f1': 0.9180327868852459, 'number': 30}, 'Family': {'precision': 0.5416666666666666, 'recall': 0.7027027027027027, 'f1': 0.611764705882353, 'number': 37}, 'Tobacco': {'precision': 0.8392857142857143, 'recall': 0.8867924528301887, 'f1': 0.8623853211009174, 'number': 53}, 'overall_precision': 0.7671957671957672, 'overall_recall': 0.8285714285714286, 'overall_f1': 0.7967032967032968, 'overall_accuracy': 0.9784445091756481}


{'Alcohol': {'precision': 0.8398058252427184, 'recall': 0.8693467336683417, 'f1': 0.8543209876543211, 'number': 199}, 'Drug': {'precision': 0.84375, 'recall': 0.8709677419354839, 'f1': 0.8571428571428571, 'number': 124}, 'Family': {'precision': 0.83125, 'recall': 0.910958904109589, 'f1': 0.8692810457516339, 'number': 146}, 'Tobacco': {'precision': 0.9026548672566371, 'recall': 0.

Epoch:  27%|██▋       | 4/15 [04:25<12:08, 66.27s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.8148148148148148, 'recall': 0.8, 'f1': 0.8073394495412846, 'number': 55}, 'Drug': {'precision': 0.9032258064516129, 'recall': 0.9333333333333333, 'f1': 0.9180327868852459, 'number': 30}, 'Family': {'precision': 0.5416666666666666, 'recall': 0.7027027027027027, 'f1': 0.611764705882353, 'number': 37}, 'Tobacco': {'precision': 0.8392857142857143, 'recall': 0.8867924528301887, 'f1': 0.8623853211009174, 'number': 53}, 'overall_precision': 0.7671957671957672, 'overall_recall': 0.8285714285714286, 'overall_f1': 0.7967032967032968, 'overall_accuracy': 0.9784445091756481}


{'Alcohol': {'precision': 0.8398058252427184, 'recall': 0.8693467336683417, 'f1': 0.8543209876543211, 'number': 199}, 'Drug': {'precision': 0.84375, 'recall': 0.8709677419354839, 'f1': 0.8571428571428571, 'number': 124}, 'Family': {'precision': 0.83125, 'recall': 0.910958904109589, 'f1': 0.8692810457516339, 'number': 146}, 'Tobacco': {'precision': 0.9026548672566371, 'recall': 0.

Epoch:  33%|███▎      | 5/15 [05:32<11:05, 66.53s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.8148148148148148, 'recall': 0.8, 'f1': 0.8073394495412846, 'number': 55}, 'Drug': {'precision': 0.9032258064516129, 'recall': 0.9333333333333333, 'f1': 0.9180327868852459, 'number': 30}, 'Family': {'precision': 0.5416666666666666, 'recall': 0.7027027027027027, 'f1': 0.611764705882353, 'number': 37}, 'Tobacco': {'precision': 0.8392857142857143, 'recall': 0.8867924528301887, 'f1': 0.8623853211009174, 'number': 53}, 'overall_precision': 0.7671957671957672, 'overall_recall': 0.8285714285714286, 'overall_f1': 0.7967032967032968, 'overall_accuracy': 0.9784445091756481}


{'Alcohol': {'precision': 0.8398058252427184, 'recall': 0.8693467336683417, 'f1': 0.8543209876543211, 'number': 199}, 'Drug': {'precision': 0.84375, 'recall': 0.8709677419354839, 'f1': 0.8571428571428571, 'number': 124}, 'Family': {'precision': 0.83125, 'recall': 0.910958904109589, 'f1': 0.8692810457516339, 'number': 146}, 'Tobacco': {'precision': 0.9026548672566371, 'recall': 0.

Epoch:  33%|███▎      | 5/15 [06:39<13:18, 79.86s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.8148148148148148, 'recall': 0.8, 'f1': 0.8073394495412846, 'number': 55}, 'Drug': {'precision': 0.9032258064516129, 'recall': 0.9333333333333333, 'f1': 0.9180327868852459, 'number': 30}, 'Family': {'precision': 0.5416666666666666, 'recall': 0.7027027027027027, 'f1': 0.611764705882353, 'number': 37}, 'Tobacco': {'precision': 0.8392857142857143, 'recall': 0.8867924528301887, 'f1': 0.8623853211009174, 'number': 53}, 'overall_precision': 0.7671957671957672, 'overall_recall': 0.8285714285714286, 'overall_f1': 0.7967032967032968, 'overall_accuracy': 0.9784445091756481}


{'Alcohol': {'precision': 0.8398058252427184, 'recall': 0.8693467336683417, 'f1': 0.8543209876543211, 'number': 199}, 'Drug': {'precision': 0.84375, 'recall': 0.8709677419354839, 'f1': 0.8571428571428571, 'number': 124}, 'Family': {'precision': 0.83125, 'recall': 0.910958904109589, 'f1': 0.8692810457516339, 'number': 146}, 'Tobacco': {'precision': 0.9026548672566371, 'recall': 0.


