# Set Up

In [None]:
import pandas as pd
from tqdm.auto import tqdm
import re
import ast
import nltk
from nltk.tokenize import sent_tokenize

tqdm.pandas()

# Filter

In [None]:
# Load PubMedQA
pubmed = pd.read_parquet("hf://datasets/qiaojin/PubMedQA/pqa_artificial/train-00000-of-00001.parquet")
pubmed_raw_df = pubmed[['question', 'context', 'long_answer', "final_decision"]]
pubmed_raw_df.columns = ['question', 'context', 'answer', "final_decision"]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
# Extract context from context dictionary

# Function to extract 'contexts' from each row
def extract_contexts(row):

    row_dict = row['context']

    # Get the list of contexts
    return " ".join(row_dict.get("contexts", []))  # join all contexts as a single string

# Function to extract 'contexts' from each row
def extract_meshes(row):

    row_dict = row['context']

    # Get the list of meshes — return as list (not string)
    return row_dict.get("meshes", [])

# Apply to the dataframe
pubmed_raw_df['extracted_contexts'] = pubmed_raw_df.progress_apply(extract_contexts, axis=1)
pubmed_raw_df['meshes'] = pubmed_raw_df.progress_apply(extract_meshes, axis=1)
pubmed_raw_df = pubmed_raw_df.drop('context', axis=1)

  0%|          | 0/211269 [00:00<?, ?it/s]

  0%|          | 0/211269 [00:00<?, ?it/s]

In [None]:
# Filter to diabetes related QAs
diabetes_mesh_terms = [
    "diabetes mellitus",
    "diabetes mellitus, type 1",
    "diabetes mellitus, type 2",
    "gestational diabetes",
    "prediabetic state",
    "hyperglycemia",
    "hypoglycemia",
    "insulin resistance",
    "glucose intolerance",
    "glucose metabolism disorders",
    "glycated hemoglobin a",
    "hba1c",
    "insulin",
    "metformin",
    "diabetic retinopathy",
    "diabetic nephropathy",
    "diabetic neuropathy",
    "diabetic foot",
    "diabetic ketoacidosis",
    "metabolic syndrome"
]

def is_diabetes_related(meshes):
    """Check if MeSH terms contain diabetes-related keywords."""
    if isinstance(meshes, list):
        mesh_text = " ".join(meshes).lower()
    else:
        mesh_text = str(meshes).lower()
    return any(term in mesh_text for term in diabetes_mesh_terms)

# Apply filter
tqdm.pandas(desc="Filtering for diabetes-related MeSH terms")
pubmed_raw_df["is_diabetes_related"] = pubmed_raw_df["meshes"].progress_apply(is_diabetes_related)

# Create diabetes-only subset
diabetes_df = pubmed_raw_df[pubmed_raw_df["is_diabetes_related"]].reset_index(drop=True).drop(['is_diabetes_related', 'meshes'], axis=1)

diabetes_df.to_excel("raw_diabetes_df.xlsx", index=False)

print(f"Extracted {len(diabetes_df)} diabetes-related samples out of {len(pubmed_raw_df)} total.")

Filtering for diabetes-related MeSH terms:   0%|          | 0/211269 [00:00<?, ?it/s]

Extracted 13739 diabetes-related samples out of 211269 total.


# Cleaning

In [None]:
# Data cleaning

def remove_irrelevant(text):
    # Iteratively remove parentheses that contain URLs, NCT ids, DOI, ClinicalTrials, etc.
    paren_pattern = re.compile(
        r'\s*\([^()]*?(?:https?://|www\.|\.gov\b|\.edu\b|\.org\b|\.com\b|\.net\b|doi|clinicaltrials|nct\d+)[^()]*\)',
        flags=re.IGNORECASE
    )
    # remove inner-most matches repeatedly so nested parentheses are handled
    while True:
        new_text = paren_pattern.sub('', text)
        if new_text == text:
            break
        text = new_text

    # Normalise all caps text
    acronyms = ["PCOS", "ICSI", "HbA1c", "HIV", "BMI", "DNA", "RNA", "mRNA", "IRCT"]
    # If mostly uppercase, convert to sentence case
    if re.search(r'[A-Z]{3,}', text) and text.isupper():
        text = text.lower().capitalize()
    # Restore known acronyms
    for acronym in acronyms:
        text = re.sub(r'\b' + acronym.lower() + r'\b', acronym, text, flags=re.IGNORECASE)

    # Remove anything in parentheses containing "ABSTRACT TRUNCATED"
    text = re.sub(r'\s*\([^)]*(ABSTRACT TRUNCATED|REFERENCE NUMBER|REGISTRATION NUMBER)[^)]*\)', '', text, flags=re.IGNORECASE)

    # Remove empty parentheses that may remain (e.g. "()")
    text = re.sub(r'\(\s*\)', '', text)

    # Remove entire sentences that contain standalone websites (not in parentheses)
    #    (after removing parentheses above, these are truly standalone)
    text = re.sub(
        r'[^.?!]*\b(?:https?://\S+|www\.\S+|\S+\.gov|\S+\.edu|\S+\.org|\S+\.com|\S+\.net)\b[^.?!]*[.?!]',
        '',
        text,
        flags=re.IGNORECASE
    )

    # Remove any reference lines
    # Escape special regex chars in phrase, to match literally
    pattern = re.compile(r'[^.?!]*\b' + re.escape("GAUK No") + r'\b[^.?!]*[.?!]', re.IGNORECASE)
    text = re.sub(pattern, '', text)

    # Remove phrases like "Chinese Abstract." or similar
    text = re.sub(r':?\s*[A-Z][a-z]+\sAbstract\.', '', text)

    return text.strip()

def clean_symbols(text):
    # Unicode normalization
    text = re.sub(r'\s+x\s+', ' × ', text)

    # Replace textual operator variants with Unicode forms
    # Handle "<=" and "< or =" variants → ≤
    text = re.sub(r'<\s*(?:or\s*=|/?=|=)', '≤', text, flags=re.IGNORECASE)
    # Handle ">=" and "> or =" variants → ≥
    text = re.sub(r'>\s*(?:or\s*=|/?=|=)', '≥', text, flags=re.IGNORECASE)
    # Handle "+=" and "+ or =" variants → ±
    text = re.sub(r'\+\s*(?:or\s*-|/?-|-)', '±', text, flags=re.IGNORECASE)
    text = text.replace('=>', '≥').replace('=<', '≤')

    # Fix incomplete decimals
    text = re.sub(r'(?<![\d\w])\.(\d+)', r'0.\1', text)
    text = re.sub(r'(?<=\d)·(?=\d)', '.', text)  # only replace middle dots between digits

    # Keep % tight to preceding number
    text = re.sub(r'(\d)\s*%', r'\1%', text)

    # Normalize parentheses/brackets
    text = re.sub(r'\(\s+', '(', text)
    text = re.sub(r'\s+\)', ')', text)

    # Clean up leftover spacing / stray punctuation
    # remove space before punctuation, collapse multiple spaces/newlines
    text = re.sub(r'\s+([.,;:?!])', r'\1', text)
    text = re.sub(r'\s{2,}', ' ', text)
    return text.strip()

def normalize_bullets(text):
    # Remove bullet "•" points
    text = re.sub(r'•\s*', '', text)

    # Replace "i. e. " with "i.e. "
    text = re.sub(r'\b\(?i\.\s*e\.\s?\b', "i.e. ", text, flags=re.IGNORECASE)
    text = re.sub(r'\b\(?e\.\s*g\.\s?\b', 'e.g. ', text, flags=re.IGNORECASE)

    pattern = re.compile(
        r'(?:(?<=^)|(?<=\n)|(?<=\. ))'  # start of line, after newline, or after ". "
        r'(?<![=])'                     # not immediately preceded by =
        r'(?<![A-Za-z0-9])'             # not preceded by a letter or digit
        r'(?<![.)-])'                   # not preceded by ., ), or -
        r'(?<!\= )'                      # not preceded by "= " (equals + space)
        r'\(?'
        r'([0-9]+|[a-z]|[ivx]+)'        # digits, letters, or roman numerals
        r'\)?'
        r'[.)]'
        r'(?=\s)'                        # must be followed by a space
    )

    text = re.sub(pattern, "", text)

    # Clean multiple spaces
    text = re.sub(r'\s{2,}', ' ', text)
    return text.strip()

def clean_question(text):
    text = text.strip()

    # Remove spaces before ?
    text = re.sub(r'\s+\?', '?', text)

    # Add ? if missing
    if text and not text.endswith("?"):
        text += "?"

    # Capitalize first letter
    if text:
        text = text[0].upper() + text[1:]

    # Remove everything between ':' and '?'
    text = re.sub(r'\s*:\s*[^?]*\?', '?', text)

    # Fix spacing before punctuation
    text = re.sub(r'\s+:', r':', text)

    # Normalize comma spacing: ensure ", " format
    text = re.sub(r'\s*,\s*', ', ', text)

    return text

def ensure_final_punctuation(text):
    text = text.rstrip()  # remove trailing whitespace
    # Remove spaces before final period if present
    text = re.sub(r'\s+\.$', '.', text)
    if text and text[-1] not in '.!?':
        text += '.'
    return text

def general_cleaning(text):
    """
    General text cleaning:
    1. Remove numbers at the end of a sentence.
    2. Collapse multiple spaces into one (excluding line breaks).
    """
    # Remove numbers at the end of lines
    text = re.sub(r'\b\d+\b(?=\s*$)', '', text, flags=re.MULTILINE)

    # Collapse multiple spaces into one (excluding line breaks)
    text = re.sub(r'[ ]{2,}', ' ', text)

    # Strip leading/trailing spaces on each line
    text = "\n".join(line.strip() for line in text.split("\n"))

    # Remove any ":" or ";" at the start or end
    text = re.sub(r'^[;:]+|[;:]+$', '', text)

    return text.strip()

def preprocess_qa(df):
    processed_questions = []
    processed_answers = []
    processed_contexts = []

    # Initialize tqdm progress bar
    tqdm.pandas(desc="Processing Q&A pairs")

    for _, row in tqdm(df.iterrows(), total=len(df), desc="Cleaning QAs", ncols=100):
        question = row['question']
        answer = row['answer']
        context = row['extracted_contexts']

        # Clean question
        question = clean_question(question)

        # Puntuating
        answer = ensure_final_punctuation(answer)
        context = ensure_final_punctuation(context)

        # Remove irrelevant metadata
        answer = remove_irrelevant(answer)
        context = remove_irrelevant(context)

        # Normalize all symbols
        answer = clean_symbols(answer)
        context = clean_symbols(context)

        # Standardize bullets (including line-break bullets)
        answer = normalize_bullets(answer)
        context = normalize_bullets(context)

        # General cleaning step
        question = general_cleaning(question)
        answer = general_cleaning(answer)
        context = general_cleaning(context)

        # Puntuating
        answer = ensure_final_punctuation(answer)
        context = ensure_final_punctuation(context)

        processed_questions.append(question)
        processed_answers.append(answer)
        processed_contexts.append(context)

    df['question_clean'] = processed_questions
    df['answer_clean'] = processed_answers
    df['context_clean'] = processed_contexts

    # Drop rows where cleaned question or answer is blank
    df = df[(df['question_clean'].str.strip() != '') & (df['answer_clean'].str.strip() != '') & (df['context_clean'].str.strip() != '')]

    # Drop duplicates based on cleaned question and answer
    df = df.drop_duplicates(subset=['question_clean', 'answer_clean', 'context_clean']).reset_index(drop=True)

    return df

pubmed_clean_df = preprocess_qa(diabetes_df)
pubmed_clean_df.to_excel("pubmed_clean.xlsx", index=False)

Cleaning QAs:   0%|                                                       | 0/13739 [00:00<?, ?it/s]

# QA and summarisation

In [None]:
def create_all_datasets(df_clean, sentences_per_passage=3):
    # Initialize lists
    qa_rows = []
    summarization_rows = []
    rag_rows = []

    passage_counter = 0
    question_counter = 0

    for _, row in tqdm(df_clean.iterrows(), total=len(df_clean), desc="Processing data", ncols=100):
        question = row['question_clean']
        context = row['context_clean']
        answer = row['answer_clean']
        final_decision = row['final_decision']

        if pd.isna(question) or pd.isna(context) or pd.isna(answer):
            continue

        # QA dataset
        qa_rows.append({'question': question, 'context': context, 'answer': answer, 'final_decision': final_decision})

        # Summarisation dataset
        summarization_rows.append({'context': context, 'summary': answer})

        # RAG dataset: split context into passages
        sentences = sent_tokenize(context)
        for i in range(0, len(sentences), sentences_per_passage):
            passage_sentences = sentences[i:i+sentences_per_passage]
            passage_text = " ".join(passage_sentences).strip()

            if passage_text:  # skip empty passages
                rag_rows.append({
                    'question_id': question_counter,
                    'question': question,
                    'final_decision': final_decision,
                    'answer': answer,
                    'passage': passage_text,
                    'passage_id': passage_counter
                })
                passage_counter += 1

        question_counter += 1

    # Create DataFrames
    qa_dataset = pd.DataFrame(qa_rows).drop_duplicates().reset_index(drop=True)
    summarization_dataset = pd.DataFrame(summarization_rows).drop_duplicates().reset_index(drop=True)
    rag_dataset = pd.DataFrame(rag_rows).drop_duplicates().reset_index(drop=True)

    return qa_dataset, summarization_dataset, rag_dataset

qa_data, summarization_data, rag_data = create_all_datasets(pubmed_clean_df, sentences_per_passage=3)

# Save to Excel
qa_data.to_excel("qa_dataset.xlsx", index=False)
summarization_data.to_excel("summarization_dataset.xlsx", index=False)
rag_data.to_excel("rag_dataset.xlsx", index=False)


Processing data:   0%|                                                    | 0/13738 [00:00<?, ?it/s]