<a href="https://colab.research.google.com/github/tennisvish/AICervicalFracture/blob/main/FractureVisV1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install pydicom grad-cam gradio torch torchvision matplotlib numpy

Collecting pydicom
  Downloading pydicom-3.0.1-py3-none-any.whl.metadata (9.4 kB)
Collecting grad-cam
  Downloading grad-cam-1.5.5.tar.gz (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m85.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting gradio
  Downloading gradio-5.29.0-py3-none-any.whl.metadata (16 kB)
Collecting ttach (from grad-cam)
  Downloading ttach-0.0.3-py3-none-any.whl.metadata (5.2 kB)
Collecting aiofiles<25.0,>=22.0 (from gradio)
  Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.10.0 (from gradio)
  Downloading gradio_c

In [None]:
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models
import numpy as np
from PIL import Image
from torchvision import transforms
import pydicom
import io

# 1. Model Definition (Fixed Architecture)
class SpineModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = models.resnet18(pretrained=False)
        self.backbone.fc = nn.Identity()
        self.vertebrae_head = nn.Sequential(nn.Linear(512, 7), nn.Sigmoid())
        self.patient_head = nn.Sequential(nn.Linear(512, 1), nn.Sigmoid())

    def forward(self, x):
        features = self.backbone(x)
        return self.patient_head(features), self.vertebrae_head(features)

# 2. Load Model
model = SpineModel()
state_dict = torch.load("spine_fracture_weights.pth", map_location='cpu')
state_dict = {k.replace("vertebrae_heads.0.", "vertebrae_head.0."): v for k,v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()

# 3. Fixed DICOM Processor with CT-specific handling
def process_file(file):
    try:
        # Get bytes from file
        if hasattr(file, 'name'):  # File path
            with open(file.name, 'rb') as f:
                bytes_data = f.read()
        else:  # Direct bytes
            bytes_data = file.read()

        # Try DICOM first
        try:
            dicom = pydicom.dcmread(io.BytesIO(bytes_data), force=True)

            # Verify it's an image
            if not hasattr(dicom, 'pixel_array'):
                raise ValueError("DICOM has no pixel data")

            # Apply rescale for CT scans
            img = dicom.pixel_array.astype(float)
            if hasattr(dicom, 'RescaleSlope'):
                img = img * float(dicom.RescaleSlope)
            if hasattr(dicom, 'RescaleIntercept'):
                img = img + float(dicom.RescaleIntercept)

            # Apply bone window (optimized for spine)
            window_center = 400  # Bone window center
            window_width = 2000  # Bone window width
            img_min = window_center - window_width/2
            img_max = window_center + window_width/2
            img = np.clip((img - img_min) / window_width * 255, 0, 255)

            # Convert to RGB
            img = img.astype(np.uint8)
            if len(img.shape) == 2:
                img = np.stack([img]*3, axis=-1)

            return Image.fromarray(img)

        except Exception as e:
            print(f"DICOM processing failed: {str(e)}")
            # Fallback to regular image
            return Image.open(io.BytesIO(bytes_data)).convert('RGB')
    except Exception as e:
        print(f"File processing failed: {str(e)}")
        return None

# 4. Prediction Function
def predict(file):
    img = process_file(file)
    if not img:
        return "Invalid file (must be DICOM/JPEG/PNG)", None

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    with torch.no_grad():
        patient_prob, vertebra_probs = model(transform(img).unsqueeze(0))

    result = f"""**Patient Fracture Probability:** {patient_prob.item():.1%}

    **Vertebrae Probabilities:**
    - C1: {vertebra_probs[0][0].item():.1%} | C2: {vertebra_probs[0][1].item():.1%}
    - C3: {vertebra_probs[0][2].item():.1%} | C4: {vertebra_probs[0][3].item():.1%}
    - C5: {vertebra_probs[0][4].item():.1%} | C6: {vertebra_probs[0][5].item():.1%}
    - C7: {vertebra_probs[0][6].item():.1%}"""

    return result, img

# 5. Gradio Interface
with gr.Blocks(title="Spine Fracture Detector", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""<h1 style='text-align: center'>🏥 AI Spine Fracture Detector</h1>""")

    with gr.Row():
        with gr.Column(scale=1):
            file_input = gr.File(
                label="Upload Medical Image",
                file_types=[".dcm", ".png", ".jpg", ".jpeg"],
                type="filepath"
            )
            gr.Examples(
                examples=["103.dcm", "1.dcm"],
                inputs=file_input,
                label="Try Sample Images"
            )
        with gr.Column(scale=2):
            output_img = gr.Image(label="Processed Image", height=400)
            output_text = gr.Markdown(label="Analysis Results")

    file_input.change(
        fn=predict,
        inputs=file_input,
        outputs=[output_text, output_img]
    )

# Launch
demo.launch(share=True)



Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://c418ae83dbfd34949c.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


