In [None]:
# Install dependencies
!pip install -q torch torchvision transformers pydicom opencv-python Pillow accelerate

#Imports
import os, glob
import torch
import pydicom
import cv2
from PIL import Image
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM

# 📁 Set paths
ROOT = Path("/content/drive/MyDrive")  # adjust if needed
DICOM_DIR = ROOT / "PTXHeadtoHeadStudyData"
JPEG_DIR = Path("/content/cxr_jpegs")
JPEG_DIR.mkdir(exist_ok=True)

# Convert DICOMs to JPEGs
def dicom_to_jpeg(dicom_path, jpeg_path):
    ds = pydicom.dcmread(str(dicom_path))
    img = ds.pixel_array
    img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
    img = cv2.equalizeHist(img)  # optional contrast enhancement
    cv2.imwrite(str(jpeg_path), img)

# Batch convert
dicom_paths = list(DICOM_DIR.rglob("*.dcm"))
jpeg_paths = []
for dcm_path in dicom_paths:
    jpg_path = JPEG_DIR / f"{dcm_path.stem}.jpg"
    dicom_to_jpeg(dcm_path, jpg_path)
    jpeg_paths.append(str(jpg_path))

# Load CheXagent-2-3b
model_name = "StanfordAIMI/CheXagent-2-3b"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
model = model.to(dtype)
model.eval()

# Define query
question = "Does this chest X-ray show a pneumothorax?"

# Run inference on each image
def ask_chexagent(image_path, question):
    query = tokenizer.from_list_format([
        {"image": image_path},
        {"text": question}
    ])
    conversation = [
        {"from": "system", "value": "You are a helpful assistant."},
        {"from": "human", "value": query}
    ]
    input_ids = tokenizer.apply_chat_template(
        conversation,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(device)

    output = model.generate(
        input_ids,
        do_sample=False,
        num_beams=1,
        temperature=1.0,
        top_p=1.0,
        use_cache=True,
        max_new_tokens=128
    )[0]

    response = tokenizer.decode(output[input_ids.size(1):-1])
    return response.strip()

# Run on all images and collect responses
results = []
for path in jpeg_paths:
    answer = ask_chexagent(path, question)
    results.append((path, answer))
    print(f" {Path(path).name} → {answer}")
