In [None]:
import os
import json
import sys
import polars as pl
from pathlib import Path
from rich.console import Console
from dotenv import load_dotenv

load_dotenv()

cons = Console()

In [None]:
import torch

cuda_available = torch.cuda.is_available()
cons.print(f"CUDA available: {cuda_available}")
if cuda_available:
    cons.print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## Load Dataset

In [None]:
from huggingface_hub import login

# Get token from environment variable instead of hardcoding
hf_token = os.getenv("HF_TOKEN")
if hf_token:
    login(token=hf_token)
else:
    print("Warning: HF_TOKEN not found in environment variables")

In [None]:
from datasets import load_dataset

ds = load_dataset("schorndorfer/mdace-inpatient")
cons.print(f"Dataset loaded")
test_df = pl.from_dataframe(ds['test'].to_pandas())
cons.print(f"Test DataFrame shape: {test_df.shape}")

## Load MedGemma

In [None]:
pip install --upgrade --quiet accelerate bitsandbytes transformers

In [None]:
from transformers import BitsAndBytesConfig
import torch

google_colab = "google.colab" in sys.modules and not os.environ.get("VERTEX_PRODUCT")

model_variant = "4b-it"  # @param ["4b-it", "27b-it", "27b-text-it"]
model_variant = "27b-text-it"  # @param ["4b-it", "27b-it", "27b-text-it"]
model_id = f"google/medgemma-{model_variant}"

use_quantization = True  # @param {type: "boolean"}

# @markdown Set `is_thinking` to `True` to turn on thinking mode. **Note:** Thinking is supported for the 27B variants only.
is_thinking = True  # @param {type: "boolean"}

# If running a 27B variant in Google Colab, check if the runtime satisfies
# memory requirements
if "27b" in model_variant and google_colab:
    if not ("A100" in torch.cuda.get_device_name(0) and use_quantization):
        raise ValueError(
            "Runtime has insufficient memory to run a 27B variant. "
            "Please select an A100 GPU and use 4-bit quantization."
        )

model_kwargs = dict(
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

if use_quantization:
    model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)

In [None]:
from transformers import pipeline

if "text" in model_variant:
    pipe = pipeline("text-generation", model=model_id, model_kwargs=model_kwargs)
else:
    pipe = pipeline("image-text-to-text", model=model_id, model_kwargs=model_kwargs)

pipe.model.generation_config.do_sample = False

## Run Inference

In [None]:
system_prompt = f"""
You are an expert clinical coder. From the following phrase extracted from a medical note, along with the entire note as context,
identify the most relevant ICD10-CM diagnosis codes

Instructions:
- Include all potential relevant codes
- Include a code only once

Output format:
- **Code**: <code>, **Description**: <description>

Just output a list of ICD-10 codes and descriptions, in the format described above.

The input format will be:

Input phrase:
<phrase to code>

Full medical note context:

<full medical note context>

###end###

Output should be in the following format:
<code 1>, <description 1>
<code 2>, <description 2>
<code 3>, <description 3>
"""

In [None]:
row_dict = test_df.row(2, named=True)
input_phrase = row_dict['covered_text']
clinical_note = row_dict['text']
prompt = f"""
Input phrase:
{input_phrase}
\n\n
Full medical note context:
{clinical_note}
\n\n
###end###
"""
row_dict

In [None]:
from IPython.display import Markdown

role_instruction = "You are an expert clinical coder. From the following medical note, identify the most relevant ICD-10 codes"

if "27b" in model_variant and is_thinking:
    system_instruction = f"SYSTEM INSTRUCTION: think silently if needed. {role_instruction}"
    max_new_tokens = 1500
else:
    system_instruction = role_instruction
    max_new_tokens = 500

messages = [
    {
        "role": "system",
        "content": [{"type": "text", "text": system_prompt}]
    },
    {
        "role": "user",
        "content": [{"type": "text", "text": prompt}]
    }
]

In [None]:
output = pipe(messages, max_new_tokens=max_new_tokens)
response = output[0]["generated_text"][-1]["content"]

display(Markdown(f"---\n\n**[ User ]**\n\n{prompt}\n\n---"))
display(Markdown(f"**[ MedGemma ]**\n\n{response}\n\n---"))