<a href="https://colab.research.google.com/github/riyasharma-coder/humanfirst-ai-medgemma/blob/main/app_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch # Added import for torch

st.set_page_config(page_title="HumanFirst AI", layout="centered")

st.title("HumanFirst AI")
st.subheader("Understand your medical report in simple language")

st.markdown(
    "⚠️ This tool does not provides medical advice. Always consult a healthcare professional."
)

@st.cache_resource
def load_model():
    model_name = "microsoft/BioGPT"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return tokenizer, model

tokenizer, model = load_model()

medical_text = st.text_area(
    "Paste your medical report below:",
    height=250
)

if st.button("Explain My Report"):
    if medical_text.strip() == "":
        st.warning("Please paste a medical report first.")
    else:
        # Define the prompt template as a regular string, to be formatted later
        prompt = '''
You are a healthcare explanation assistant.

Explain the following medical report in simple language for a patient.

Rules:
- Do not diagnose
- Do not suggest treatment or medicines
- Use calm and supportive tone
- Explain medical terms simply

Medical Report:
{medical_text}
'''.format(medical_text=medical_text)

        inputs = tokenizer(prompt, return_tensors="pt")

        # Ensure correct device placement if CUDA is available
        if torch.cuda.is_available():
            model.to("cuda")
            inputs = {k: v.to("cuda") for k, v in inputs.items()}

        output = model.generate(
            **inputs,
            max_new_tokens=350,
            temperature=0.6,
            do_sample=True
        )

        response = tokenizer.decode(output[0], skip_special_tokens=True)

        st.subheader("Explanation:")
        st.markdown(response)

