In [16]:
import pandas as pd
import numpy as np
import torch
import transformers
from transformers import AutoTokenizer, AutoModel


LLM

In [None]:
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id


gpu = 0
device = f"cuda:{gpu}" if torch.cuda.is_available() else "cpu"

model_nf4 = transformers.AutoModelForCausalLM.from_pretrained(model_id, 
                                                 device_map={"": device})

pipeline = transformers.pipeline(
    "text-generation",
    model= model_nf4, #model_id,
    tokenizer=tokenizer)

In [6]:
def questions_answers(text):
    messages = [
    #    {"role": "system", "content": "You are an expert in creating key questions from a medical text and extract the answers from the text. Extract 3-10 Q/A pairs without repititions of key entities in the Q/As. Avoid general questions like 'What is the exclusion criteria?'. Make sure an answer is NO MORE than 5 tokens/words. Output as json format like this: {'Question': 'question1', 'Answer': 'answer1', 'Question': 'question2' , 'Answer': 'answer2', ...} \n Input: "},
        {"role": "system", "content": "You are an expert in creating key questions from a medical text and extract the answers from the text. Extract 3-10 Q/A pairs without repititions of key entities in the Q/As. Avoid general questions like 'What is the exclusion criteria?'. Make sure an answer is NO MORE than 5 tokens/words. Output ONLY json formated Q/A pairs like this: {'Question': 'question1', 'Answer': 'answer1'} \n {'Question': 'question2' , 'Answer': 'answer2'} \n ... \n Input: "},
        {"role": "user", "content": text}]
    
    prompt = pipeline.tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
    )

    terminators = [
        pipeline.tokenizer.eos_token_id,
        pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    outputs = pipeline(
        prompt,
        max_new_tokens=1024,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.1,
        top_p=0.9,
    )
    #print(outputs[0]["generated_text"][len(prompt):])
    return outputs[0]["generated_text"][len(prompt):]

In [9]:
import ast

def parse_q_a_criteria(q_a_criteria):
    """
    Converts a string of question-answer pairs into a list of formatted strings.

    Parameters:
    q_a_criteria (str): Input string containing question-answer pairs.

    Returns:
    list: List of strings combining questions and answers.
    """
    result = []
    # Split by newline to handle individual JSON-like entries
    for entry in q_a_criteria.split('\n'):
        try:
            # Safely evaluate the string to a dictionary
            qa_dict = ast.literal_eval(entry)
            if 'Question' in qa_dict and 'Answer' in qa_dict:
                # Format the question-answer pair
                result.append(f"{qa_dict['Question']} {qa_dict['Answer']}")
        except (ValueError, SyntaxError):
            continue
    return result

Q/A generation using LLM

In [12]:
nct_id = 'NCT03134911'
intervention = 'DOAC or VKA, VKA'
disease = 'Atrial Fibrillation'
title = "Health-related Quality of Life in Patients on Anticoagulants (RE-QUOL)"
outcome_measures = "Health Related Quality of Life (QoL) (HRQoL) Scores in the Spanish Adaptation of the Sawicki Questionnaire"
keywords = "Arrhythmias, Cardiac, Heart Diseases, Cardiovascular Diseases, Pathologic Processes, Pathological Conditions, Signs and Symptoms, Atrial Fibrillation, N(4)-oleylcytosine arabinoside"
criteria = 'Inclusion Criteria:\n The patient is willing and provides written informed consent to participate in this study.\nThe patient is at least 18 years of age \nThe patient has a diagnosis of non-valvular atrial fibrillation \nThe patient is on the same anticoagulant therapy (VKA or DOAC) during at least 6 months and maximum 2 years. \nIf treated with VKA, availability of % Time in Therapeutic Range (TTR) in past analytical records or enough amount of International Normalized Ratio (INR) measures to calculate it. \nExclusion Criteria:\nCurrent participation in any clinical trial of a drug or device\nContraindication to the use of DOAC or VKA as described in the Summary of Product Characteristics (SmPC).'

In [13]:
q_a_set = parse_q_a_criteria(questions_answers(criteria))

Predefined Q/A

In [14]:
q_a_set.append('What are the drugs used? '+ intervention)
q_a_set.append('What is the disease treated in this trial? '+ disease)
q_a_set.append('What is the title of the trial? '+ title)
q_a_set.append('What are the outcome measures? '+ outcome_measures)
q_a_set.append('What are the keywords? '+ keywords)

Load SECRET

In [17]:
# Load the BioBERT model and tokenizer
model_name = "dmis-lab/biobert-base-cased-v1.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

state_dict = torch.load('models/global_model.pth')

# Remove `module.` prefix if present
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] if k.startswith('module.') else k  # remove 'module.' prefix
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)
model.to(device)
model.eval()

  state_dict = torch.load('models/global_model.pth')


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [18]:
def embed_text(text):
    # Tokenize input text
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    
    # Pass through the model
    with torch.no_grad():
        outputs = model(**inputs)

    # Extract the embeddings (hidden states from the last layer)
    # outputs.last_hidden_state -> (batch_size, sequence_length, hidden_size)
    embeddings = outputs.last_hidden_state.to(device)

    # Pool the embeddings (e.g., by taking the mean across the sequence length)
    pooled_embeddings = embeddings.mean(dim=1)

    return pooled_embeddings

Get embedding

In [19]:
#convert q_a_set to string and then embed
q_a_string = " ".join(q_a_set)
q_a_embedding = embed_text(q_a_string)
print(q_a_embedding.shape)  # Should print: torch.Size([1, hidden_size])

torch.Size([1, 768])
