# Chest X-ray Prediction Notebook

This notebook lets you load a trained CNN model and predict (Normal/Abnormal) on chest X-ray images using a file/folder picker. Predictions include confidence scores and image display.


In [9]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import ipywidgets as widgets
from IPython.display import display, clear_output
import os
import numpy as np


## Model Definition (copy from training notebook)


In [10]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 32 * 32, 128)
        self.fc2 = nn.Linear(128, 1)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        return x


## Load Model Weights


In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN().to(device)
model.load_state_dict(torch.load('best_model.keras', map_location=device))
model.eval()


RuntimeError: [enforce fail at inline_container.cc:166] . file in archive is not in a subdirectory: metadata.json

## Preprocessing Function (same as training)


In [None]:
def preprocess_image(img_path):
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.485], [0.229])
    ])
    img = Image.open(img_path).convert('RGB')
    return transform(img).unsqueeze(0)


## Prediction Function


In [None]:
def predict_image(img_path):
    input_tensor = preprocess_image(img_path).to(device)
    with torch.no_grad():
        output = model(input_tensor)
        prob = output.item()
        pred = 'Abnormal' if prob >= 0.5 else 'Normal'
    return pred, prob


## File Picker Widgets


In [None]:
file_picker = widgets.FileUpload(accept='.png,.jpg,.jpeg', multiple=True)
folder_picker = widgets.Text(
    value='',
    placeholder='Enter folder path with images',
    description='Folder:',
    disabled=False
)
mode_selector = widgets.ToggleButtons(
    options=['Single/Bulk Images', 'Folder'],
    description='Mode:',
    disabled=False
)
display(mode_selector)
output = widgets.Output()
display(output)
def on_mode_change(change):
    output.clear_output()
    if change['new'] == 'Single/Bulk Images':
        display(file_picker)
    else:
        display(folder_picker)
mode_selector.observe(on_mode_change, names='value')
on_mode_change({'new': mode_selector.value})


## Run Prediction and Display Results


In [None]:
def display_results(img_paths):
    for img_path in img_paths:
        pred, prob = predict_image(img_path)
        img = Image.open(img_path)
        display(img.resize((256, 256)))
        print(f'Image: {os.path.basename(img_path)} | Prediction: {pred} | Confidence: {prob:.2f}')
        print('-'*60)
def on_file_upload(change):
    output.clear_output()
    img_paths = []
    for fname, fileinfo in file_picker.value.items():
        with open(fname, 'wb') as f:
            f.write(fileinfo['content'])
        img_paths.append(fname)
    with output:
        display_results(img_paths)
def on_folder_submit(sender):
    output.clear_output()
    folder = folder_picker.value
    img_paths = [os.path.join(folder, f) for f in os.listdir(folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    with output:
        display_results(img_paths)
file_picker.observe(on_file_upload, names='value')
folder_picker.on_submit(on_folder_submit)
