In [None]:
import json
import os
import matplotlib.pyplot as plt
import seaborn as sns
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import f1_score, confusion_matrix, precision_recall_fscore_support
from tqdm import tqdm
from transformers import BertForTokenClassification, BertTokenizer, AdamW, get_scheduler, AutoTokenizer, AutoModel
from sklearn.utils.class_weight import compute_class_weight
from collections import Counter
import re
import wandb

In [None]:
!python --version

In [None]:
!pip freeze > requirements.txt

In [None]:
DATA_DIR = "/kaggle/input/gutbrainie2025/gutbrainie2025"
wandb.login(key="0fe636b8bf5ddbd71f8f52823cc7c39ce880bc1b") # change this before making it public

In [None]:
class AnnotationDataset(Dataset):
    def __init__(self, root_path, tokenizer=None, split='Train', quality_filter=['platinum_quality', 'gold_quality', 'silver_quality']):
        self.samples = []
        annotations_dir = os.path.join(root_path, 'Annotations', split)
            
        self.tokenizer = tokenizer
               
        if split == 'Train':
            for quality in quality_filter:  # filter out bronze quality since it contains autogenerated annotations
                quality_dir = os.path.join(annotations_dir, quality)
                json_format_dir = os.path.join(quality_dir, 'json_format')
                if not os.path.exists(json_format_dir):
                    print(f"No folder {json_format_dir} was found!")
                    continue
                
                # append data points (tuple of article identifier and corresponding annotations as a dictionary) to the sample list 
                for file_name in os.listdir(json_format_dir):
                    if file_name.endswith('.json'):
                        file_path = os.path.join(json_format_dir, file_name)
                        with open(file_path, 'r', encoding='utf-8') as f:
                            data = json.load(f)
                            #self.samples.extend(data.items())  
                            sorted_items = sorted(data.items(), key=lambda item: item[0])  # sort items by article identifier number
                            self.samples.extend(sorted_items)
                          
        elif split == 'Dev':
            json_format_dir = os.path.join(annotations_dir, 'json_format')
            if not os.path.exists(json_format_dir):
                raise FileNotFoundError(f"No folder {json_format_dir} was found!")
                
            json_files = [fname for fname in os.listdir(json_format_dir) if fname.endswith('.json')]
            for json_file in json_files:
                file_path = os.path.join(json_format_dir, json_file)
                with open(file_path, 'r', encoding='utf-8') as f:
                    data = json.load(f)
                    sorted_items = sorted(data.items(), key=lambda item: item[0]) 
                    self.samples.extend(sorted_items)        
        else:
            raise ValueError("Specify a split, must be either 'Train' or 'Dev'!")
        
    def __len__(self):
        return len(self.samples) 
    
    def __getitem__(self, idx):
        return self.samples[idx]  # one data point (=article id) with annotations
        
    def plot_abstract_lengths(self):
        """
        Plots the distribution of tokenized word lengths of abstracts using either whitespace tokenization or BERT tokenization.
        """
        abstract_lengths = []
        for article_id, data in self.samples:
            abstract = data['metadata'].get('abstract', '')
            
            if self.tokenizer:  # use BERT tokenizer if its given
                tokens = self.tokenizer.tokenize(abstract)
                token_count = len(tokens)
                abstract_lengths.append(token_count)
                tokenizer_type = "BERT Tokenized"
            else:  # white space tokenization (just as an overview, baselines use NLTK tokenizer)
                word_count = len(abstract.split())
                abstract_lengths.append(word_count)
                tokenizer_type = "Whitespace Tokenized"
                
        print("Maximum number of tokens per abstract: ", max(abstract_lengths))
        plt.figure(figsize=(8, 4))
        plt.hist(abstract_lengths, bins=30, color='#E6E6FA', edgecolor='#D1C8E3')
        plt.title(f"Distribution of Abstract Lengths ({tokenizer_type})", fontsize=14, fontweight='bold')
        plt.xlabel("Token Count" if self.tokenizer else "Word Count", fontsize=12, fontweight='medium')
        plt.ylabel("Frequency", fontsize=12, fontweight='medium')
        plt.grid(True, linestyle='--', alpha=0.5)
        plt.tick_params(axis='both', which='major', labelsize=12, length=6, width=1.2, direction='in', grid_alpha=0.5)
    
        plt.tight_layout()
        plt.show()

    
    def get_text_data(self):
        """
        Extracts title and abstract text from the dataset.
        """
        all_titles = []
        all_abstracts = []
        
        for _, data in self.samples:
            if 'metadata' in data:
                if 'title' in data['metadata'] and data['metadata']['title']:
                    all_titles.append(data['metadata']['title'])
                if 'abstract' in data['metadata'] and data['metadata']['abstract']:
                    all_abstracts.append(data['metadata']['abstract'])
            
        return " ".join(all_titles), " ".join(all_abstracts)

    def build_vocab(self): # important for vocabulary coverage check 
        """
        Tokenizes the dataset and builds a vocabulary.
        """
        vocab = Counter()
        all_titles, all_abstracts = self.get_text_data()  # get the raw text
        
        # tokenize text (based on whitespace and punctuation)
        words = re.findall(r'\b\w+\b', all_titles.lower()) + re.findall(r'\b\w+\b', all_abstracts.lower())
        
        vocab.update(words)  # count all word occurences
        return vocab


def split_datasets(train_data, val_data, test_data):
    '''
    Splits the training dataset into a new training and validation set.  
    The validation set is sized based on the test set length.  
    '''
    train_size = len(train_data.samples)
    test_size = len(test_data.samples) 
    val_size = test_size  
    train_new_size = train_size - val_size  # remaining size for new training set

    train_subset, val_subset = random_split(train_data.samples, [train_new_size, val_size]) # split the train set into new train and val

    train_data.samples = train_subset
    val_data.samples = val_subset

    return train_data, val_data

### RE Dataset for binary tag based relation extraction

Differently to the NER data set, title and abstract have to be processed together since an entity in the title can be in a relationship with an entity mentioned in the abtract. Further, this increases the need for striding. Moreover, relationships are directional, there is an entity 1 serving as a subject and entity 2 serving as a subject, with entity 2 potentially preceeding entity 1 in the text. This complicates the enumeration of candidate entity pairs, as the number of possible combinations grows factorially.

**Considerations:**

    - If there are n relationships, create n negative pairs (either randomly or with a heuristic filter).
    - Smart negative sampling: Generate negatives where Entity1 and Entity2 are of the same type as positive samples. Select entities that co-occur in the same document but are far apart (maybe 50%, but this can be done randomly). easy (just in same document), medium (same entity type), and hard (distance based). Also, take into account that a subject can appear after an object (passive structures etc.).
    - Take all relationships as data points (but consider dummification/replacement with tag)? For now, focus on mention based (only add entity markers, see Gu et al. (2021))
    - special tokens have to be added to tokenizer's vocabulary
    - one data point = title + abstract, marked with entity 1 (subject) and entity 2 (object), labels = set(0,1)
    - Relation encoding: 1) entity 1 + entity 2 concatenated, 2) CLS, 3) CLS + entity 1 + entitiy 2
    - striding? Or is there any way to use PubMedBERT with longer contexts than 512 tokens? Take into account however, that striding misses relations that are in two separate chunks. This is not an issue for NER, but for RE.
    - This time take the training loop from the GutBrainie challenge
    - data augmentation from external resources is more complicated here than for NER
    - ca. 32.720 training data points if every article ID has 20 relations (and 20 negative relations to balance classes)
    - for training: using the ground truth NEs probably is fine.
    - for inference: candidate generation: probably with own NER model. Make an experiment, A) own NER model, B) ground truth. Return set of relations with tags (tag before the mention based NEs.)
    - apply self attention on concatenation of entity 1 and entity 2
    - entity aware attention through attention biasing
    - capture global context better through an additional attention layer

    Gu et al. (2021): For featurization, the relation instance is either represented by a special [CLS] token or by concatenating the mention representations. In the latter case, if an entity mention contains multiple tokens, its representation is usually produced by pooling those of individual tokens (max or average).

In [None]:
# more strategic negative sampling (=harder negatives)

# helper functions for modularity of code
def get_adjusted_indices(entity, offset):
    """
    Adjusts start and end indices for entities in the abstract (adds an offset = length of title to abstracts).
    """
    if entity.get("location") == "abstract":
        start_idx = entity["start_idx"] + offset
        end_idx = entity["end_idx"] + 1 + offset
    else:
        start_idx = entity["start_idx"]
        end_idx = entity["end_idx"] + 1
    return start_idx, end_idx


def mark_entities(text, subj_start_idx, subj_end_idx, obj_start_idx, obj_end_idx):
    """
    Inserts special entity markers into the text.
    """
    text_chars = list(text)
    text_chars.insert(obj_end_idx, "</ent2>")
    text_chars.insert(obj_start_idx, "<ent2>")
    text_chars.insert(subj_end_idx, "</ent1>")
    text_chars.insert(subj_start_idx, "<ent1>")
    return "".join(text_chars)


class REDataset(AnnotationDataset):
    def __init__(self, root_path, tokenizer, max_length=512, split="Train", quality_filter=['platinum_quality', 'gold_quality', 'silver_quality']):
        """
        Creates a relation extraction dataset.
        Each data point is a concatenation of abstract and title. Entity markers (<ent1> and <ent2>) are inserted to mark the two entities in question.
        These entity markers have to be added to the tokenizer that is passed to the initialisation of the class.
        For each article, positive relation candidates are generated based on the ground truth. For this, all possible mention based relations are considered.
        (A set of tag based relations is later inferred during inference. Here, the entities are not modified except for a special entity marker).
        Negative samples are generated from other candidate entity pairs (randomly, check later to include easy, medium, and hard examples, especially entities that are the same ones as in the relations).
        """
        super().__init__(root_path, tokenizer=tokenizer, split=split, quality_filter=quality_filter)
        #self.tokenizer = tokenizer # is already initiated by the parent class Annotation Data set
        self.max_length = max_length
        self.relation_samples = []

        #counter = 0
        # concatenate title and abstract because a relation can hold between an entity in the title and one in the abstract
        for article_id, data in self.samples:
            #counter += 1
            title = data['metadata'].get('title', '')
            abstract = data['metadata'].get('abstract', '')
            full_text = (title + " " + abstract).strip() 
            if not full_text:
                continue
            
            # get ground truth entities and all relations 
            entities = data.get("entities", [])
            relations = data.get("relations", [])

            # for entities in the abstract, we need to add the length of the title to the indices (since we are concatenating titles and abstracts)
            offset = len(title) + 1

            reversed_pairs = [] # we need this for sampling of negatives below
            
            for rel in relations:
                subj_entity = {"start_idx": rel["subject_start_idx"], "end_idx": rel["subject_end_idx"], "location": rel["subject_location"]}
                subj_start_idx, subj_end_idx = get_adjusted_indices(subj_entity, offset) # adjust indices for subject
    
                obj_entity = {"start_idx": rel["object_start_idx"], "end_idx": rel["object_end_idx"], "location": rel["object_location"]}
                obj_start_idx, obj_end_idx = get_adjusted_indices(obj_entity, offset) # adjust indices for object

                #if obj_entity != subj_entity: # check whether that is possible
                #    reversed_pairs.append((obj_entity, subj_entity)) # reversed pairs
                #    if counter <2:
                #        print(obj_entity, subj_entity)

                marked_text = mark_entities(full_text, subj_start_idx, subj_end_idx, obj_start_idx, obj_end_idx)
                    
                self.relation_samples.append({
                    "article_id": article_id,
                    "text": marked_text,
                    "label": 1 # positive relation
                })

            # create negative examples by generating as many negative examples as positives (to balance classes)
            num_pos = len(relations)
            num_reversed = int(num_pos*0.25)
            num_same_tag = int(num_pos*0.25)
            num_random = num_pos - num_reversed - num_same_tag

            #if counter < 2:
             #   print("num pos: ", num_pos, " num_reversed: ", num_reversed, " num_random: ", num_random )
                
            candidate_pairs = [] # get all possible candidate pairs
            for i in range(len(entities)):
                for j in range(len(entities)):
                    candidate_pairs.append((entities[i], entities[j])) # entities look like this: {'start_idx': 0, 'end_idx': 26, 'location': 'title', 'text_span': 'Lactobacillus fermentum NS9', 'label': 'dietary supplement'}
            random.shuffle(candidate_pairs)
            
            same_tag_pairs = []
            random_pairs = []
            
            for subj, obj in candidate_pairs:
                # exclude all possible candidate pairs (order matters because we have directional relationships between subj and obj)
                is_positive = any(subj["text_span"] == r["subject_text_span"] and obj["text_span"] == r["object_text_span"] for r in relations)
                # exclude all pertubations of positive pairs (we subsample them above)
                is_reversed = any(obj["text_span"] == r["subject_text_span"] and subj["text_span"] == r["object_text_span"] for r in relations)
                has_same_tag = any(subj["label"] == r["subject_label"] and obj["label"] == r["object_label"] for r in relations)
                
                if not is_positive and not is_reversed:
                    if has_same_tag:
                        same_tag_pairs.append((subj,obj))
                    else:
                        random_pairs.append((subj,obj))

            # create random subset of the negatives
            random.shuffle(same_tag_pairs)
            random.shuffle(random_pairs)
            random.shuffle(reversed_pairs)

            #if counter < 2:
             #   print("BEFORE SAMPLING: num_reversed: ", len(reversed_pairs), " num_random: ", len(random_pairs), "num_same_tag:", len(same_tag_pairs))
            
            sampled_reversed = reversed_pairs[:min(num_reversed, len(reversed_pairs))]
            sampled_same_tag = same_tag_pairs[:min(num_same_tag, len(same_tag_pairs))]
            sampled_random = random_pairs[:num_pos - len(sampled_reversed) - len(sampled_same_tag)]

            #if counter < 2:
             #   print("AFTER SAMPLING: num_reversed: ", len(sampled_reversed), " num_random: ", len(sampled_random), "num_same_tag:", len(sampled_same_tag))

            final_negative_pairs = sampled_reversed + sampled_same_tag + sampled_random

            for subj, obj in final_negative_pairs:
                    subj_start_idx, subj_end_idx = get_adjusted_indices(subj, offset)
                    obj_start_idx, obj_end_idx = get_adjusted_indices(obj, offset)
                    marked_text = mark_entities(full_text, subj_start_idx, subj_end_idx, obj_start_idx, obj_end_idx)
                    self.relation_samples.append({
                        "article_id": article_id,
                        "text": marked_text,
                        "label": 0
                    }) # might store type of negative here for further examination
            #if counter < 2:
                #print("Final samples of this abstract:", len(self.relation_samples))
        #print("Final len:", len(self.relation_samples))
        #random.shuffle(self.relation_samples)
        #self.relation_samples = self.relation_samples[:100]
    
    def __len__(self):
        return len(self.relation_samples)
    
    def __getitem__(self, idx):
        """
        Returns a tokenized relation extraction data point:
            - input_ids
            - attention_mask
            - label (0 or 1) for binary classification

        Uses a dynamic window approach to make sure entities are captured within the 512 token limit.
        """
        sample = self.relation_samples[idx]
        
        tokenized_text = self.tokenizer(
            sample["text"], 
            #padding="max_length", 
            truncation=True, 
            max_length=self.max_length,
            return_tensors="pt"
        )
    
        for key in tokenized_text:
            tokenized_text[key] = tokenized_text[key].squeeze(0) # ???
        return {
            "input_ids": tokenized_text["input_ids"],
            "attention_mask": tokenized_text["attention_mask"],
            "label": torch.tensor(sample["label"], dtype=torch.long)
        }

In [None]:
# helper functions for modularity of code
'''def get_adjusted_indices(entity, offset):
    """
    Adjusts start and end indices for entities in the abstract (adds an offset = length of title to abstracts).
    """
    if entity.get("location") == "abstract":
        start_idx = entity["start_idx"] + offset
        end_idx = entity["end_idx"] + 1 + offset
    else:
        start_idx = entity["start_idx"]
        end_idx = entity["end_idx"] + 1
    return start_idx, end_idx


def mark_entities(text, subj_start_idx, subj_end_idx, obj_start_idx, obj_end_idx):
    """
    Inserts special entity markers into the text.
    """
    text_chars = list(text)
    text_chars.insert(obj_end_idx, "</ent2>")
    text_chars.insert(obj_start_idx, "<ent2>")
    text_chars.insert(subj_end_idx, "</ent1>")
    text_chars.insert(subj_start_idx, "<ent1>")
    return "".join(text_chars)


class REDataset(AnnotationDataset):
    def __init__(self, root_path, tokenizer, max_length=512, split="Train", quality_filter=['platinum_quality', 'gold_quality', 'silver_quality']):
        """
        Creates a relation extraction dataset.
        Each data point is a concatenation of abstract and title. Entity markers (<ent1> and <ent2>) are inserted to mark the two entities in question.
        These entity markers have to be added to the tokenizer that is passed to the initialisation of the class.
        For each article, positive relation candidates are generated based on the ground truth. For this, all possible mention based relations are considered.
        (A set of tag based relations is later inferred during inference. Here, the entities are not modified except for a special entity marker).
        Negative samples are generated from other candidate entity pairs (randomly, check later to include easy, medium, and hard examples, especially entities that are the same ones as in the relations).
        """
        super().__init__(root_path, tokenizer=tokenizer, split=split, quality_filter=quality_filter)
        #self.tokenizer = tokenizer # is already initiated by the parent class Annotation Data set
        self.max_length = max_length
        self.relation_samples = []

        counter = 0
        
        # concatenate title and abstract because a relation can hold between an entity in the title and one in the abstract
        for article_id, data in self.samples:
            counter += 1
            title = data['metadata'].get('title', '')
            abstract = data['metadata'].get('abstract', '')
            full_text = (title + " " + abstract).strip() 
            if not full_text:
                continue
            
            # get ground truth entities and all relations 
            entities = data.get("entities", [])
            relations = data.get("relations", [])

            # for entities in the abstract, we need to add the length of the title to the indices (since we are concatenating titles and abstracts)
            offset = len(title) + 1

            for rel in relations:
                subj_entity = {"start_idx": rel["subject_start_idx"], "end_idx": rel["subject_end_idx"], "location": rel["subject_location"]}
                subj_start_idx, subj_end_idx = get_adjusted_indices(subj_entity, offset) # adjust indices for subject
    
                obj_entity = {"start_idx": rel["object_start_idx"], "end_idx": rel["object_end_idx"], "location": rel["object_location"]}
                obj_start_idx, obj_end_idx = get_adjusted_indices(obj_entity, offset) # adjust indices for object

                marked_text = mark_entities(full_text, subj_start_idx, subj_end_idx, obj_start_idx, obj_end_idx)
                    
                self.relation_samples.append({
                    "article_id": article_id,
                    "text": marked_text,
                    "label": 1 # positive relation
                })

            
            # create negative examples by generating as many negative examples as positives (to balance classes)
            num_pos = len(relations)
            #if counter < 2:
                #print(num_pos)

            candidate_pairs = [] # get all possible candidate pairs
            for i in range(len(entities)):
                for j in range(len(entities)):
                    candidate_pairs.append((entities[i], entities[j])) # entities look like this: {'start_idx': 0, 'end_idx': 26, 'location': 'title', 'text_span': 'Lactobacillus fermentum NS9', 'label': 'dietary supplement'}
            
            random.shuffle(candidate_pairs)
            negatives_added = 0
            # here we can calculate a percentage... if negatives_added/num_pos < 0.5... and then have 50% easy, 25% medium, 25% hard
            for pair in candidate_pairs:
                subj, obj = pair
                # exclude all possible candidate pairs (order matters because we have directional relationships between subj and obj)
                is_positive = any(subj["text_span"] == r["subject_text_span"] and obj["text_span"] == r["object_text_span"] for r in relations)
                #if counter < 2 and is_positive:
                    #print("is positive: ", subj, obj, "\n")
                if not is_positive:
                    #if counter <2:
                        #print("negative", subj, obj)
                    # add the negative example
                    subj_start_idx, subj_end_idx = get_adjusted_indices(subj, offset)
                    obj_start_idx, obj_end_idx = get_adjusted_indices(obj, offset)

                    marked_text = mark_entities(full_text, subj_start_idx, subj_end_idx, obj_start_idx, obj_end_idx)
                    
                    self.relation_samples.append({
                        "article_id": article_id,
                        "text": marked_text,
                        "label": 0
                    })
                    
                    negatives_added += 1
                    if negatives_added >= num_pos:
                        break # we want a balanced data set, stop if number of positives is reached
            
                        
        #random.shuffle(self.relation_samples)
        #self.relation_samples = self.relation_samples[:100]
    
    def __len__(self):
        return len(self.relation_samples)
    
    def __getitem__(self, idx):
        """
        Returns a tokenized relation extraction data point:
            - input_ids
            - attention_mask
            - label (0 or 1) for binary classification

        Uses a dynamic window approach to make sure entities are captured within the 512 token limit.
        """
        sample = self.relation_samples[idx]
        
        tokenized_text = self.tokenizer(
            sample["text"], 
            #padding="max_length", 
            truncation=True, 
            max_length=self.max_length,
            return_tensors="pt"
        )
    
        for key in tokenized_text:
            tokenized_text[key] = tokenized_text[key].squeeze(0) # ???
        return {
            "input_ids": tokenized_text["input_ids"],
            "attention_mask": tokenized_text["attention_mask"],
            "label": torch.tensor(sample["label"], dtype=torch.long)
        }'''

#### Sanity checks

In [None]:
tokenizer = "NeuML/pubmedbert-base-embeddings"
tokenizer = AutoTokenizer.from_pretrained(tokenizer) 
special_tokens = ['<ent1>', '</ent1>', '<ent2>', '</ent2>']
test_dataset = REDataset(DATA_DIR, tokenizer=tokenizer, split="Dev")
test_dataset.tokenizer.add_tokens(special_tokens) # make sure to also resize the size of embeddings later for the model
print(len(test_dataset))

In [None]:
index = 4
input_ids = test_dataset[index]['input_ids']

decoded_text = test_dataset.tokenizer.decode(input_ids, skip_special_tokens=True)

tokens = test_dataset.tokenizer.convert_ids_to_tokens(input_ids)

print("Decoded text:", decoded_text)
print("Tokens:", tokens)
print(test_dataset[index])

### Prepare data loaders

In [None]:
# https://pytorch.org/docs/stable/data.html

def collate_fn(batch):
    input_ids = [item["input_ids"] for item in batch]
    attention_masks = [item["attention_mask"] for item in batch]
    labels = torch.stack([item["label"] for item in batch])
    
    # dynamic padding to longest seq of batch (to increase computational efficiency)
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0) # tokenizer.pad_token_id
    attention_masks_padded = pad_sequence(attention_masks, batch_first=True, padding_value=0) # https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_sequence.html
    # keep batch size as first dimension. Tensor of size B x T x * where T is the length of the longest sequence
    
    return {
        "input_ids": input_ids_padded,
        "attention_mask": attention_masks_padded,
        "label": labels
    }

def create_dataloaders(batch_size, tokenizer, device):
    train_dataset = REDataset(DATA_DIR, tokenizer=tokenizer, split="Train")
    val_dataset = REDataset(DATA_DIR, tokenizer=tokenizer, split="Train")  # dummy val data set
    test_dataset = REDataset(DATA_DIR, tokenizer=tokenizer, split="Dev")  # take dev set as test set (until official test set release)

    # split into train and val 
    train_dataset, val_dataset = split_datasets(train_dataset, val_dataset, test_dataset)
    #print(train_dataset[1])
    
    train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=collate_fn)
    val_dataloader = DataLoader(val_dataset, batch_size, shuffle=True, collate_fn=collate_fn)
    test_dataloader = DataLoader(test_dataset, batch_size, shuffle=False, collate_fn=collate_fn)
    
    return train_dataloader, val_dataloader, test_dataloader

In [None]:
train_dataset = REDataset(DATA_DIR, tokenizer=tokenizer, split="Train")
val_dataset = REDataset(DATA_DIR, tokenizer=tokenizer, split="Train")  # dummy val data set
test_dataset = REDataset(DATA_DIR, tokenizer=tokenizer, split="Dev")  # take test as val 

# split into train and val 
train_dataset, val_dataset = split_datasets(train_dataset, val_dataset, test_dataset)

### Architecture



In [None]:
class RelationClassifier(nn.Module):
    def __init__(self, model, hidden_size, dropout, ent1_start_id, ent1_end_id, ent2_start_id, ent2_end_id):
        """
        Binary relation classification model using a BERT-based model with a linear classification layer on top.
        The hidden size should reflect the adjusted embedding size of the model after adding special tokens to the tokenizer.
        """

        super(RelationClassifier, self).__init__()
        self.transformer = model
        self.dropout = nn.Dropout(dropout)
        self.hidden_size = hidden_size
        #self.attention = nn.MultiheadAttention(embed_dim=hidden_size * 3, num_heads=8)
        self.classifier = nn.Linear(hidden_size * 3, 1) # input is concatenation of CLS + ent1 + ent2, output is one logit (binary classification)

        self.ent1_start_id = ent1_start_id
        self.ent1_end_id = ent1_end_id
        self.ent2_start_id = ent2_start_id
        self.ent2_end_id = ent2_end_id
        

    def forward(self, input_ids, attention_mask):
        outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask) # get model outputs
        sequence_output = outputs.last_hidden_state  # take the last hidden state, shape: (batch_size, seq_len, hidden_size) 
        cls_repr = sequence_output[:, 0, :] # CLS token (first token in each sequence)
        
        batch_size = input_ids.size(0)
        ent1_repr_list = [] # store entity representations (one per sequence), list has length = batch size
        ent2_repr_list = [] 
        
        # iterate over sequences in a batch
        for i in range(batch_size):
            tokens = input_ids[i]  # shape: (seq_len)
            token_reps = sequence_output[i] # shape: (seq_len, hidden_size)
            
            # get the start and end of entities by the entity markers in input ids. The .nonzero() function finds the positions in the tensor where these tokens occur.
            ent1_start_pos = (tokens == self.ent1_start_id).nonzero(as_tuple=True)[0] # as_tuple=True to get 1D tensor. [0] takes the first occurrence of start token (since there is only one)
            ent1_end_pos = (tokens == self.ent1_end_id).nonzero(as_tuple=True)[0]
            ent2_start_pos = (tokens == self.ent2_start_id).nonzero(as_tuple=True)[0]
            ent2_end_pos = (tokens == self.ent2_end_id).nonzero(as_tuple=True)[0]
            
            # extract entity representations and average over subtokens
            if len(ent1_start_pos) > 0 and len(ent1_end_pos) > 0:
                start_idx = ent1_start_pos[0].item() + 1 #  convert 1D tensor to integer and exluce the ent marker pos (just take the average of tokens between them)
                end_idx = ent1_end_pos[0].item()
                ent1_repr = token_reps[start_idx:end_idx].mean(dim=0) # average over subtokens 
            else:
                ent1_repr = torch.zeros(token_reps.size(1), device=token_reps.device) # otherwise return 0 vector (but after dynamic padding it shouldnt be a problem anymore)
            
            if len(ent2_start_pos) > 0 and len(ent2_end_pos) > 0:
                start_idx = ent2_start_pos[0].item() + 1
                end_idx = ent2_end_pos[0].item()
                ent2_repr = token_reps[start_idx:end_idx].mean(dim=0)  # average over subtokens 
            else:
                ent2_repr = torch.zeros(token_reps.size(1), device=token_reps.device)
                
            ent1_repr_list.append(ent1_repr) # add the entity representation for this sequence to the list for the whole batch
            ent2_repr_list.append(ent2_repr)
        
        ent1_repr = torch.stack(ent1_repr_list, dim=0) # stack the 16 (=batch size) tensors (with shape=hidden_size) along batch dimension =  --> shape = (batch_size,hidden_dim)
        ent2_repr = torch.stack(ent2_repr_list, dim=0)
        
        # concatentate CLS, entity1, and entity2 representations
        combined_repr = torch.cat([cls_repr, ent1_repr, ent2_repr], dim=1) # concatenate CLS token + entity representations along second dimension --> shape = (batch_size, hidden_dim*3) 
        combined_repr = self.dropout(combined_repr) # add a dropout layer
        
        # single logit as output (for binary classification)
        logit = self.classifier(combined_repr).squeeze(1)  # shape: (batch_size). We want only one logit per seq

        return logit 

In [None]:
# seed function taken from https://github.com/heraclex12/R-BERT-Relation-Classification/blob/master/BERT_for_Relation_Classification.ipynb
def set_seed(seed):
    """Sets a random seed."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def train_and_evaluate(model_name, tokenizer_voc_size, seed, train_dataloader, val_dataloader, test_dataloader, lr, weight_decay, num_epochs, dropout, device, max_norm, ent1_start_id, ent1_end_id, ent2_start_id, ent2_end_id, threshold):
    """
    Trains and evaluates a relation classification model.
    """
    set_seed(seed)
    
    model_str = MODEL_CONFIGS[model_name]["model_name"]
    model = AutoModel.from_pretrained(model_str)
    model.resize_token_embeddings(tokenizer_voc_size)  # adjust embeddings of the model for special tokens

    hidden_size = model.config.hidden_size
    
    model = RelationClassifier(model, hidden_size, dropout, ent1_start_id, ent1_end_id, ent2_start_id, ent2_end_id).to(device)

    bce_loss = nn.BCEWithLogitsLoss()  # use binary crossentropy loss, cf. https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html (combines a Sigmoid layer and BCELoss).
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    num_training_steps = len(train_dataloader) * num_epochs # this will be displayed in wandb on the x-axis
    num_warmup_steps = int(0.1 * num_training_steps) # 10% warm up steps
    scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)

    best_val_loss = float('inf')
    patience = 1 # 1 epoch patience
    patience_counter = 0

    train_losses, val_losses = [], []
    train_f1s_micro, val_f1s_micro = [], []
    train_f1s_macro, val_f1s_macro = [], []

    global_step = 0
    global_step_val = 0
    
    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        all_train_labels, all_train_preds = [], []

        # cf. https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
        for step, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device).float()

            optimizer.zero_grad() # zero parameter gradients
            
            outputs = model(input_ids, attention_mask) # forward pass
            loss = bce_loss(outputs, labels) # calculate bce loss

            loss.backward() # backward pass 
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            optimizer.step() # optimize
            scheduler.step()

            total_train_loss += loss.item()
            
            preds = (torch.sigmoid(outputs).detach().cpu().numpy() > threshold).astype(int) 
            all_train_labels.extend(labels.cpu().numpy())
            all_train_preds.extend(preds)

            batch_f1_micro = f1_score(labels.cpu().numpy(), preds, average="micro") # calculate micro F1
            batch_f1_macro = f1_score(labels.cpu().numpy(), preds, average="macro") # calculate macro F1

            # log the micro F1 and macro F1 into wandb
            wandb.log({
                "step": global_step,
                "train_loss_batch": loss.item(),
                "train_f1_micro_batch": batch_f1_micro,
                "train_f1_macro_batch": batch_f1_macro,
            })
            global_step += 1

        train_losses.append(total_train_loss / len(train_dataloader))
        train_f1_micro = f1_score(all_train_labels, all_train_preds, average="micro")
        train_f1_macro = f1_score(all_train_labels, all_train_preds, average="macro")
        train_f1s_micro.append(train_f1_micro)
        train_f1s_macro.append(train_f1_macro)

        wandb.log({
        "train_loss": train_losses[-1],  
        "train_f1_micro": train_f1s_micro[-1],
        "train_f1_macro": train_f1s_macro[-1],})
        #}, step=epoch + 1) # log per epoch not step

        model.eval()
        total_val_loss = 0
        all_val_labels, all_val_preds = [], []

        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc=f"Validation", leave=False):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device).float()

                outputs = model(input_ids, attention_mask)
                loss = bce_loss(outputs, labels)

                total_val_loss += loss.item()
                all_val_labels.extend(labels.cpu().numpy())
                all_val_preds.extend(torch.sigmoid(outputs).cpu().numpy() > threshold)

                global_step_val += 1

        val_losses.append(total_val_loss / len(val_dataloader))
        val_f1_micro = f1_score(all_val_labels, all_val_preds, average="micro")
        val_f1_macro = f1_score(all_val_labels, all_val_preds, average="macro")
        val_f1s_micro.append(val_f1_micro)
        val_f1s_macro.append(val_f1_macro)

        wandb.log({
        "step": global_step_val,
        "val_loss": val_losses[-1],  
        "val_f1_micro": val_f1s_micro[-1],
        "val_f1_macro": val_f1s_macro[-1],
        }) 

        print(f"Epoch {epoch + 1}/{num_epochs}")
        print(f"Training Loss: {train_losses[-1]:.3f}, Train F1 macro: {train_f1s_macro[-1]:.3f},Train F1 micro: {train_f1s_micro[-1]:.3f} ")
        print(f"Validation Loss: {val_losses[-1]:.3f}, Val F1 macro: {val_f1s_macro[-1]:.3f}, Val F1 micro: {val_f1s_micro[-1]:.3f} ")

        # early stopping should be triggered if loss is not decreasing
        if val_losses[-1] < best_val_loss:
            best_val_loss = val_losses[-1]
            patience_counter = 0
            best_model_state = model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Stopping early, no improvement in decreasing loss!")
                break

    # load the best model (with lowest loss)
    model.load_state_dict(best_model_state)
    model.eval() # set to evaluation mode

    # evalzte on test set
    all_test_labels, all_test_preds = [], []
    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Test Evaluation", leave=False):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device).float()

            logits = model(input_ids, attention_mask)
            all_test_labels.extend(labels.cpu().numpy())
            all_test_preds.extend(torch.sigmoid(logits).cpu().numpy() > threshold)

    test_f1_micro = f1_score(all_test_labels, all_test_preds, average='micro')
    test_f1_macro = f1_score(all_test_labels, all_test_preds, average='macro')

    print(f"Test micro f1: {test_f1_micro:.4f}")
    print(f"Test macro f1: {test_f1_macro:.4f}")

    wandb.log({
        "test_f1_micro": test_f1_micro,
        "test_f1_macro": test_f1_macro,
    })

    wandb.finish()

    return model, test_f1_micro, test_f1_macro, train_losses, val_losses, train_f1s_micro, train_f1s_macro, val_f1s_micro, val_f1_macro

# modified after Chollet 2018: 75 ADD REFERENCE
def plot_train_val_metrics(train_losses, val_losses, train_f1s, val_f1s, model_name, seed):
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    axs[0].plot(train_losses, label=f'Training loss', color="cyan")
    axs[0].plot(val_losses, label=f'Validation loss', linestyle='--', color="magenta")
    #axs[0].set_title(f'{model_name} - Train and Val Loss micro (Seed {seed})')
    axs[0].set_xlabel("Epochs")
    axs[0].set_ylabel("Loss")
    axs[0].legend()

    axs[1].plot(train_f1s, label=f'Training Micro F1', color="cyan")
    axs[1].plot(val_f1s, label=f'Validation Micro F1', linestyle='--', color="magenta")
    #axs[1].set_title(f'{model_name} - Train and Val F1 micro Score (Seed {seed})')
    axs[1].set_xlabel("Epochs")
    axs[1].set_ylabel("Micro F1")
    axs[1].legend()

    plt.tight_layout()
    plt.savefig(f"/kaggle/working/train_val_loss_f1_{model_name}_{seed}.png")
    plt.show()

In [None]:
MODEL_CONFIGS = {
    "PubMedBERT": {
        "model_name": "NeuML/pubmedbert-base-embeddings",
        "tokenizer": "NeuML/pubmedbert-base-embeddings"
    },
    
    "BERT": {
        "model_name": "bert-base-uncased",
        "tokenizer": "bert-base-uncased"
    },
    "BioBERT": {
        "model_name": "dmis-lab/biobert-v1.1",
        "tokenizer": "dmis-lab/biobert-v1.1"
    }
}

MODEL_CONFIGS = {
    "PubMedBERT": {
        "model_name": "NeuML/pubmedbert-base-embeddings",
        "tokenizer": "NeuML/pubmedbert-base-embeddings"
    }
    }

special_tokens = ['<ent1>', '</ent1>', '<ent2>', '</ent2>']
#NUM_SEEDS = 5
#seeds = [42] + random.sample(range(1, 1000), 4) 
seeds = [42]
print (seeds)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_SAVE_DIR = "/kaggle/working/best_models/"
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

best_models = {}
results = []

for model_name in MODEL_CONFIGS.keys():
    BATCH_SIZE = 16
    NUM_EPOCHS = 6
    LR = 1e-5
    WEIGHT_DECAY = 0.01 
    DROPOUT = 0.1 
    MAX_NORM = 1.0
    THRESHOLD = 0.6 # treshold for predicting positive class

    # add special tokens to tokenizer (entity markers)
    tokenizer_str = MODEL_CONFIGS[model_name]["tokenizer"]
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_str)
    tokenizer.add_tokens(special_tokens)
    tokenizer_voc_size = len(tokenizer)
    vocab = tokenizer.get_vocab()
    last_four_token_ids = sorted(vocab.values())[-4:] # extract the last four token ids (=entity markers)
    
    ent1_start_id, ent1_end_id, ent2_start_id, ent2_end_id = last_four_token_ids 

    train_dataloader, val_dataloader, test_dataloader = create_dataloaders(BATCH_SIZE,tokenizer, DEVICE) # create the data loaders

    test_micro_f1_scores = []
    test_macro_f1_scores = []
    all_train_losses = []
    all_val_losses = []
    all_train_f1s_micro = []
    all_train_f1s_macro = []
    all_val_f1s_micro = []
    all_val_f1s_macro = []
    
    best_micro_f1 = -1
    best_macro_f1 = -1
    best_model_state = None
    best_model_seed = None

    for seed in seeds:
        wandb.init(
        project="Relation_Classification",
        entity="lp2",
        config={
        "model": model_name,
        "learning_rate": LR,
        "batch_size": BATCH_SIZE,
        "epochs": NUM_EPOCHS,
        "dropout": DROPOUT,
        "weight_decay": WEIGHT_DECAY,
        "max_norm": MAX_NORM
        },
        name=f"{model_name}_seed{seed}_{THRESHOLD}",
        group=model_name,  # Groups all runs for a given model together
        tags=[model_name, f"seed-{seed}"])
        config = wandb.config

        # train and evaluate model
        model, test_micro_f1, test_macro_f1, train_losses, val_losses, train_f1s_micro, train_f1_macro, val_f1s_micro, val_f1_macro = train_and_evaluate(
            model_name, tokenizer_voc_size, seed, train_dataloader, val_dataloader, test_dataloader, LR, WEIGHT_DECAY, NUM_EPOCHS, DROPOUT, DEVICE, MAX_NORM, ent1_start_id, ent1_end_id, ent2_start_id, ent2_end_id, THRESHOLD) 
        
        test_micro_f1_scores.append(test_micro_f1)
        test_macro_f1_scores.append(test_macro_f1)
        all_train_losses.append(train_losses)
        all_val_losses.append(val_losses)
        all_train_f1s_micro.append(train_f1s_micro)
        all_train_f1s_macro.append(train_f1_macro) 
        all_val_f1s_micro.append(val_f1s_micro) 
        all_val_f1s_macro.append(val_f1_macro)
        
        plot_train_val_metrics(train_losses, val_losses, train_f1s_micro, val_f1s_micro, model_name, seed)

        # save the best model 
        if test_micro_f1 > best_micro_f1:
            best_micro_f1 = test_micro_f1
            best_model_state = model.state_dict()
            best_model_seed = seed

        seed_model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}_{seed}.pt")
        torch.save(model.state_dict(), seed_model_path)

    best_model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}_best.pt")
    best_models[model_name] = best_model_path
    torch.save(best_model_state, best_model_path)

    # get mean and std of micro and macro F1s
    mean_micro_f1 = np.mean(test_micro_f1_scores)
    std_micro_f1 = np.std(test_micro_f1_scores)
    mean_macro_f1 = np.mean(test_macro_f1_scores)
    std_macro_f1 = np.std(test_macro_f1_scores)
    
    results.append({
        "model": model_name,
        "avg_test_micro_f1": mean_micro_f1,
        "std_test_micro_f1": std_micro_f1,
        "avg_test_macro_f1": mean_macro_f1,
        "std_test_macro_f1": std_macro_f1,
        "best_micro_f1": best_micro_f1,
        "best_macro_f1": best_macro_f1,
        "best_model_seed": best_model_seed,
        "best_model_path": best_model_path
    })


results_df = pd.DataFrame(results)
results_json_path = "/kaggle/working/relation_classification_results.json"
results_df.to_json(results_json_path, orient="records", indent=4)

wandb.finish()

### Evaluation: Step 1: Generating predictions

Prepare evaluation in the form that is required for the official evaluation script of GutBrainIE2025. 

Given a title+abstract (one article ID) of the test set:

1) generate candidates. This can be done by enumerating all combinations of entity pairs in the ground truth OR using my own NER model to extract entitites + their spans + labels.
2) predict whether there is a relationship between them. Store the labels (as tuple) in a list of relations for this article ID.
3) generate a set with ordered tuples and create the output format.

In [None]:
# replace model with best model and corresponding tokenizer
model = model.eval()
tokenizer = tokenizer 
THRESHOLD_INFERENCE = 0.85
THRESHOLD = 0.6
use_ground_truth = True

"""else:
        entities = run_ner_model(text) """
predictions = {}

# ground truth path for NERs 
GROUND_TRUTH_PATH = "/kaggle/input/gutbrainie2025/gutbrainie2025/Annotations/Dev/json_format/dev.json"

with open(GROUND_TRUTH_PATH, "r", encoding="utf-8") as f:
    ground_truth_data = json.load(f)
    
def get_ground_truth_entities(abstract_id):
    # get entities for that abstract id
    article_data = ground_truth_data.get(abstract_id, {})
    entities = article_data.get("entities", [])
    return entities

def generate_candidate_pairs(entities):
    candidate_pairs = [] # get all possible candidate pairs
    for i in range(len(entities)):
        for j in range(len(entities)):
            if i == j: # CHECK WHETHER THIS COULD TECHNICALLY BE POSSIBLE THAT SUBJ = OBJ, reflexive relationships?
                continue
            candidate_pairs.append((entities[i], entities[j])) # entities look like this: {'start_idx': 0, 'end_idx': 26, 'location': 'title', 'text_span': 'Lactobacillus fermentum NS9', 'label': 'dietary supplement'}
    return candidate_pairs
    
    
if use_ground_truth:
    for abstract_id, article_data in tqdm(ground_truth_data.items(), desc="Processing Abstracts", unit="abstract"):
        entities = get_ground_truth_entities(abstract_id)
        candidate_pairs = generate_candidate_pairs(entities)

        metadata = article_data.get("metadata", {})
        title = metadata.get("title", "")
        abstract = metadata.get("abstract", "")
        full_text = (title + " " + abstract).strip()  # get the combination of text and abstract
        offset = len(title) + 1 # offset for abstract positions
        
        for entity1, entity2 in candidate_pairs:
            subj_start_idx, subj_end_idx = get_adjusted_indices(entity1, offset)
            obj_start_idx, obj_end_idx  = get_adjusted_indices(entity2, offset)
            
            # insert entity markers in the text for entity1 and entity2
            marked_text = mark_entities(full_text, subj_start_idx, subj_end_idx, obj_start_idx, obj_end_idx)
        
            inputs = tokenizer(marked_text, return_tensors="pt", truncation=True, max_length=512, padding=True)
            inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
            # run through the relation classifier
            with torch.no_grad():
                output = model(inputs["input_ids"], inputs["attention_mask"])
            relation_exists = (torch.sigmoid(output) > THRESHOLD_INFERENCE).item() # returns 1 if relation, 0 else
        
            if relation_exists:
                rel_info = {"subject_label": entity1["label"], "object_label": entity2["label"]}
                if abstract_id not in predictions:
                    predictions[abstract_id] = {"binary_tag_based_relations": []}
                if rel_info not in predictions[abstract_id]["binary_tag_based_relations"]:
                    predictions[abstract_id]["binary_tag_based_relations"].append(rel_info)

with open(f"relation_predictions_{model_name}.json", "w") as f:
    json.dump(predictions, f, indent=4)

### Step 2: using challenge script to calculate metrics (cf. https://github.com/MMartinelli-hub/GutBrainIE_2025_Baseline/blob/main/Eval/evaluate.py)

In [None]:
# DEFINE HERE THE PATH(S) TO YOUR PREDICTIONS
#PREDICTIONS_PATH_6_1 = 'org_T61_BaselineRun_NuNerZero.json'
PREDICTIONS_PATH_6_2 = "/kaggle/working/relation_predictions_PubMedBERT.json"
#PREDICTIONS_PATH_6_3 = 'org_T622_BaselineRun_ATLOP.json'
#PREDICTIONS_PATH_6_4 = 'org_T623_BaselineRun_ATLOP.json'

# DEFINE HERE FOR WHICH SUBTASK(S) YOU WANT TO EVAL YOUR PREDICTIONS
eval_6_1_NER = False
eval_6_2_binary_tag_RE = True
eval_6_3_ternary_tag_RE = False
eval_6_4_ternary_mention_RE = False

GROUND_TRUTH_PATH = "/kaggle/input/gutbrainie2025/gutbrainie2025/Annotations/Dev/json_format/dev.json"
try:
    with open(GROUND_TRUTH_PATH, 'r', encoding='utf-8') as file:
        ground_truth = json.load(file)
except OSError:
    raise OSError(f'Error in opening the specified json file: {GROUND_TRUTH_PATH}')

LEGAL_ENTITY_LABELS = [
    "anatomical location",
    "animal",
    "bacteria",
    "biomedical technique",
    "chemical",
    "DDF",
    "dietary supplement",
    "drug",
    "food",
    "gene",
    "human",
    "microbiome",
    "statistical technique"
]

LEGAL_RELATION_LABELS = [
    "administered",
    "affect",
    "change abundance",
    "change effect",
    "change expression",
    "compared to",
    "impact",
    "influence",
    "interact",
    "is a",
    "is linked to",
    "located in",
    "part of",
    "produced by",
    "strike",
    "target",
    "used by"
]


def eval_submission_6_1_NER(path):
    try:
        with open(path, 'r', encoding='utf-8') as file:
            predictions = json.load(file)
    except OSError:
        raise OSError(f'Error in opening the specified json file: {path}')
    
    ground_truth_NER = dict()
    count_annotated_entities_per_label = {}
    
    for pmid, article in ground_truth.items():
        if pmid not in ground_truth_NER:
            ground_truth_NER[pmid] = []
        for entity in article['entities']:
            start_idx = int(entity["start_idx"])
            end_idx = int(entity["end_idx"])
            location = str(entity["location"])
            text_span = str(entity["text_span"])
            label = str(entity["label"]) 
            
            entry = (start_idx, end_idx, location, text_span, label)
            ground_truth_NER[pmid].append(entry)
            
            if label not in count_annotated_entities_per_label:
                count_annotated_entities_per_label[label] = 0
            count_annotated_entities_per_label[label] += 1

    count_predicted_entities_per_label = {label: 0 for label in list(count_annotated_entities_per_label.keys())}
    count_true_positives_per_label = {label: 0 for label in list(count_annotated_entities_per_label.keys())}

    for pmid in predictions.keys():
        try:
            entities = predictions[pmid]['entities']
        except KeyError:
            raise KeyError(f'{pmid} - Not able to find field \"entities\" within article')
        
        for entity in entities:
            try:
                start_idx = int(entity["start_idx"])
                end_idx = int(entity["end_idx"])
                location = str(entity["location"])
                text_span = str(entity["text_span"])
                label = str(entity["label"]) 
            except KeyError:
                raise KeyError(f'{pmid} - Not able to find one or more of the expected fields for entity: {entity}')
            
            if label not in LEGAL_ENTITY_LABELS:
                raise NameError(f'{pmid} - Illegal label {label} for entity: {entity}')

            if label in count_predicted_entities_per_label:
                count_predicted_entities_per_label[label] += 1

            entry = (start_idx, end_idx, location, text_span, label)
            if entry in ground_truth_NER[pmid]:
                count_true_positives_per_label[label] += 1

    count_annotated_entities = sum(count_annotated_entities_per_label[label] for label in list(count_annotated_entities_per_label.keys()))
    count_predicted_entities = sum(count_predicted_entities_per_label[label] for label in list(count_annotated_entities_per_label.keys()))
    count_true_positives = sum(count_true_positives_per_label[label] for label in list(count_annotated_entities_per_label.keys()))

    micro_precision = count_true_positives / (count_predicted_entities + 1e-10)
    micro_recall = count_true_positives / (count_annotated_entities + 1e-10)
    micro_f1 = 2 * ((micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-10))

    precision, recall, f1 = 0, 0, 0
    n = 0
    for label in list(count_annotated_entities_per_label.keys()):
        n += 1
        current_precision = count_true_positives_per_label[label] / (count_predicted_entities_per_label[label] + 1e-10) 
        current_recall = count_true_positives_per_label[label] / (count_annotated_entities_per_label[label] + 1e-10) 
        
        precision += current_precision
        recall += current_recall
        f1 += 2 * ((current_precision * current_recall) / (current_precision + current_recall + 1e-10))
    
    precision = precision / n
    recall = recall / n
    f1 = f1 / n

    return precision, recall, f1, micro_precision, micro_recall, micro_f1

def eval_submission_6_2_binary_tag_RE(path):
    try:
        with open(path, 'r', encoding='utf-8') as file:
            predictions = json.load(file)
    except OSError:
        raise OSError(f'Error in opening the specified json file: {path}')
    
    ground_truth_binary_tag_RE = dict()
    count_annotated_relations_per_label = {}

    for pmid, article in ground_truth.items():
        if pmid not in ground_truth_binary_tag_RE:
            ground_truth_binary_tag_RE[pmid] = []
        for relation in article['binary_tag_based_relations']:
            subject_label = str(relation["subject_label"])
            object_label = str(relation["object_label"]) 

            label = (subject_label, object_label)
            ground_truth_binary_tag_RE[pmid].append(label)

            if label not in count_annotated_relations_per_label:
                count_annotated_relations_per_label[label] = 0
            count_annotated_relations_per_label[label] += 1
    
    count_predicted_relations_per_label = {label: 0 for label in list(count_annotated_relations_per_label.keys())}
    count_true_positives_per_label = {label: 0 for label in list(count_annotated_relations_per_label.keys())}

    for pmid in predictions.keys():
        try:
            relations = predictions[pmid]['binary_tag_based_relations']
        except KeyError:
            raise KeyError(f'{pmid} - Not able to find field \"binary_tag_based_relations\" within article')
        
        for relation in relations:
            try:
                subject_label = str(relation["subject_label"])
                object_label = str(relation["object_label"]) 
            except KeyError:
                raise KeyError(f'{pmid} - Not able to find one or more of the expected fields for relation: {relation}')
            
            if subject_label not in LEGAL_ENTITY_LABELS:
                raise NameError(f'{pmid} - Illegal subject entity label {subject_label} for relation: {relation}')
            
            if object_label not in LEGAL_ENTITY_LABELS:
                raise NameError(f'{pmid} - Illegal object entity label {object_label} for relation: {relation}')

            label = (subject_label, object_label)
            if label in count_predicted_relations_per_label:
                count_predicted_relations_per_label[label] += 1

            if label in ground_truth_binary_tag_RE[pmid]:
                count_true_positives_per_label[label] += 1

    count_annotated_relations = sum(count_annotated_relations_per_label[label] for label in list(count_annotated_relations_per_label.keys()))
    count_predicted_relations = sum(count_predicted_relations_per_label[label] for label in list(count_annotated_relations_per_label.keys()))
    count_true_positives = sum(count_true_positives_per_label[label] for label in list(count_annotated_relations_per_label.keys()))

    micro_precision = count_true_positives / (count_predicted_relations + 1e-10)
    micro_recall = count_true_positives / (count_annotated_relations + 1e-10)
    micro_f1 = 2 * ((micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-10))

    precision, recall, f1 = 0, 0, 0
    n = 0
    for label in list(count_annotated_relations_per_label.keys()):
        n += 1
        current_precision = count_true_positives_per_label[label] / (count_predicted_relations_per_label[label] + 1e-10) 
        current_recall = count_true_positives_per_label[label] / (count_annotated_relations_per_label[label] + 1e-10) 
        
        precision += current_precision
        recall += current_recall
        f1 += 2 * ((current_precision * current_recall) / (current_precision + current_recall + 1e-10))
    
    precision = precision / n
    recall = recall / n
    f1 = f1 / n

    return precision, recall, f1, micro_precision, micro_recall, micro_f1


def eval_submission_6_3_ternary_tag_RE(path):
    try:
        with open(path, 'r', encoding='utf-8') as file:
            predictions = json.load(file)
    except OSError:
        raise OSError(f'Error in opening the specified json file: {path}')
    
    ground_truth_ternary_tag_RE = dict()
    count_annotated_relations_per_label = {}

    for pmid, article in ground_truth.items():
        if pmid not in ground_truth_ternary_tag_RE:
            ground_truth_ternary_tag_RE[pmid] = []
        for relation in article['ternary_tag_based_relations']:
            subject_label = str(relation["subject_label"])
            predicate = str(relation["predicate"])
            object_label = str(relation["object_label"]) 
            
            label = (subject_label, predicate, object_label)
            ground_truth_ternary_tag_RE[pmid].append(label)

            if label not in count_annotated_relations_per_label:
                count_annotated_relations_per_label[label] = 0
            count_annotated_relations_per_label[label] += 1

    count_predicted_relations_per_label = {label: 0 for label in list(count_annotated_relations_per_label.keys())}
    count_true_positives_per_label = {label: 0 for label in list(count_annotated_relations_per_label.keys())}

    for pmid in predictions.keys():
        try:
            relations = predictions[pmid]['ternary_tag_based_relations']
        except KeyError:
            raise KeyError(f'{pmid} - Not able to find field \"ternary_tag_based_relations\" within article')
        
        for relation in relations:            
            try:
                subject_label = str(relation["subject_label"])
                predicate = str(relation["predicate"])
                object_label = str(relation["object_label"]) 
            except KeyError:
                raise KeyError(f'{pmid} - Not able to find one or more of the expected fields for relation: {relation}')
            
            if subject_label not in LEGAL_ENTITY_LABELS:
                raise NameError(f'{pmid} - Illegal subject entity label {subject_label} for relation: {relation}')
            
            if object_label not in LEGAL_ENTITY_LABELS:
                raise NameError(f'{pmid} - Illegal object entity label {object_label} for relation: {relation}')
            
            if predicate not in LEGAL_RELATION_LABELS:
                raise NameError(f'{pmid} - Illegal predicate {predicate} for relation: {relation}')

            label = (subject_label, predicate, object_label)
            if label in count_predicted_relations_per_label:
                count_predicted_relations_per_label[label] += 1

            if label in ground_truth_ternary_tag_RE[pmid]:
                count_true_positives_per_label[label] += 1

    count_annotated_relations = sum(count_annotated_relations_per_label[label] for label in list(count_annotated_relations_per_label.keys()))
    count_predicted_relations = sum(count_predicted_relations_per_label[label] for label in list(count_annotated_relations_per_label.keys()))
    count_true_positives = sum(count_true_positives_per_label[label] for label in list(count_annotated_relations_per_label.keys()))

    micro_precision = count_true_positives / (count_predicted_relations + 1e-10)
    micro_recall = count_true_positives / (count_annotated_relations + 1e-10)
    micro_f1 = 2 * ((micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-10))

    precision, recall, f1 = 0, 0, 0
    n = 0
    for label in list(count_annotated_relations_per_label.keys()):
        n += 1
        current_precision = count_true_positives_per_label[label] / (count_predicted_relations_per_label[label] + 1e-10) 
        current_recall = count_true_positives_per_label[label] / (count_annotated_relations_per_label[label] + 1e-10) 
        
        precision += current_precision
        recall += current_recall
        f1 += 2 * ((current_precision * current_recall) / (current_precision + current_recall + 1e-10))
    
    precision = precision / n
    recall = recall / n
    f1 = f1 / n

    return precision, recall, f1, micro_precision, micro_recall, micro_f1


def eval_submission_6_4_ternary_mention_RE(path):
    try:
        with open(path, 'r', encoding='utf-8') as file:
            predictions = json.load(file)
    except OSError:
        raise OSError(f'Error in opening the specified json file: {path}')
    
    ground_truth_ternary_mention_RE = dict()
    count_annotated_relations_per_label = {}

    for pmid, article in ground_truth.items():
        if pmid not in ground_truth_ternary_mention_RE:
            ground_truth_ternary_mention_RE[pmid] = []
        for relation in article['ternary_mention_based_relations']:
            subject_text_span = str(relation["subject_text_span"])
            subject_label = str(relation["subject_label"])
            predicate = str(relation["predicate"])
            object_text_span = str(relation["object_text_span"])
            object_label = str(relation["object_label"]) 

            entry = (subject_text_span, subject_label, predicate, object_text_span, object_label)
            ground_truth_ternary_mention_RE[pmid].append(entry)

            label = (subject_label, predicate, object_label)
            if label not in count_annotated_relations_per_label:
                count_annotated_relations_per_label[label] = 0
            count_annotated_relations_per_label[label] += 1

    count_predicted_relations_per_label = {label: 0 for label in list(count_annotated_relations_per_label.keys())}
    count_true_positives_per_label = {label: 0 for label in list(count_annotated_relations_per_label.keys())}
    
    for pmid in predictions.keys():
        try:
            relations = predictions[pmid]['ternary_mention_based_relations']
        except KeyError:
            raise KeyError(f'{pmid} - Not able to find field \"ternary_mention_based_relations\" within article')
        
        for relation in relations:
            try:
                subject_text_span = str(relation["subject_text_span"])
                subject_label = str(relation["subject_label"])
                predicate = str(relation["predicate"])
                object_text_span = str(relation["object_text_span"])
                object_label = str(relation["object_label"]) 
            except KeyError:
                raise KeyError(f'{pmid} - Not able to find one or more of the expected fields for relation: {relation}')
            
            if subject_label not in LEGAL_ENTITY_LABELS:
                raise NameError(f'{pmid} - Illegal subject entity label {subject_label} for relation: {relation}')
            
            if object_label not in LEGAL_ENTITY_LABELS:
                raise NameError(f'{pmid} - Illegal object entity label {object_label} for relation: {relation}')
            
            if predicate not in LEGAL_RELATION_LABELS:
                raise NameError(f'{pmid} - Illegal predicate {predicate} for relation: {relation}')
                        
            entry = (subject_text_span, subject_label, predicate, object_text_span, object_label)
            label = (subject_label, predicate, object_label) 
            
            if label in count_predicted_relations_per_label:
                count_predicted_relations_per_label[label] += 1
            
            if entry in ground_truth_ternary_mention_RE[pmid]:
                count_true_positives_per_label[label] += 1
    
    count_annotated_relations = sum(count_annotated_relations_per_label[label] for label in list(count_annotated_relations_per_label.keys()))
    count_predicted_relations = sum(count_predicted_relations_per_label[label] for label in list(count_annotated_relations_per_label.keys()))
    count_true_positives = sum(count_true_positives_per_label[label] for label in list(count_annotated_relations_per_label.keys()))

    micro_precision = count_true_positives / (count_predicted_relations + 1e-10)
    micro_recall = count_true_positives / (count_annotated_relations + 1e-10)
    micro_f1 = 2 * ((micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-10))

    precision, recall, f1 = 0, 0, 0
    n = 0
    for label in list(count_annotated_relations_per_label.keys()):
        n += 1
        current_precision = count_true_positives_per_label[label] / (count_predicted_relations_per_label[label] + 1e-10) 
        current_recall = count_true_positives_per_label[label] / (count_annotated_relations_per_label[label] + 1e-10) 
        
        precision += current_precision
        recall += current_recall
        f1 += 2 * ((current_precision * current_recall) / (current_precision + current_recall + 1e-10))
    
    precision = precision / n
    recall = recall / n
    f1 = f1 / n

    return precision, recall, f1, micro_precision, micro_recall, micro_f1


if __name__ == '__main__':
    round_to_decimal_position = 4

    if eval_6_1_NER:
        precision, recall, f1, micro_precision, micro_recall, micro_f1 = eval_submission_6_1_NER(PREDICTIONS_PATH_6_1)
        print("\n\n=== 6_1_NER ===")
        print(f"Macro-precision: {round(precision, round_to_decimal_position)}")
        print(f"Macro-recall: {round(recall, round_to_decimal_position)}")
        print(f"Macro-F1: {round(f1, round_to_decimal_position)}")
        print(f"Micro-precision: {round(micro_precision, round_to_decimal_position)}")
        print(f"Micro-recall: {round(micro_recall, round_to_decimal_position)}")
        print(f"Micro-F1: {round(micro_f1, round_to_decimal_position)}")

    """if eval_6_2_binary_tag_RE:
        precision, recall, f1, micro_precision, micro_recall, micro_f1 = eval_submission_6_2_binary_tag_RE(PREDICTIONS_PATH_6_2)
        print("\n\n=== 6_2_binary_tag_RE ===")
        print(f"Macro-precision: {round(precision, round_to_decimal_position)}")
        print(f"Macro-recall: {round(recall, round_to_decimal_position)}")
        print(f"Macro-F1: {round(f1, round_to_decimal_position)}")
        print(f"Micro-precision: {round(micro_precision, round_to_decimal_position)}")
        print(f"Micro-recall: {round(micro_recall, round_to_decimal_position)}")
        print(f"Micro-F1: {round(micro_f1, round_to_decimal_position)}")"""

    if eval_6_2_binary_tag_RE:
        precision, recall, f1, micro_precision, micro_recall, micro_f1 = eval_submission_6_2_binary_tag_RE(PREDICTIONS_PATH_6_2)
    
        round_to = round_to_decimal_position  # Just for readability
        results_text = (
        "\n\n=== 6_2_binary_tag_RE ===\n"
        f"Macro-precision: {round(precision, round_to)}\n"
        f"Macro-recall: {round(recall, round_to)}\n"
        f"Macro-F1: {round(f1, round_to)}\n"
        f"Micro-precision: {round(micro_precision, round_to)}\n"
        f"Micro-recall: {round(micro_recall, round_to)}\n"
        f"Micro-F1: {round(micro_f1, round_to)}\n"
        )
    
    print(results_text)  # Print to console

    # Save results to a text file
    with open("evaluation_results.txt", "w", encoding="utf-8") as file:
        file.write(results_text)

    if eval_6_3_ternary_tag_RE:
        precision, recall, f1, micro_precision, micro_recall, micro_f1 = eval_submission_6_3_ternary_tag_RE(PREDICTIONS_PATH_6_3)
        print("\n\n=== 6_3_ternary_tag_RE ===")
        print(f"Macro-precision: {round(precision, round_to_decimal_position)}")
        print(f"Macro-recall: {round(recall, round_to_decimal_position)}")
        print(f"Macro-F1: {round(f1, round_to_decimal_position)}")
        print(f"Micro-precision: {round(micro_precision, round_to_decimal_position)}")
        print(f"Micro-recall: {round(micro_recall, round_to_decimal_position)}")
        print(f"Micro-F1: {round(micro_f1, round_to_decimal_position)}")

    if eval_6_4_ternary_mention_RE:
        precision, recall, f1, micro_precision, micro_recall, micro_f1 = eval_submission_6_4_ternary_mention_RE(PREDICTIONS_PATH_6_4)
        print("\n\n=== 6_4_ternary_mention_RE ===")
        print(f"Macro-precision: {round(precision, round_to_decimal_position)}")
        print(f"Macro-recall: {round(recall, round_to_decimal_position)}")
        print(f"Macro-F1: {round(f1, round_to_decimal_position)}")
        print(f"Micro-precision: {round(micro_precision, round_to_decimal_position)}")
        print(f"Micro-recall: {round(micro_recall, round_to_decimal_position)}")
        print(f"Micro-F1: {round(micro_f1, round_to_decimal_position)}")