In [None]:
import os, getpass

# This will not show your key on screen
os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API Key: ")

print("API key set! You can now run the Streamlit app cell.")


In [None]:
%%writefile app.py
import os
import torch
import streamlit as st
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForSeq2SeqLM,
)
from openai import OpenAI

# -------------------
# Basic config
# -------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

LONGFORMER_MODEL_NAME = "VirajJoshi/longformer-oncosummarizer-tuned"
BIOCLINICALBERT_MODEL_NAME = "VirajJoshi/bioclinicalbert-oncosummarizer-tuned"
BART_MODEL_NAME = "VirajJoshi/bart-rond-baseline-oncosummarizer"

ONCO_THRESHOLD = 0.5           # probability threshold for "oncology-relevant"
MAX_BERT_CHUNKS = 4            # how many top chunks to keep for summarization
BART_MAX_INPUT_TOKENS = 512
BART_MAX_NEW_TOKENS = 128


# -------------------
# Model loading
# -------------------

@st.cache_resource
def load_models():
    # Longformer doc classifier
    long_tok = AutoTokenizer.from_pretrained(LONGFORMER_MODEL_NAME)
    long_model = AutoModelForSequenceClassification.from_pretrained(
        LONGFORMER_MODEL_NAME
    ).to(DEVICE)

    # BioClinicalBERT chunk classifier
    bert_tok = AutoTokenizer.from_pretrained(BIOCLINICALBERT_MODEL_NAME)
    bert_model = AutoModelForSequenceClassification.from_pretrained(
        BIOCLINICALBERT_MODEL_NAME
    ).to(DEVICE)

    # BART summarizer
    bart_tok = AutoTokenizer.from_pretrained(BART_MODEL_NAME)
    bart_model = AutoModelForSeq2SeqLM.from_pretrained(
        BART_MODEL_NAME
    ).to(DEVICE)

    return long_tok, long_model, bert_tok, bert_model, bart_tok, bart_model


# -------------------
# Helper functions
# -------------------

def classify_document(text, long_tok, long_model):
    """Longformer doc-level oncology vs non-oncology."""
    inputs = long_tok(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=4096,
        padding="max_length",
    )
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

    with torch.no_grad():
        logits = long_model(**inputs).logits
        probs = torch.softmax(logits, dim=-1)[0]

    prob_non_onco = float(probs[0])
    prob_onco = float(probs[1])
    label = int(torch.argmax(probs))  # 0 = non-onco, 1 = onco

    return label, prob_onco, prob_non_onco


def select_relevant_chunks(text, bert_tok, bert_model,
                           max_chunks=MAX_BERT_CHUNKS):
    """
    Use BioClinicalBERT to score overlapping chunks and keep the top K
    oncology-relevant ones. We then concatenate those chunks as input to BART.
    """
    enc = bert_tok(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=512,
        stride=128,
        return_overflowing_tokens=True,
    )

    input_ids_list = enc["input_ids"]  # (num_chunks, seq_len)

    chunk_scores = []
    chunk_texts = []

    for i in range(input_ids_list.size(0)):
        chunk_ids = input_ids_list[i].unsqueeze(0).to(DEVICE)

        with torch.no_grad():
            logits = bert_model(input_ids=chunk_ids).logits
            probs = torch.softmax(logits, dim=-1)[0]
            prob_onco = float(probs[1])

        chunk_scores.append(prob_onco)
        chunk_texts.append(
            bert_tok.decode(input_ids_list[i], skip_special_tokens=True)
        )

    # Sort chunks by oncology probability (descending)
    sorted_idx = sorted(
        range(len(chunk_scores)),
        key=lambda j: chunk_scores[j],
        reverse=True,
    )

    selected_idx = sorted(sorted_idx[:max_chunks])  # keep original order

    selected_chunks = [chunk_texts[j].strip() for j in selected_idx]
    merged_text = "\n\n".join(selected_chunks)

    return merged_text


def summarize_with_bart(text, bart_tok, bart_model):
    """Generate abstractive summary using your tuned BART model."""
    inputs = bart_tok(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=BART_MAX_INPUT_TOKENS,
    )
    input_ids = inputs["input_ids"].to(DEVICE)
    attention_mask = inputs["attention_mask"].to(DEVICE)

    with torch.no_grad():
        summary_ids = bart_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            num_beams=4,
            max_new_tokens=BART_MAX_NEW_TOKENS,
            min_new_tokens=30,
            early_stopping=True,
        )

    summary = bart_tok.decode(summary_ids[0], skip_special_tokens=True)
    return summary.strip()


def extract_json_with_gpt4(note_text, summary=None):
    """
    Use GPT-4o-mini to turn the note/summary into structured JSON.
    Requires OPENAI_API_KEY in env.
    """
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        st.error("OPENAI_API_KEY is not set. JSON mode unavailable.")
        return None

    client = OpenAI(api_key=api_key)

    system_prompt = (
        "You are an assistant that extracts key oncology clinical information "
        "from notes. You must respond ONLY with a single valid JSON object, "
        "no prose. Use null for fields that are not mentioned.\n\n"
        "JSON schema:\n"
        "{\n"
        '  \"patient_age\": int or null,\n'
        '  \"patient_sex\": \"male\" | \"female\" | \"other\" | null,\n'
        '  \"primary_cancer_type\": string or null,\n'
        '  \"primary_site\": string or null,\n'
        '  \"stage\": string or null,\n'
        '  \"metastatic_sites\": [string] or [],\n'
        '  \"key_mutations\": [string] or [],\n'
        '  \"treatments\": [string] or [],\n'
        '  \"response_to_treatment\": string or null,\n'
        '  \"performance_status\": string or null,\n'
        '  \"important_findings\": [string] or [],\n'
        '  \"recommended_next_steps\": [string] or []\n'
        "}"
    )

    user_content = "Original note:\n" + note_text
    if summary is not None:
        user_content += "\n\nModel summary:\n" + summary

    response = client.chat.completions.create(
        model="gpt-4o-mini",
        temperature=0,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_content},
        ],
    )

    content = response.choices[0].message.content
    return content


def run_pipeline(note_text, mode, long_tok, long_model,
                 bert_tok, bert_model,
                 bart_tok, bart_model):
    """
    mode = 'summary' or 'json'
    Returns dict with outputs for Streamlit to display.
    """
    result = {}

    # 1. Classify document with Longformer
    label, prob_onco, prob_non = classify_document(
        note_text, long_tok, long_model
    )
    result["longformer_label"] = "oncology" if label == 1 else "non-oncology"
    result["prob_onco"] = prob_onco
    result["prob_non_onco"] = prob_non

    # 2. Select relevant chunks using BioClinicalBERT
    filtered_text = select_relevant_chunks(
        note_text, bert_tok, bert_model
    )
    result["filtered_text"] = filtered_text

    # 3. Summarize with BART (always do it â€“ needed for both modes)
    summary = summarize_with_bart(filtered_text, bart_tok, bart_model)
    result["summary"] = summary

    # 4. If JSON requested, call GPT-4
    if mode == "json":
        json_str = extract_json_with_gpt4(note_text, summary=summary)
        result["json"] = json_str

    return result


# -------------------
# Streamlit UI
# -------------------

def main():
    st.set_page_config(
        page_title="OncoSummarizer",
        layout="wide",
    )

    st.title("OncoSummarizer ðŸ§¬")
    st.write(
        "Paste an oncology clinical note and choose whether you want a "
        "concise **summary** or a **JSON structure of key clinical fields**."
    )

    # Load models once
    (long_tok, long_model,
     bert_tok, bert_model,
     bart_tok, bart_model) = load_models()

    st.markdown("**Click inside the box below and paste/type your clinical note:**")

    note_text = st.text_area(
        "Clinical note",
        value="",
        height=300,
        key="note_input",
        placeholder="Paste the full clinical note here...",
    )

    mode = st.radio(
        "What would you like to generate?",
        ("Summary", "JSON with key elements"),
        index=0,
        key="mode_radio",
    )

    if st.button("Run OncoSummarizer"):
        if not note_text.strip():
            st.warning("Please paste a clinical note.")
            return

        with st.spinner("Running models... this might take a bit for long notes."):
            mode_key = "summary" if mode == "Summary" else "json"
            outputs = run_pipeline(
                note_text,
                mode_key,
                long_tok, long_model,
                bert_tok, bert_model,
                bart_tok, bart_model,
            )

        # Show classification info
        st.subheader("1. Document classification (Longformer)")
        st.write(f"**Predicted label:** {outputs['longformer_label']}")
        st.write(
            f"Oncology prob: `{outputs['prob_onco']:.3f}`  |  "
            f"Non-oncology prob: `{outputs['prob_non_onco']:.3f}`"
        )
        if outputs["prob_onco"] < ONCO_THRESHOLD:
            st.warning(
                "This note is not strongly classified as oncology-related. "
                "Results might be less meaningful."
            )

        # Show filtered text (optional, collapsible)
        with st.expander("2. Filtered text used for summarization (BioClinicalBERT)"):
            st.write(outputs["filtered_text"])
            st.caption(
                "Chunks selected by BioClinicalBERT as most oncology-relevant."
            )

        # Show summary
        st.subheader("3. Model summary (BART)")
        st.write(outputs["summary"])

        # Show JSON if requested
        if mode_key == "json":
            st.subheader("4. Structured JSON (GPT-4o-mini)")
            if outputs.get("json") is None:
                st.error(
                    "Could not generate JSON. Check OPENAI_API_KEY or logs."
                )
            else:
                st.code(outputs["json"], language="json")


if __name__ == "__main__":
    main()


In [None]:
!pip install streamlit transformers torch openai
!npm install -g localtunnel


In [None]:
!streamlit run app.py --server.port 6006 --server.headless true & npx localtunnel --port 6006


In [None]:
!curl ifconfig.me


In [None]:
!pip install pyngrok


In [None]:
from pyngrok import ngrok

ngrok.set_auth_token("Your Pyngrok Token")


In [None]:
from pyngrok import ngrok

ngrok.kill()
public_url = ngrok.connect(6006)
print("PUBLIC URL:", public_url)

!streamlit run app.py --server.port 6006 --server.headless true
