In [1]:
!pip uninstall torch torchvision torchaudio -y
!pip cache purge
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121



Files removed: 8
Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch
  Downloading https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp312-cp312-win_amd64.whl (2449.3 MB)
     ---------------------------------------- 0.0/2.4 GB ? eta -:--:--
     ---------------------------------------- 0.0/2.4 GB ? eta -:--:--
     ---------------------------------------- 0.0/2.4 GB ? eta -:--:--
     ---------------------------------------- 0.0/2.4 GB 1.7 MB/s eta 0:24:18
     ---------------------------------------- 0.0/2.4 GB 3.1 MB/s eta 0:13:09
     ---------------------------------------- 0.0/2.4 GB 5.0 MB/s eta 0:08:06
     ---------------------------------------- 0.0/2.4 GB 5.8 MB/s eta 0:07:01
     ---------------------------------------- 0.0/2.4 GB 5.0 MB/s eta 0:08:11
     ---------------------------------------- 0.0/2.4 GB 5.0 MB/s eta 0:08:11
     ---------------------------------------- 0.0/2.4 GB 3.7 MB/s eta 0:11:03
     --------------------------------




     ---------------------------------------  2.4/2.4 GB 2.2 MB/s eta 0:00:22
     ---------------------------------------  2.4/2.4 GB 2.1 MB/s eta 0:00:23
     ---------------------------------------  2.4/2.4 GB 2.1 MB/s eta 0:00:23
     ---------------------------------------  2.4/2.4 GB 2.0 MB/s eta 0:00:23
     ---------------------------------------  2.4/2.4 GB 2.0 MB/s eta 0:00:23
     ---------------------------------------  2.4/2.4 GB 2.1 MB/s eta 0:00:22
     ---------------------------------------  2.4/2.4 GB 2.1 MB/s eta 0:00:22
     ---------------------------------------  2.4/2.4 GB 2.1 MB/s eta 0:00:22
     ---------------------------------------  2.4/2.4 GB 2.1 MB/s eta 0:00:21
     ---------------------------------------  2.4/2.4 GB 2.1 MB/s eta 0:00:21
     ---------------------------------------  2.4/2.4 GB 2.1 MB/s eta 0:00:20
     ---------------------------------------  2.4/2.4 GB 2.1 MB/s eta 0:00:20
     ---------------------------------------  2.4/2.4 GB 2.1 MB

In [2]:
pip install torch

Note: you may need to restart the kernel to use updated packages.




In [3]:
pip install torchvision

Note: you may need to restart the kernel to use updated packages.




In [4]:
pip install transformers

Note: you may need to restart the kernel to use updated packages.




In [5]:
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image

# Device setup
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load X-ray vs Non-X-ray Classifier
xray_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
num_features = xray_model.fc.in_features
xray_model.fc = nn.Linear(num_features, 2)  # binary classifier
xray_model.load_state_dict(torch.load("xray_vs_nonxray_resnet18.pth", map_location=DEVICE))
xray_model.to(DEVICE)
xray_model.eval()

# Transform (same for both models)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# X-ray Check Function
def is_xray(img_path, threshold=0.8):
    img = Image.open(img_path).convert("RGB")
    img = transform(img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        probs = torch.softmax(xray_model(img), dim=1)
        xray_prob = probs[0][0].item()  # assuming class "0" = xray, "1" = non_xray

    return xray_prob >= threshold, xray_prob


# Disease Ensemble (DenseNet + ResNet)
label_cols = ["Lung Opacity","Pleural Effusion","Edema","Atelectasis","Cardiomegaly"]

densenet = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
densenet.classifier = nn.Linear(densenet.classifier.in_features, len(label_cols))
densenet.load_state_dict(torch.load("best_model_dense.pth", map_location=DEVICE))
densenet.to(DEVICE).eval()

resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
resnet.fc = nn.Linear(resnet.fc.in_features, len(label_cols))
resnet.load_state_dict(torch.load("best_model_resnet.pth", map_location=DEVICE))
resnet.to(DEVICE).eval()

def ensemble_predict_image(img_path, thresholds):
    img = Image.open(img_path).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        out1 = torch.sigmoid(densenet(img_tensor))
        out2 = torch.sigmoid(resnet(img_tensor))
        probs = probs = torch.zeros_like(out1)
        for i, c in enumerate(label_cols):
          alpha = best_config[c]["alpha"]
          probs[:, i] = alpha * out1[:, i] + (1 - alpha) * out2[:, i]


        preds = torch.zeros_like(probs, dtype=torch.int)
        for i, class_name in enumerate(label_cols):
            preds[:, i] = (probs[:, i] > thresholds[class_name]).int()

    return probs.cpu().numpy(), preds.cpu().numpy()


# Full Pipeline
def predict_pipeline(img_path, thresholds):
    # Step 1: X-ray check
    valid_xray, prob = is_xray(img_path, threshold=0.8)

    if not valid_xray:
        return{
            "valid_xray": False,
            "message": f"Not a chest X-ray. Skipping disease prediction."
        }
        #return f"Not a chest X-ray (xray_prob={prob*100:.2f}%). Skipping disease prediction."

    # Step 2: Disease prediction
    probs, preds = ensemble_predict_image(img_path, thresholds)
    results = {}
    for i, name in enumerate(label_cols):
        prediction = "Positive" if preds[0][i] == 1 else "Negative"
        probability = probs[0][i] * 100
        results[name] = {
            "Prediction": prediction,
            "Probability": f"{probability:.2f}%"
        }
    return results


# Example usage
best_config = {
    'Lung Opacity': {'alpha': 0.05, 'threshold': 0.32},
    'Pleural Effusion': {'alpha': 0.90, 'threshold': 0.38},
    'Edema': {'alpha': 0.65, 'threshold': 0.30},
    'Atelectasis': {'alpha': 1.0, 'threshold': 0.20},
    'Cardiomegaly': {'alpha': 1.0, 'threshold': 0.14}
}

thresholds = {c: best_config[c]["threshold"] for c in label_cols}

img_path = "E://photo//wallpapers//1398943.jpg" # random image.non xray image is more fast
result = predict_pipeline(img_path, thresholds)

# Print results in the specified format
if isinstance(result, dict):
    for name, res in result.items():
        if isinstance(res, dict):  # only disease predictions
            print(f"{name}: {res['Prediction']} ({res['Probability']})")
        else:  # handles booleans or strings (non-xray case)
            print(f"{name}: {res}")
else:
    print(result)


#LLM part
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Knowledge base 
knowledge_base = {
    "Lung Opacity": {
        "description": "Abnormal areas visible on a chest X-ray, which may indicate fluid, infection, or other abnormal tissue in the lungs.",
        "causes": "pneumonia, tuberculosis, lung cancer, or pulmonary fibrosis",
        "treatment": "antibiotics for infections, antiviral therapy for viral causes, steroids for inflammation, or oxygen therapy as needed"
    },
    "Pleural Effusion": {
        "description": "Excess fluid around the lungs visible on X-ray.",
        "causes": "heart failure, lung infection, cancer, or kidney disease",
        "treatment": "fluid drainage, diuretics, or treating the underlying cause"
    },
    "Edema": {
        "description": "Fluid accumulation in tissues or organs, commonly seen in the lungs or extremities.",
        "causes": "heart failure, kidney disease, liver disease, infections, or medication side effects",
        "treatment": "diuretics, treating the underlying cause, fluid restriction, or oxygen therapy for pulmonary edema"
    },
    "Atelectasis": {
        "description": "Partial or complete collapse of a part of the lung, reducing gas exchange.",
        "causes": "airway obstruction from mucus or tumors, post-surgical collapse, or compression from pleural effusion",
        "treatment": "deep breathing exercises, chest physiotherapy, treating the underlying obstruction, or bronchoscopy if needed"
    },
    "Cardiomegaly": {
        "description": "Enlargement of the heart seen on X-ray, indicating possible heart dysfunction.",
        "causes": "hypertension, heart valve disease, cardiomyopathy, or heart failure",
        "treatment": "managing blood pressure, medications for heart failure, surgical repair of valves if necessary, or lifestyle modifications"
    }
}

def generate_medical_report(positive_findings, model_name="google/flan-t5-base"):
    """
    Generate medical report using Flan-T5 with optimized approach
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    report = ""

    for disease in positive_findings:
        if disease not in knowledge_base:
            print(f"Warning: {disease} not found in knowledge base. Skipping.")
            continue

        info = knowledge_base[disease]

        # Create structured input text for the LLM
        input_text = (
            f"Medical condition: {disease}. "
            f"Description: {info['description']} "
            f"Common causes: {info['causes']}. "
            f"Treatment options: {info['treatment']}."
        )

        # Optimized prompt for medical professionals
        prompt = (
            "Rewrite this medical information as a single clear paragraph for healthcare professionals. "
            "Include the description, causes, and treatments exactly once. "
            "Use professional medical terminology. "
            "Do not repeat information. "
            f"Start with the condition name '{disease}:' \n\n{input_text}"
        )

        try:
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400)
            outputs = model.generate(
                **inputs,
                max_length=200,
                min_length=60,
                num_beams=3,
                early_stopping=True,
                do_sample=False,
                repetition_penalty=1.3,
                no_repeat_ngram_size=3
            )

            generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Format with proper heading
            if not generated_text.startswith(f"=== {disease} ==="):
                formatted_text = f"=== {disease} ===\n{generated_text}"
            else:
                formatted_text = generated_text

            report += formatted_text + "\n\n"

        except Exception as e:
            print(f"Error generating text for {disease}: {e}")
            # Fallback to structured format
            fallback_text = (
                f"=== {disease} ===\n"
                f"{info['description']} This condition is commonly caused by {info['causes']}. "
                f"Treatment typically includes {info['treatment']}."
            )
            report += fallback_text + "\n\n"

    return report.strip()


  xray_model.load_state_dict(torch.load("xray_vs_nonxray_resnet18.pth", map_location=DEVICE))
  densenet.load_state_dict(torch.load("best_model_dense.pth", map_location=DEVICE))
  resnet.load_state_dict(torch.load("best_model_resnet.pth", map_location=DEVICE))


valid_xray: False
message: Not a chest X-ray. Skipping disease prediction.


'# Example usage\nif __name__ == "__main__":\n    positive_findings = positive_findings\n\n    print("=== MEDICAL FINDINGS REPORT ===")\n    if positive_findings:\n      report = generate_medical_report(positive_findings)\n      print(report)\n    else:\n      print("No positive findings to report.")\n'

In [6]:
from flask import Flask, render_template, request
import os
from werkzeug.utils import secure_filename

# Flask setup
app = Flask(__name__)
UPLOAD_FOLDER = "static/uploads"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER

@app.route("/", methods=["GET", "POST"])
def index():
    if request.method == "POST":
        # Save uploaded image
        file = request.files["file"]
        filename = secure_filename(file.filename)
        filepath = os.path.join(app.config["UPLOAD_FOLDER"], filename)
        file.save(filepath)

        # Run predictions
        results = predict_pipeline(filepath, thresholds)

        # Handle non-X-ray case
        if results.get("valid_xray") == False:
            report = results["message"]
            positive_findings = []
        else:
            # Collect positive findings
            positive_findings = [
                disease for disease, info in results.items()
                if isinstance(info, dict) and info.get("Prediction") == "Positive"
            ]
            report = generate_medical_report(positive_findings) if positive_findings else "No positive findings."

        return render_template("result.html", results=results, report=report, filename=filename)

    return render_template("index.html")


In [None]:
app.run(debug=True, use_reloader=False)

 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [26/Nov/2025 11:06:14] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [26/Nov/2025 11:06:30] "GET /favicon.ico HTTP/1.1" 404 -
127.0.0.1 - - [26/Nov/2025 11:18:50] "POST / HTTP/1.1" 200 -
127.0.0.1 - - [26/Nov/2025 11:18:50] "GET /static/uploads/view1_frontal.jpg HTTP/1.1" 200 -
127.0.0.1 - - [26/Nov/2025 11:19:07] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [26/Nov/2025 11:20:59] "POST / HTTP/1.1" 200 -
127.0.0.1 - - [26/Nov/2025 11:20:59] "GET /static/uploads/view1_frontal.jpg HTTP/1.1" 200 -
127.0.0.1 - - [26/Nov/2025 11:21:15] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [26/Nov/2025 11:21:34] "POST / HTTP/1.1" 200 -
127.0.0.1 - - [26/Nov/2025 11:21:34] "GET /static/uploads/view1_frontal.jpg HTTP/1.1" 200 -
127.0.0.1 - - [26/Nov/2025 11:21:44] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [26/Nov/2025 11:22:01] "POST / HTTP/1.1" 200 -
127.0.0.1 - - [26/Nov/2025 11:22:01] "GET /static/uploads/6f1e17a68b9e5f86d78d3bb4437b7ec2.jpg HTTP/1.1" 200 -
1