In [None]:
import json
import gzip
import os
import random
from google.colab import drive
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import torch
from sklearn.model_selection import train_test_split
from datasets import Dataset
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support


# Mount Google Drive
drive.mount('/content/drive')

# Data loading and positive/negative datasets construction
def load_json_gz(filepath):
    """Loads a gzipped json file."""
    with gzip.open(filepath, 'rt', encoding='utf-8') as f:
        return json.load(f)

def create_entity_linking_samples(medmentions_data, umls_data, max_negative_per_positive=2, sample_size=100):
    """
    Generates positive and negative samples for entity linking with balanced sampling.

    Args:
    medmentions_data: A list of MedMentions data dictionaries.
    umls_data: A list of UMLS data dictionaries.
    max_negative_per_positive: Maximum number of negative samples to generate for each positive sample.
    sample_size: The number of MedMentions entries to use.

    Returns:
    A list of dictionaries, each representing a sample with "text", "label".
    """

    umls_dict = {entry['cui']: entry for entry in umls_data if 'cui' in entry and 'name' in entry}
    samples = []

    # Sample MedMentions data
    sampled_medmentions_data = random.sample(medmentions_data, min(sample_size, len(medmentions_data)))


    for doc in sampled_medmentions_data:
        if 'annotations' in doc:
            for annotation in doc['annotations']:
              if 'cui' in annotation and  annotation['cui'] in umls_dict and 'text' in annotation:

                # positive samples
                positive_samples = {
                        "text": f"{annotation['text']} [SEP] {umls_dict[annotation['cui']]['name']}",
                        "label": 1
                   }
                samples.append(positive_samples)
                # negative samples
                negative_count = 0
                
                while negative_count < max_negative_per_positive:
                    
                   negative_cui = random.choice([c for c in umls_dict.keys() if c != annotation['cui']])
                   negative_samples = {
                      "text": f"{annotation['text']} [SEP] {umls_dict[negative_cui]['name']}",
                       "label": 0
                   }
                   samples.append(negative_samples)
                   negative_count +=1


    # Balance the dataset, ensure 50/50 positive/negative
    positive_samples = [sample for sample in samples if sample['label'] == 1]
    negative_samples = [sample for sample in samples if sample['label'] == 0]
    
    num_pos = len(positive_samples)
    num_neg = len(negative_samples)
    
    if num_pos > num_neg:
       sampled_positive = random.sample(positive_samples, num_neg)
       final_samples = sampled_positive + negative_samples
    elif num_neg > num_pos:
      sampled_negative = random.sample(negative_samples, num_pos)
      final_samples = sampled_negative + positive_samples
    else:
       final_samples = samples
    

    return final_samples

"""# load MedMentions dataset
def load_medmentions(file_path):
    with gzip.open(file_path, 'rt', encoding='utf-8') as f:
        return json.load(f)

# load UMLS dataset
def load_umls(file_path):
    with gzip.open(file_path, 'rt', encoding='utf-8') as f:
        return json.load(f)


medmentions_data = load_medmentions(medmentions_path)
umls_data = load_umls(umls_path)

# create index of UMLS 
umls_index = {entry['cui']: entry for entry in umls_data}

# construct positive and negative samples
positive_examples = []
negative_examples = []

for doc in medmentions_data[:10]:
    if 'annotations' in doc:
        for annotation in doc['annotations']:
            if 'cui' in annotation and annotation['cui'] in umls_index:
                # 正例
                positive_examples.append({
                    "text": f"{annotation['text']} [SEP] {umls_index[annotation['cui']]['name']}",
                    "label": 1
                })
                # 负例
                random_cui = random.choice(list(umls_index.keys()))
                while random_cui == annotation['cui']:
                  random_cui = random.choice(list(umls_index.keys()))

                  #random_cui = random.choice([c for c in umls_index.keys() if c != annotation['cui']])
                negative_examples.append({
                    "text": f"{annotation['text']} [SEP] {umls_index[random_cui]['name']}",
                    "label": 0
                })

all_examples = positive_examples + negative_examples
random.shuffle(all_examples)
df = pd.DataFrame(all_examples)
print(df)"""

In [None]:
if __name__ == '__main__':
    # Assume the files are in the "colab" directory in Google Drive
    medmentions_filepath = "/content/drive/My Drive/Colab/medmentions.json.gz"
    umls_filepath = "/content/drive/My Drive/Colab/umls.json.gz"

    # Load data from the files
    try:
        medmentions_data = load_json_gz(medmentions_filepath)
        umls_data = load_json_gz(umls_filepath)
    except FileNotFoundError:
        print(f"Error: Could not find either '{medmentions_filepath}' or '{umls_filepath}'. Please make sure these files are in the specified directory on Google Drive.")
        samples = [] # return an empty list
        exit()
    else: # only create samples if no error
        # Create samples
        samples = create_entity_linking_samples(medmentions_data, umls_data, sample_size=100, max_negative_per_positive=2)

        df = pd.DataFrame(samples)
        df = df.sample(frac=1).reset_index(drop=True) # shuffle the dataframe
        
        print(df)
        
    # Split dataset
    train_texts, temp_texts, train_labels, temp_labels = train_test_split(
        df['text'], df['label'], test_size=0.3, random_state=42
    )
    val_texts, test_texts, val_labels, test_labels = train_test_split(
        temp_texts, temp_labels, test_size=0.5, random_state=42
    )

    # Transfer to HuggingFace data structure
    train_dataset = Dataset.from_dict({"text": train_texts.tolist(), "label": train_labels.tolist()})
    val_dataset = Dataset.from_dict({"text": val_texts.tolist(), "label": val_labels.tolist()})
    test_dataset = Dataset.from_dict({"text": test_texts.tolist(), "label": test_labels.tolist()})

    # Load tokenizer and model
    model_name = "allenai/biomed_roberta_base"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

    # Tokenize datasets
    def tokenize_function(examples):
        return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=512)

    tokenized_train = train_dataset.map(tokenize_function, batched=True)
    tokenized_val = val_dataset.map(tokenize_function, batched=True)
    tokenized_test = test_dataset.map(tokenize_function, batched=True)

    # HuggingFace Trainer
    training_args = TrainingArguments(
        output_dir="./results",
        evaluation_strategy="epoch",
        learning_rate=5e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=1,
        weight_decay=0.01,
        logging_dir="./logs",
        save_strategy="epoch",
    )

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
        tokenizer=tokenizer,
    )

    # Training
    trainer.train()
    print("Training finished")

In [None]:
# Evaluation function
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = logits.argmax(axis=-1)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average="binary")
        acc = accuracy_score(labels, predictions)
        return {"accuracy": acc, "precision": precision, "recall": recall, "f1": f1}

    # Use Trainer to evaluate test datasets
    test_results = trainer.predict(tokenized_test)
    metrics = compute_metrics((test_results.predictions, test_results.label_ids))
    print(metrics)