# Entity, Relations prediction.
**Description**


1.   Loads Ner and relation models.
2.   Given a sentence predicts entities , roles, status and methods.
3.  Identifies and extracts spans from above category predictions having continuous labels with a tolerance of max internal gap of 1.
4.  Generates span pairs from the predictive data, mapping entities to  Role, Status, and Method.
5.  Inserts span markers according to span pairs and positions of it.Entity spans are marked with [SPAN1_START] and [SPAN1_END]. Role spans are marked with [SPAN2_START] and [SPAN2_END]
6.  These span pairs serve as inputs to the relation model, which then predicts the relationships between them.

Prediction output is data frame containing text with span markers, entity, role/Status/method and the relation prediction.



In [None]:
import json
import pandas as pd

from transformers import BertTokenizerFast, BertModel, AdamW, get_linear_schedule_with_warmup, DataCollatorForTokenClassification
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
import itertools

In [None]:
model_type='Flt_ent_role_model'
ver=2

txt='SOCIAL HISTORY:  The patient is retired. He is married. He had 4 children. He quite smoking 25 years ago after a 35-year history of smoking. He does not drink alcohol.'

In [None]:
id_label_status={0:'O',1:'B-Status',2:'I-Status'}
id_label_method={0:'O',1:'B-Method',2:'I-Method'}
id_label_role={0:'O',1:'B-Type',2:'I-Type',3:'B-Amount',4:'I-Amount',5:'B-Temporal',6:'I-Temporal',7:'B-Frequency',8:'I-Frequency',9:'B-QuitHistory',10:'I-QuitHistory',11:'B-ExposureHistory',12:'I-ExposureHistory',13:'B-Location',14:'I-Location'}
id_label_event={0:'No Relation',1:'Relation'}
label_id_status = {v: k for k, v in id_label_status.items()}
label_id_method = {v: k for k, v in id_label_method.items()}
label_id_role = {v: k for k, v in id_label_role.items()}
label_id_ent = {'B-Alcohol':1,
 'B-Drug':3,
 'B-Family':5,
 'B-Tobacco':7,
 'I-Alcohol':2,
 'I-Drug':4,
 'I-Family':6,
 'I-Tobacco':8,
 'O':0}
id_label_ent = {v: k for k, v in label_id_ent.items()}
label_id_event = {v: k for k, v in id_label_event.items()}
num_freeze_layers=6
max_len=512
bert_model_name='emilyalsentzer/Bio_ClinicalBERT'
tokenizer = BertTokenizerFast.from_pretrained(bert_model_name)
ner_model_path='/content/drive/MyDrive/PHD_assessment_gmu/models/final_models/'+model_type+'_'+str(ver)+'.pth'
rel_classifier_pth='/content/drive/MyDrive/PHD_assessment_gmu/models/final_models/Indepent_relation_classifier_v6/'

# Model Loading

## Ner Model

In [None]:
class EntityBertModel(nn.Module):
  def __init__(self, model_name, num_freeze_layers,num_status_labels,num_method_labels,num_role_labels,num_entity_labels, dropout=0.1):
    super(EntityBertModel, self).__init__()
    self.bertmodel = BertModel.from_pretrained(model_name)
    #Performing freezing bert layers
    for layer in self.bertmodel.encoder.layer[:num_freeze_layers]:
      for param in layer.parameters():
          param.requires_grad = False
    self.dropout = nn.Dropout(dropout)
    self.status_classifier = nn.Linear(self.bertmodel.config.hidden_size, num_status_labels)
    self.method_classifier = nn.Linear(self.bertmodel.config.hidden_size, num_method_labels)
    self.role_classifier = nn.Linear(self.bertmodel.config.hidden_size, num_role_labels)
    self.entity_classifier = nn.Linear(self.bertmodel.config.hidden_size, num_entity_labels)
  def forward(self, input_ids, attention_mask):
    bert_output = self.bertmodel(input_ids=input_ids, attention_mask=attention_mask)
    sequence_output = self.dropout(bert_output[0])
    status_logits = self.status_classifier(sequence_output)
    method_logits = self.method_classifier(sequence_output)
    role_logits = self.role_classifier(sequence_output)
    entity_logits = self.entity_classifier(sequence_output)

    return status_logits, method_logits, role_logits, entity_logits


In [None]:
ner_model= EntityBertModel(model_name=bert_model_name,num_freeze_layers=num_freeze_layers,num_status_labels=len(id_label_status),num_method_labels=len(id_label_method),num_role_labels=len(id_label_role),num_entity_labels=len(id_label_ent))

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

In [None]:
ner_model.load_state_dict(torch.load(ner_model_path,map_location=device))

<All keys matched successfully>

## Relation Model

### Extracting spans and forming span pairs and predicting relations

In [None]:
def extract_spans_from_labels(labels):
    #Span: Having continous labels with a tolerance for a maximum internal gap of one 'O'.
    #Extracting spans from labels
    #Outputs: comprising a list of span tuples. Each tuple encapsulates the category of the span as well as the range of token indices that constitute the span.
    spans = []
    current_span = []
    current_label = None
    last_valid_index = -1  # Tracks the last index of a non-'O' label

    for i, label in enumerate(labels):
        # Ignore BIO scheme and consider only the entity type
        simplified_label = label[2:] if label.startswith(('B-', 'I-')) else label

        # Start a new span or continue the current one
        if simplified_label != 'O':
            # If starting a new span or within allowable break from last non-'O'
            if current_label is None or simplified_label != current_label or i - last_valid_index > 2:
                # Save the current span before starting a new one and reset the current span
                if current_span:
                    spans.append((current_label, current_span))
                    current_span = []
                current_label = simplified_label
            current_span.append(i)
            last_valid_index = i
        elif current_span and i - last_valid_index > 3:
            # End the current span if the break is too long
            spans.append((current_label, current_span))
            current_span = []
            current_label = None

    # Add the last span if exists
    if current_span:
        spans.append((current_label, current_span))

    return spans
def insert_span_markers(text, spans):
    '''
    Inserts span markers for entity and role in the text.
    Entity spans are marked with [SPAN1_START] and [SPAN1_END]
    Role spans are marked with [SPAN2_START] and [SPAN2_END]
    Outputs: text with span markers inserted in the correct positions.
    '''
    # Create a list to store the markers that need to be inserted at each index
    insertions = {i: [] for i in range(len(text) + 1)}

    # Populate the insertions dictionary with the correct markers for each span
    for span in spans:
        start, end, category = span
        if category == 'entity':
            insertions[start].append('[SPAN1_START]')
            insertions[end].append('[SPAN1_END]')
        elif category == 'role':
            insertions[start].append('[SPAN2_START]')
            insertions[end].append('[SPAN2_END]')

    # Construct the new text with markers
    new_text_pieces = []
    for i, char in enumerate(text):
        # Add markers before the current character
        if insertions[i]:
            new_text_pieces.append(' '+' '.join(insertions[i]) + ' ')
        new_text_pieces.append(char)
    # Add any markers that should be inserted after the last character
    if insertions[len(text)]:
        new_text_pieces.append(' ' + ' '.join(insertions[len(text)]))

    # Join all pieces of the new text
    return ''.join(new_text_pieces)

def correct_span_positions(span_pos):
  '''
  Transforms a list of sequential token indices within a span into a concise range representation.
  '''
  updated_span_pos=[]
  for ele in span_pos:
    #Handling for span having only 1 token
    if len(ele[1])==1:
      updated_span_pos.append((ele[0],(ele[1][0],ele[1][0])))
    elif len(ele[1])>=2:
      strt_pos=ele[1][0]
      end_pos=ele[1][-1]
      updated_span_pos.append((ele[0],(ele[1][0],ele[1][-1])))
  return updated_span_pos
def convert_id_label(prediction_data,id_label):
  prediction_data=[id_label[ele] for ele in prediction_data]
  return prediction_data
def generate_prediction_spans(prediction_data_ent,prediction_data_role,prediction_data_status,prediction_data_method,id_label_ent,id_label_method,id_label_role,id_label_status):
  '''
  Generate spans from the prediction data.
  Generates span pairs from the predictive data, mapping entities to  Role, Status, and Method.
  Outputs: Relation pairs dict (key as relation pair((entity,offsetpositions),(role,offsetposotions)) and value as relation label)
           and spans for entity, role, status, and method.
  '''
  prediction_data_ent=convert_id_label(prediction_data_ent,id_label_ent)
  prediction_data_role=convert_id_label(prediction_data_role,id_label_role)
  prediction_data_method=convert_id_label(prediction_data_method,id_label_method)
  prediction_data_status=convert_id_label(prediction_data_status,id_label_status)
  #Extracts spans from the prediction data
  pred_spans_ent=extract_spans_from_labels(prediction_data_ent)
  pred_spans_roles=extract_spans_from_labels(prediction_data_role)
  pred_spans_status=extract_spans_from_labels(prediction_data_status)
  pred_spans_method=extract_spans_from_labels(prediction_data_method)
  #Transforms a list of sequential token indices within a span into a concise range representation.
  pred_spans_ent=correct_span_positions(pred_spans_ent)
  pred_spans_roles=correct_span_positions(pred_spans_roles)
  pred_spans_status=correct_span_positions(pred_spans_status)
  pred_spans_method=correct_span_positions(pred_spans_method)
  #Generates span pairs from the predictive data, mapping entities to  Role, Status, and Method.
  relation_pairs=list(itertools.product(pred_spans_ent, pred_spans_roles))
  relation_pairs.extend(list(itertools.product(pred_spans_ent, pred_spans_status)))
  relation_pairs.extend(list(itertools.product(pred_spans_ent, pred_spans_method)))
  relation_pair_labels_dict = {key: 'No Relation' for key in relation_pairs}
  return pred_spans_ent,pred_spans_roles,pred_spans_status,pred_spans_method,relation_pair_labels_dict
def generate_relation_prediction_data(sentence,tokenizer_outputs,relation_pair_labels_dict):
  '''
  Reads relation pair label dict and generates relation data inputs.
  Maps token indices to offset positions and inserts span markers.
  Outputs: Relation data input list
  '''
  data_input_list=[]
  offset_mapping_list=tokenizer_outputs['offset_mapping']
  for pair,pair_label in relation_pair_labels_dict.items():
    entity_pos=[]
    entity_pos.append(offset_mapping_list[pair[0][1][0]][0])
    entity_pos.append(offset_mapping_list[pair[0][1][1]][1])
    entity_pos.append('entity')
    role_pos=[]
    role_pos.append(offset_mapping_list[pair[1][1][0]][0])
    role_pos.append(offset_mapping_list[pair[1][1][1]][1])
    role_pos.append('role')

    marked_sentence = insert_span_markers(sentence,[entity_pos,role_pos])
    data_input_list.append({'text':marked_sentence,'Entity':pair[0][0],'Entity_pos':entity_pos,'Role':pair[1][0],'Role_pos':role_pos,'label':pair_label})
  return data_input_list

def generate_relation_data(text,inputs,prediction_data_ent,prediction_data_role,prediction_data_status,prediction_data_method,id_label_ent,id_label_method,id_label_role,id_label_status,relation_labels=None):
  '''
  Generates relation data input list from the prediction data.
  Outputs: Relation data input list
  '''
  #inputs = tokenizer(text,max_length=max_len,truncation=True,return_offsets_mapping=True)

  tokens_len=len(inputs['input_ids'])
  if tokens_len==len(prediction_data_ent):
    pred_spans_ent,pred_spans_roles,pred_spans_status,pred_spans_method,relation_pair_labels_dict=generate_prediction_spans(prediction_data_ent,prediction_data_role,prediction_data_status,prediction_data_method,id_label_ent,id_label_method,id_label_role,id_label_status)
    relation_data=generate_relation_prediction_data(text,inputs,relation_pair_labels_dict)
    return relation_data
  else:
    return None
def predict_relation(text):
  '''
  Predicts relation for a given text with the span markers.
  Outputs: Relation label for the given text
  '''
  rel_model.eval()
  inputs=tokenizer2(text, add_special_tokens=True, padding='max_length',max_length=max_len , truncation=True,return_tensors='pt')
  with torch.no_grad():
    inp={k: v.to(device) for k, v in inputs.items()}
    outputs = rel_model(**inp)
  logits = outputs.logits
  raw_predictions = torch.argmax(logits, dim=-1)
  prediction=id_label_event[raw_predictions.item()]
  return prediction


In [None]:
from transformers import BertForSequenceClassification
tokenizer2=BertTokenizerFast.from_pretrained(rel_classifier_pth)
rel_model=BertForSequenceClassification.from_pretrained(rel_classifier_pth)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# Predicting entity relation

In [None]:
def predict_entity_relation(txt):
  '''
  Predicts entities,roles,status and method also establishes relation or not between them.
  '''
  inputs=tokenizer(txt,max_length=max_len,truncation=True,return_offsets_mapping=True,return_tensors='pt')
  model_input={'input_ids':inputs['input_ids'].to(device),'attention_mask':inputs['attention_mask'].to(device)}
  ner_model.eval()
  with torch.no_grad():
    status_logits, method_logits, role_logits, entity_logits=ner_model(**model_input)
    entity_probabilities = torch.softmax(entity_logits, dim=-1)
    entity_predictions = torch.argmax(entity_probabilities, dim=-1)
    role_probabilities  = torch.softmax(role_logits,dim=-1)
    role_predictions = torch.argmax(role_probabilities, dim =-1)
    status_probabilities=torch.softmax(status_logits,dim=-1)
    status_predictions=torch.argmax(status_probabilities,dim=-1)

    method_probabilities=torch.softmax(method_logits,dim=-1)
    method_predictions=torch.argmax(method_probabilities,dim=-1)
    rel_dt_lst=generate_relation_data(txt,inputs,entity_predictions.tolist()[0],role_predictions.tolist()[0],status_predictions.tolist()[0],method_predictions.tolist()[0],id_label_ent,id_label_method,id_label_role,id_label_status)
    rel_df=pd.DataFrame(rel_dt_lst)
    rel_df['predict_relation']=rel_df['text'].apply(predict_relation)
  return rel_df


In [None]:
predict_entity_relation(txt)