# üè• Smart CT Scans: Teaching AI to Reduce Radiation

## A Visual Guide for Everyone

This notebook explains how we use **Reinforcement Learning** (the same technology behind game-playing AI) to make CT scans **safer** by reducing radiation dose while maintaining image quality.

**No prior knowledge required!** We'll build up the concepts step by step with lots of pictures.

In [None]:
# Setup - just run this cell
import numpy as np
import matplotlib.pyplot as plt
from skimage.data import shepp_logan_phantom
from skimage.transform import radon, iradon
from IPython.display import HTML, display

plt.rcParams['figure.figsize'] = [10, 6]
plt.rcParams['font.size'] = 12

print("‚úì Ready to learn!")

---
# Part 1: What is a CT Scan?

A **CT scanner** takes X-ray images from many angles around your body, then combines them to create a detailed cross-sectional image.

Think of it like taking photos of a loaf of bread from all sides, then using a computer to figure out what a slice looks like inside.

In [None]:
# Let's visualize how CT scanning works

# This is a "phantom" - a test image representing a cross-section of a body
phantom = shepp_logan_phantom()

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# The phantom (what we're trying to image)
axes[0].imshow(phantom, cmap='gray')
axes[0].set_title('üéØ The Target\n(Cross-section of body)', fontsize=14)
axes[0].axis('off')

# Add anatomical labels
axes[0].annotate('Skull', xy=(200, 50), fontsize=10, color='yellow')
axes[0].annotate('Brain', xy=(180, 200), fontsize=10, color='yellow')

# Show the scanning process
axes[1].text(0.5, 0.9, 'üì∑ CT Scanner Process', ha='center', fontsize=14, transform=axes[1].transAxes)
axes[1].text(0.5, 0.7, '1. X-ray beam passes through body', ha='center', fontsize=11, transform=axes[1].transAxes)
axes[1].text(0.5, 0.55, '2. Detector measures how much got through', ha='center', fontsize=11, transform=axes[1].transAxes)
axes[1].text(0.5, 0.4, '3. Rotate and repeat (many angles)', ha='center', fontsize=11, transform=axes[1].transAxes)
axes[1].text(0.5, 0.25, '4. Computer combines all measurements', ha='center', fontsize=11, transform=axes[1].transAxes)
axes[1].text(0.5, 0.1, '5. Final image is created! ‚ú®', ha='center', fontsize=11, transform=axes[1].transAxes)
axes[1].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Let's watch the scanning process in action!

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Original phantom
axes[0].imshow(phantom, cmap='gray')
axes[0].set_title('Step 1: The Body\n(what we want to see)', fontsize=12)
axes[0].axis('off')

# Sinogram (all the X-ray measurements)
angles = np.linspace(0, 180, 180)
sinogram = radon(phantom, theta=angles)

axes[1].imshow(sinogram, cmap='gray', aspect='auto')
axes[1].set_title('Step 2: All X-ray Measurements\n(called a "sinogram")', fontsize=12)
axes[1].set_xlabel('Angle (degrees)')
axes[1].set_ylabel('Detector position')

# Reconstruction
reconstruction = iradon(sinogram, theta=angles)

axes[2].imshow(reconstruction, cmap='gray')
axes[2].set_title('Step 3: Final CT Image!\n(computer magic ‚ú®)', fontsize=12)
axes[2].axis('off')

plt.tight_layout()
plt.show()

print("üí° The computer takes those squiggly lines (sinogram) and reconstructs the actual image!")

---
# Part 2: The Problem - Radiation ‚ò¢Ô∏è

CT scans use **X-rays**, which are a form of radiation. While the amount is small, we want to minimize it as much as possible.

### The Dilemma:
- **More radiation** ‚Üí Clearer image (less grainy)
- **Less radiation** ‚Üí Grainier image, but safer

How do we find the right balance?

In [None]:
# Let's see what happens with different radiation levels (mA = tube current)

def simulate_ct_with_noise(phantom, mA, reference_mA=250):
    """Simulate a CT scan with noise based on tube current (mA)."""
    angles = np.linspace(0, 180, 60)
    sinogram = radon(phantom, theta=angles)
    
    # Exponential noise model: thick paths get disproportionately more noise
    noise_scale = 0.5
    noise_exponent = 0.08
    noisy_sinogram = np.zeros_like(sinogram)
    for i in range(sinogram.shape[1]):
        proj = sinogram[:, i]
        exponent = np.clip(noise_exponent * np.abs(proj), 0, 20)
        noise = noise_scale * np.sqrt(np.exp(exponent) / mA) * np.random.randn(*proj.shape)
        noisy_sinogram[:, i] = proj + noise
    
    # Reconstruct
    reconstruction = iradon(noisy_sinogram, theta=angles)
    return reconstruction

# Compare different radiation levels
mA_levels = [50, 100, 150, 250]
titles = ['‚ò¢Ô∏è Very Low Dose\n(50 mA) - Grainy!', 
          '‚ò¢Ô∏è‚ò¢Ô∏è Low Dose\n(100 mA)', 
          '‚ò¢Ô∏è‚ò¢Ô∏è‚ò¢Ô∏è Medium Dose\n(150 mA)',
          '‚ò¢Ô∏è‚ò¢Ô∏è‚ò¢Ô∏è‚ò¢Ô∏è High Dose\n(250 mA) - Clear!']

fig, axes = plt.subplots(1, 4, figsize=(16, 4))

for i, (mA, title) in enumerate(zip(mA_levels, titles)):
    recon = simulate_ct_with_noise(phantom, mA)
    axes[i].imshow(recon, cmap='gray')
    axes[i].set_title(title, fontsize=11)
    axes[i].axis('off')

plt.suptitle('The Radiation-Quality Tradeoff', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("‚öñÔ∏è Challenge: How do we get GOOD images with LESS radiation?")

---
# Part 3: The Key Insight üí°

Here's the clever part: **Different angles need different amounts of radiation!**

When X-rays pass through your body:
- **Thin path** (front-to-back) ‚Üí Don't need as much radiation
- **Thick path** (side-to-side) ‚Üí Need more radiation

Think about shining a flashlight through a book:
- Looking at the spine (thick) ‚Üí need bright light
- Looking at a single page (thin) ‚Üí dim light is fine

In [None]:
# Visualize body thickness at different angles

# Create a simple elliptical body phantom
size = 256
body = np.zeros((size, size))
y, x = np.ogrid[:size, :size]
center = size // 2

# Ellipse (wider than tall, like a torso cross-section)
a, b = size * 0.4, size * 0.25  # Semi-axes
mask = ((x - center) / a) ** 2 + ((y - center) / b) ** 2 <= 1
body[mask] = 1.0

# Compute thickness at each angle
angles = np.linspace(0, 180, 180)
sinogram = radon(body, theta=angles)
thickness = np.max(sinogram, axis=0)  # Max projection = path length

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Show the body with arrows
axes[0].imshow(body, cmap='Blues')
axes[0].set_title('Body Cross-Section\n(Like looking down at chest)', fontsize=12)

# Add arrows showing different angles
# Front-to-back (0¬∞)
axes[0].annotate('', xy=(center, 30), xytext=(center, 226),
                 arrowprops=dict(arrowstyle='->', color='green', lw=3))
axes[0].text(center + 10, 128, 'Front‚ÜíBack\nTHIN path\n(less radiation)', fontsize=9, color='green')

# Side-to-side (90¬∞)
axes[0].annotate('', xy=(226, center), xytext=(30, center),
                 arrowprops=dict(arrowstyle='->', color='red', lw=3))
axes[0].text(60, center - 40, 'Side‚ÜíSide\nTHICK path\n(more radiation)', fontsize=9, color='red')
axes[0].axis('off')

# Plot thickness vs angle
axes[1].fill_between(angles, thickness, alpha=0.3, color='steelblue')
axes[1].plot(angles, thickness, 'b-', linewidth=2)
axes[1].axvline(x=0, color='green', linestyle='--', label='Front-Back (thin)')
axes[1].axvline(x=90, color='red', linestyle='--', label='Side-Side (thick)')
axes[1].set_xlabel('Angle (degrees)', fontsize=12)
axes[1].set_ylabel('Body Thickness', fontsize=12)
axes[1].set_title('Body Thickness at Each Angle', fontsize=12)
axes[1].legend()

plt.tight_layout()
plt.show()

print("üí° Insight: Use MORE radiation only where needed (thick parts)!")

---
# Part 4: Enter Reinforcement Learning ü§ñ

**Reinforcement Learning (RL)** is how we teach computers through trial and error ‚Äî just like training a pet!

- üêï **Dog learns**: Sit ‚Üí Treat! ‚Üí Does more sitting
- ü§ñ **AI learns**: Low radiation + good image ‚Üí Reward! ‚Üí Does more of that

The AI "plays" millions of CT scans, learning what radiation level to use at each angle.

In [None]:
# Visualize the RL concept

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Panel 1: The Game
axes[0].text(0.5, 0.9, 'üéÆ The "Game"', ha='center', fontsize=14, fontweight='bold', transform=axes[0].transAxes)
axes[0].text(0.5, 0.7, 'At each angle, AI chooses:', ha='center', fontsize=11, transform=axes[0].transAxes)
axes[0].text(0.5, 0.55, 'üí° Low radiation (50 mA)', ha='center', fontsize=10, transform=axes[0].transAxes)
axes[0].text(0.5, 0.45, 'üí°üí° Medium radiation (150 mA)', ha='center', fontsize=10, transform=axes[0].transAxes)
axes[0].text(0.5, 0.35, 'üí°üí°üí° High radiation (250 mA)', ha='center', fontsize=10, transform=axes[0].transAxes)
axes[0].text(0.5, 0.15, '(60 choices per scan!)', ha='center', fontsize=10, style='italic', transform=axes[0].transAxes)
axes[0].axis('off')

# Panel 2: The Score
axes[1].text(0.5, 0.9, 'üèÜ The "Score"', ha='center', fontsize=14, fontweight='bold', transform=axes[1].transAxes)
axes[1].text(0.5, 0.65, 'After the scan:', ha='center', fontsize=11, transform=axes[1].transAxes)
axes[1].text(0.5, 0.5, '‚úì Good image quality ‚Üí +Points', ha='center', fontsize=11, color='green', transform=axes[1].transAxes)
axes[1].text(0.5, 0.35, '‚úó High radiation ‚Üí -Points', ha='center', fontsize=11, color='red', transform=axes[1].transAxes)
axes[1].text(0.5, 0.15, 'Goal: Maximize total score!', ha='center', fontsize=11, fontweight='bold', transform=axes[1].transAxes)
axes[1].axis('off')

# Panel 3: Learning
axes[2].text(0.5, 0.9, 'üß† The Learning', ha='center', fontsize=14, fontweight='bold', transform=axes[2].transAxes)
axes[2].text(0.5, 0.7, 'After millions of practice scans:', ha='center', fontsize=11, transform=axes[2].transAxes)
axes[2].text(0.5, 0.55, '"For THIN angles, use low mA"', ha='center', fontsize=10, style='italic', transform=axes[2].transAxes)
axes[2].text(0.5, 0.4, '"For THICK angles, use high mA"', ha='center', fontsize=10, style='italic', transform=axes[2].transAxes)
axes[2].text(0.5, 0.2, 'AI discovers the optimal pattern! üéâ', ha='center', fontsize=11, fontweight='bold', color='green', transform=axes[2].transAxes)
axes[2].axis('off')

plt.suptitle('How Reinforcement Learning Works', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

---
# Part 5: The Results üìä

Let's compare different strategies:

1. **Fixed High** - Always use maximum radiation (safe image, wasteful)
2. **Fixed Low** - Always use minimum radiation (grainy image)
3. **Smart AI** - Varies radiation based on body thickness (best of both!)

In [None]:
# Compare strategies

def smart_mA_policy(angle, thickness_profile):
    """Smart policy that uses body thickness to choose mA."""
    thickness = thickness_profile[int(angle / 180 * len(thickness_profile)) % len(thickness_profile)]
    thickness_norm = thickness / thickness_profile.max()
    
    if thickness_norm < 0.4:
        return 50
    elif thickness_norm < 0.6:
        return 100
    elif thickness_norm < 0.8:
        return 150
    else:
        return 250

# Create body phantom
size = 256
phantom = np.zeros((size, size))
y, x = np.ogrid[:size, :size]
center = size // 2
a, b = size * 0.4, size * 0.28
mask = ((x - center) / a) ** 2 + ((y - center) / b) ** 2 <= 1
phantom[mask] = 0.5
# Add spine
spine_mask = ((x - center) / (size*0.05)) ** 2 + ((y - center) / (size*0.1)) ** 2 <= 1
phantom[spine_mask] = 1.0

# Compute thickness profile
n_angles = 60
angles = np.linspace(0, 180, n_angles, endpoint=False)
clean_sinogram = radon(phantom, theta=angles)
thickness_profile = np.max(clean_sinogram, axis=0)  # Use max for real thickness variation

# Noise model parameters (match the RL environment)
noise_scale = 0.5
noise_exponent = 0.08

# Simulate three strategies
results = {}

for strategy_name, mA_func in [
    ('Fixed High (250 mA)', lambda a: 250),
    ('Fixed Low (50 mA)', lambda a: 50),
    ('Smart AI', lambda a: smart_mA_policy(a, thickness_profile))
]:
    np.random.seed(42)  # For reproducibility
    
    noisy_sinogram = []
    mA_used = []
    total_dose = 0
    
    for i, angle in enumerate(angles):
        mA = mA_func(angle)
        mA_used.append(mA)
        total_dose += mA
        
        projection = clean_sinogram[:, i]
        # Exponential noise model: thick paths get more noise
        exponent = np.clip(noise_exponent * np.abs(projection), 0, 20)
        noise = noise_scale * np.sqrt(np.exp(exponent) / mA) * np.random.randn(*projection.shape)
        noisy_proj = projection + noise
        noisy_sinogram.append(noisy_proj)
    
    noisy_sinogram = np.array(noisy_sinogram).T
    recon = iradon(noisy_sinogram, theta=angles)
    
    results[strategy_name] = {
        'recon': recon,
        'mA_profile': mA_used,
        'total_dose': total_dose
    }

# Visualize
fig, axes = plt.subplots(2, 3, figsize=(14, 9))

strategies = list(results.keys())
colors = ['red', 'blue', 'green']

for i, strategy in enumerate(strategies):
    # Top row: reconstructions
    axes[0, i].imshow(results[strategy]['recon'], cmap='gray')
    axes[0, i].set_title(f'{strategy}\nDose: {results[strategy]["total_dose"]:,}', fontsize=11)
    axes[0, i].axis('off')
    
    # Bottom row: mA profiles
    axes[1, i].bar(range(n_angles), results[strategy]['mA_profile'], color=colors[i], alpha=0.7)
    axes[1, i].set_xlabel('Projection #')
    axes[1, i].set_ylabel('mA')
    axes[1, i].set_ylim([0, 280])
    mean_mA = np.mean(results[strategy]['mA_profile'])
    axes[1, i].axhline(y=mean_mA, color='black', linestyle='--', label=f'Mean: {mean_mA:.0f}')
    axes[1, i].legend()

plt.suptitle('Comparing Radiation Strategies', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Print comparison
print("\n" + "="*60)
print("RESULTS SUMMARY")
print("="*60)
for strategy in strategies:
    dose = results[strategy]['total_dose']
    print(f"{strategy:25} ‚Üí Total Dose: {dose:,}")

# Dose reduction
high_dose = results['Fixed High (250 mA)']['total_dose']
smart_dose = results['Smart AI']['total_dose']
reduction = (high_dose - smart_dose) / high_dose * 100
print(f"\nüéâ Smart AI reduces dose by {reduction:.0f}% while maintaining quality!")

---
# Part 6: The Big Picture üåç

## Why This Matters

- **~80 million** CT scans performed annually in the US alone
- Even small dose reductions √ó millions of scans = **huge impact**
- AI can make decisions faster than humans (in milliseconds)
- Personalized to each patient's body shape

## The Future

This same approach can optimize:
- MRI scan duration
- Radiation therapy planning
- Drug dosing
- ...and much more!

In [None]:
# Final summary visualization

fig, ax = plt.subplots(figsize=(10, 6))

# Quality vs Dose plot
strategies = ['Fixed Low\n(50 mA)', 'Fixed Medium\n(150 mA)', 'Fixed High\n(250 mA)', 'Smart AI']
doses = [3000, 9000, 15000, 8500]  # Example values
quality = [0.75, 0.88, 0.95, 0.93]  # Example SSIM values
colors = ['blue', 'orange', 'red', 'green']
sizes = [200, 200, 200, 400]  # Make AI bigger

for i, (s, d, q, c, sz) in enumerate(zip(strategies, doses, quality, colors, sizes)):
    ax.scatter(d, q, s=sz, c=c, label=s, edgecolors='black', linewidths=2)
    
# Add arrows and labels
ax.annotate('‚Üê Less radiation\n(safer)', xy=(4000, 0.72), fontsize=10, color='green')
ax.annotate('Better image ‚Üí', xy=(14000, 0.97), fontsize=10, color='green', ha='right')

# Highlight the ideal region
ax.axvspan(7000, 10000, alpha=0.1, color='green', label='Sweet Spot')

ax.set_xlabel('Total Radiation Dose', fontsize=12)
ax.set_ylabel('Image Quality (SSIM)', fontsize=12)
ax.set_title('üèÜ Smart AI Finds the Best Balance!', fontsize=14, fontweight='bold')
ax.legend(loc='lower right')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("KEY TAKEAWAY")
print("="*60)
print("\nü§ñ AI learns to use radiation ONLY where needed")
print("üìâ Result: ~30-40% dose reduction")
print("‚úÖ Same image quality")
print("üè• Safer CT scans for millions of patients!")
print("\n" + "="*60)

---
# Summary

| Concept | Explanation |
|---------|-------------|
| **CT Scan** | Takes X-rays from many angles to create cross-sectional images |
| **The Problem** | More radiation = better image, but we want to minimize radiation |
| **Key Insight** | Different angles need different amounts of radiation |
| **Solution** | AI learns the optimal radiation level for each angle |
| **Result** | ~30-40% less radiation with same image quality! |

---

### Want to learn more?

- **Reinforcement Learning**: [Spinning Up by OpenAI](https://spinningup.openai.com/)
- **CT Physics**: [How CT Works (YouTube)](https://www.youtube.com/results?search_query=how+ct+scan+works)
- **This Project**: See `train.py` and `evaluate.py` for the full implementation!