In [1]:
# ============================
# 1. Install & Import Libraries
# ============================
# !pip install -q timm
# !pip install -q torchmetrics
# !pip install -q matplotlib seaborn opencv-python

import os
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision import transforms, models

import timm
import torchmetrics

In [2]:


# ============================
# 2. Dataset Preparation
# ============================

# Example: Using Kaggle Chest X-ray Pneumonia dataset
# Dataset folder structure:
# /train/NORMAL, /train/PNEUMONIA
# /val/NORMAL, /val/PNEUMONIA

DATA_DIR = "/kaggle/input/chest-xray-pneumonia/chest_xray"
train_dir = os.path.join(DATA_DIR, 'train')
val_dir   = os.path.join(DATA_DIR, 'val')

# Custom Dataset
class XrayDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = self.df.iloc[idx, 0]
        label = self.df.iloc[idx, 1]
        
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)  # X-ray is grayscale
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)    # Convert to 3 channels
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Create CSV for training
def create_csv(data_dir):
    classes = os.listdir(data_dir)
    paths, labels = [], []
    for cls in classes:
        cls_path = os.path.join(data_dir, cls)
        for img in os.listdir(cls_path):
            paths.append(os.path.join(cls_path, img))
            labels.append(0 if cls=="NORMAL" else 1)  # 0: Normal, 1: Pneumonia
    df = pd.DataFrame({'path': paths, 'label': labels})
    return df

train_df = create_csv(train_dir)
val_df   = create_csv(val_dir)

In [3]:


# ============================
# 3. Data Transforms
# ============================

train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

val_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

train_dataset = XrayDataset(train_df, transform=train_transform)
val_dataset   = XrayDataset(val_df, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [4]:


# ============================
# 4. Model Definition
# ============================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Using pre-trained DenseNet121
model = models.densenet121(pretrained=True)
model.classifier = nn.Linear(model.classifier.in_features, 2)  # 2 classes
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 194MB/s]


In [5]:
# ============================
# 5. Training Loop
# ============================

num_epochs = 3

for epoch in range(num_epochs):
    model.train()
    running_loss = 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    # Validation
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            correct += (preds==labels).sum().item()
            total += labels.size(0)
    acc = correct/total
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {running_loss/len(train_loader):.4f} | Val Acc: {acc:.4f}")

Epoch 1/3 | Loss: 0.1110 | Val Acc: 0.7500
Epoch 2/3 | Loss: 0.0493 | Val Acc: 1.0000
Epoch 3/3 | Loss: 0.0326 | Val Acc: 0.9375


In [6]:
# ============================
# 6. Template-based Report
# ============================

disease_map = {0: "Normal", 1: "Pneumonia"}
treatment_map = {
    "Normal": "No treatment required. Maintain healthy habits.",
    "Pneumonia": "Consult physician. Possible antibiotics and rest recommended."
}

def generate_report(img_path, model):
    model.eval()
    image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    image = val_transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(image)
        pred = torch.argmax(output, 1).item()
    
    disease = disease_map[pred]
    treatment = treatment_map[disease]
    
    report = f"Diagnosis: {disease}\nTreatment Suggestion: {treatment}"
    return report


In [7]:

# Test report
test_img = val_df.iloc[10,0]
print(generate_report(test_img, model))


Diagnosis: Normal
Treatment Suggestion: No treatment required. Maintain healthy habits.


In [8]:
# Define path to save model
MODEL_PATH = "/kaggle/working/xray_disease_model.pth"

# Save the model
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'disease_map': disease_map,
    'treatment_map': treatment_map
}, MODEL_PATH)

print(f"Model saved at {MODEL_PATH}")


Model saved at /kaggle/working/xray_disease_model.pth


In [9]:
# Reload model
loaded_model = models.densenet121(pretrained=False)  # pretrained=False when loading
loaded_model.classifier = nn.Linear(loaded_model.classifier.in_features, 2)
loaded_model = loaded_model.to(device)

checkpoint = torch.load(MODEL_PATH, map_location=device)
loaded_model.load_state_dict(checkpoint['model_state_dict'])

# Reuse disease and treatment maps
disease_map = checkpoint['disease_map']
treatment_map = checkpoint['treatment_map']

loaded_model.eval()
print("Model loaded successfully!")




Model loaded successfully!


In [37]:
from IPython.display import display
from ipywidgets import FileUpload

# Create uploader widget
uploader = FileUpload(accept='image/*', multiple=False)
display(uploader)

print("Upload your X-ray image using the widget above. Then run the next block to generate the report.")


FileUpload(value=(), accept='image/*', description='Upload')

Upload your X-ray image using the widget above. Then run the next block to generate the report.


In [38]:
# Check if a file has been uploaded
if len(uploader.value) > 0:
    # Kaggle FileUpload returns a tuple: (filename, content, mimetype)
    uploaded_file = uploader.value[0]  # take first file
    uploaded_filename = uploaded_file['name'] if 'name' in uploaded_file else uploaded_file[0]
    file_content = uploaded_file['content'] if 'content' in uploaded_file else uploaded_file[1]
    
    # Save uploaded file locally
    with open(uploaded_filename, 'wb') as f:
        f.write(file_content)
    
    print(f"Uploaded file saved as: {uploaded_filename}")
    
    # Generate report using the loaded model
    report = generate_report(uploaded_filename, loaded_model)
    print("\n--- Disease Report ---")
    print(report)
else:
    print("No file uploaded yet. Please upload an image in the previous block.")


Uploaded file saved as: person101_bacteria_483.jpeg

--- Disease Report ---
Diagnosis: Pneumonia
Treatment Suggestion: Consult physician. Possible antibiotics and rest recommended.


In [16]:
# Predefined disease info
disease_info = {
    "Normal": "Your X-ray appears normal. Keep a healthy lifestyle and routine checkups.",
    "Pneumonia": (
        "Pneumonia is an infection that inflames the air sacs in the lungs.\n"
        "Causes: bacterial, viral, or fungal infection.\n"
        "Symptoms: cough, fever, difficulty breathing.\n"
        "Treatment: antibiotics (if bacterial), rest, fluids, consult a physician."
    )
}

# Optional: simple Q&A for chat
chat_responses = {
    "cause": "Causes can be bacterial, viral, or fungal infections.",
    "symptom": "Common symptoms include cough, fever, and difficulty breathing.",
    "treatment": "Treatment may include antibiotics, rest, and fluids.",
    "prevention": "Prevention includes vaccination, hygiene, and avoiding sick contacts."
}


In [17]:
def disease_chat(disease, user_input):
    disease = disease_info.get(disease, "Unknown disease")
    
    # Lowercase user input for matching
    msg = user_input.lower()
    
    if "cause" in msg or "reason" in msg:
        return chat_responses.get("cause")
    elif "symptom" in msg:
        return chat_responses.get("symptom")
    elif "treatment" in msg:
        return chat_responses.get("treatment")
    elif "prevent" in msg:
        return chat_responses.get("prevention")
    else:
        return "I can provide info about causes, symptoms, treatment, and prevention. Try asking one of these."


In [18]:
# Generate diagnosis first
report = generate_report(uploaded_filename, loaded_model)
print("\n--- Disease Report ---")
print(report)

# Extract disease from report (simple)
disease_detected = report.split("\n")[0].replace("Diagnosis: ", "").strip()

# Example chat loop
print("\n--- Chat about your disease ---")
print(f"You can ask questions about {disease_detected}. Type 'exit' to stop.")

# while True:
#     user_input = input("You: ")
#     if user_input.lower() in ["exit", "quit"]:
#         print("Exiting chat.")
#         break
#     reply = disease_chat(disease_detected, user_input)
#     print("Bot:", reply)



--- Disease Report ---
Diagnosis: Pneumonia
Treatment Suggestion: Consult physician. Possible antibiotics and rest recommended.

--- Chat about your disease ---
You can ask questions about Pneumonia. Type 'exit' to stop.


In [19]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
chat_model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small").to(device)

print("Hugging Face chat model loaded!")


tokenizer_config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/641 [00:00<?, ?B/s]

2025-09-07 12:32:47.722483: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757248367.972082      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757248368.042349      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


model.safetensors:   0%|          | 0.00/351M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Hugging Face chat model loaded!


In [20]:
def hf_disease_chat(disease, user_input, chat_history_ids=None):
    # Prepend disease info to make it disease-aware
    input_text = f"Disease: {disease}. Question: {user_input}"
    
    # Encode the input and append to chat history if exists
    new_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt').to(device)
    
    if chat_history_ids is not None:
        bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
    else:
        bot_input_ids = new_input_ids
    
    # Generate a response
    chat_history_ids = chat_model.generate(bot_input_ids, max_length=2000, pad_token_id=tokenizer.eos_token_id)
    
    # Decode the response
    response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
    
    return response, chat_history_ids


In [22]:
# Disease detected from previous block
disease_detected = report.split("\n")[0].replace("Diagnosis: ", "").strip()

print(f"Chat about {disease_detected}. Type 'exit' to stop.")

chat_history = None

# while True:
#     user_input = input("You: ")
#     if user_input.lower() in ["exit", "quit"]:
#         print("Exiting chat.")
#         break
    
#     reply, chat_history = hf_disease_chat(disease_detected, user_input, chat_history)
#     print("Bot:", reply)


Chat about Pneumonia. Type 'exit' to stop.


In [23]:
# Predefined disease info
disease_info = {
    "Normal": "Your X-ray appears normal. Keep a healthy lifestyle and routine checkups.",
    "Pneumonia": (
        "Pneumonia is an infection that inflames the air sacs in the lungs.\n"
        "Causes: bacterial, viral, or fungal infection.\n"
        "Symptoms: cough, fever, difficulty breathing.\n"
        "Treatment: antibiotics (if bacterial), rest, fluids, consult a physician.\n"
        "Prevention: vaccination, hygiene, avoiding sick contacts."
    )
}

# Simple mapping for user questions
chat_responses = {
    "cause": "Causes can be bacterial, viral, or fungal infections.",
    "symptom": "Common symptoms include cough, fever, and difficulty breathing.",
    "treatment": "Treatment may include antibiotics, rest, and fluids.",
    "preventiwhyon": "Prevention includes vaccination, hygiene, and avoiding sick contacts."
}


In [24]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load DialoGPT-small for general conversation
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
chat_model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small").to(device)

def hybrid_disease_chat(disease, user_input, chat_history_ids=None):
    msg = user_input.lower()
    
    # First, check if user question matches medical info
    if "cause" in msg or "reason" in msg:
        return chat_responses.get("cause"), chat_history_ids
    elif "symptom" in msg:
        return chat_responses.get("symptom"), chat_history_ids
    elif "treatment" in msg:
        return chat_responses.get("treatment"), chat_history_ids
    elif "prevent" in msg:
        return chat_responses.get("prevention"), chat_history_ids
    else:
        # Otherwise, fall back to general chat model
        input_text = f"Disease: {disease}. User: {user_input}"
        new_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt').to(device)
        if chat_history_ids is not None:
            bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
        else:
            bot_input_ids = new_input_ids

        chat_history_ids = chat_model.generate(bot_input_ids, max_length=2000, pad_token_id=tokenizer.eos_token_id)
        response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
        return response, chat_history_ids
def safe_disease_chat(disease, user_input):
    msg = user_input.lower()
    if "cause" in msg or "reason" in msg:
        return chat_responses.get("cause")
    elif "symptom" in msg:
        return chat_responses.get("symptom")
    elif "treatment" in msg:
        return chat_responses.get("treatment")
    elif "prevent" in msg:
        return chat_responses.get("prevention")
    else:
        return "I can only answer questions about causes, symptoms, treatment, or prevention of this disease."


In [25]:
# Disease detected from X-ray
disease_detected = report.split("\n")[0].replace("Diagnosis: ", "").strip()

print(f"Chat about {disease_detected}. Type 'exit' to stop.")

chat_history = None

# while True:
#     user_input = input("You: ")
#     if user_input.lower() in ["exit", "quit"]:
#         print("Exiting chat.")
#         break
    
#     reply, chat_history = hybrid_disease_chat(disease_detected, user_input, chat_history)
#     print("Bot:", reply)


Chat about Pneumonia. Type 'exit' to stop.


In [26]:
# Predefined template answers
chat_responses = {
    "cause": "Causes can be bacterial, viral, or fungal infections.",
    "symptom": "Common symptoms include cough, fever, and difficulty breathing.",
    "treatment": "Treatment may include antibiotics, rest, and fluids.",
    "prevention": "Prevention includes vaccination, hygiene, and avoiding sick contacts."
}

# Safe disease chat function
def safe_disease_chat(disease, user_input):
    msg = user_input.lower()
    if "cause" in msg or "reason" in msg:
        return chat_responses.get("cause")
    elif "symptom" in msg:
        return chat_responses.get("symptom")
    elif "treatment" in msg:
        return chat_responses.get("treatment")
    elif "prevent" in msg:
        return chat_responses.get("prevention")
    else:
        return "I can only answer questions about causes, symptoms, treatment, or prevention of this disease."


In [39]:
# Example: disease detected from previous report
disease_detected = report.split("\n")[0].replace("Diagnosis: ", "").strip()

print(f"Chat about the X-Ray Image, Disease detected:{disease_detected}. Type 'exit' to stop.")

while True:
    user_input = input("You: ")
    if user_input.lower() in ["exit", "quit"]:
        print("Exiting chat.")
        break
    
    reply = safe_disease_chat(disease_detected, user_input)
    print("Bot:", reply)


Chat about the X-Ray Image, Disease detected:Pneumonia. Type 'exit' to stop.


You:  hi, what is the cause of this ?


Bot: Causes can be bacterial, viral, or fungal infections.


You:  how do i understand? what symptom?


Bot: Common symptoms include cough, fever, and difficulty breathing.


You:  is there any common treatment?


Bot: Treatment may include antibiotics, rest, and fluids.


You:  any other symptom?


Bot: Common symptoms include cough, fever, and difficulty breathing.


You:  how can i prevent this? is there any prevention?


Bot: Prevention includes vaccination, hygiene, and avoiding sick contacts.


You:  I love You too!


Bot: I can only answer questions about causes, symptoms, treatment, or prevention of this disease.


You:  exit


Exiting chat.
