## **Flat NER model for roles and entities**
### **Description:**

*   Performing named entity recognition for entities and roles at token - level.


*   Entities and Role which are not considered due to insufficient data for training:
  1.  discarded_enities=['EnvironmentalExposure','SexualHistory',
'InfectiousDiseases','PhysicalActivity','Residence']
  2.   discarded_roles=['LivingStatus','Other','MedicalCondition','Extent','History']

*   Labelling is done using BIO-scheme.
*   Labels are generated using the mapping between label offsets and Bert token offsets.
*   In entities and roles, observations indicate that certain documents exhibit overlapping spans among them, as well as between roles and entities.
*   To handle overlapping between roles and entities we are classifying it with seperate heads.

*   By analysis, it was observed by seperating status, method from the roles we can avoid overlapping among them, so even here classifying with seperate heads.

*   Overlapping between the Entities where handled by attaching multiple labels to the token with sep '900'. Updating the labels dynamically during training and making the model choose the label that is a right fit.
### **Experimentations:**
*  Creating a base line model and understand/find evidence whether each category can be trained with the sufficient amount of data.
* Does the Bert embeddings getting tuned better when entities and role trained together thereby providing mutual support?
*  Experimentation is held to know whether the dynamic labelling help the model in training or it creates more nuances?


*   Does the ner model excels by training entities and roles together?

  **Run the notebook:**
* Provide model name and ver and project directory in input section and run all cells.

  ###   **Model**
*   ***Tokenizer:*** BertTokenizerFast
*   ***pre-trained Bert model:*** 'emilyalsentzer/Bio_ClinicalBERT'

*   ***Hyperparameters:***
  * eps=1e-8
  * learning_rate=7e-5
  * weight_decay=0
  * num_train_epochs=15
  * patience=3
  * batch_size=16
  * max_len_token=512
*   Fine-tuned a pre-trained Bert model with freezing 6 layers and multiple classification heads where each classifies status, methods, entities, and  roles respectively.

*   Since we are training multiple tasks simultaneously, we aggregate the losses and send them for backpropagation.

*   Sequence evaluation is used as NER metrics.

*   Model stored at project_directory+'/models/final_models/'+ver+'_'+model_name+'.pth'.
### **Inputs:**


Created a prebuilt datasets and their paths:
*   train_data_set path: 'PHD_assessment_gmu/data/train_dataset.pth'


*   test_data_set path: 'PHD_assessment_gmu/data/test_dataset.pth'
### **Metrics:**

**Entity_metric:**

{'Alcohol': {'precision': 0.7413793103448276, 'recall': 0.7818181818181819, 'f1': 0.7610619469026548, 'number': 55}, 'Drug': {'precision': 0.875, 'recall': 0.9333333333333333, 'f1': 0.9032258064516129, 'number': 30}, 'Family': {'precision': 0.4583333333333333, 'recall': 0.55, 'f1': 0.5, 'number': 20}, 'LivingSituation': {'precision': 0.3103448275862069, 'recall': 0.47368421052631576, 'f1': 0.375, 'number': 19}, 'MaritalStatus': {'precision': 0.42105263157894735, 'recall': 0.5333333333333333, 'f1': 0.47058823529411764, 'number': 30}, 'Occupation': {'precision': 0.1590909090909091, 'recall': 0.2413793103448276, 'f1': 0.1917808219178082, 'number': 29}, 'Tobacco': {'precision': 0.7833333333333333, 'recall': 0.8867924528301887, 'f1': 0.8318584070796461, 'number': 53}, 'overall_precision': 0.5649122807017544, 'overall_recall': 0.6822033898305084, 'overall_f1': 0.6180422264875239, 'overall_accuracy': 0.9347509466938537}

**Role_metrics:**

{'Amount': {'precision': 0.7377049180327869, 'recall': 0.8181818181818182, 'f1': 0.7758620689655172, 'number': 55}, 'ExposureHistory': {'precision': 0.5, 'recall': 0.6, 'f1': 0.5454545454545454, 'number': 10}, 'Frequency': {'precision': 0.7857142857142857, 'recall': 0.7857142857142857, 'f1': 0.7857142857142857, 'number': 28}, 'Location': {'precision': 0.2413793103448276, 'recall': 0.4375, 'f1': 0.3111111111111111, 'number': 16}, 'QuitHistory': {'precision': 0.4666666666666667, 'recall': 0.4666666666666667, 'f1': 0.4666666666666667, 'number': 15}, 'Temporal': {'precision': 0.42857142857142855, 'recall': 0.5454545454545454, 'f1': 0.4799999999999999, 'number': 11}, 'Type': {'precision': 0.5722543352601156, 'recall': 0.7443609022556391, 'f1': 0.6470588235294118, 'number': 133}, 'overall_precision': 0.5783132530120482, 'overall_recall': 0.7164179104477612, 'overall_f1': 0.64, 'overall_accuracy': 0.9230993300320419}

**Status_metrics:**

{'Status': {'precision': 0.6126126126126126, 'recall': 0.7010309278350515, 'f1': 0.6538461538461537, 'number': 194}, 'overall_precision': 0.6126126126126126, 'overall_recall': 0.7010309278350515, 'overall_f1': 0.6538461538461537, 'overall_accuracy': 0.9554325662685698}

**Method_metrics:**

{'Method': {'precision': 0.3333333333333333, 'recall': 0.42424242424242425, 'f1': 0.3733333333333333, 'number': 33}, 'overall_precision': 0.3333333333333333, 'overall_recall': 0.42424242424242425, 'overall_f1': 0.3733333333333333, 'overall_accuracy': 0.9796096708418293}












In [None]:
!pip install evaluate

In [None]:
import pandas as pd
import evaluate
from transformers import BertTokenizerFast, BertModel, AdamW, get_linear_schedule_with_warmup, DataCollatorForTokenClassification
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence

In [None]:
from tqdm import tqdm, trange

In [None]:
!pip install seqeval

In [None]:
#Input Section
model_type='Flt_ent_role_model'
ver=1
project_directory='/content/drive/MyDrive/PHD_assessment_gmu/'

metric=evaluate.load("seqeval")
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_model_name='emilyalsentzer/Bio_ClinicalBERT'
tokenizer = BertTokenizerFast.from_pretrained(bert_model_name)
num_freeze_layers=6

Hypertuning parameters

In [None]:
eps=1e-8
learning_rate=7e-5
weight_decay=0
num_train_epochs=15
patience=3
batch_size=16
max_len=512

Dataset, Model paths and discarding classes





In [None]:
discarded_enities=['EnvironmentalExposure','SexualHistory','InfectiousDiseases','PhysicalActivity','Residence']
discarded_roles=['LivingStatus','Other','MedicalCondition','Extent','History']

train_dataset_path=project_directory+'data/'+'train_dataset.pth'
test_dataset_path=project_directory+'data/'+'test_dataset.pth'
save_model_path=project_directory+'models/final_models/'


Mapping the classes to ids.


In [None]:
id_label_status={0:'O',1:'B-Status',2:'I-Status'}
id_label_method={0:'O',1:'B-Method',2:'I-Method'}
id_label_role={0:'O',1:'B-Type',2:'I-Type',3:'B-Amount',4:'I-Amount',5:'B-Temporal',6:'I-Temporal',7:'B-Frequency',8:'I-Frequency',9:'B-QuitHistory',10:'I-QuitHistory',11:'B-ExposureHistory',12:'I-ExposureHistory',13:'B-Location',14:'I-Location'}
id_label_ent={0:'O',1:'B-Tobacco',2:'I-Tobacco',3:'B-Alcohol',4:'I-Alcohol',5:'B-Family',6:'I-Family',7:'B-Drug',8:'I-Drug',9:'B-Occupation',10:'I-Occupation',11:'B-MaritalStatus',12:'I-MaritalStatus',13:'B-LivingSituation',14:'I-LivingSituation'}
id_label_event={0:'No Relation',1:'Relation'}
label_id_status = {v: k for k, v in id_label_status.items()}
label_id_method = {v: k for k, v in id_label_method.items()}
label_id_role = {v: k for k, v in id_label_role.items()}
label_id_ent = {v: k for k, v in id_label_ent.items()}
label_id_event = {v: k for k, v in id_label_event.items()}

Generate Entity and Role labels for a doc





In [None]:

class GenerateLabel:

  @staticmethod
  def generate_enity_labels(entity_list, token_len, token_offsets):
    # Initialize a list to store the BIO labels for each token
    entity_labels = [label_id_ent['O']] * token_len
    for entity in entity_list:
        category = entity['entity_category']
        # Labelling the entities which are not under discarded_entities.
        if category not in discarded_enities:
          entity_start_pos = int(entity['entity_strt_pos'])
          entity_end_pos = int(entity['entity_end_pos'])-1

          # Find tokens that correspond to the entity's position
          entity_start_token = None
          entity_end_token = None

          for i, (start_offset, end_offset) in enumerate(token_offsets):
              if entity_start_token is None and start_offset >= entity_start_pos:
                  entity_start_token = i
              if end_offset > entity_end_pos:
                  entity_end_token = i
                  break

          # Assign BIO labels to the tokens, handling overlapping labels
          if entity_start_token is not None:
              if entity_labels[entity_start_token] == label_id_ent['O']:
                  entity_labels[entity_start_token] = label_id_ent['B-' + category]
              else:
                  if isinstance(entity_labels[entity_start_token], list):
                    entity_labels[entity_start_token].append(label_id_ent['B-' + category])
                  else:
                    #If multiple labels should be allocated to a token, then they are assigned to a list
                    entity_labels[entity_start_token] = [entity_labels[entity_start_token],label_id_ent['B-' + category]]


              if entity_end_token is not None:
                  for i in range(entity_start_token + 1, entity_end_token + 1):
                      if entity_labels[i] == label_id_ent['O']:
                          entity_labels[i] = label_id_ent['I-' + category]
                      else:
                        if isinstance(entity_labels[i],list):
                          entity_labels[i].append(label_id_ent['I-' + category])
                        else:
                          #If multiple labels should be allocated to a token, then they are assigned to a list
                          entity_labels[i] = [entity_labels[i],label_id_ent['I-' + category]]


    return entity_labels

  @staticmethod
  def generate_role_labels(role_list, token_len, token_offsets):
    # Initialize a list to store the BIO labels for each token
    role_labels = [label_id_role['O']] * token_len
    status_labels = [label_id_status['O']] * token_len
    method_labels = [label_id_method['O']] * token_len

    for role in role_list:
        category = role['entity_category']
        entity_start_pos = int(role['entity_strt_pos'])
        entity_end_pos = int(role['entity_end_pos'])-1
        if category in discarded_roles:
          continue
        # Find tokens that correspond to the role's position
        entity_start_token = None
        entity_end_token = None

        for i, (start_offset, end_offset) in enumerate(token_offsets):
            if entity_start_token is None and start_offset >= entity_start_pos:
                entity_start_token = i
            if end_offset > entity_end_pos:
                entity_end_token = i
                break

        # Assign BIO labels to the tokens
        if category == 'Status':
          if entity_start_token is not None:
            status_labels[entity_start_token] = label_id_status['B-' + category]
            if entity_end_token is not None:
                status_labels[entity_start_token + 1:entity_end_token + 1] = [label_id_status['I-' + category]] * (entity_end_token - entity_start_token)

        elif category == 'Method':
          if entity_start_token is not None:
            method_labels[entity_start_token] = label_id_method['B-' + category]
            if entity_end_token is not None:
                method_labels[entity_start_token + 1:entity_end_token + 1] = [label_id_method['I-' + category]] * (entity_end_token - entity_start_token)

        else:
          if entity_start_token is not None:
            role_labels[entity_start_token] = label_id_role['B-' + category]
            if entity_end_token is not None:
                role_labels[entity_start_token + 1:entity_end_token + 1] = [label_id_role['I-' + category]] * (entity_end_token - entity_start_token)


    return role_labels, status_labels, method_labels

  @staticmethod
  def generate_relation_labels( token_offsets, data):
    entity_token_indices = {}
    role_token_indices = {}

    entity_list = data.get('entity_list', [])
    role_list = data.get('role_list', [])
    events_list = data.get('events_list', [])

    # Create a dictionary to map entity IDs to their token indices
    for entity in entity_list:
        entity_id = entity['entity_id']
        entity_start_pos = int(entity['entity_strt_pos'])
        entity_end_pos = int(entity['entity_end_pos'])-1

        entity_start_token = None
        entity_end_token = None

        for i, (start_offset, end_offset) in enumerate(token_offsets):
            if entity_start_token is None and start_offset >= entity_start_pos:
                entity_start_token = i
            if end_offset > entity_end_pos:
                entity_end_token = i
                break

        if entity_start_token is not None:
            entity_token_indices[entity_id] = (entity_start_token, entity_end_token)

    # Create a dictionary to map event-related role IDs to their token indices
    for event in events_list:
        entity_id = event['entity_id']
        related_roles = event['Related_roles']

        entity_indices = entity_token_indices.get(entity_id, None)

        if entity_indices:
            role_indices = []

            for role_id in related_roles:
                role = next((role for role in role_list if role['role_id'] == role_id), None)

                if role:
                    role_start_pos = int(role['entity_strt_pos'])
                    role_end_pos = int(role['entity_end_pos'])-1

                    role_start_token = None
                    role_end_token = None

                    for i, (start_offset, end_offset) in enumerate(token_offsets):
                        if role_start_token is None and start_offset >= role_start_pos:
                            role_start_token = i
                        if end_offset > role_end_pos:
                            role_end_token = i
                            break

                    if role_start_token is not None:
                        role_indices.append((role_start_token, role_end_token))

            if role_indices:
                role_token_indices[(entity_indices[0],entity_indices[1])] = role_indices

    return role_token_indices

Padding function used at the time data loader to have same size in a batch.

In [None]:
def collate_fn_entity_role(batch):
  input_ids = [torch.tensor(x['input_ids']) for x in batch]
  attention_mask = [torch.tensor(x['attention_mask']) for x in batch]
  #Handling multiple labels for a entity by concatenating them using a separator '900'
  for x in batch:
    for idx in range(0,len(x['entity_labels'])):
      if isinstance(x['entity_labels'][idx],list):
        mutiple_labels=x['entity_labels'][idx]
        sing_mul_label=''
        for lab in mutiple_labels:
          #Assigning Single label to a token if it is not previously intialised.
          if sing_mul_label=='':
            sing_mul_label=str(lab)
          else:
            #Concatenating multiple labels with a separator '900'
            sing_mul_label=sing_mul_label+'900'+str(lab)
        x['entity_labels'][idx] =int(sing_mul_label)

  entity_labels = [torch.tensor(x['entity_labels']) for x in batch]
  role_labels = [torch.tensor(x['role_labels']) for x in batch]
  status_labels = [torch.tensor(x['status_labels']) for x in batch]
  method_labels = [torch.tensor(x['method_labels']) for x in batch]

  #Performing padding with tokenizer pad value for inputs and -100 pad value for labels.
  input_ids = pad_sequence(input_ids, batch_first=True,padding_value=tokenizer.pad_token_id)
  attention_mask = pad_sequence(attention_mask, batch_first=True,padding_value=0)
  entity_labels = pad_sequence(entity_labels, batch_first=True,padding_value=-100)
  role_labels = pad_sequence(role_labels, batch_first=True,padding_value=-100)
  status_labels = pad_sequence(status_labels, batch_first=True,padding_value=-100)
  method_labels = pad_sequence(method_labels, batch_first=True,padding_value=-100)
  return {
    'input_ids':input_ids,
    'attention_mask':attention_mask,
    'entity_labels':entity_labels,
    'role_labels':role_labels,
    'status_labels':status_labels,
    'method_labels':method_labels
  }


Dataset class which assembles inputs and generate labels for documents.

In [None]:
class ERDataset(Dataset):
  def __init__(self, data, tokenizer, max_len):
    self.data = data
    self.tokenizer = tokenizer
    self.max_len = max_len
  def __len__(self):
    return len(self.data)
  def __getitem__(self, idx):
    item=self.data[idx]
    text = item['text']
    inputs = self.tokenizer(text,max_length=max_len,truncation=True,return_offsets_mapping=True)
    tokens_len=len(inputs['input_ids'])
    offset_mapping_list=inputs['offset_mapping']
    #Generating labels
    entity_labels=GenerateLabel.generate_enity_labels(item['entity_list'],tokens_len,offset_mapping_list)
    role_labels, status_labels, method_labels=GenerateLabel.generate_role_labels(item['role_list'],tokens_len,offset_mapping_list)
    relation_labels=GenerateLabel.generate_relation_labels(offset_mapping_list,item)
    return {
      'input_ids':inputs['input_ids'],
      'attention_mask':inputs['attention_mask'],
      'entity_labels':entity_labels,
      'role_labels':role_labels,
      'status_labels':status_labels,
      'method_labels':method_labels,
      'relation_labels':relation_labels,
      'text':text,
      'file_name':item['file_name'],
      'offset_mapping':offset_mapping_list
    }

Loading Datasets and Dataloaders

In [None]:
train_dataset=torch.load(train_dataset_path)
test_dataset=torch.load(test_dataset_path)


In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,collate_fn=collate_fn_entity_role)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size,collate_fn=collate_fn_entity_role)

Bert Model Architecture

In [None]:
class EntityBertModel(nn.Module):
  def __init__(self, model_name, num_freeze_layers,num_status_labels,num_method_labels,num_role_labels,num_entity_labels, dropout=0.1):
    super(EntityBertModel, self).__init__()
    self.bertmodel = BertModel.from_pretrained(model_name)
    #Freezing layers
    for layer in self.bertmodel.encoder.layer[:num_freeze_layers]:
      for param in layer.parameters():
          param.requires_grad = False
    self.dropout = nn.Dropout(dropout)
    self.status_classifier = nn.Linear(self.bertmodel.config.hidden_size, num_status_labels)
    self.method_classifier = nn.Linear(self.bertmodel.config.hidden_size, num_method_labels)
    self.role_classifier = nn.Linear(self.bertmodel.config.hidden_size, num_role_labels)
    self.entity_classifier = nn.Linear(self.bertmodel.config.hidden_size, num_entity_labels)
  def forward(self, input_ids, attention_mask):
    bert_output = self.bertmodel(input_ids=input_ids, attention_mask=attention_mask)
    sequence_output = self.dropout(bert_output[0])
    status_logits = self.status_classifier(sequence_output)
    method_logits = self.method_classifier(sequence_output)
    role_logits = self.role_classifier(sequence_output)
    entity_logits = self.entity_classifier(sequence_output)

    return status_logits, method_logits, role_logits, entity_logits

Initializing Model

In [None]:
model= EntityBertModel(model_name=bert_model_name,num_freeze_layers=num_freeze_layers,num_status_labels=len(id_label_status),num_method_labels=len(id_label_method),num_role_labels=len(id_label_role),num_entity_labels=len(id_label_ent))

Loading optimizer and schedulers

In [None]:
optimizer = AdamW(model.parameters(), lr=learning_rate,eps=eps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)*num_train_epochs)

Training arguments

In [None]:
training_args={
    'output_dir':save_model_path,
    'num_train_epochs':num_train_epochs,
    'optimizer':optimizer,
    'scheduler':scheduler,
    'patience':patience,
    'run_name':str(ver)+'_'+model_type
}

Updates nested labels based on model prediction

In [None]:
def update_nested_labels(entity_logits,ground_truths):
  #extracts indices which consists of multiple labels.
  indices=torch.where(ground_truths>1000)
  index_pairs = list(zip(indices[0].tolist(), indices[1].tolist()))
  probabilities = torch.softmax(entity_logits, dim=-1)
  predictions = torch.argmax(probabilities, dim=-1)
  #updates labels based upon model prediction.
  for i,j in index_pairs:
    multi_label=str(ground_truths[i,j].item())
    parts_str = multi_label.split('900')
    parts_tensors = [int(part) for part in parts_str if part]
    pred_logit=predictions[i, j].item()
    if pred_logit in parts_tensors:
      ground_truths[i,j]=pred_logit
    else:
      ground_truths[i,j]=parts_tensors[0]
  return ground_truths

Computing Ner metrics

In [None]:
def compute_ner_metric(preds,labels,id_label):
    #Seqeval is used as NER metrics
    # Removing padded labels and predictions when computing the metrics.
    ground_truths=[[id_label[l] for l in lab if l!=-100] for lab in labels]
    prediction_labels=[[id_label[p] for p,l in zip(predict,lab) if l!=-100] for predict,lab in zip(preds,labels)]
    metric_res=metric.compute(predictions=prediction_labels,references=ground_truths)
    return metric_res


Evaluating the validation dataset

In [None]:

def evaluate_entity_model(model,val_loader,loss_fn,id_label_ent,id_label_role,id_label_status,id_label_method):
  model.eval()
  with torch.no_grad():

    #Initializing Validation losses for each of categories
    val_ent_loss=0
    val_status_loss=0
    val_method_loss=0
    val_role_loss=0
    val_average_loss=0

    #Initializing labels and prediction list, in-turn used for evaluating metrics.
    all_entity_labels=[]
    all_entity_predictions=[]
    all_role_labels=[]
    all_role_predictions=[]
    all_method_labels=[]
    all_method_predictions=[]
    all_status_labels=[]
    all_status_predictions=[]

    for step,batch in enumerate(val_loader):

      #Extracting inputs from the batch.
      inputs={'input_ids':batch['input_ids'].to(device),'attention_mask':batch['attention_mask'].to(device)}
      entity_labels=batch['entity_labels'].to(device)
      role_labels=batch['role_labels'].to(device)
      method_labels=batch['method_labels'].to(device)
      status_labels=batch['status_labels'].to(device)

      #Model computation
      status_logits, method_logits, role_logits, entity_logits=model(**inputs)

      #Updating nested labels for entity
      updated_entity_labels=update_nested_labels(entity_logits,entity_labels)

      #Loss computation
      entity_loss=loss_fn(entity_logits.view(-1,len(id_label_ent)),updated_entity_labels.view(-1))
      role_loss=loss_fn(role_logits.view(-1,len(id_label_role)),role_labels.view(-1))
      status_loss=loss_fn(status_logits.view(-1,len(id_label_status)),status_labels.view(-1))
      method_loss=loss_fn(method_logits.view(-1,len(id_label_method)),method_labels.view(-1))
      avg_loss_step=(entity_loss+role_loss+status_loss+method_loss)/4
      val_ent_loss+=entity_loss
      val_role_loss += role_loss
      val_method_loss += method_loss
      val_status_loss += status_loss
      val_average_loss += avg_loss_step

      #Calculating probabilities and predictions
      entity_probabilities = torch.softmax(entity_logits, dim=-1)
      entity_predictions = torch.argmax(entity_probabilities, dim=-1)
      role_probabilities  = torch.softmax(role_logits,dim=-1)
      role_predictions = torch.argmax(role_probabilities, dim =-1)
      status_probabilities=torch.softmax(status_logits,dim=-1)
      status_predictions=torch.argmax(status_probabilities,dim=-1)
      method_probabilities=torch.softmax(method_logits,dim=-1)
      method_predictions=torch.argmax(method_probabilities,dim=-1)

      #Collecting all labels and predictions
      all_status_predictions.extend(status_predictions.tolist())
      all_status_labels.extend(batch['status_labels'].tolist())
      all_method_labels.extend(batch['method_labels'].tolist())
      all_method_predictions.extend(method_predictions.tolist())
      all_role_labels.extend(batch['role_labels'].tolist())
      all_role_predictions.extend(role_predictions.tolist())
      all_entity_labels.extend(updated_entity_labels.tolist())
      all_entity_predictions.extend(entity_predictions.tolist())

    #Computing NER metrics
    entity_metrics=compute_ner_metric(all_entity_predictions,all_entity_labels,id_label_ent)
    role_metrics=compute_ner_metric(all_role_predictions,all_role_labels,id_label_role)
    status_metrics=compute_ner_metric(all_status_predictions,all_status_labels,id_label_status)
    method_metrics=compute_ner_metric(all_method_predictions,all_method_labels,id_label_method)

  return {'avg_loss':val_average_loss/len(val_loader),'entity_evaluation':(val_ent_loss/len(val_loader),entity_metrics),'role_evaluation':(val_role_loss/len(val_loader),role_metrics),
            'method_evaluation':(val_method_loss/len(val_loader),method_metrics),'status_evaluation':(val_status_loss/len(val_loader),status_metrics)}




Train function

In [None]:
def train_entity_bert_model(model,tr_dataloader,tst_dataloader,id_label_ent,id_label_role,id_label_status,id_label_method,bert_model_name,training_args):
  if torch.cuda.is_available():
    model.cuda()
  train_cycles=trange(training_args['num_train_epochs'],desc='Epoch',disable=0)

  loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
  #Intializing optimizers and schedulers
  optimizer=training_args['optimizer']
  scheduler=training_args['scheduler']
  patience=training_args['patience']

  #Intializing training and validation loss lists for accumulating around the epochs
  tr_ent_loss_lst=[]
  tr_status_loss_lst=[]
  tr_method_loss_lst=[]
  tr_role_loss_lst=[]
  tr_avg_loss_lst=[]

  val_ent_loss_list=[]
  val_status_loss_list=[]
  val_method_loss_list=[]
  val_role_loss_list=[]
  val_avg_loss_list=[]

  early_stopping_count=0
  for cycle in train_cycles:
    epoch_cycles=tqdm(tr_dataloader,desc='Iteration',disable=-1)
    model.train()
    #Intializing epoch loss
    tr_ent_loss=0
    tr_status_loss=0
    tr_method_loss=0
    tr_role_loss=0
    average_loss=0
    for step,batch in enumerate(epoch_cycles):
      #Model Computation
      optimizer.zero_grad()
      inputs={'input_ids':batch['input_ids'].to(device),'attention_mask':batch['attention_mask'].to(device)}
      entity_labels=batch['entity_labels'].to(device)

      status_logits, method_logits, role_logits, entity_logits=model(**inputs)

      role_labels=batch['role_labels'].to(device)
      method_labels=batch['method_labels'].to(device)
      status_labels=batch['status_labels'].to(device)

      #updates nested labels with the prediction
      updated_entity_labels=update_nested_labels(entity_logits,entity_labels)

      #Loss computation for a batch
      entity_loss=loss_fn(entity_logits.view(-1,len(id_label_ent)),updated_entity_labels.view(-1))
      role_loss=loss_fn(role_logits.view(-1,len(id_label_role)),role_labels.view(-1))
      status_loss=loss_fn(status_logits.view(-1,len(id_label_status)),status_labels.view(-1))
      method_loss=loss_fn(method_logits.view(-1,len(id_label_method)),method_labels.view(-1))
      avg_loss_step=(entity_loss+role_loss+status_loss+method_loss)/4

      avg_loss_step.backward()
      optimizer.step()
      scheduler.step()

      #Loss computation across an epoch
      tr_ent_loss+=entity_loss.item()
      tr_status_loss+=status_loss.item()
      tr_method_loss+=method_loss.item()
      tr_role_loss+=role_loss.item()
      average_loss+=avg_loss_step.item()

    #Accumulation losses for all epochs.
    tr_ent_loss_lst.append(tr_ent_loss/len(tr_dataloader))
    tr_status_loss_lst.append(tr_status_loss/len(tr_dataloader))
    tr_method_loss_lst.append(tr_method_loss/len(tr_dataloader))
    tr_role_loss_lst.append(tr_role_loss/len(tr_dataloader))
    tr_avg_loss_lst.append(average_loss/len(tr_dataloader))
    print('Epoch: {}  Train Entity Loss: {}'.format(cycle,tr_ent_loss_lst[-1]))
    print('Epoch: {}  Train Status Loss: {}'.format(cycle,tr_status_loss_lst[-1]))
    print('Epoch: {}  Train Method Loss: {}'.format(cycle,tr_method_loss_lst[-1]))
    print('Epoch: {}  Train Role Loss: {}'.format(cycle,tr_role_loss_lst[-1]))
    print('Epoch: {}  Train Average Loss: {}'.format(cycle,tr_avg_loss_lst[-1]))

    #Computing metrics and evaluating models.
    val_results=evaluate_entity_model(model,tst_dataloader,loss_fn,id_label_ent,id_label_role,id_label_status,id_label_method)

    #Accumulation losses for all epochs.
    val_ent_loss_list.append(val_results['entity_evaluation'][0])
    val_status_loss_list.append(val_results['status_evaluation'][0])
    val_method_loss_list.append(val_results['method_evaluation'][0])
    val_role_loss_list.append(val_results['role_evaluation'][0])
    val_avg_loss_list.append(val_results['avg_loss'])
    print('Epoch: {}  Val Entity Loss: {}'.format(cycle,val_ent_loss_list[-1]))
    print('Epoch: {}  Val Status Loss: {}'.format(cycle,val_status_loss_list[-1]))
    print('Epoch: {}  Val Method Loss: {}'.format(cycle,val_method_loss_list[-1]))
    print('Epoch: {}  Val Role Loss: {}'.format(cycle,val_role_loss_list[-1]))
    print('Epoch: {}  Avg Val Loss: {}'.format(cycle, val_avg_loss_list[-1]))
    print('Epoch: {}  Metrics below: \n')

    #Computing metrics and evaluating models.
    tr_results=evaluate_entity_model(model,tr_dataloader,loss_fn,id_label_ent,id_label_role,id_label_status,id_label_method)

    print('Entity_metrics: \n')
    print(val_results['entity_evaluation'][1])
    print('\n')
    print(tr_results['entity_evaluation'][1])
    print('\n')
    print('Role_metrics: \n')
    print(val_results['role_evaluation'][1])
    print('\n')
    print(tr_results['role_evaluation'][1])
    print('\n')
    print('Status_metrics: \n')
    print(val_results['status_evaluation'][1])
    print('\n')
    print(tr_results['status_evaluation'][1])
    print('\n')
    print('Method_metrics: \n')
    print(val_results['method_evaluation'][1])
    print('\n')
    print(tr_results['method_evaluation'][1])
    print('\n')

    #Early stopping is implemented by checking for a continuous increase in validation loss over a specified number of epochs.

    if cycle==0:
      min_ent_val_loss=val_ent_loss_list[-1]
      min_status_val_loss=val_status_loss_list[-1]
      min_method_val_loss=val_method_loss_list[-1]
      min_role_val_loss=val_role_loss_list[-1]
      min_val_loss=val_results['avg_loss']
      early_stopping_count_ent=0
      early_stopping_count_status=0
      early_stopping_count_method=0
      early_stopping_count_role=0
    else:
      if val_results['avg_loss']<min_val_loss:
        min_val_loss=val_results['avg_loss']
        early_stopping_count=0
      else:
        early_stopping_count+=1
      if val_ent_loss_list[-1]<min_ent_val_loss:
        min_ent_val_loss=val_ent_loss_list[-1]
        early_stopping_count_ent=0
      else:
        early_stopping_count_ent+=1
      if val_status_loss_list[-1]<min_status_val_loss:
        min_status_val_loss=val_status_loss_list[-1]
        early_stopping_count_status=0
      else:
        early_stopping_count_status+=1
      if val_method_loss_list[-1]<min_method_val_loss:
        min_method_val_loss=val_method_loss_list[-1]
        early_stopping_count_method=0
      else:
        early_stopping_count_method+=1
      if val_role_loss_list[-1]<min_role_val_loss:
        min_role_val_loss=val_role_loss_list[-1]
        early_stopping_count_role=0
      else:
        early_stopping_count_role+=1
      if early_stopping_count_ent>=patience or early_stopping_count_status>=patience or early_stopping_count_method>=patience or early_stopping_count_role>=patience:
        print('Early stopping counter for entity model : {}'.format(early_stopping_count_ent))
        print('Early stopping counter for status model : {}'.format(early_stopping_count_status))
        print('Early stopping counter for method model : {}'.format(early_stopping_count_method))
        print('Early stopping counter for role model : {}'.format(early_stopping_count_role))
        print('Early stopping counter for overall model : {}'.format(early_stopping_count))
        print('Early stopping at epoch: {}'.format(cycle))
        break
  #Model Saving
  torch.save(model.state_dict(),training_args['output_dir']+training_args['run_name']+'.pth')

  return model, {'train_loss':{'tr_ent_loss_lst':tr_ent_loss_lst,
                 'tr_method_loss_lst':tr_method_loss_lst,'tr_status_loss_lst':tr_status_loss_lst,'tr_role_loss_lst':tr_role_loss_lst,'tr_avg_loss_lst':tr_avg_loss_lst},
                 'val_loss':{'val_ent_loss_list':val_ent_loss_list,'val_status_loss_list':val_status_loss_list,'val_method_loss_list':val_method_loss_list,'val_avg_loss_list':val_avg_loss_list}}







In [None]:
trained_model,loss=train_entity_bert_model(model,train_dataloader,test_dataloader,id_label_ent,id_label_role,id_label_status,id_label_method,bert_model_name,training_args)

Epoch:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch: 0  Train Entity Loss: 1.2045435063979204
Epoch: 0  Train Status Loss: 0.472620180424522
Epoch: 0  Train Method Loss: 0.3584589782883139
Epoch: 0  Train Role Loss: 1.2009985411868376
Epoch: 0  Train Average Loss: 0.809155309901518


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 0  Val Entity Loss: 0.9496631622314453
Epoch: 0  Val Status Loss: 0.3481992185115814
Epoch: 0  Val Method Loss: 0.20097143948078156
Epoch: 0  Val Role Loss: 0.8422054648399353
Epoch: 0  Avg Val Loss: 0.5852597951889038
Epoch: {}  Metrics below: 



Epoch:   7%|▋         | 1/15 [00:11<02:36, 11.17s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 55}, 'Drug': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 30}, 'Family': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 21}, 'LivingSituation': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 19}, 'MaritalStatus': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 30}, 'Occupation': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 29}, 'Tobacco': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 53}, 'overall_precision': 0.0, 'overall_recall': 0.0, 'overall_f1': 0.0, 'overall_accuracy': 0.8505680163122633}


{'Alcohol': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 199}, 'Drug': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 124}, 'Family': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 102}, 'LivingSituation': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 79}, 'MaritalStatus': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number

Epoch:  13%|█▎        | 2/15 [00:20<02:12, 10.18s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 55}, 'Drug': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 30}, 'Family': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 21}, 'LivingSituation': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 19}, 'MaritalStatus': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 30}, 'Occupation': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 29}, 'Tobacco': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 53}, 'overall_precision': 0.0, 'overall_recall': 0.0, 'overall_f1': 0.0, 'overall_accuracy': 0.8505680163122633}


{'Alcohol': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 199}, 'Drug': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 124}, 'Family': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 102}, 'LivingSituation': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 79}, 'MaritalStatus': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number

Epoch:  20%|██        | 3/15 [00:30<02:00, 10.06s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.6904761904761905, 'recall': 0.5272727272727272, 'f1': 0.5979381443298969, 'number': 55}, 'Drug': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 30}, 'Family': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 20}, 'LivingSituation': {'precision': 0.12244897959183673, 'recall': 0.3157894736842105, 'f1': 0.1764705882352941, 'number': 19}, 'MaritalStatus': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 30}, 'Occupation': {'precision': 0.1111111111111111, 'recall': 0.10344827586206896, 'f1': 0.10714285714285715, 'number': 29}, 'Tobacco': {'precision': 0.21641791044776118, 'recall': 0.5471698113207547, 'f1': 0.31016042780748665, 'number': 53}, 'overall_precision': 0.2648221343873518, 'overall_recall': 0.2838983050847458, 'overall_f1': 0.2740286298568507, 'overall_accuracy': 0.8884357704631518}


{'Alcohol': {'precision': 0.5879120879120879, 'recall': 0.5376884422110553, 'f1': 0.5616797900262468, 'number': 199}, 'Drug': {'pre

Epoch:  27%|██▋       | 4/15 [00:40<01:47,  9.81s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.5735294117647058, 'recall': 0.7090909090909091, 'f1': 0.6341463414634145, 'number': 55}, 'Drug': {'precision': 0.875, 'recall': 0.23333333333333334, 'f1': 0.3684210526315789, 'number': 30}, 'Family': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 20}, 'LivingSituation': {'precision': 0.4, 'recall': 0.7368421052631579, 'f1': 0.5185185185185185, 'number': 19}, 'MaritalStatus': {'precision': 0.43478260869565216, 'recall': 0.3333333333333333, 'f1': 0.3773584905660377, 'number': 30}, 'Occupation': {'precision': 0.391304347826087, 'recall': 0.3103448275862069, 'f1': 0.34615384615384615, 'number': 29}, 'Tobacco': {'precision': 0.4027777777777778, 'recall': 0.5471698113207547, 'f1': 0.46399999999999997, 'number': 53}, 'overall_precision': 0.4675324675324675, 'overall_recall': 0.4576271186440678, 'overall_f1': 0.462526766595289, 'overall_accuracy': 0.9079522283716865}


{'Alcohol': {'precision': 0.4781021897810219, 'recall': 0.65829145728643

Epoch:  33%|███▎      | 5/15 [00:49<01:37,  9.75s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.6610169491525424, 'recall': 0.7090909090909091, 'f1': 0.6842105263157895, 'number': 55}, 'Drug': {'precision': 0.6875, 'recall': 0.7333333333333333, 'f1': 0.7096774193548386, 'number': 30}, 'Family': {'precision': 0.6666666666666666, 'recall': 0.4, 'f1': 0.5, 'number': 20}, 'LivingSituation': {'precision': 0.34285714285714286, 'recall': 0.631578947368421, 'f1': 0.4444444444444445, 'number': 19}, 'MaritalStatus': {'precision': 0.5172413793103449, 'recall': 0.5, 'f1': 0.5084745762711865, 'number': 30}, 'Occupation': {'precision': 0.23255813953488372, 'recall': 0.3448275862068966, 'f1': 0.2777777777777778, 'number': 29}, 'Tobacco': {'precision': 0.7580645161290323, 'recall': 0.8867924528301887, 'f1': 0.8173913043478261, 'number': 53}, 'overall_precision': 0.5625, 'overall_recall': 0.6483050847457628, 'overall_f1': 0.6023622047244095, 'overall_accuracy': 0.925138362947859}


{'Alcohol': {'precision': 0.536, 'recall': 0.6733668341708543, 'f1': 0

Epoch:  40%|████      | 6/15 [00:59<01:28,  9.87s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.6212121212121212, 'recall': 0.7454545454545455, 'f1': 0.6776859504132231, 'number': 55}, 'Drug': {'precision': 0.6111111111111112, 'recall': 0.7333333333333333, 'f1': 0.6666666666666666, 'number': 30}, 'Family': {'precision': 0.5882352941176471, 'recall': 0.5, 'f1': 0.5405405405405405, 'number': 20}, 'LivingSituation': {'precision': 0.3870967741935484, 'recall': 0.631578947368421, 'f1': 0.48000000000000004, 'number': 19}, 'MaritalStatus': {'precision': 0.3333333333333333, 'recall': 0.5, 'f1': 0.4, 'number': 30}, 'Occupation': {'precision': 0.17391304347826086, 'recall': 0.27586206896551724, 'f1': 0.21333333333333332, 'number': 29}, 'Tobacco': {'precision': 0.7619047619047619, 'recall': 0.9056603773584906, 'f1': 0.8275862068965516, 'number': 53}, 'overall_precision': 0.5131578947368421, 'overall_recall': 0.6610169491525424, 'overall_f1': 0.5777777777777778, 'overall_accuracy': 0.9242644916982231}


{'Alcohol': {'precision': 0.509505703422053

Epoch:  47%|████▋     | 7/15 [01:10<01:21, 10.13s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.640625, 'recall': 0.7454545454545455, 'f1': 0.6890756302521008, 'number': 55}, 'Drug': {'precision': 0.65, 'recall': 0.8666666666666667, 'f1': 0.7428571428571429, 'number': 30}, 'Family': {'precision': 0.55, 'recall': 0.55, 'f1': 0.55, 'number': 20}, 'LivingSituation': {'precision': 0.39285714285714285, 'recall': 0.5789473684210527, 'f1': 0.46808510638297873, 'number': 19}, 'MaritalStatus': {'precision': 0.28846153846153844, 'recall': 0.5, 'f1': 0.36585365853658536, 'number': 30}, 'Occupation': {'precision': 0.21951219512195122, 'recall': 0.3103448275862069, 'f1': 0.2571428571428571, 'number': 29}, 'Tobacco': {'precision': 0.7741935483870968, 'recall': 0.9056603773584906, 'f1': 0.8347826086956522, 'number': 53}, 'overall_precision': 0.5244299674267101, 'overall_recall': 0.6822033898305084, 'overall_f1': 0.5930018416206262, 'overall_accuracy': 0.9268861054471308}


{'Alcohol': {'precision': 0.50187265917603, 'recall': 0.6733668341708543, 'f1

Epoch:  53%|█████▎    | 8/15 [01:21<01:12, 10.31s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.6557377049180327, 'recall': 0.7272727272727273, 'f1': 0.689655172413793, 'number': 55}, 'Drug': {'precision': 0.6756756756756757, 'recall': 0.8333333333333334, 'f1': 0.746268656716418, 'number': 30}, 'Family': {'precision': 0.5625, 'recall': 0.45, 'f1': 0.5, 'number': 20}, 'LivingSituation': {'precision': 0.26666666666666666, 'recall': 0.42105263157894735, 'f1': 0.326530612244898, 'number': 19}, 'MaritalStatus': {'precision': 0.36585365853658536, 'recall': 0.5, 'f1': 0.4225352112676056, 'number': 30}, 'Occupation': {'precision': 0.175, 'recall': 0.2413793103448276, 'f1': 0.2028985507246377, 'number': 29}, 'Tobacco': {'precision': 0.7704918032786885, 'recall': 0.8867924528301887, 'f1': 0.8245614035087719, 'number': 53}, 'overall_precision': 0.527972027972028, 'overall_recall': 0.6398305084745762, 'overall_f1': 0.578544061302682, 'overall_accuracy': 0.9277599766967667}


{'Alcohol': {'precision': 0.5375494071146245, 'recall': 0.68341708542713

Epoch:  60%|██████    | 9/15 [01:31<01:01, 10.32s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.6557377049180327, 'recall': 0.7272727272727273, 'f1': 0.689655172413793, 'number': 55}, 'Drug': {'precision': 0.7428571428571429, 'recall': 0.8666666666666667, 'f1': 0.8, 'number': 30}, 'Family': {'precision': 0.5263157894736842, 'recall': 0.5, 'f1': 0.5128205128205129, 'number': 20}, 'LivingSituation': {'precision': 0.28125, 'recall': 0.47368421052631576, 'f1': 0.35294117647058826, 'number': 19}, 'MaritalStatus': {'precision': 0.3783783783783784, 'recall': 0.4666666666666667, 'f1': 0.417910447761194, 'number': 30}, 'Occupation': {'precision': 0.17073170731707318, 'recall': 0.2413793103448276, 'f1': 0.20000000000000004, 'number': 29}, 'Tobacco': {'precision': 0.7704918032786885, 'recall': 0.8867924528301887, 'f1': 0.8245614035087719, 'number': 53}, 'overall_precision': 0.534965034965035, 'overall_recall': 0.6483050847457628, 'overall_f1': 0.5862068965517242, 'overall_accuracy': 0.9315467521118556}


{'Alcohol': {'precision': 0.5450980392156

Epoch:  67%|██████▋   | 10/15 [01:42<00:52, 10.43s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.6666666666666666, 'recall': 0.7272727272727273, 'f1': 0.6956521739130435, 'number': 55}, 'Drug': {'precision': 0.9032258064516129, 'recall': 0.9333333333333333, 'f1': 0.9180327868852459, 'number': 30}, 'Family': {'precision': 0.42857142857142855, 'recall': 0.45, 'f1': 0.4390243902439024, 'number': 20}, 'LivingSituation': {'precision': 0.21212121212121213, 'recall': 0.3684210526315789, 'f1': 0.2692307692307693, 'number': 19}, 'MaritalStatus': {'precision': 0.36585365853658536, 'recall': 0.5, 'f1': 0.4225352112676056, 'number': 30}, 'Occupation': {'precision': 0.15, 'recall': 0.20689655172413793, 'f1': 0.17391304347826086, 'number': 29}, 'Tobacco': {'precision': 0.7704918032786885, 'recall': 0.8867924528301887, 'f1': 0.8245614035087719, 'number': 53}, 'overall_precision': 0.5296167247386759, 'overall_recall': 0.6440677966101694, 'overall_f1': 0.5812619502868068, 'overall_accuracy': 0.9265948150305855}


{'Alcohol': {'precision': 0.60995850622

Epoch:  73%|███████▎  | 11/15 [01:52<00:42, 10.55s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.6461538461538462, 'recall': 0.7636363636363637, 'f1': 0.7000000000000001, 'number': 55}, 'Drug': {'precision': 0.875, 'recall': 0.9333333333333333, 'f1': 0.9032258064516129, 'number': 30}, 'Family': {'precision': 0.4782608695652174, 'recall': 0.55, 'f1': 0.5116279069767442, 'number': 20}, 'LivingSituation': {'precision': 0.2903225806451613, 'recall': 0.47368421052631576, 'f1': 0.36, 'number': 19}, 'MaritalStatus': {'precision': 0.2916666666666667, 'recall': 0.4666666666666667, 'f1': 0.35897435897435903, 'number': 30}, 'Occupation': {'precision': 0.13953488372093023, 'recall': 0.20689655172413793, 'f1': 0.16666666666666666, 'number': 29}, 'Tobacco': {'precision': 0.7704918032786885, 'recall': 0.8867924528301887, 'f1': 0.8245614035087719, 'number': 53}, 'overall_precision': 0.5181518151815182, 'overall_recall': 0.6652542372881356, 'overall_f1': 0.5825602968460112, 'overall_accuracy': 0.9254296533644043}


{'Alcohol': {'precision': 0.562015503

Epoch:  73%|███████▎  | 11/15 [02:02<00:44, 11.18s/it]

Entity_metrics: 

{'Alcohol': {'precision': 0.7413793103448276, 'recall': 0.7818181818181819, 'f1': 0.7610619469026548, 'number': 55}, 'Drug': {'precision': 0.875, 'recall': 0.9333333333333333, 'f1': 0.9032258064516129, 'number': 30}, 'Family': {'precision': 0.4583333333333333, 'recall': 0.55, 'f1': 0.5, 'number': 20}, 'LivingSituation': {'precision': 0.3103448275862069, 'recall': 0.47368421052631576, 'f1': 0.375, 'number': 19}, 'MaritalStatus': {'precision': 0.42105263157894735, 'recall': 0.5333333333333333, 'f1': 0.47058823529411764, 'number': 30}, 'Occupation': {'precision': 0.1590909090909091, 'recall': 0.2413793103448276, 'f1': 0.1917808219178082, 'number': 29}, 'Tobacco': {'precision': 0.7833333333333333, 'recall': 0.8867924528301887, 'f1': 0.8318584070796461, 'number': 53}, 'overall_precision': 0.5649122807017544, 'overall_recall': 0.6822033898305084, 'overall_f1': 0.6180422264875239, 'overall_accuracy': 0.9347509466938537}


{'Alcohol': {'precision': 0.6567796610169492, 'recall


