In [None]:
#!pip install -q -U torch --index-url https://download.pytorch.org/whl/cu117
#!pip install -q -U transformers=="4.38.2"
#!pip install -q -U transformers=="4.30"
#!pip install -q accelerate
#!pip install -q -i https://pypi.org/simple/ bitsandbytes
!pip install bitsandbytes==0.41.3
!pip install -q -U sentence_transformers
!pip install -q -U scann
!pip install -q num2words

In [None]:
pip install inflect

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import json
import re
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import scann
import inflect
import nltk
from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer
from num2words import num2words

import torch

import transformers
from transformers import (AutoModelForCausalLM,
                          AutoTokenizer,
                          BitsAndBytesConfig,
                         )
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
import bitsandbytes as bnb

In [None]:
def define_device():

    defined_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"using {defined_device}")

    return defined_device

In [None]:
nltk.download('stopwords')

In [None]:
#stop_words = stopwords.words("english")
#stemmer = SnowballStemmer("english")

In [None]:
# TEXT CLEANING
TEXT_CLEANING = "@\S+|https?:\S+|http?:\S|[^A-Za-z0-9]+"

def replace_numbers(match):
    number = int(match.group(0))
    return num2words(number)

def preprocess(text, stem=False):

    text = re.sub(TEXT_CLEANING, ' ', str(text).lower()).strip()
    tokens = []

    for token in text.split():
        if token not in stop_words:
            if stem:
                tokens.append(stemmer.stem(token))
            else:
                tokens.append(token)

    return " ".join(tokens)

In [None]:
def ensure_age_format(entry):
    if isinstance(entry, str):
        if entry[-1] == 'Y':
            entry = entry[:-1]  # Remove the 'Y' suffix
    return entry

In [None]:
def convert_num_to_word(entry):
    if isinstance(entry, (int, float)):
        return f"{num2words(int(entry))} years old"
    elif isinstance(entry, str) and entry.isnumeric():
        return f"{num2words(int(entry))} years old"
    elif isinstance(entry, str) and entry[-1].isalpha():
        return entry  # If already in desired format, leave unchanged
    return entry  # For other cases, leave unchanged

In [None]:
def replace_hyphens_with_to(text):
    # Regular expression to match ranges like '5-6', '15-25', etc.
    range_pattern = re.compile(r'(\d+)-(\d+)')

    # Function to replace hyphen with 'to' in number ranges
    def replace_range(match):
        start_number = match.group(1)
        end_number = match.group(2)
        return f"{start_number} to {end_number}"

    # Use re.sub to replace all hyphens in number ranges with 'to'
    converted_text = range_pattern.sub(replace_range, text)

    return converted_text

In [None]:
def convert_numbers_to_words(text):
    # Create an instance of the inflect engine
    p = inflect.engine()

    # Regular expression to match numbers, number ranges, and units in the text
    number_pattern = re.compile(r'(\d+)-(\d+)(Hz|sec|uv)?|(\d+)(Hz|sec|uv)?')

    # Function to replace a number or number range with its word representation
    def replace_number(match):
        if match.group(1) and match.group(2):
            # This is a range with optional unit
            start_number = int(match.group(1))
            end_number = int(match.group(2))
            unit = match.group(3) if match.group(3) else ''
            return f"{p.number_to_words(start_number)} to {p.number_to_words(end_number)}{(' ' + unit) if unit else ''}"
        else:
            # This is a single number with optional unit
            number = int(match.group(4))
            unit = match.group(5) if match.group(5) else ''
            return f"{p.number_to_words(number)}{(' ' + unit) if unit else ''}"

    # Use re.sub to replace all numbers, number ranges, and units in the text
    converted_text = number_pattern.sub(replace_number, text)

    return converted_text

In [None]:
def replace_lr(sentence):
    """
    Replaces 'l' with 'left' and 'r' with 'right' in the given sentence.

    Args:
    sentence (str): The input sentence to modify.

    Returns:
    str: The modified sentence with 'l' replaced by 'left' and 'r' replaced by 'right'.
    """
    sentence = sentence.replace('(L)', 'left')
    sentence = sentence.replace('(L )', 'left')
    sentence = sentence.replace('(R)', 'right')
    sentence = sentence.replace('(R )', 'right')
    return sentence

In [None]:
def remove_last_number(sentence):
    # Using regular expression to find the last number in the sentence
    match = re.search(r'\(\d{2}/\d{2}/\d{4}\)$', sentence)
    if match:
        # If a number is found at the end of the sentence, remove it
        last_number = match.group()
        sentence = sentence[:-len(last_number)].strip()
    return sentence

In [None]:
data = pd.read_csv("EEG Topic Model Label.csv", on_bad_lines='skip', encoding = 'latin-1')

In [None]:
data = data.drop(['testdate', 'date', 'initial'], axis=1)

In [None]:
data = data.dropna()

In [None]:
data.head()

In [None]:
data['age'] = data['age'].apply(ensure_age_format)
#data['age'] = data['age'].apply(convert_num_to_word)
data.conclusion = data.conclusion.apply(lambda x: remove_last_number(x))
data['text'] =  data['age'] + ' years old ' + data['List43'] + ' ' + data['conclusion']

In [None]:
data.text = data.text.apply(lambda x: replace_hyphens_with_to(x))
#data.text = data.text.apply(lambda x: convert_num_to_word(x))
#data.text = data.text.apply(lambda x: convert_numbers_to_words(x))
data.text = data.text.apply(lambda x: replace_lr(x))
#data.text = data.text.apply(lambda x: preprocess(x))

In [None]:
data['text'] = 'eegno ' + data['eegno'] + ' RN ' + data['RN'] + ' ' + data['text']

In [None]:
data.head()

In [None]:
class LLMHF():
    """Wrapper for the Transformers implementation of Gemma"""

    def __init__(self, model_name, token, max_seq_length=2048):
        self.model_name = model_name
        self.max_seq_length = max_seq_length

        # Initialize the model and tokenizer
        print("\nInitializing model:")
        self.device = define_device()
        self.token = token
        self.model, self.tokenizer = self.initialize_model(self.model_name, self.device, self.token, self.max_seq_length)

    def initialize_model(self, model_name, device, token, max_seq_length):
        """Initialize a causal language model (LLM) and tokenizer with specified settings"""

        # Define the data type for computation
        compute_dtype = getattr(torch, "float16")
        # Define the configuration for quantization
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=compute_dtype,
        )

        # Load the pre-trained model with quantization configuration
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map='auto',
            token=token,
            quantization_config=bnb_config,
        )

        # Load the tokenizer with specified device and max_seq_length
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            device_map='auto',
            token=token,
            max_seq_length=max_seq_length
        )

        if tokenizer.pad_token is None:
            tokenizer.add_special_tokens({'pad_token': '[PAD]'})

        # Return the initialized model and tokenizer
        return model, tokenizer
    
    def generate_text(self, prompt, max_new_tokens=2048, temperature=0.00):
        """Generate output using LLM"""
        
        # Encode the prompt and convert to PyTorch Tensor
        input_ids = self.tokenizer(prompt, return_tensors="pt", padding=True).input_ids.to(self.device)
        
        # Determine if sampling should be performed based on temperature
        do_sample = temperature > 0
        
        # Generate output based on the input prompt
        outputs = self.model.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature
        )
        
        # Decode the generated output into text
        results = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        return results
    
    def generate_text2(self, prompt, max_new_tokens=2048, temperature=0.0):
          """Generate text using the instantiated tokenizer and model with specified settings"""

          # Encode the prompt and convert to PyTorch tensor
          input_ids = self.tokenizer(prompt, return_tensors="pt", padding=True).input_ids.to(self.device)

          # Determine if sampling should be performed based on temperature
          do_sample = temperature > 0

          # Generate text based on the input prompt
          outputs = self.model.generate(
              input_ids=input_ids,
              max_new_tokens=max_new_tokens,
              do_sample=do_sample,
              pad_token_id=self.tokenizer.eos_token_id,
              temperature=temperature
          )

          # Decode the generated output into text
          results = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
          #results = [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

          # Return the list of generated text results
          return results

In [None]:
def get_embedding(text, embedding_model):
    """Get embeddings for a given text using the provided embedding model"""

    # Encode the text to obtain embeddings using the provided embedding model
    embedding = embedding_model.encode(text, show_progress_bar=False)

    # Convert the embeddings to a list of floats and return
    return embedding.tolist()

def map2embeddings(data, embedding_model):
    """Map a list of texts to their embeddings using the provided embedding model"""

    # Initialize an empty list to store embeddings
    embeddings = []

    # Iterate over each text in the input data list
    no_texts = len(data)
    print(f"Mapping {no_texts} pieces of information")
    for i in tqdm(range(no_texts)):
        # Get embeddings for the current text using the provided embedding model
        embeddings.append(get_embedding(data[i], embedding_model))

    # Return the list of embeddings
    return embeddings

In [None]:
def clean_text(txt, EOS_TOKEN):
    """Clean text by removing specific tokens and redundant spaces"""
    txt = (txt
           .replace(EOS_TOKEN, "") # Replace the end-of-sentence token with an empty string
           .replace("**", "")      # Replace double asterisks with an empty string
           .replace("<pad>", "")   # Replace "<pad>" with an empty string
           .replace("  ", " ")     # Replace double spaces with single spaces
          ).strip()                # Strip leading and trailing spaces from the text
    return txt

In [None]:
class EEGIR():
    """Information Retrieval based on a provided knowledge base"""

    def __init__(self, llm_model, embeddings_name, max_new_tokens,temperature=0.4, role="expert"):
        """Initialize the AI assistant."""
        # Initialize attributes
        self.max_new_tokens = max_new_tokens
        self.embeddings_name = embeddings_name
        self.knowledge_base = []
        self.temperature = temperature
        self.role = role

        # Initialize Gemma model (it can be transformer-based or any other)
        self.llm_model = llm_model

        # Load the embedding model
        self.embedding_model = SentenceTransformer(self.embeddings_name)

    def store_knowledge_base(self, knowledge_base):
        """Store the knowledge base"""
        self.knowledge_base = knowledge_base

    def learn_knowledge_base(self, knowledge_base):
        """Store and index the knowledge based to be used by the assistant"""
        # Storing the knowledge base
        self.store_knowledge_base(knowledge_base)

        # Load and index the knowledge base
        print("Indexing and mapping the knowledge base:")
        embeddings = self.map2embeddings(self.knowledge_base)
        self.embeddings = np.array(embeddings).astype(np.float32)

        # Instantiate the searcher for similarity search
        self.index_embeddings()

    def map2embeddings(self, knowledge_base, embedding_model=None):
        """Map knowledge base texts to embeddings"""
        if embedding_model is None:
            embedding_model = self.embedding_model
        return [embedding_model.encode(text) for text in knowledge_base]

    def index_embeddings(self):
        """Index the embeddings using ScaNN """
        self.searcher = (scann.scann_ops_pybind.builder(db=self.embeddings, num_neighbors=50, distance_measure="dot_product")
                 .tree(num_leaves=min(self.embeddings.shape[0] // 2, 1000),
                       num_leaves_to_search=100,
                       training_sample_size=self.embeddings.shape[0])
                 .score_ah(2, anisotropic_quantization_threshold=0.2)
                 .reorder(100)
                 .build()
           )

    def query(self, query):
        """Query the knowledge base of the AI assistant."""
        # Generate relevant documents and a summary for the query
        output = generate_retrieved_table(query,
                                          self.knowledge_base,
                                          self.searcher,
                                          self.embedding_model,
                                          self.llm_model,
                                          max_new_tokens=self.max_new_tokens,
                                          temperature=self.temperature,
                                          role=self.role)
        #output_json = json.dumps(output, indent=4)
        #print(output)
        return output

    def set_temperature(self, temperature):
        """Set the temperature (creativity) of the IR."""
        self.temperature = temperature

    def set_role(self, role):
        """Define the answering style of the IR."""
        self.role = role

    def save_embeddings(self, filename="embeddings.npy"):
        """Save the embeddings to disk"""
        np.save(filename, self.embeddings)

    def load_embeddings(self, filename="embeddings.npy"):
        """Load the embeddings from disk and index them"""
        self.embeddings = np.load(filename)
        # Re-instantiate the searcher
        self.index_embeddings()

    def save_kb_embeddings(self, filename="embeddings.npy", knowledge_base_filename="knowledge_base.json"):
        """Save the embeddings and knowledge base to disk"""
        np.save(filename, self.embeddings)
        with open(knowledge_base_filename, 'w') as f:
            json.dump(self.knowledge_base, f)
        print(f"Embeddings saved to {filename} and knowledge base saved to {knowledge_base_filename}")

    def load_kb_embeddings(self, filename="embeddings.npy", knowledge_base_filename="knowledge_base.json"):
        """Load the embeddings and knowledge base from disk and index them"""
        self.embeddings = np.load(filename)
        with open(knowledge_base_filename, 'r') as f:
            self.knowledge_base = json.load(f)
        self.index_embeddings()
        print(f"Embeddings loaded from {filename} and knowledge base loaded from {knowledge_base_filename}")

In [None]:
import  gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
embeddings_name = "thenlper/gte-large"
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
hf_token = "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"

In [None]:
llamaIR = EEGIR(llm_model=LLMHF(model_name,token=hf_token), embeddings_name=embeddings_name, max_new_tokens=4096)

In [None]:
#llamaIR.learn_knowledge_base(knowledge_base=data['text'].tolist())

In [None]:
#gemmaIR.save_kb_embeddings()

In [None]:
llamaIR.load_kb_embeddings("embeddings.npy", "knowledge_base.json")

In [None]:
def generate_retrieved_table(question, context, searcher, embedding_model, model,
                             max_new_tokens=4096, temperature=0.2, role="expert"):
    """Generate an answer for a given question using context from a dataset"""
    
    # Embed the input question using the provided embedding model
    embedded_question = np.array(get_embedding(question, embedding_model)).reshape(1, -1)
    
    # Find similar contexts in the dataset based on the embedded question
    neighbors, distances = searcher.search_batched(embedded_question)
    
    # Extract context from the dataset based on the indices of similar contexts
    relevant_contexts = " ".join([context[pos] for pos in np.ravel(neighbors)])
    
    # Split the contexts
    context_lines = relevant_contexts.split(" eegno ")
    #context_lines = ["eegno " + line.strip() for line in context_lines if line.strip()]
    context_lines = [line.strip() for line in context_lines if line.strip()]
    #print(context_lines)
    
    # Get the end-of-sentence token from the tokenizer
    try:
        EOS_TOKEN = model.tokenizer.eos_token
    except AttributeError:
        EOS_TOKEN = "<eos>"
    
    # Generate a prompt for providing an answer
    prompt = f"""
    As an Information Extractor and Table Organizer Expert, your task is to extract information from the provided context 
    and organize it into JSON objects. Only give the JSON output. I will tip you $10000000 if you complete the task correctly.

    Given the following context strings:
    
    {context_lines}
    
    The text format is:
    - eegno: followed by a unique identifier
    - RN: followed by a patient ID
    - Age: followed by the patient's age in years
    - Gender: followed by the patient's gender
    - eeg: followed by the EEG test result (e.g., abnormal finding, normal)
    - location: followed by the location stated in the result (e.g., focal, generalized, diffuse, temporal, frontal, occipital, bitemporal, frontotemporal, bilateral)
    - epileptiform: followed by the epileptiform abnormal paroxysmal activity in the result (e.g., sharp, spike, polyspike, wave)
    - seizure: followed by the seizure activity in the result (e.g., interictal, sharp, spike, ictal)
    - diagnosis: followed by the diagnosed disease stated in the result (e.g., seizure, epilepsy, cortical, dysfunction, encephalopathy, hypoxia, metabolic)
    - PSG: followed by the rest activity stated in the result (e.g., sleep, narcolepsy, rem, spindle, vertex)
    - non-epileptiform: followed by seizure that are not due to epilepsy in the result (e.g., slow, fast, beta, delta, theta, periodic)
    
    Extract the eegno, RN, age, gender, conclusion, location, epileptiform, seizure, diagnosis, PSG, and non-epileptiform for each test and present them in a JSON array.
    
    Here's an example context string:
    
    eegno 12345 RN 67890 25 years old female eeg abnormal finding
    eegno 54321 RN 12345 30 years old male eeg normal
    eegno 23456 RN 98762 70 years old male eeg shows normal waking sleep background infrequent left temporal sharp waves sharp slow waves could consistent partial epilepsy
    eegno 98735 RN 23018 2 years old male one episode sharp waves posterior region queried significance may indicate epilepsy confirm epilepsy needs clinical correlation
    eegno 34528 RN 82345 23 years old female eeg showed generalized attenuation with frontocentral delta activities of 1 to 2 Hz. There was no response to sensory stimuli. This is consistent with severe cortical dysfunction.
    
    The corresponding JSON output should look like this:
    
    [
        {{"eegno": "12345", "RN": "67890", "age": "25", "gender": "female", "conclusion": "eeg abnormal finding", "location": "None", "epileptiform": "None", "seizure": "None", "diagnosis": "normal", "PSG", "None", "non-epileptiform": "None"}},
        {{"eegno": "54321", "RN": "12345", "age": "30", "gender": "male", "conclusion": "eeg normal", "location": "None", "epileptiform": "None", "seizure": "None", "diagnosis": "normal", "PSG", "None", "non-epileptiform": "None"}},
        {{"eegno": "23456", "RN": "98762", "age": "70", "gender": "male", "conclusion": "eeg shows normal waking sleep background infrequent left temporal sharp waves sharp slow waves could consistent partial epilepsy", "location": "temporal", "epileptiform": "sharp waves", "seizure": "sharp", "diagnosis": "epilepsy", "PSG", "sleep", "non-epileptiform": "slow"}},
        {{"eegno": "98735", "RN": "23018", "age": "2", "gender": "male", "conclusion": "one episode sharp waves posterior region queried significance may indicate epilepsy confirm epilepsy needs clinical correlation", "location": "posterior", "epileptiform": "sharp waves", "seizure": "sharp", "diagnosis": "epilepsy", "PSG", "None", "non-epileptiform": "None"}},
        {{"eegno": "34528", "RN": "82345", "age": "23", "gender": "female", "conclusion": "eeg showed generalized attenuation with frontocentral delta activities of 1 to 2 Hz. There was no response to sensory stimuli. This is consistent with severe cortical dysfunction.", "location": "frontocentral", "epileptiform": "None", "seizure": "None", "diagnosis": "severe cortical dysfunction", "PSG", "None", "non-epileptiform": "delta"}}
    ]
    
    Start your output with: 'Here's the JSON Output'.
    Process the input lines and provide the JSON output. Only output the complete final JSON output, do not output anything else other than JSON objects.
    """.strip() + EOS_TOKEN

    # Generate an answer based on the prompt
    results = model.generate_text2(prompt, max_new_tokens=max_new_tokens, temperature=temperature)
    
    # Extract JSON output from the generated text
    generated_text = results.split("Here's the JSON Output:")[-1].strip()
    generated_text = str(generated_text)
    print(generated_text)
    
    while True:
        if not generated_text:
            raise ValueError("Couldn't fix JSON")
        try:
            data = json.loads(generated_text + "]")
        except json.decoder.JSONDecodeError:
            generated_text = generated_text[:-1]
            continue
        break
    
    try:
        df = pd.DataFrame(data)
    except json.JSONDecodeError as e:
        print(f"Error parsing JSON: {e}")

    
    # Return the DataFrame
    return df

In [None]:
pd.set_option('display.max_colwidth', None)  # Display full column width
pd.set_option('display.max_rows', None)      # Display all rows
pd.set_option('display.max_columns', None)   # Display all columns
pd.set_option('display.width', None)     

In [None]:
output = llamaIR.query("frontal sharp waves spike discharge")

In [None]:
output

In [None]:
output = llamaIR.query("Status epilepticus.")

In [None]:
output

In [None]:
output = llamaIR.query("Female with normal EEG")

In [None]:
output

In [None]:
# 4. Distribution of Age
plt.figure(figsize=(10, 6))
sns.histplot(output['age'], kde=True, bins=10)
plt.title('Age Distribution')
plt.xlabel('Age')
plt.ylabel('Frequency')
plt.show()

In [None]:
output = llamaIR.query("17 years old male severe cortical dysfunction")

In [None]:
output

In [None]:
plt.figure(figsize=(10, 6))
sns.histplot(output['age'], kde=True, bins=10)
plt.title('Age Distribution')
plt.xlabel('Age')
plt.ylabel('Frequency')
plt.show()

In [None]:
output.dtypes

In [None]:
output['age'] = pd.to_numeric(output['age'])

In [None]:
plt.figure(figsize=(14, 6))
sns.boxplot(x='diagnosis', y='age', data=output)
plt.title('Age vs Diagnosis')
plt.xlabel('Diagnosis')
plt.ylabel('Age')
plt.xticks(rotation=45)
plt.show()

In [None]:
plt.figure(figsize=(14, 6))
sns.countplot(y='location', data=output, order=output['location'].value_counts().index)
plt.title('Location Distribution')
plt.xlabel('Count')
plt.ylabel('Location')
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
sns.scatterplot(x='age', y='epileptiform', hue='gender', data=output)
plt.title('Age vs Epileptiform Severity')
plt.xlabel('Age')
plt.ylabel('Epileptiform Severity')
plt.show()

In [None]:
output = llamaIR.query("Male eeg consistent with partial epilepsy disorder.")

In [None]:
output

In [None]:
output = llamaIR.query("Male with bitemporal sharp waves")

In [None]:
output

In [None]:
output = llamaIR.query("Patient on temporal")

In [None]:
output

In [None]:
output = llamaIR.query("Patient with vertex sharp waves under sleep condition")

In [None]:
output

In [None]:
output = llamaIR.query("Patient with delta or beta activity")

In [None]:
output