<a href="https://colab.research.google.com/github/shadab007-byte/Fine-tuned-DistilGPT2-82M-params-using-LoRA-on-medical-datasets/blob/main/Fine_tuned_DistilGPT2_(82M_params)_using_LoRA_on_medical_datasets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
print("📦 Installing dependencies...")
!pip install -q transformers peft datasets gradio accelerate PyPDF2

📦 Installing dependencies...


In [2]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model
from datasets import Dataset
import gradio as gr
import PyPDF2
import io
import warnings
warnings.filterwarnings('ignore')

print(f"✅ Libraries imported successfully!")
print(f"🔥 GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU Name: {torch.cuda.get_device_name(0)}")

✅ Libraries imported successfully!
🔥 GPU Available: True
   GPU Name: Tesla T4


In [3]:
def extract_text_from_pdf(pdf_file):
    """Extract text from uploaded PDF file"""
    try:
        if pdf_file is None:
            return None

        # Ensure bytes
        pdf_bytes = pdf_file.read() if hasattr(pdf_file, "read") else pdf_file

        # Read from bytes
        pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))

        text = ""
        for page in pdf_reader.pages:
            text += page.extract_text() + "\n"

        return text.strip()
    except Exception as e:
        return f"Error extracting PDF: {str(e)}"


In [4]:
TRAINING_DATA = [
    {
        "report": "Patient presents with acute bronchitis. Auscultation reveals bilateral wheezing and rhonchi. Prescribed albuterol inhaler 90mcg, 2 puffs q4-6h PRN and azithromycin 250mg daily for 5 days.",
        "summary": "You have a chest infection causing wheezing. Use the inhaler (2 puffs every 4-6 hours when needed) and take antibiotics once daily for 5 days."
    },
    {
        "report": "Laboratory results indicate elevated HbA1c at 8.2%, fasting glucose 165mg/dL. Diagnosis: Type 2 Diabetes Mellitus. Initiated metformin 500mg BID, advised carbohydrate-controlled diet and 30min aerobic exercise 5x/week.",
        "summary": "Your blood sugar levels are high, indicating diabetes. Take metformin twice daily, reduce carbs in your diet, and exercise 30 minutes five times a week."
    },
    {
        "report": "MRI findings: L4-L5 disc herniation with posterior protrusion causing mild neural foraminal stenosis. Recommend physical therapy, NSAIDs for pain management, and consideration of epidural steroid injection if symptoms persist.",
        "summary": "Your MRI shows a herniated disc in your lower back pressing on nerves. Start physical therapy and take anti-inflammatory medication. If pain continues, we may recommend an injection."
    },
    {
        "report": "Echocardiogram reveals ejection fraction of 45%, mild left ventricular hypertrophy. Diagnosis: Heart failure with reduced ejection fraction. Prescribed lisinopril 10mg daily, carvedilol 3.125mg BID, furosemide 20mg daily.",
        "summary": "Your heart isn't pumping as efficiently as it should. Take three medications daily as prescribed: one for blood pressure, one for heart rate, and one water pill."
    },
    {
        "report": "Dermatological examination shows multiple nevi with irregular borders. Excisional biopsy of 6mm lesion on left shoulder performed. Pathology pending. Follow-up in 2 weeks for results and suture removal.",
        "summary": "We removed a suspicious mole from your shoulder for testing. Come back in 2 weeks to get the stitches removed and discuss the results."
    },
    {
        "report": "Complete blood count reveals hemoglobin 9.2g/dL, MCV 72fL, ferritin 8ng/mL. Diagnosis: Iron deficiency anemia. Prescribed ferrous sulfate 325mg TID with vitamin C, avoid taking with dairy products.",
        "summary": "You have low iron causing anemia. Take iron supplements three times daily with vitamin C, but not with milk or dairy products."
    },
    {
        "report": "Sleep study demonstrates AHI of 32 events/hour, oxygen desaturation to 82%. Diagnosis: Severe obstructive sleep apnea. CPAP therapy initiated at 10cm H2O, follow-up titration study in 6 weeks.",
        "summary": "Your sleep test shows you stop breathing frequently during sleep. You'll use a CPAP machine at night to help you breathe. We'll adjust settings in 6 weeks."
    },
    {
        "report": "Thyroid function tests: TSH 12.4 mIU/L, Free T4 0.6ng/dL. Diagnosis: Hypothyroidism. Started levothyroxine 50mcg daily, take on empty stomach 30min before breakfast. Recheck labs in 8 weeks.",
        "summary": "Your thyroid is underactive. Take thyroid medication every morning on an empty stomach, 30 minutes before eating. We'll retest in 8 weeks."
    },
    {
        "report": "Chest X-ray shows bilateral infiltrates, fever 101.5°F, SpO2 89% on room air. Diagnosis: Community-acquired pneumonia. Admitted for IV ceftriaxone 1g daily and azithromycin 500mg daily, supplemental oxygen via nasal cannula.",
        "summary": "You have pneumonia in both lungs. We're admitting you to the hospital for IV antibiotics and oxygen support until you improve."
    },
    {
        "report": "Colonoscopy revealed 8mm polyp at hepatic flexure, removed via snare polypectomy. Pathology: tubular adenoma with low-grade dysplasia. Recommend repeat colonoscopy in 3 years.",
        "summary": "We found and removed a benign polyp during your colonoscopy. It wasn't cancerous, but come back for another screening in 3 years."
    },
    {
        "report": "Patient History: 55-year-old male with hypertension (diagnosed 2018), type 2 diabetes (2019), and hyperlipidemia. Previous MI in 2020, underwent PCI with stent placement. Current medications: aspirin 81mg, atorvastatin 40mg, metformin 1000mg BID, lisinopril 20mg. Recent labs show HbA1c 7.1%, LDL 95mg/dL, BP 128/82.",
        "summary": "This 55-year-old man has high blood pressure, diabetes, and high cholesterol. He had a heart attack in 2020 and got a stent. He takes 4 daily medications. His recent tests show good control of diabetes and cholesterol, with stable blood pressure."
    },
    {
        "report": "Medical History Summary: 42F with recurrent UTIs (3 episodes this year), IBS diagnosed 2015, seasonal allergies. Surgical history includes C-section 2010 and appendectomy 2008. Family history significant for breast cancer (mother, age 58). Takes daily probiotic, cetirizine 10mg PRN.",
        "summary": "This 42-year-old woman has repeated bladder infections, irritable bowel syndrome, and seasonal allergies. She's had two previous surgeries. Her mother had breast cancer, so she should discuss screening. She takes probiotics daily and allergy medicine as needed."
    }
]

print(f"✅ Loaded {len(TRAINING_DATA)} training examples (including patient histories)")


✅ Loaded 12 training examples (including patient histories)


In [5]:
def prepare_dataset():
    """Format data for training"""
    formatted_data = []
    for item in TRAINING_DATA:
        text = f"### Medical Report:\n{item['report']}\n\n### Patient Summary:\n{item['summary']}<|endoftext|>"
        formatted_data.append({"text": text})
    return Dataset.from_list(formatted_data)

dataset = prepare_dataset()
print(f"✅ Dataset prepared: {len(dataset)} examples")

✅ Dataset prepared: 12 examples


In [6]:
MODEL_NAME = "distilgpt2"

print(f"🔄 Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto"
)

print(f"✅ Model loaded on: {model.device}")

🔄 Loading distilgpt2...


`torch_dtype` is deprecated! Use `dtype` instead!


✅ Model loaded on: cuda:0


In [7]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["c_attn"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
print("\n📊 Trainable Parameters:")
model.print_trainable_parameters()


📊 Trainable Parameters:
trainable params: 294,912 || all params: 82,207,488 || trainable%: 0.3587


In [8]:
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=512,  # Increased for longer patient histories
        padding="max_length"
    )

tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=dataset.column_names
)
print(f"✅ Dataset tokenized")

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

✅ Dataset tokenized


In [9]:
training_args = TrainingArguments(
    output_dir="./medisummarize",
    num_train_epochs=50,
    per_device_train_batch_size=2,
    learning_rate=3e-4,
    logging_steps=10,
    save_strategy="no",
    report_to="none",
    fp16=torch.cuda.is_available(),
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator
)

print("🚀 Starting training...")
print("="*60)
trainer.train()
print("="*60)
print("✅ Training complete!")


The model is already on multiple devices. Skipping the move to device specified in `args`.


🚀 Starting training...


`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
10,4.5425
20,4.2651
30,4.13
40,3.9586
50,3.7298
60,3.5597
70,3.4744
80,3.3276
90,3.2858
100,3.1757


✅ Training complete!


In [10]:
def extract_text_from_pdf(pdf_file):
    """Extract text from uploaded PDF file"""
    try:
        if pdf_file is None:
            return None

        # ✅ Handle various input types
        if hasattr(pdf_file, "read"):  # e.g., file-like object from Gradio/Streamlit
            pdf_bytes = pdf_file.read()
        elif isinstance(pdf_file, (bytes, bytearray)):
            pdf_bytes = pdf_file
        elif isinstance(pdf_file, str):  # path string
            with open(pdf_file, "rb") as f:
                pdf_bytes = f.read()
        else:
            return "Error extracting PDF: Unsupported file type."

        # ✅ Read using PyPDF2
        pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
        text = ""
        for page in pdf_reader.pages:
            text += page.extract_text() + "\n"

        return text.strip()
    except Exception as e:
        return f"Error extracting PDF: {str(e)}"


def generate_summary(medical_report):
    """Generate patient-friendly summary"""
    if not medical_report or not medical_report.strip():
        return "⚠️ Please enter a medical report or upload a PDF."

    # Truncate if too long
    if len(medical_report) > 1500:
        medical_report = medical_report[:1500] + "..."

    prompt = f"### Medical Report:\n{medical_report}\n\n### Patient Summary:\n"

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=200,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.2
        )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    if "### Patient Summary:" in generated_text:
        summary = generated_text.split("### Patient Summary:")[1].strip()
        return summary

    return generated_text


def process_pdf_and_summarize(pdf_file):
    """Extract text from PDF and generate summary"""
    if pdf_file is None:
        return "", "⚠️ Please upload a PDF file."

    extracted_text = extract_text_from_pdf(pdf_file)

    if isinstance(extracted_text, str) and extracted_text.startswith("Error"):
        return extracted_text, ""

    summary = generate_summary(extracted_text)

    return extracted_text, summary


print("✅ Inference functions ready")


✅ Inference functions ready


In [11]:
print("\n🧪 Testing model...")
test_report = "Blood pressure 165/95, BMI 31. Diagnosis: Stage 2 Hypertension. Started lisinopril 10mg daily, DASH diet advised."
test_summary = generate_summary(test_report)
print(f"✨ Test Summary: {test_summary[:100]}...")


🧪 Testing model...
✨ Test Summary: Your blood pressure 165 pounds and recommend a weekly vitamin B12-6 supplement to your doctor every ...


In [12]:
SAMPLE_REPORTS = {
    "Fracture (X-ray)": "X-ray imaging shows fracture of distal radius with dorsal angulation. Closed reduction performed, short arm cast applied. Prescribed ibuprofen 400mg TID for pain, follow-up in 1 week for repeat imaging.",

    "UTI (Lab Test)": "Urinalysis positive for nitrites and leukocyte esterase. Diagnosis: Urinary tract infection. Prescribed nitrofurantoin 100mg BID for 7 days. Increase fluid intake to 2-3L daily.",

    "COPD (Lung Function)": "Spirometry reveals FEV1/FVC ratio of 0.65, FEV1 55% predicted. Diagnosis: Moderate COPD. Initiated tiotropium inhaler 18mcg daily, albuterol rescue inhaler as needed.",

    "Appendicitis (Emergency)": "CT scan shows acute appendicitis with periappendiceal stranding. WBC 15,000. Emergency appendectomy scheduled. NPO status, IV antibiotics initiated.",

    "Patient History": "Patient History: 62-year-old female with long-standing rheumatoid arthritis (2005), osteoporosis, GERD. Takes methotrexate 15mg weekly, folic acid, omeprazole 20mg daily, calcium+vitamin D. Recent DEXA scan shows T-score -2.8 at lumbar spine. No fractures. Referred to rheumatology for biologic therapy consideration."
}

In [13]:
with gr.Blocks(title="MediSummarize", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # 🏥 MediSummarize: AI Medical Report Simplifier
    ### LoRA Fine-tuned LLM for Patient-Friendly Medical Summaries

    Upload PDF medical reports or paste text to get easy-to-understand summaries.
    """)

    with gr.Tab("📄 PDF Upload"):
        gr.Markdown("### Upload Medical Report PDF")

        with gr.Row():
            with gr.Column():
                pdf_input = gr.File(
                    label="Upload PDF Medical Report",
                    file_types=[".pdf"]
                )
                pdf_btn = gr.Button("🔍 Extract & Summarize", variant="primary", size="lg")

            with gr.Column():
                extracted_text = gr.Textbox(
                    label="📋 Extracted Text",
                    lines=8,
                    interactive=False
                )

        summary_from_pdf = gr.Textbox(
            label="✨ Patient-Friendly Summary",
            lines=6,
            interactive=False
        )

        pdf_btn.click(
            process_pdf_and_summarize,
            inputs=pdf_input,
            outputs=[extracted_text, summary_from_pdf]
        )

    with gr.Tab("✍️ Text Input"):
        gr.Markdown("### Enter or Paste Medical Report")

        with gr.Row():
            with gr.Column():
                text_input = gr.Textbox(
                    label="📄 Medical Report / Patient History",
                    placeholder="Paste medical report here...",
                    lines=10
                )

                with gr.Row():
                    summarize_btn = gr.Button("✨ Generate Summary", variant="primary", size="lg")
                    clear_btn = gr.Button("🗑️ Clear", size="lg")

                gr.Markdown("**📝 Try Sample Reports:**")
                sample_dropdown = gr.Dropdown(
                    choices=list(SAMPLE_REPORTS.keys()),
                    label="Select Sample",
                    value=None
                )
                load_sample_btn = gr.Button("Load Sample", size="sm")

            with gr.Column():
                summary_output = gr.Textbox(
                    label="💬 Patient-Friendly Summary",
                    lines=14,
                    interactive=False
                )

        # Button actions
        summarize_btn.click(generate_summary, inputs=text_input, outputs=summary_output)
        clear_btn.click(lambda: ("", ""), outputs=[text_input, summary_output])

        def load_sample(choice):
            return SAMPLE_REPORTS.get(choice, "")

        load_sample_btn.click(load_sample, inputs=sample_dropdown, outputs=text_input)

    with gr.Tab("ℹ️ Project Info"):
        gr.Markdown("""
        ## 🎯 About MediSummarize

        An AI-powered tool that converts complex medical reports into patient-friendly language using fine-tuned Large Language Models.

        ### 🔬 Technical Stack
        - **Base Model:** DistilGPT2 (82M parameters)
        - **Fine-tuning Method:** LoRA (Low-Rank Adaptation)
        - **Framework:** HuggingFace Transformers + PEFT
        - **PDF Processing:** PyPDF2
        - **Interface:** Gradio
        - **Training:** 15 epochs, 3e-4 learning rate

        ### 📊 Dataset
        - **Size:** 12 medical report-summary pairs
        - **Types:** Lab results, imaging reports, prescriptions, patient histories
        - **Domains:** Cardiology, pulmonology, endocrinology, orthopedics, gastroenterology

        ### ✨ Key Features
        - ✅ PDF medical report upload and text extraction
        - ✅ Patient history summarization
        - ✅ Multi-page document support
        - ✅ Real-time text input processing
        - ✅ Pre-loaded sample reports
        - ✅ LoRA efficient fine-tuning (trainable params: ~0.3M vs 82M)

        ### 🎓 Skills Demonstrated
        - Large Language Model fine-tuning
        - Parameter-Efficient Fine-Tuning (PEFT/LoRA)
        - Natural Language Processing
        - PDF document processing
        - Healthcare AI applications
        - Model deployment with Gradio
        - GPU-accelerated training

        ### 📈 Model Performance
        - Training time: ~2-3 minutes on T4 GPU
        - Inference time: <2 seconds per report
        - Memory efficient: LoRA reduces trainable parameters by 99.6%

        ### 🚀 Use Cases
        - Hospital patient education
        - Telemedicine platforms
        - Personal health record systems
        - Medical literacy tools
        - Healthcare accessibility

        ---

        **Created as a demonstration of LLM fine-tuning capabilities**
        """)

In [14]:
print("\n" + "="*60)
print("🎉 MediSummarize is ready!")
print("="*60)

demo.launch(share=True, debug=True)