# EchoQuality Demo Notebook

This notebook demonstrates how to use the EchoQuality model to assess the quality of echocardiogram videos. The model analyzes DICOM files and predicts whether the video quality is acceptable.

## Overview

The EchoQuality model uses a pre-trained R(2+1)D model to classify the quality of echocardiogram videos. The model processes DICOM files, applies masking to isolate the ultrasound region, and classifies videos as PASS/FAIL with a threshold of 0.3.

## Setup

First, let's import the necessary libraries and set up the environment.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
import pydicom
import cv2
from tqdm import tqdm
from torchvision.models.video import r2plus1d_18
import json

# Import functions from our modules
from inference.EchoPrime_qc import mask_outside_ultrasound, crop_and_scale, get_quality_issues
from training.echo_model_evaluation import visualize_gradcam

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Load the Model

Now, let's load the pre-trained EchoQuality model.

In [None]:
# Constants for video processing
frames_to_take = 32
frame_stride = 2
video_size = 112
mean = torch.tensor([29.110628, 28.076836, 29.096405]).reshape(3, 1, 1, 1)
std = torch.tensor([47.989223, 46.456997, 47.20083]).reshape(3, 1, 1, 1)

# Load model
model_weights = "./weights/video_quality_model.pt"
model = r2plus1d_18(num_classes=1)
model.load_state_dict(torch.load(model_weights, map_location=device))
model = model.to(device)
model.eval()

print(f"Model loaded from {model_weights}")

## Process a DICOM File

Let's define a function to process a DICOM file and prepare it for the model.

In [None]:
def process_dicom(dicom_path, save_mask_images=False):
    """
    Process a single DICOM file and prepare it for the model.
    
    Args:
        dicom_path (str): Path to the DICOM file
        save_mask_images (bool): Whether to save mask images
        
    Returns:
        torch.Tensor: Processed video tensor
    """
    try:
        # Read DICOM file
        dcm = pydicom.dcmread(dicom_path)
        pixels = dcm.pixel_array
        
        # Print DICOM info
        print(f"DICOM shape: {pixels.shape}")
        print(f"DICOM dimensions: {pixels.ndim}")
        
        # Handle different dimensions
        if pixels.ndim < 3 or pixels.shape[2] == 3:
            print(f"Skipping {dicom_path}: Invalid dimensions {pixels.shape}")
            return None
        
        # If single channel, repeat to 3 channels
        if pixels.ndim == 3:
            pixels = np.repeat(pixels[..., None], 3, axis=3)
        
        # Mask everything outside ultrasound region
        filename = os.path.basename(dicom_path)
        pixels = mask_outside_ultrasound(pixels, filename if save_mask_images else None)
        
        # Model specific preprocessing
        x = np.zeros((len(pixels), video_size, video_size, 3))
        for i in range(len(x)):
            x[i] = crop_and_scale(pixels[i])
        
        # Convert to tensor and permute dimensions
        x = torch.as_tensor(x, dtype=torch.float).permute([3, 0, 1, 2])
        
        # Normalize
        x.sub_(mean).div_(std)
        
        # If not enough frames, add padding
        if x.shape[1] < frames_to_take:
            padding = torch.zeros(
                (
                    3,
                    frames_to_take - x.shape[1],
                    video_size,
                    video_size,
                ),
                dtype=torch.float,
            )
            x = torch.cat((x, padding), dim=1)
        
        # Apply stride and take required frames
        start = 0
        x = x[:, start: (start + frames_to_take): frame_stride, :, :]
        
        return x
    
    except Exception as e:
        print(f"Error processing {dicom_path}: {str(e)}")
        return None

## Find DICOM Files

Let's find DICOM files in the example directory.

In [None]:
# Path to example DICOM files
example_dir = "./data/example_study"

# Find DICOM files
dicom_paths = glob.glob(f"{example_dir}/**/*", recursive=True)
print(f"Found {len(dicom_paths)} DICOM files")

# Display the first few paths
for path in dicom_paths[:5]:
    print(f"- {path}")

## Process and Analyze a Single DICOM File

Let's process and analyze a single DICOM file to see how the model works.

In [None]:
# Process the first DICOM file
if len(dicom_paths) > 0:
    dicom_path = dicom_paths[0]
    print(f"Processing {dicom_path}...")
    
    # Process DICOM file
    video = process_dicom(dicom_path, save_mask_images=True)
    
    if video is not None:
        print(f"Processed video shape: {video.shape}")
        
        # Run inference
        with torch.no_grad():
            video_tensor = video.unsqueeze(0).to(device)  # Add batch dimension
            output = model(video_tensor)
            probability = torch.sigmoid(output).item()
            prediction = 1 if probability >= 0.3 else 0
            status = "PASS" if prediction > 0 else "FAIL"
            assessment = get_quality_issues(probability)
        
        print(f"\nQuality Assessment Results:")
        print(f"Score: {probability:.4f}")
        print(f"Status: {status}")
        print(f"Assessment: {assessment}")
        
        # Visualize the first frame
        plt.figure(figsize=(10, 6))
        frame = video.permute(1, 2, 3, 0)[0].cpu().numpy()  # Get first frame
        frame = (frame - frame.min()) / (frame.max() - frame.min())  # Normalize for display
        plt.imshow(frame)
        plt.title(f"First Frame - Quality Score: {probability:.4f} ({status})")
        plt.axis('off')
        plt.show()
        
        # Generate GradCAM visualization
        print("\nGenerating GradCAM visualization...")
        visualize_gradcam(
            model, 
            video, 
            target_layer_name="layer4", 
            save_path="./results/gradcam_visualization.png"
        )
        
        # Display GradCAM visualization
        plt.figure(figsize=(12, 8))
        img = plt.imread("./results/gradcam_visualization.png")
        plt.imshow(img)
        plt.axis('off')
        plt.title("GradCAM Visualization")
        plt.show()
    else:
        print("Failed to process DICOM file.")
else:
    print("No DICOM files found.")

## Process Multiple DICOM Files

Now, let's process all DICOM files in the example directory and analyze the results.

In [None]:
# Process all DICOM files
results = {}

for dicom_path in tqdm(dicom_paths, desc="Processing"):
    filename = os.path.basename(dicom_path)
    
    # Process DICOM file
    video = process_dicom(dicom_path)
    
    if video is None:
        print(f"Skipping {filename}: Processing failed")
        continue
    
    # Run inference
    with torch.no_grad():
        video_tensor = video.unsqueeze(0).to(device)  # Add batch dimension
        output = model(video_tensor)
        probability = torch.sigmoid(output).item()
        prediction = 1 if probability >= 0.3 else 0
        status = "PASS" if prediction > 0 else "FAIL"
        assessment = get_quality_issues(probability)
    
    # Store results
    results[filename] = {
        "score": probability,
        "status": status,
        "assessment": assessment
    }

# Display results
print("\nQuality Assessment Results:")
print("=" * 80)
print(f"{'Filename':<60} {'Score':<10} {'Pass/Fail':<10} {'Assessment'}")
print("-" * 80)

for filename, result in results.items():
    # Truncate filename if too long
    short_filename = filename[:57] + "..." if len(filename) > 60 else filename.ljust(60)
    print(f"{short_filename} {result['score']:.4f}    {result['status']:<10} {result['assessment']}")

# Summary statistics
pass_count = sum(1 for result in results.values() if result["status"] == "PASS")
total_count = len(results)
pass_rate = pass_count/total_count*100 if total_count > 0 else 0

print(f"\nSummary: {pass_count}/{total_count} videos passed quality check ({pass_rate:.1f}%)")

## Visualize Results

Let's create some visualizations of the results.

In [None]:
# Extract scores
scores = [result["score"] for result in results.values()]

# Create histogram
plt.figure(figsize=(10, 6))
plt.hist(scores, bins=20, alpha=0.7, color='blue')
plt.axvline(x=0.3, color='red', linestyle='--', label='Threshold (0.3)')
plt.xlabel('Quality Score')
plt.ylabel('Count')
plt.title('Distribution of Echo Quality Scores')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# Create pie chart of pass/fail
plt.figure(figsize=(8, 8))
plt.pie(
    [pass_count, total_count - pass_count], 
    labels=["PASS", "FAIL"], 
    autopct='%1.1f%%',
    colors=['#4CAF50', '#F44336'],
    explode=(0.1, 0)
)
plt.title('Pass/Fail Distribution')
plt.show()

## Save Results

Finally, let's save the results to a JSON file.

In [None]:
# Save results to JSON
with open("./results/quality_results.json", "w") as f:
    json.dump(results, f, indent=2)

print("Results saved to ./results/quality_results.json")

## Conclusion

In this notebook, we demonstrated how to use the EchoQuality model to assess the quality of echocardiogram videos. The model processes DICOM files, applies masking to isolate the ultrasound region, and classifies videos as PASS/FAIL with a threshold of 0.3.

The model can be used to automatically filter out low-quality echocardiogram videos, which can help improve the accuracy of downstream analysis and reduce the time spent on manual quality control.