<a href="https://colab.research.google.com/github/sanjayyy-22/CodeRxWizard/blob/main/Capstone_medical_prescription_CodeRxWizard.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Working model with Mistral OCR and TinyLlama-1.1B

In [None]:
# Install dependencies
!pip install pyngrok reportlab mistralai transformers datasets accelerate peft trl bitsandbytes torch

# Install ngrok
!wget https://bin.equinox.io/c/bNyj1mQVY4c/ngrok-v3-stable-linux-amd64.tgz
!tar -xvzf ngrok-v3-stable-linux-amd64.tgz
!mv ngrok /usr/local/bin/

# Authenticate ngrok (replace YOUR_AUTH_TOKEN with your ngrok authtoken)
!ngrok authtoken "NGROK AUTH TOKEN"

In [None]:
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer

# Quantization configuration
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True
)

# Model ID
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

# Load dataset
dataset = load_dataset("truehealth/medicationqa")

# Reduce dataset size
if len(dataset["train"]) > 500:
    dataset["train"] = dataset["train"].select(range(500))

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Load model with 4-bit quantization
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=True
)

# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

# LoRA configuration
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

# Apply LoRA
model = get_peft_model(model, lora_config)

# Format dataset for instruction tuning
def format_instruction(example):
    instruction = f"""Given the misspelled medication name: '{example['Question']}', identify the correct medication name and explain why it's the correct form."""
    response = f"""The correct medication name is '{example['Focus (Drug)']}'. {example['Answer']}"""
    example["text"] = f"<|user|>\n{instruction}\n<|assistant|>\n{response}</s>"
    return example

# Apply formatting
formatted_dataset = dataset.map(format_instruction)

# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="adamw_torch",
    learning_rate=2e-4,
    weight_decay=0.01,
    logging_steps=10,
    save_strategy="epoch",
    eval_strategy="no",
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    max_grad_norm=0.3,
    save_total_limit=1
)

# Formatting function for trainer
def formatting_func(example):
    return example["text"]

# Initialize trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=formatted_dataset["train"],
    formatting_func=formatting_func,
)

# Train the model
trainer.train()

# Save the trained adapter
model.save_pretrained("/content/medical_term_lora_adapter")
tokenizer.save_pretrained("/content/medical_term_lora_adapter")

print("Model fine-tuning completed!")

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Applying formatting function to train dataset:   0%|          | 0/500 [00:00<?, ? examples/s]

Converting train dataset to ChatML:   0%|          | 0/500 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/500 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/500 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/500 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
10,2.0665
20,1.5887
30,1.264
40,1.0896
50,1.1718
60,1.1501
70,1.0718
80,1.0424
90,1.1103
100,0.9956


Model fine-tuning completed!


In [None]:
%%writefile /content/app.py
from mistralai import Mistral, DocumentURLChunk, ImageURLChunk, TextChunk
from pathlib import Path
import streamlit as st
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Paragraph, Table, Spacer
from reportlab.lib.styles import getSampleStyleSheet
import json
import base64
import time
import os
import tempfile

# Initialize Mistral client
client = Mistral(api_key='MISTRAL API KEY')

# Process prescription image to get JSON
def process_prescription(image_file):
    assert image_file.is_file()
    encoded = base64.b64encode(image_file.read_bytes()).decode()
    base64_data_url = f"data:image/jpeg;base64,{encoded}"

    image_response = client.ocr.process(document=ImageURLChunk(image_url=base64_data_url), model="mistral-ocr-latest")
    time.sleep(1)

    image_ocr_markdown = image_response.pages[0].markdown

    chat_response = client.chat.complete(
        model="pixtral-12b-latest",
        messages=[
            {
                "role": "user",
                "content": [
                    ImageURLChunk(image_url=base64_data_url),
                    TextChunk(text=f"This is image's OCR in markdown:\n<BEGIN_IMAGE_OCR>\n{image_ocr_markdown}\n<END_IMAGE_OCR>.\nConvert this into a sensible structured json response. The output should be strictly be json with no extra commentary")
                ],
            },
        ],
        response_format={"type": "json_object"},
        temperature=0
    )

    response_dict = json.loads(chat_response.choices[0].message.content)
    json_string = json.dumps(response_dict, indent=4)
    return json_string

# Generate PDF based on JSON content
def generate_pdf(json_data, filename):
    data_dict = json.loads(json_data)

    doc = SimpleDocTemplate(filename, pagesize=letter)
    styles = getSampleStyleSheet()
    elements = []

    # Title
    elements.append(Paragraph("Medical Report", styles['Title']))
    elements.append(Spacer(1, 12))

    # Check if prescription with medications exists
    prescription = data_dict.get("prescription", {})
    medications = prescription.get("medications", [])

    if medications:
        # Prescription case: display patient, age, date, medications
        elements.append(Paragraph("Prescription Details", styles['Heading2']))
        elements.append(Paragraph(f"Patient: {prescription.get('patient', 'N/A')}", styles['Normal']))
        elements.append(Paragraph(f"Age: {prescription.get('age', 'N/A')}", styles['Normal']))
        elements.append(Paragraph(f"Date: {prescription.get('date', data_dict.get('date', 'N/A'))}", styles['Normal']))
        elements.append(Spacer(1, 12))

        # Medications table
        elements.append(Paragraph("Medications", styles['Heading3']))
        med_data = [["Name", "Dosage", "Frequency", "Duration"]]
        for med in medications:
            med_data.append([
                med.get("name", "N/A"),
                med.get("dosage", "N/A"),
                med.get("frequency", "N/A"),
                med.get("duration", "N/A")
            ])
        elements.append(Table(med_data))
    else:
        # Non-prescription case: display all available data
        # Doctor Details
        doctor = data_dict.get("doctor", {})
        if doctor:
            elements.append(Paragraph("Doctor Details", styles['Heading2']))
            elements.append(Paragraph(f"Name: {doctor.get('name', 'N/A')}", styles['Normal']))
            qualifications = ", ".join(doctor.get('qualifications', [])) if doctor.get('qualifications') else 'N/A'
            elements.append(Paragraph(f"Qualifications: {qualifications}", styles['Normal']))
            elements.append(Paragraph(f"Specialization: {doctor.get('specialization', 'N/A')}", styles['Normal']))
            elements.append(Paragraph(f"Registration Number: {doctor.get('registration_number', 'N/A')}", styles['Normal']))
            elements.append(Paragraph(f"Phone: {doctor.get('phone', 'N/A')}", styles['Normal']))
            elements.append(Paragraph(f"Extension: {doctor.get('extension', 'N/A')}", styles['Normal']))
            elements.append(Paragraph(f"Appointment Phone: {doctor.get('appointment_phone', 'N/A')}", styles['Normal']))
            elements.append(Paragraph(f"Email: {doctor.get('email', 'N/A')}", styles['Normal']))
            elements.append(Spacer(1, 12))

        # Patient Details
        patient = data_dict.get("patient", {})
        if patient:
            elements.append(Paragraph("Patient Details", styles['Heading2']))
            elements.append(Paragraph(f"Name: {patient.get('name', 'N/A')}", styles['Normal']))
            elements.append(Paragraph(f"Age: {patient.get('age', 'N/A')}", styles['Normal']))
            elements.append(Paragraph(f"Gender: {patient.get('gender', 'N/A')}", styles['Normal']))
            elements.append(Paragraph(f"IP Number: {patient.get('ip_number', 'N/A')}", styles['Normal']))
            elements.append(Paragraph(f"Condition: {patient.get('condition', 'N/A')}", styles['Normal']))
            elements.append(Paragraph(f"Hospital Stay: {patient.get('hospital_stay', 'N/A')}", styles['Normal']))
            elements.append(Spacer(1, 12))

        # Contact Info
        contact = data_dict.get("contact_info", {})
        if contact:
            elements.append(Paragraph("Contact Information", styles['Heading2']))
            elements.append(Paragraph(f"Address: {contact.get('address', 'N/A')}", styles['Normal']))
            elements.append(Paragraph(f"Phone: {contact.get('phone', 'N/A')}", styles['Normal']))
            elements.append(Spacer(1, 12))

        # Date
        if "date" in data_dict:
            elements.append(Paragraph("Date", styles['Heading2']))
            elements.append(Paragraph(data_dict.get('date', 'N/A'), styles['Normal']))
            elements.append(Spacer(1, 12))

    # Build PDF
    doc.build(elements)

# Streamlit UI
st.title("Medical Report Processor")
st.write("Upload a medical report image (JPEG/PNG) to extract details and generate a PDF.")

# Add some CSS for better styling
st.markdown("""
    <style>
    .section-box {
        background-color: #f8f9fa;
        padding: 15px;
        border-radius: 5px;
        margin-bottom: 20px;
    }
    .stButton>button {
        background-color: #4CAF50;
        color: white;
        border-radius: 5px;
    }
    </style>
""", unsafe_allow_html=True)

uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    # Save uploaded file temporarily
    with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
        tmp.write(uploaded_file.read())
        tmp_path = tmp.name

    try:
        # Display the uploaded image
        st.image(uploaded_file, caption="Uploaded Report", use_column_width=True)

        # Process the image
        with st.spinner("Processing report..."):
            json_string = process_prescription(Path(tmp_path))

        # Parse JSON
        data_dict = json.loads(json_string)

        # Generate PDF
        pdf_path = "/content/report.pdf"
        generate_pdf(json_string, pdf_path)

        # Display formatted output
        st.subheader("Extracted Report Data")

        # Check if prescription with medications exists
        prescription = data_dict.get("prescription", {})
        medications = prescription.get("medications", [])

        if medications:
            # Prescription case
            st.markdown("<div class='section-box'>", unsafe_allow_html=True)
            st.markdown("### Prescription Details")
            st.write(f"**Patient**: {prescription.get('patient', 'N/A')}")
            st.write(f"**Age**: {prescription.get('age', 'N/A')}")
            st.write(f"**Date**: {prescription.get('date', data_dict.get('date', 'N/A'))}")
            st.markdown("#### Medications")
            med_data = [["Name", "Dosage", "Frequency", "Duration"]]
            for med in medications:
                med_data.append([
                    med.get("name", "N/A"),
                    med.get("dosage", "N/A"),
                    med.get("frequency", "N/A"),
                    med.get("duration", "N/A")
                ])
            st.table(med_data)
            st.markdown("</div>", unsafe_allow_html=True)
        else:
            # Non-prescription case
            # Doctor Details
            doctor = data_dict.get("doctor", {})
            if doctor:
                st.markdown("<div class='section-box'>", unsafe_allow_html=True)
                st.markdown("### Doctor Details")
                st.write(f"**Name**: {doctor.get('name', 'N/A')}")
                qualifications = ", ".join(doctor.get('qualifications', [])) if doctor.get('qualifications') else 'N/A'
                st.write(f"**Qualifications**: {qualifications}")
                st.write(f"**Specialization**: {doctor.get('specialization', 'N/A')}")
                st.write(f"**Registration Number**: {doctor.get('registration_number', 'N/A')}")
                st.write(f"**Phone**: {doctor.get('phone', 'N/A')}")
                st.write(f"**Extension**: {doctor.get('extension', 'N/A')}")
                st.write(f"**Appointment Phone**: {doctor.get('appointment_phone', 'N/A')}")
                st.write(f"**Email**: {doctor.get('email', 'N/A')}")
                st.markdown("</div>", unsafe_allow_html=True)

            # Patient Details
            patient = data_dict.get("patient", {})
            if patient:
                st.markdown("<div class='section-box'>", unsafe_allow_html=True)
                st.markdown("### Patient Details")
                st.write(f"**Name**: {patient.get('name', 'N/A')}")
                st.write(f"**Age**: {patient.get('age', 'N/A')}")
                st.write(f"**Gender**: {patient.get('gender', 'N/A')}")
                st.write(f"**IP Number**: {patient.get('ip_number', 'N/A')}")
                st.write(f"**Condition**: {patient.get('condition', 'N/A')}")
                st.write(f"**Hospital Stay**: {patient.get('hospital_stay', 'N/A')}")
                st.markdown("</div>", unsafe_allow_html=True)

            # Contact Info
            contact = data_dict.get("contact_info", {})
            if contact:
                st.markdown("<div class='section-box'>", unsafe_allow_html=True)
                st.markdown("### Contact Information")
                st.write(f"**Address**: {contact.get('address', 'N/A')}")
                st.write(f"**Phone**: {contact.get('phone', 'N/A')}")
                st.markdown("</div>", unsafe_allow_html=True)

            # Date
            if "date" in data_dict:
                st.markdown("<div class='section-box'>", unsafe_allow_html=True)
                st.markdown("### Date")
                st.write(f"**Date**: {data_dict.get('date', 'N/A')}")
                st.markdown("</div>", unsafe_allow_html=True)

        # Provide download button for PDF
        with open(pdf_path, "rb") as f:
            st.download_button(
                label="Download Report PDF",
                data=f,
                file_name="report.pdf",
                mime="application/pdf"
            )

        # Clean up
        os.remove(tmp_path)
        os.remove(pdf_path)

    except Exception as e:
        st.error(f"Error: {str(e)}")
        if os.path.exists(tmp_path):
            os.remove(tmp_path)

Overwriting /content/app.py


In [1]:
from pyngrok import ngrok
import subprocess
import time

# Start Streamlit in the background
process = subprocess.Popen(["streamlit", "run", "/content/app.py", "--server.port", "8501"])

# Start ngrok tunnel
public_url = ngrok.connect(8501).public_url
print(f"Streamlit app is live at: {public_url}")