In [None]:
import json
import os
import pandas as pd
from pathlib import Path
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List

In [None]:
# Define a central location for storing models
CENTRAL_MODEL_DIR = os.path.expanduser('~/huggingface_models')

# model_name = 'microsoft/phi-2'
# model_name = 'microsoft/phi-1_5'
# model_name = 'microsoft/Phi-3.5-mini-instruct'
# model_name = 'google/gemma-2-9b'
# model_name = 'meta-llama/Meta-Llama-3.1-8B'
# model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
# model_name = 'google/gemma-2-2b-it'
model_name = 'google/gemma-2-9b-it'

# Create the central directory if it doesn't exist
os.makedirs(CENTRAL_MODEL_DIR, exist_ok=True)

# Define the path where the model will be saved locally
local_model_path = os.path.join(CENTRAL_MODEL_DIR, model_name.replace('/', '-'))

In [None]:
# Automatically detect and use GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set up the device map
if torch.cuda.is_available():
    device_map = "auto"  # This will automatically distribute the model across available GPUs
else:
    device_map = {"": device}  # Use the detected device (CPU in this case)

### Optional huggingface login

In [None]:
# from huggingface_hub import notebook_login
# notebook_login()

In [None]:
# Check if the model exists locally
if os.path.exists(local_model_path):
    print(f"Loading model from local path: {local_model_path}")
    original_model = AutoModelForCausalLM.from_pretrained(
        local_model_path,
        device_map=device_map,
        # quantization_config=bnb_config,
        trust_remote_code=True
    )
else:
    print(f"Downloading model from {model_name}")
    original_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map=device_map,
        # quantization_config=bnb_config,
        trust_remote_code=True
    )
    # Save the model locally
    original_model.save_pretrained(local_model_path)
    print(f"Model saved to {local_model_path}")

In [None]:
def create_llm_function(model, tokenizer, max_new_tokens=512, temperature=0.7):
    def llm_function(prompt: str) -> str:
        # Tokenize the input prompt
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

        # Generate the output
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )

        # Decode the output
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Extract the model's response (everything after the prompt)
        response = generated_text[len(prompt):].strip()

        return response

    return llm_function

In [None]:
df = pd.read_csv('../data/ReportsDATASET.csv')

In [None]:
def preprocess_radiology_report(report: str) -> str:
    """
    Preprocesses a radiology report by removing unnecessary whitespace,
    newline characters, and potential HTML tags.

    Args:
    report (str): The original radiology report text

    Returns:
    str: The preprocessed radiology report text
    """
    # Remove any HTML tags
    report = re.sub(r'<[^>]+>', '', report)
    
    # Replace multiple newlines and spaces with a single space
    report = re.sub(r'\s+', ' ', report)
    
    # Remove leading and trailing whitespace
    report = report.strip()
    
    # Replace 'XXXX' with a placeholder like '[REDACTED]'
    report = re.sub(r'XXXX', '[REDACTED]', report)
    
    return report

In [None]:
example_report = df['Text'][0]

In [None]:
preprocessed_report = preprocess_radiology_report(example_report)
print(preprocessed_report)

In [None]:
def post_process_abnormalities(classification_result: str) -> List[str]:
    """
    Takes the JSON string output from classify_abnormalities and returns a list of
    abnormalities that are present (have a value of 1).

    Args:
    classification_result (str): JSON string output from classify_abnormalities

    Returns:
    List[str]: List of abnormalities that are present
    """
    try:
        # Parse the JSON string into a dictionary
        result_dict = json.loads(classification_result)
        
        # Filter the dictionary for keys with value 1
        present_abnormalities = [abnormality for abnormality, value in result_dict.items() if value == 1]
        
        return present_abnormalities
    except json.JSONDecodeError:
        raise ValueError("Invalid JSON string provided")
    except Exception as e:
        raise ValueError(f"Error processing classification result: {str(e)}")

In [None]:
def post_process_llm_output(output: str) -> str:
    """
    Clean up the LLM output by removing code block markers and newlines.
    """
    # Remove code block markers
    output = re.sub(r'```(?:json)?\s*', '', output)
    output = output.replace('`', '')
    
    # Remove newlines
    output = output.replace('\n', ' ')
    
    # Remove any leading/trailing whitespace
    output = output.strip()
    
    return output

In [None]:
def classify_abnormalities(abnormalities: List[str], report: str, llm_function) -> str:
    # Preprocess the report
    preprocessed_report = preprocess_radiology_report(report)

    # Create a dynamic prompt for the LLM
    prompt = f"""
Given the following radiology report, classify the presence (1) or absence (0) of the specified abnormalities.
Output the result as a JSON string without any additional explanation.

Abnormalities to classify: {', '.join(abnormalities)}

Radiology report:
{preprocessed_report}

Output format:
{{
    "abnormality1": 0 or 1,
    "abnormality2": 0 or 1,
    ...
}}
Return a JSON string without any explanation.
"""

    # Call the LLM function with the prompt
    llm_output = llm_function(prompt)

    # Post-process the LLM output
    llm_output = post_process_llm_output(llm_output)

    # Ensure the output is valid JSON
    try:
        result = json.loads(llm_output)
        # Verify that all abnormalities are present in the output
        for abnormality in abnormalities:
            if abnormality not in result:
                raise ValueError(f"Missing abnormality in LLM output: {abnormality}")
        return json.dumps(result)
    except json.JSONDecodeError:
        raise ValueError("LLM output is not valid JSON")
    except Exception as e:
        raise ValueError(f"Error processing LLM output: {str(e)}")

# Example usage 
def mock_llm_function(prompt: str) -> str:
    # This is a mock function that simulates an LLM's response
    return '{"pulmonary edema": 1, "consolidation": 0, "pleural effusion": 1, "pneumothorax": 0, "cardiomegaly": 1}'

In [None]:
abnormalities = ["pulmonary edema", "consolidation", "pleural effusion", "pneumothorax", "cardiomegaly"]

In [None]:
# Example usage
result = classify_abnormalities(abnormalities, example_report, mock_llm_function)
print(result)

In [None]:
# Using the post-processing function
present_abnormalities = post_process_abnormalities(result)
print("Present abnormalities:", present_abnormalities)

In [None]:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Create the LLM function
llm_function = create_llm_function(original_model, tokenizer)

In [None]:
example_2 = df['Text'][3]

In [None]:
example_report

In [None]:
result = classify_abnormalities(abnormalities, example_report, llm_function)
print(result)

In [None]:
print(example_2)

In [None]:
result = classify_abnormalities(abnormalities, example_2, llm_function)
print(result)