In [None]:
!pip install git+https://github.com/huggingface/transformers.git@88d960937c81a32bfb63356a2e8ecf7999619681 gradio

In [None]:
!pip3 install torch torchvision torchaudio

In [None]:
!pip install ipywidgets

In [2]:
from transformers import AutoModelForCausalLM, AutoProcessor
from pathlib import Path
import torch

In [3]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    "microsoft/maira-2", trust_remote_code=True
)
processor = AutoProcessor.from_pretrained("microsoft/maira-2", trust_remote_code=True)

In [3]:
import requests
from PIL import Image


def get_sample_data() -> dict[str, Image.Image | str]:
    """
    Download chest X-rays from IU-Xray, which we didn't train MAIRA-2 on. License is CC.
    We modified this function from the Rad-DINO repository on Huggingface.
    """
    frontal_image_url = (
        "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
    )
    lateral_image_url = (
        "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
    )

    def download_and_open(url: str) -> Image.Image:
        response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True)
        return Image.open(response.raw)

    frontal_image = download_and_open(frontal_image_url)
    lateral_image = download_and_open(lateral_image_url)

    sample_data = {
        "frontal": frontal_image,
        "lateral": lateral_image,
        "indication": "Dyspnea.",
        "comparison": "None.",
        "technique": "PA and lateral views of the chest.",
        "phrase": "Pleural effusion.",  # For the phrase grounding example. This patient has pleural effusion.
    }
    return sample_data


sample_data = get_sample_data()

In [None]:
sample_data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
processed_inputs = processor.format_and_preprocess_reporting_input(
    current_frontal=sample_data["frontal"],
    current_lateral=sample_data["lateral"],
    prior_frontal=None,  # Our example has no prior
    indication=sample_data["indication"],
    technique=sample_data["technique"],
    comparison=sample_data["comparison"],
    prior_report=None,  # Our example has no prior
    return_tensors="pt",
    get_grounding=False,  # For this example we generate a non-grounded report
)

processed_inputs = processed_inputs.to(device)
with torch.no_grad():
    output_decoding = model.generate(
        **processed_inputs,
        max_new_tokens=450,  # Set to 450 for grounded reporting
        use_cache=True,
    )
prompt_length = processed_inputs["input_ids"].shape[-1]
decoded_text = processor.decode(
    output_decoding[0][prompt_length:], skip_special_tokens=True
)
decoded_text = (
    decoded_text.lstrip()
)  # Findings generation completions have a single leading space
prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
print("Parsed prediction:", prediction)

## Gradio App

In [11]:
def download_image(url: str) -> Image.Image:
    """
    Download the image from the given URL and return as a PIL Image.
    """
    response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True)
    return Image.open(response.raw)

In [12]:
def generate_findings(
    frontal_url: str, lateral_url: str, indication: str, comparison: str, technique: str
):
    """
    1. Download the frontal & lateral images from the provided URLs.
    2. Format & preprocess the input for the model using `processor`.
    3. Generate the findings from the model.
    4. Return the two images and the generated findings text.
    """
    # 1. Download images
    frontal_image = download_image(frontal_url)
    lateral_image = download_image(lateral_url)

    # 2. Prepare inputs for the model
    processed_inputs = processor.format_and_preprocess_reporting_input(
        current_frontal=frontal_image,
        current_lateral=lateral_image,
        prior_frontal=None,  # Example doesn't use prior images
        indication=indication,
        technique=technique,
        comparison=comparison,
        prior_report=None,  # Example doesn't use prior reports
        return_tensors="pt",
        get_grounding=False,  # For a non-grounded report
    )
    processed_inputs = processed_inputs.to(model.device)

    # 3. Generate the findings
    with torch.no_grad():
        output_decoding = model.generate(
            **processed_inputs,
            max_new_tokens=450,
            use_cache=True,
        )

    # Skip the prompt portion for a cleaner result
    prompt_length = processed_inputs["input_ids"].shape[-1]
    decoded_text = processor.decode(
        output_decoding[0][prompt_length:], skip_special_tokens=True
    )
    decoded_text = decoded_text.lstrip()

    # Convert the model output into plain text
    prediction = processor.convert_output_to_plaintext_or_grounded_sequence(
        decoded_text
    )

    # Return:
    # - frontal/lateral images so they can be displayed in Gradio
    # - the generated findings
    return frontal_image, lateral_image, prediction

In [13]:
import gradio as gr

In [14]:
app_name = "MAIRA-2 CXR Report Generator"
app_description = """
Enter URLs for the frontal and lateral chest X-ray images and relevant metadata.
Click "Generate Findings" to see the automatic radiology report findings.
"""

with gr.Blocks(title=app_name) as demo:
    gr.Markdown(f"## {app_name}")
    gr.Markdown(app_description)

    with gr.Row():
        frontal_url = gr.Textbox(
            label="Frontal Image URL",
            value="https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png",
        )
        lateral_url = gr.Textbox(
            label="Lateral Image URL",
            value="https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png",
        )

    indication = gr.Textbox(label="Indication", value="Dyspnea.")
    comparison = gr.Textbox(label="Comparison", value="None.")
    technique = gr.Textbox(
        label="Technique", value="PA and lateral views of the chest."
    )

    generate_button = gr.Button("Generate Findings")

    with gr.Row():
        frontal_image_out = gr.Image(label="Frontal Image")
        lateral_image_out = gr.Image(label="Lateral Image")
    result_text_out = gr.Textbox(label="Generated Findings", lines=6)

    generate_button.click(
        fn=generate_findings,
        inputs=[frontal_url, lateral_url, indication, comparison, technique],
        outputs=[frontal_image_out, lateral_image_out, result_text_out],
    )

In [None]:
if __name__ == "__main__":
    demo.launch()