# Generating Train and test data.
## Description:


1.   Reads raw processed data and generates entity and roles list present in each document.
2.   Splits the data into train and test having a satisfiable stratified split for each role category.
3.  Validates the split is stratified in both entities and roles.






In [None]:
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertModel, AdamW, get_linear_schedule_with_warmup
import json
import pandas as pd
from itertools import chain
from collections import Counter
import sys
import torch

In [None]:
sys.path.append('/content/drive/MyDrive/PHD_assessment_gmu/')

### Parameter Initialization

In [None]:
discarded_enities=['EnvironmentalExposure','SexualHistory','InfectiousDiseases','PhysicalActivity']
discarded_roles=['LivingStatus','Other','MedicalCondition','Extent','History']
priority_ord_list=['QuitHistory','ExposureHistory','Temporal','Location','Frequency','Method','Amount','Type','Status']
max_len=512
min_label_size=10
project_directory='/content/drive/MyDrive/PHD_assessment_gmu/'
raw_dataset_path=project_directory+'data/'+'SocialHistoryMTSamples.json'
train_dataset_path=project_directory+'data/'+'train_dataset.pth'
test_dataset_path=project_directory+'data/'+'test_dataset.pth'
tokenizer = BertTokenizerFast.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')

In [None]:
id_label_status={0:'O',1:'B-Status',2:'I-Status'}
id_label_method={0:'O',1:'B-Method',2:'I-Method'}
id_label_role={0:'O',1:'B-Type',2:'I-Type',3:'B-Amount',4:'I-Amount',5:'B-Temporal',6:'I-Temporal',7:'B-Frequency',8:'I-Frequency',9:'B-QuitHistory',10:'I-QuitHistory',11:'B-ExposureHistory',12:'I-ExposureHistory',13:'B-Location',14:'I-Location'}
id_label_ent={0:'O',1:'B-Tobacco',2:'I-Tobacco',3:'B-Alcohol',4:'I-Alcohol',5:'B-Family',6:'I-Family',7:'B-Drug',8:'I-Drug',9:'B-Occupation',10:'I-Occupation',11:'B-MaritalStatus',12:'I-MaritalStatus',13:'B-LivingSituation',14:'I-LivingSituation',15:'B-Residence',16:'I-Residence'}
id_label_event={0:'No Relation',1:'Relation'}

In [None]:
label_id_status = {v: k for k, v in id_label_status.items()}
label_id_method = {v: k for k, v in id_label_method.items()}
label_id_role = {v: k for k, v in id_label_role.items()}
label_id_ent = {v: k for k, v in id_label_ent.items()}
label_id_event = {v: k for k, v in id_label_event.items()}
nested_entities_files=[]

### Labels Generation

In [None]:

class GenerateLabel:

  @staticmethod
  def generate_enity_labels(entity_list, token_len, token_offsets,file_name):
    '''
    Generates entity labels for each token in the sentence. Performs by mapping the entity's offset positions to the tokens offset.
    Also handles overlapping entities by assigning them as a list to that token.
    Inputs:
    entity_list: List of entity dicts, each entity dict has the information such as offset positions, id , category and text.
    token_len: Initializes the labels size
    token_offsets: list of tuples of offset mappings of bert tokens.
    file_name: Name of the file from which the entity list is generated.

    Outputs:
    entity_labels: List of BIO labels for each token.
    '''
    # Initialize a list to store the BIO labels for each token
    global nested_entities_files
    entity_labels = [label_id_ent['O']] * token_len
    for entity in entity_list:
        category = entity['entity_category']
        # Labelling the entities which are not under discarded_entities.
        if category not in discarded_enities:
          entity_start_pos = int(entity['entity_strt_pos'])
          entity_end_pos = int(entity['entity_end_pos'])-1

          # Find tokens that correspond to the entity's position
          entity_start_token = None
          entity_end_token = None

          for i, (start_offset, end_offset) in enumerate(token_offsets):
              if entity_start_token is None and start_offset >= entity_start_pos:
                  entity_start_token = i
              if end_offset > entity_end_pos:
                  entity_end_token = i
                  break

          # Assign BIO labels to the tokens, handling overlapping labels
          if entity_start_token is not None:
              if entity_labels[entity_start_token] == label_id_ent['O']:
                  entity_labels[entity_start_token] = label_id_ent['B-' + category]
              else:
                  if isinstance(entity_labels[entity_start_token], list):
                      entity_labels[entity_start_token].append(label_id_ent['B-' + category])
                  else:
                      entity_labels[entity_start_token] = [entity_labels[entity_start_token],label_id_ent['B-' + category]]
                  nested_entities_files.append(file_name)
              # Assign BIO labels to the tokens, handling overlapping labels
              if entity_end_token is not None:
                  for i in range(entity_start_token + 1, entity_end_token + 1):
                      if entity_labels[i] == label_id_ent['O']:
                          entity_labels[i] = label_id_ent['I-' + category]
                      else:
                          if isinstance(entity_labels[i], list):
                              entity_labels[i].append(label_id_ent['I-' + category])
                          else:
                            #If multiple labels should be allocated to a token, then they are assigned to a list
                              entity_labels[i] = [entity_labels[i],label_id_ent['I-' + category]]
                          nested_entities_files.append(file_name)

    return entity_labels

  @staticmethod
  def generate_role_labels(role_list, token_len, token_offsets):
    '''
    Generates role labels for each token in the sentence. Performs by mapping the entity's offset positions to the tokens offset.
    Also handles overlapping entities by seperating status and method from the role and creating seperate labels for it.

    Inputs:
    role_list: List of role dicts, each entity dict has the information such as offset positions, id , category and text.
    token_len: Initializes the labels size
    token_offsets: list of tuples of offset mappings of bert tokens.
    file_name: Name of the file from which the entity list is generated.

    Outputs:
    role_labels: List of BIO labels for each token.
    status_labels: List of BIO labels for each token.
    method_labels: List of BIO labels for each token.
    '''
    # Initialize a list to store the BIO labels for each token
    role_labels = [label_id_role['O']] * token_len
    status_labels = [label_id_status['O']] * token_len
    method_labels = [label_id_method['O']] * token_len

    for role in role_list:
        category = role['entity_category']
        entity_start_pos = int(role['entity_strt_pos'])
        entity_end_pos = int(role['entity_end_pos'])-1
        if category in discarded_roles:
          continue
        # Find tokens that correspond to the entity's position
        entity_start_token = None
        entity_end_token = None

        for i, (start_offset, end_offset) in enumerate(token_offsets):
            if entity_start_token is None and start_offset >= entity_start_pos:
                entity_start_token = i
            if end_offset > entity_end_pos:
                entity_end_token = i
                break

        # Assign BIO labels to the tokens
        if category == 'Status':
          if entity_start_token is not None:
            status_labels[entity_start_token] = label_id_status['B-' + category]
            if entity_end_token is not None:
                status_labels[entity_start_token + 1:entity_end_token + 1] = [label_id_status['I-' + category]] * (entity_end_token - entity_start_token)

        elif category == 'Method':
          if entity_start_token is not None:
            method_labels[entity_start_token] = label_id_method['B-' + category]
            if entity_end_token is not None:
                method_labels[entity_start_token + 1:entity_end_token + 1] = [label_id_method['I-' + category]] * (entity_end_token - entity_start_token)

        else:
          if entity_start_token is not None:
            role_labels[entity_start_token] = label_id_role['B-' + category]
            if entity_end_token is not None:
                role_labels[entity_start_token + 1:entity_end_token + 1] = [label_id_role['I-' + category]] * (entity_end_token - entity_start_token)


    return role_labels, status_labels, method_labels

  @staticmethod
  def generate_relation_labels( token_offsets, data):
    '''
    Generates relation mapping with token indices for each event present in a document.
    Reads events and maps them to their related entities and roles with their token indices.


    Inputs:
    token_offsets: list of tuples of offset mappings of bert tokens.
    data: Data dict which consists of entity_list, role_list and events_list.

    Outputs:
    relation_labels: The relation mapping is a dictionary where key is a tuple of entity token indices and value is a list of tuple of role token indices.

    '''
    entity_token_indices = {}
    role_token_indices = {}

    entity_list = data.get('entity_list', [])
    role_list = data.get('role_list', [])
    events_list = data.get('events_list', [])

    # Create a dictionary to map entity IDs to their token indices
    for entity in entity_list:
        entity_id = entity['entity_id']
        entity_start_pos = int(entity['entity_strt_pos'])
        entity_end_pos = int(entity['entity_end_pos'])-1

        entity_start_token = None
        entity_end_token = None

        for i, (start_offset, end_offset) in enumerate(token_offsets):
            if entity_start_token is None and start_offset >= entity_start_pos:
                entity_start_token = i
            if end_offset > entity_end_pos:
                entity_end_token = i
                break

        if entity_start_token is not None:
            entity_token_indices[entity_id] = (entity_start_token, entity_end_token)

    # Create a dictionary to map event-related role IDs to their token indices
    for event in events_list:
        entity_id = event['entity_id']
        related_roles = event['Related_roles']

        entity_indices = entity_token_indices.get(entity_id, None)

        if entity_indices:
            role_indices = []

            for role_id in related_roles:
                role = next((role for role in role_list if role['role_id'] == role_id), None)

                if role:
                    role_start_pos = int(role['entity_strt_pos'])
                    role_end_pos = int(role['entity_end_pos'])-1

                    role_start_token = None
                    role_end_token = None

                    for i, (start_offset, end_offset) in enumerate(token_offsets):
                        if role_start_token is None and start_offset >= role_start_pos:
                            role_start_token = i
                        if end_offset > role_end_pos:
                            role_end_token = i
                            break

                    if role_start_token is not None:
                        role_indices.append((role_start_token, role_end_token))

            if role_indices:
                role_token_indices[(entity_indices[0],entity_indices[1])] = role_indices

    return role_token_indices

### Dataset

In [None]:

class ERDataset(Dataset):
  def __init__(self, data, tokenizer, max_len):
    self.data = data
    self.tokenizer = tokenizer
    self.max_len = max_len
  def __len__(self):
    return len(self.data)
  def __getitem__(self, idx):
    item=self.data[idx]
    text = item['text']
    inputs = self.tokenizer(text,max_length=max_len,truncation=True,return_offsets_mapping=True)
    tokens_len=len(inputs['input_ids'])
    offset_mapping_list=inputs['offset_mapping']
    entity_labels=GenerateLabel.generate_enity_labels(item['entity_list'],tokens_len,offset_mapping_list,item['file_name'])
    role_labels, status_labels, method_labels=GenerateLabel.generate_role_labels(item['role_list'],tokens_len,offset_mapping_list)
    relation_labels=GenerateLabel.generate_relation_labels(offset_mapping_list,item)
    return {
      'input_ids':inputs['input_ids'],
      'attention_mask':inputs['attention_mask'],
      'entity_labels':entity_labels,
      'role_labels':role_labels,
      'status_labels':status_labels,
      'method_labels':method_labels,
      'relation_labels':relation_labels,
      'text':text,
      'file_name':item['file_name'],
      'offset_mapping':offset_mapping_list
    }


## Splitting Dataset

In [None]:
def split_train_test(list_of_dicts, test_size, stratify_attribute, priority_order,min_label_size):
    '''
    Splits a list of dictionaries into a training and test set
    In each document we specified what entities and roles are present in a list.
    Dataset is stratified split based on the stratify_attribute which specified entity/role and their frequency .
    The process ensures a minimum of 10 instances for each class in the test set, with the smallest class constituting atleast 20% split.

    :param list_of_dicts: The list of dictionaries to split
    :param test_size: The fraction of the data to
    :param stratify_attribute: Focusing on entity or role
    :priority order: Priortizing Entities/roles when performing a satisfied split

    :return: A tuple of the form (train_set, test_set)
    '''
    # Initialize counters for the elements
    train_counter = Counter()
    test_counter = Counter()

    # Initialize the train and test sets
    train_set = []
    test_set = []

    # Identify all unique elements and their overall frequencies
    all_elements = [item for sublist in list_of_dicts for item in sublist[stratify_attribute]]
    total_element_frequency = Counter(all_elements)
    unique_elements_dict = {key: 0 for key in set(all_elements)}
    # Ensure each unique element appears at least 10 times(been decided by getting 20% of the size of the least amount of class) in the test set, if possible
    file_name_list=[]
    for unique in priority_order:
        if unique_elements_dict[unique] >= min_label_size:
          continue
        for d in list_of_dicts:
          if unique_elements_dict[unique] >= min_label_size:
            break
          if unique in d[stratify_attribute]:
              if d['file_name'] not in file_name_list:
                file_name_list.append(d['file_name'])
                test_set.append(d)
                for ele in d[stratify_attribute]:
                    unique_elements_dict[ele] += 1
    test_counter = Counter(item for d in test_set for item in d[stratify_attribute])
    print(unique_elements_dict)
    print(test_counter)
    # Fill the remaining slots in the test set, prioritizing stratified distribution
    remaining_dicts = [d for d in list_of_dicts if d not in test_set]
    while remaining_dicts and len(test_set) < len(list_of_dicts) * test_size:
        # Select a dictionary that best improves the stratified distribution
        best_dict = None
        best_improvement = float('inf')
        for d in remaining_dicts:
            potential_test_counter = test_counter.copy()
            potential_test_counter.update(d[stratify_attribute])
            improvement = sum(abs((potential_test_counter[element] / len(test_set + [d])) -
                                  (total_element_frequency[element] / len(list_of_dicts)))
                              for element in priority_order)
            if improvement < best_improvement:
                best_improvement = improvement
                best_dict = d

        if best_dict:
            test_set.append(best_dict)
            test_counter.update(best_dict[stratify_attribute])
            remaining_dicts.remove(best_dict)

    # Add any remaining dictionaries to the training set
    train_set.extend(remaining_dicts)

    return train_set, test_set



In [None]:
#Load raw dataset
with open(raw_dataset_path, 'r') as f:
    data = json.load(f)

In [None]:
min_label_size=10
#extracts entities and roles list in each doc.
for item in data:
  entity_category_list=[]
  for entity in item['role_list']:
    if entity['entity_category'] not in discarded_roles:
      entity_category_list.append(entity['entity_category'])
  item['entity_category_list']=entity_category_list
#Splitting train and test data.
trainset,testset = split_train_test(data,0.2,'entity_category_list',priority_ord_list,min_label_size)
print('Total Documents present: {}'.format(len(data)))
print('Train Documents: {}'.format(len(trainset)))
print('Test Documents: {}'.format(len(testset)))

### Extracts nested filenames from train and test sets

In [None]:
nested_entities_files=[]

In [None]:

train_dataset = ERDataset(trainset, tokenizer, max_len)
for ele in train_dataset:
  pass
train_nested_entities_filenames=list(set(nested_entities_files))

In [None]:
test_dataset = ERDataset(testset, tokenizer, max_len)
for ele in test_dataset:
  pass
test_nested_entities_filenames=list(set(nested_entities_files))


In [None]:
len(train_nested_entities_filenames)

In [None]:
nested_entities_file_names={'train_files':train_nested_entities_filenames,'test_files':test_nested_entities_filenames}

### Saving datasets

In [None]:
with open(project_directory+'data/'+'trainset.json', 'w', encoding='utf-8') as file:
  json.dump(trainset, file)
with open(project_directory+'data/'+'testset.json', 'w', encoding='utf-8') as file:
  json.dump(testset, file)
with open(project_directory+'data/'+'nested_ent_filenames.json', 'w', encoding='utf-8') as file:
  json.dump(nested_entities_file_names, file)

In [None]:
torch.save(train_dataset,train_dataset_path)
torch.save(test_dataset,test_dataset_path)

In [None]:
loaded_train_dataset=torch.load(train_dataset_path)
loaded_test_dataset=torch.load(test_dataset_path)

In [None]:
loaded_train_dataset[0]

### Validating split counts for entity and roles

In [3]:
import json
from itertools import chain
with open(project_directory+'data/'+'trainset.json', 'r', encoding='utf-8') as file:
  trainset=json.load( file)
with open(project_directory+'data/'+'testset.json', 'r', encoding='utf-8') as file:
  testset=json.load(file)

In [5]:
import pandas as pd
traindf=pd.DataFrame(trainset)
testdf=pd.DataFrame(testset)

In [6]:

tr_entity_df=pd.DataFrame(list(chain(*(traindf['entity_list'].tolist()))))
tst_entity_df=pd.DataFrame(list(chain(*(testdf['entity_list'].tolist()))))


In [7]:
tr_entity_df['entity_category'].value_counts()

Tobacco                  225
Alcohol                  199
Family                   146
Drug                     124
Occupation               106
MaritalStatus             91
LivingSituation           79
Residence                 39
PhysicalActivity          26
EnvironmentalExposure     17
InfectiousDiseases        11
SexualHistory              8
Name: entity_category, dtype: int64

In [8]:
tst_entity_df['entity_category'].value_counts()

Alcohol                  55
Tobacco                  53
Family                   37
MaritalStatus            31
Drug                     30
Occupation               29
LivingSituation          19
Residence                12
PhysicalActivity          5
EnvironmentalExposure     3
Name: entity_category, dtype: int64

In [9]:
tr_role_df=pd.DataFrame(list(chain(*(traindf['role_list'].tolist()))))
tst_role_df=pd.DataFrame(list(chain(*(testdf['role_list'].tolist()))))


In [10]:
tr_role_df['entity_category'].value_counts()

Status              764
Type                522
Amount              214
Method              134
Frequency           111
Location             64
Temporal             37
ExposureHistory      36
QuitHistory          30
LivingStatus         16
Other                 9
MedicalCondition      8
Extent                2
History               1
Name: entity_category, dtype: int64

In [12]:
tst_role_df['entity_category'].value_counts()

Status              194
Type                133
Amount               55
Method               33
Frequency            28
Location             16
QuitHistory          15
Temporal             11
ExposureHistory      10
LivingStatus          3
Extent                1
MedicalCondition      1
Name: entity_category, dtype: int64

In [13]:
len(testset)

68

In [14]:
len(trainset)

269

### Extracting  and saving independent entity/Role train and test data

In [10]:
project_directory='/content/drive/MyDrive/PHD_assessment_gmu/'
import json

In [11]:
with open(project_directory+'data/'+'trainset.json', 'r', encoding='utf-8') as file:
  train_data=json.load( file)
with open(project_directory+'data/'+'testset.json', 'r', encoding='utf-8') as file:
  test_data=json.load( file)

In [21]:
train_data[0]

{'text': 'SOCIAL HISTORY:  Negative for illicit drugs, alcohol, and tobacco.\n\n',
 'entity_list': [{'entity_id': 1,
   'entity_type': 'Entity',
   'entity_category': 'Drug',
   'entity_strt_pos': '38',
   'entity_end_pos': '43',
   'entity_text': 'drugs'},
  {'entity_id': 2,
   'entity_type': 'Entity',
   'entity_category': 'Alcohol',
   'entity_strt_pos': '45',
   'entity_end_pos': '52',
   'entity_text': 'alcohol'},
  {'entity_id': 3,
   'entity_type': 'Entity',
   'entity_category': 'Tobacco',
   'entity_strt_pos': '58',
   'entity_end_pos': '65',
   'entity_text': 'tobacco'}],
 'role_list': [{'role_id': 4,
   'entity_type': 'Role',
   'entity_category': 'Status',
   'entity_strt_pos': '16',
   'entity_end_pos': '25',
   'entity_text': ' Negative'},
  {'role_id': 5,
   'entity_type': 'Role',
   'entity_category': 'Type',
   'entity_strt_pos': '30',
   'entity_end_pos': '37',
   'entity_text': 'illicit'}],
 'events_list': [{'Event_id': 1, 'entity_id': 1, 'Related_roles': [5, 4]},
  

In [12]:
id_label_status={0:'O',1:'B-Status',2:'I-Status'}
id_label_method={0:'O',1:'B-Method',2:'I-Method'}
id_label_role={0:'O',1:'B-Type',2:'I-Type',3:'B-Amount',4:'I-Amount',5:'B-Temporal',6:'I-Temporal',7:'B-Frequency',8:'I-Frequency',9:'B-QuitHistory',10:'I-QuitHistory',11:'B-ExposureHistory',12:'I-ExposureHistory',13:'B-Location',14:'I-Location'}
id_label_ent={0:'O',1:'B-Tobacco',2:'I-Tobacco',3:'B-Alcohol',4:'I-Alcohol',5:'B-Family',6:'I-Family',7:'B-Drug',8:'I-Drug',9:'B-Occupation',10:'I-Occupation',11:'B-MaritalStatus',12:'I-MaritalStatus',13:'B-LivingSituation',14:'I-LivingSituation'}
id_label_event={0:'No Relation',1:'Relation'}
label_id_status = {v: k for k, v in id_label_status.items()}
label_id_method = {v: k for k, v in id_label_method.items()}
label_id_role = {v: k for k, v in id_label_role.items()}
label_id_ent = {v: k for k, v in id_label_ent.items()}
label_id_event = {v: k for k, v in id_label_event.items()}

In [13]:
def extract_entity_cat(label_id):
  entity_cat=[]
  for ele in label_id.keys():
    if ele !='O':
      entity_cat.append(ele[2:])
  return set(entity_cat)

In [14]:
ent_lab=extract_entity_cat(label_id_ent)
role_lab=extract_entity_cat(label_id_role)
status_lab=extract_entity_cat(label_id_status)
method_lab=extract_entity_cat(label_id_method)

In [17]:
ent_lab

{'Alcohol',
 'Drug',
 'Family',
 'LivingSituation',
 'MaritalStatus',
 'Occupation',
 'Tobacco'}

In [19]:
ent_lab & set(['Occupation', 'Family', 'Tobacco', 'Alcohol'])

{'Alcohol', 'Family', 'Occupation', 'Tobacco'}

In [25]:
for ele in ['Ent','Role','Status','Method']:
  if ele == 'Ent':
    entity_cat_list=ent_lab
    entity_cat='entity'
  elif ele == 'Role':
    entity_cat_list=role_lab
    entity_cat='role'
  elif ele == 'Status':
    entity_cat_list=status_lab
    entity_cat='status'
  elif ele == 'Method':
    entity_cat_list=method_lab
    entity_cat='method'
  train_method_data=[]
  test_method_data=[]
  for item in train_data:
    if ele == 'Ent':
      if entity_cat_list & set(item['entity_category_list']):
        train_method_data.append(item)
    else:
      if entity_cat_list & set(item['role_category_list']):
        train_method_data.append(item)
  for item in test_data:
    if ele == 'Ent':
      if entity_cat_list & set(item['entity_category_list']):
        test_method_data.append(item)
    else:
      if entity_cat_list & set(item['role_category_list']):
        test_method_data.append(item)
  with open(project_directory+'data/'+'tr_'+entity_cat+'_set.json', 'w', encoding='utf-8') as file:
    json.dump(train_method_data,file)
  with open(project_directory+'data/'+'test_'+entity_cat+'_set.json', 'w', encoding='utf-8') as file:
    json.dump(test_method_data,file)

In [23]:
len(test_method_data)

0