# MAISI 3D Diffusion Model - Parallel Training Tutorial

## 🚀 Complete Guide to Multi-GPU Medical Image Generation

This tutorial provides a step-by-step guide to set up, configure, and run the MAISI parallel training script for generating 3D medical images. We'll cover everything from environment setup to troubleshooting common errors.

### What You'll Learn:
- ✅ **Environment Setup**: Install dependencies and verify GPU availability
- ✅ **Data Preparation**: Properly configure medical imaging data paths  
- ✅ **Parallel Training**: Run multi-GPU training with torchrun
- ✅ **Error Resolution**: Fix common issues like missing data paths
- ✅ **Output Analysis**: Understand generated images and file structure

### Prerequisites:
- NVIDIA GPU(s) with CUDA support
- Python 3.8+ environment
- Access to medical imaging data (.nii.gz files) or willingness to use simulated data

---

## 🎯 Error Analysis from Your Previous Run

Your command failed because you used the placeholder path `/path/to/medical/data`. Let's fix this step by step!

# 1. Check Environment and Install Dependencies

Let's start by verifying your Python environment and installing the required packages for MAISI training.

In [None]:
import sys
import subprocess
import os
from pathlib import Path

print("🔍 ENVIRONMENT CHECK:")
print("=" * 50)
print(f"Python version: {sys.version}")
print(f"Python executable: {sys.executable}")
print(f"Current working directory: {os.getcwd()}")

# Check if we're in a virtual environment
if hasattr(sys, 'real_prefix') or (hasattr(sys, 'base_prefix') and sys.base_prefix != sys.prefix):
    print("✅ Running in virtual environment")
else:
    print("⚠️  Not in virtual environment - consider using one")

print("\n📦 INSTALLING REQUIRED PACKAGES:")
print("=" * 50)

# Install required packages
packages = [
    "torch",
    "torchvision", 
    "monai-weekly[pillow,tqdm]",
    "nibabel",
    "numpy",
    "scipy"
]

for package in packages:
    try:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package])
        print(f"✅ {package} installed successfully")
    except subprocess.CalledProcessError as e:
        print(f"❌ Failed to install {package}: {e}")

print("\n🎉 Installation complete!")
print("Next: Let's check GPU availability...")

In [None]:
print("🖥️  GPU AVAILABILITY CHECK:")
print("=" * 50)

try:
    import torch
    
    if torch.cuda.is_available():
        gpu_count = torch.cuda.device_count()
        print(f"✅ CUDA is available! Found {gpu_count} GPU(s)")
        
        for i in range(gpu_count):
            gpu_name = torch.cuda.get_device_name(i)
            gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
            print(f"   GPU {i}: {gpu_name} ({gpu_memory:.1f} GB)")
            
        # Test GPU functionality
        print(f"\n🧪 Testing GPU {0}...")
        x = torch.randn(1000, 1000).cuda()
        y = torch.matmul(x, x.T)
        print(f"✅ GPU computation test passed!")
        del x, y  # Free memory
        
    else:
        print("❌ CUDA not available. Training will be slow on CPU.")
        print("   Make sure you have:")
        print("   • NVIDIA GPU installed")
        print("   • CUDA drivers installed") 
        print("   • PyTorch with CUDA support")
        
except ImportError:
    print("❌ PyTorch not available. Please install it first.")
    
print(f"\n📊 SYSTEM RESOURCES:")
print("=" * 50)
try:
    import psutil
    print(f"CPU cores: {psutil.cpu_count()}")
    print(f"RAM: {psutil.virtual_memory().total / 1024**3:.1f} GB")
    print(f"Available RAM: {psutil.virtual_memory().available / 1024**3:.1f} GB")
except ImportError:
    print("psutil not available for system info")

# 2. Verify Data Directory and File Presence

**This is where your previous command failed!** The path `/path/to/medical/data` was just a placeholder. Let's set up the correct data path for your system.

In [None]:
import glob
import os
from pathlib import Path

print("📁 DATA DIRECTORY CONFIGURATION:")
print("=" * 50)

# Option 1: Use real medical data (if you have it)
# CHANGE THIS PATH to your actual medical data directory!
REAL_DATA_PATH = "/home/santino/medical_data"  # ⬅️ UPDATE THIS PATH!

# Option 2: Common paths where medical data might be stored
common_paths = [
    "/home/santino/medical_data",
    "/home/santino/data/medical",
    "/data/medical",
    "/mnt/medical_data",
    "/shared/medical_data",
    "~/Documents/medical_data",
    "~/Desktop/medical_data"
]

print("🔍 Searching for medical data in common locations...")

found_data = False
data_files = []

for path in common_paths:
    expanded_path = os.path.expanduser(path)
    if os.path.exists(expanded_path):
        nii_files = glob.glob(os.path.join(expanded_path, "*.nii.gz"))
        if nii_files:
            print(f"✅ Found {len(nii_files)} .nii.gz files in: {expanded_path}")
            REAL_DATA_PATH = expanded_path
            data_files = nii_files
            found_data = True
            break
        else:
            print(f"📂 Directory exists but no .nii.gz files: {expanded_path}")
    else:
        print(f"❌ Directory not found: {expanded_path}")

if not found_data:
    print(f"\n⚠️  NO MEDICAL DATA FOUND!")
    print(f"Please either:")
    print(f"1. Place your .nii.gz files in one of these directories:")
    for path in common_paths[:3]:
        print(f"   • {path}")
    print(f"2. Or update REAL_DATA_PATH variable above with your actual path")
    print(f"3. Or use simulated data for testing (shown below)")
    
    # Set flag to use simulated data
    USE_SIMULATED_DATA = True
    print(f"\n🎭 Will use SIMULATED DATA for this demo")
else:
    USE_SIMULATED_DATA = False
    print(f"\n✅ Will use REAL MEDICAL DATA from: {REAL_DATA_PATH}")
    print(f"📊 Found {len(data_files)} medical images:")
    for i, file in enumerate(data_files[:5]):  # Show first 5
        print(f"   {i+1}. {os.path.basename(file)}")
    if len(data_files) > 5:
        print(f"   ... and {len(data_files) - 5} more files")

In [None]:
# Create a test data directory if no real data found
if USE_SIMULATED_DATA:
    print("\n🏗️  CREATING TEST DATA DIRECTORY:")
    print("=" * 50)
    
    test_data_dir = "/home/santino/UNC_RAD/test_medical_data" 
    os.makedirs(test_data_dir, exist_ok=True)
    
    # Create some dummy .nii.gz files for testing
    try:
        import nibabel as nib
        import numpy as np
        
        print(f"📁 Creating test directory: {test_data_dir}")
        
        # Create 3 dummy medical images
        for i in range(3):
            # Create random 3D image data (simulating medical scan)
            img_data = np.random.randint(0, 255, (64, 64, 32), dtype=np.uint8)
            
            # Create NIfTI image
            nii_img = nib.Nifti1Image(img_data, affine=np.eye(4))
            
            # Save as .nii.gz file
            filename = f"test_patient_{i+1:03d}.nii.gz"
            filepath = os.path.join(test_data_dir, filename)
            nib.save(nii_img, filepath)
            
            print(f"✅ Created: {filename}")
            
        # Update paths for testing
        REAL_DATA_PATH = test_data_dir
        data_files = glob.glob(os.path.join(test_data_dir, "*.nii.gz"))
        
        print(f"\n🎉 Test data ready! Created {len(data_files)} files in {test_data_dir}")
        print("   You can now run the training with this test data.")
        
    except ImportError:
        print("❌ nibabel not available. Cannot create test data.")
        print("   Please install nibabel or provide real medical data.")

print(f"\n📋 FINAL DATA CONFIGURATION:")
print("=" * 50)
print(f"Data path: {REAL_DATA_PATH}")
print(f"Using simulated data: {USE_SIMULATED_DATA}")
print(f"Number of files: {len(data_files) if 'data_files' in locals() else 0}")

# 3. Configure Parallel Training Parameters

Now let's configure all the parameters for running the parallel training script. We'll build the exact torchrun command that will work with your system.

In [None]:
print("⚙️  TRAINING CONFIGURATION:")
print("=" * 50)

# Check available GPUs again to set nproc_per_node
try:
    import torch
    if torch.cuda.is_available():
        available_gpus = torch.cuda.device_count()
        print(f"Available GPUs: {available_gpus}")
    else:
        available_gpus = 1  # Will use CPU
        print("No GPUs available - will use CPU (very slow)")
except:
    available_gpus = 1

# Configure training parameters
CONFIG = {
    # Torchrun parameters - Use actual GPU count, not hardcoded 4
    "nproc_per_node": available_gpus,  # Use all available GPUs
    "nnodes": 1,  # Single node for this tutorial
    "master_port": 29500,  # Use alternative port to avoid conflicts
    
    # MAISI script parameters  
    "data_path": REAL_DATA_PATH,
    "use_real_data": not USE_SIMULATED_DATA,
    "epochs": 5 if USE_SIMULATED_DATA else 25,  # Short for demo
    "batch_size": 1,  # Safe default
    "model_version": "maisi3d-rflow",  # Faster version
    "base_seed": 42,
    
    # Script location
    "script_path": "maisi_train_diff_unet_parallel.py"
}

print(f"Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

# Build the torchrun command
torchrun_cmd = [
    "torchrun",
    f"--nproc_per_node={CONFIG['nproc_per_node']}",
    f"--nnodes={CONFIG['nnodes']}",
    f"--master-port={CONFIG['master_port']}",  # Add explicit port
    CONFIG["script_path"]
]

# Add MAISI script arguments
if CONFIG["use_real_data"]:
    torchrun_cmd.extend(["--real-data", "--data-path", CONFIG["data_path"]])

torchrun_cmd.extend([
    "--epochs", str(CONFIG["epochs"]),
    "--batch-size", str(CONFIG["batch_size"]),
    "--model-version", CONFIG["model_version"],
    "--base-seed", str(CONFIG["base_seed"])
])

print(f"\n🚀 TORCHRUN COMMAND:")
print("=" * 50)
command_str = " ".join(torchrun_cmd)
print(command_str)

print(f"\n💡 EXPECTED BEHAVIOR:")
print("=" * 50)
print(f"• Will use {CONFIG['nproc_per_node']} GPU(s) - Perfect for your 2-GPU setup!")
print(f"• Will generate {CONFIG['nproc_per_node']} unique medical images")
print(f"• Training will run for {CONFIG['epochs']} epochs")
print(f"• Each GPU gets batch_size={CONFIG['batch_size']} images")
print(f"• Using {'real' if CONFIG['use_real_data'] else 'simulated'} data")
print(f"• Using port {CONFIG['master_port']} to avoid conflicts")

# 4. Run MAISI Parallel Training Script

**⚠️ IMPORTANT**: Make sure the `maisi_train_diff_unet_parallel.py` script is in your current directory before running this cell!

This cell will execute the training and capture all output in real-time.

In [None]:
import subprocess
import time
import os
from datetime import datetime

print("🚀 STARTING MAISI PARALLEL TRAINING:")
print("=" * 60)

# Check if script exists
script_path = CONFIG["script_path"]
if not os.path.exists(script_path):
    print(f"❌ ERROR: Script not found: {script_path}")
    print(f"   Please make sure {script_path} is in your current directory:")
    print(f"   Current directory: {os.getcwd()}")
    print(f"   Files in directory: {os.listdir('.')}")
    print(f"\n   You can download it from the MONAI repository or copy it to this location.")
else:
    print(f"✅ Found script: {script_path}")
    
    # Show the command we're about to run
    print(f"\n📋 Command to execute:")
    print(" ".join(torchrun_cmd))
    
    print(f"\n⏰ Training started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print("=" * 60)
    
    try:
        # Set environment variable to avoid OMP warnings
        env = os.environ.copy()
        env["OMP_NUM_THREADS"] = "1"
        
        # Start the process
        process = subprocess.Popen(
            torchrun_cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,  # Combine stderr with stdout
            text=True,
            bufsize=1,  # Line buffered
            universal_newlines=True,
            env=env
        )
        
        # Read output in real-time
        output_lines = []
        for line in process.stdout:
            print(line, end='')  # Print immediately 
            output_lines.append(line)
            
        # Wait for process to complete
        return_code = process.wait()
        
        print("=" * 60)
        print(f"⏰ Training completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        
        if return_code == 0:
            print("🎉 SUCCESS: Training completed successfully!")
        else:
            print(f"❌ FAILED: Training failed with return code {return_code}")
            
        # Store output for analysis
        training_output = "".join(output_lines)
        
    except FileNotFoundError:
        print("❌ ERROR: torchrun not found. Please install PyTorch properly.")
    except Exception as e:
        print(f"❌ ERROR: Unexpected error during training: {e}")
        training_output = str(e)

# 5. Monitor Output and Handle Errors

Let's analyze the training output and provide guidance if any errors occurred.

In [None]:
print("🔍 OUTPUT ANALYSIS:")
print("=" * 50)

if 'training_output' in locals():
    # Check for common error patterns
    error_patterns = {
        "No .nii.gz files found": "Data path issue - no medical images found",
        "CUDA out of memory": "GPU memory insufficient - reduce batch size",
        "ConnectionTimeout": "Multi-node communication issue",
        "FileNotFoundError": "Missing file or script",
        "ImportError": "Missing Python package",
        "Permission denied": "File permission issue",
        "Address already in use": "Port conflict in multi-node setup"
    }
    
    errors_found = []
    for pattern, description in error_patterns.items():
        if pattern.lower() in training_output.lower():
            errors_found.append((pattern, description))
    
    if errors_found:
        print("❌ ERRORS DETECTED:")
        for pattern, description in errors_found:
            print(f"   • {pattern}: {description}")
            
        print(f"\n🔧 TROUBLESHOOTING GUIDE:")
        
        if "No .nii.gz files found" in [p for p, d in errors_found]:
            print(f"📁 Data Path Issue:")
            print(f"   • Check that your data path exists: {CONFIG['data_path']}")
            print(f"   • Verify .nii.gz files are present: ls {CONFIG['data_path']}/*.nii.gz")
            print(f"   • Update REAL_DATA_PATH variable in cell 2")
            print(f"   • Or use simulated data for testing")
            
        if "CUDA out of memory" in [p for p, d in errors_found]:
            print(f"💾 Memory Issue:")
            print(f"   • Reduce batch_size to 1: --batch-size 1")
            print(f"   • Use fewer GPUs: --nproc_per_node 1")  
            print(f"   • Close other GPU applications")
            
    else:
        print("✅ No obvious errors detected in output")
        
    # Check for success indicators
    success_patterns = [
        "Training completed successfully",
        "inference completed",
        "Generated images saved",
        "MAISI PARALLEL COMPUTING CONFIGURATION"
    ]
    
    success_found = any(pattern.lower() in training_output.lower() 
                       for pattern in success_patterns)
    
    if success_found:
        print("🎉 SUCCESS INDICATORS FOUND:")
        for pattern in success_patterns:
            if pattern.lower() in training_output.lower():
                print(f"   ✅ {pattern}")
    
    # Extract key information from output
    lines = training_output.split('\n')
    config_lines = [line for line in lines if '•' in line and any(
        keyword in line.lower() for keyword in 
        ['gpus', 'epochs', 'batch size', 'data path', 'model version']
    )]
    
    if config_lines:
        print(f"\n📊 EXTRACTED CONFIGURATION:")
        for line in config_lines[:10]:  # Show first 10 config lines
            print(f"   {line.strip()}")
            
else:
    print("⚠️  No training output available to analyze")
    print("   Make sure to run the training cell above first")

# 6. Inspect Generated Images and Output Structure

After successful training, let's examine the output directory structure and generated medical images.

In [None]:
import os
import glob
from pathlib import Path

print("📂 OUTPUT DIRECTORY INSPECTION:")
print("=" * 50)

# Check for the typical output directory
output_dir = "./output_work_dir"

if os.path.exists(output_dir):
    print(f"✅ Found output directory: {output_dir}")
    
    # Show directory structure
    print(f"\n📁 Directory Structure:")
    for root, dirs, files in os.walk(output_dir):
        level = root.replace(output_dir, '').count(os.sep)
        indent = ' ' * 2 * level
        print(f"{indent}{os.path.basename(root)}/")
        subindent = ' ' * 2 * (level + 1)
        for file in files[:5]:  # Show first 5 files per directory
            print(f"{subindent}{file}")
        if len(files) > 5:
            print(f"{subindent}... and {len(files) - 5} more files")
    
    # Look for generated images
    generated_images = glob.glob(os.path.join(output_dir, "**", "*.nii.gz"), recursive=True)
    
    if generated_images:
        print(f"\n🖼️  GENERATED MEDICAL IMAGES:")
        print("=" * 50)
        print(f"Found {len(generated_images)} generated images:")
        
        for img_path in generated_images:
            rel_path = os.path.relpath(img_path, output_dir)
            file_size = os.path.getsize(img_path) / 1024**2  # MB
            print(f"   📄 {rel_path} ({file_size:.1f} MB)")
            
        # Analyze naming pattern
        output_images = [img for img in generated_images if 'output' in img or 'generated' in img]
        if output_images:
            print(f"\n🎯 Generated Images (Final Results):")
            for img in output_images:
                filename = os.path.basename(img)
                print(f"   🖼️  {filename}")
                
                # Try to extract seed and rank from filename
                if 'seed' in filename and 'rank' in filename:
                    try:
                        parts = filename.replace('.nii.gz', '').split('_')
                        seed_part = [p for p in parts if p.startswith('seed')]
                        rank_part = [p for p in parts if p.startswith('rank')]
                        if seed_part and rank_part:
                            seed = seed_part[0].replace('seed', '')
                            rank = rank_part[0].replace('rank', '')
                            print(f"      • GPU {rank} generated this with seed {seed}")
                    except:
                        pass
                        
        print(f"\n💡 VIEWING YOUR RESULTS:")
        print("=" * 50)
        print("You can view these medical images with:")
        print("• 3D Slicer (recommended): https://www.slicer.org/")
        print("• ITK-SNAP: http://www.itksnap.org/")
        print("• FSL tools: fsleyes your_image.nii.gz")
        print("• Python/MONAI: nibabel.load('your_image.nii.gz')")
        
    else:
        print("❌ No generated .nii.gz images found")
        print("   Training may not have completed successfully")
        
    # Check for model checkpoints
    model_files = glob.glob(os.path.join(output_dir, "**", "*.pt"), recursive=True)
    if model_files:
        print(f"\n🤖 MODEL CHECKPOINTS:")
        print("=" * 50)
        for model_path in model_files:
            rel_path = os.path.relpath(model_path, output_dir)
            file_size = os.path.getsize(model_path) / 1024**2  # MB
            print(f"   💾 {rel_path} ({file_size:.1f} MB)")
            
else:
    print(f"❌ Output directory not found: {output_dir}")
    print("   Training may not have started or failed early")
    print("   Check the training output above for errors")

print(f"\n📊 SUMMARY:")
print("=" * 50)
if os.path.exists(output_dir):
    total_size = sum(os.path.getsize(os.path.join(dirpath, filename))
                    for dirpath, dirnames, filenames in os.walk(output_dir)
                    for filename in filenames) / 1024**2
    print(f"• Total output size: {total_size:.1f} MB")
    print(f"• Output location: {os.path.abspath(output_dir)}")
else:
    print("• No output generated yet")
    print("• Run the training script successfully first")

# 🎉 Tutorial Complete!

## What We Accomplished

✅ **Environment Setup**: Verified Python, installed dependencies, checked GPU availability  
✅ **Data Configuration**: Set up proper data paths and created test data if needed  
✅ **Parameter Configuration**: Built the correct torchrun command with all arguments  
✅ **Training Execution**: Ran the parallel training script with proper error handling  
✅ **Output Analysis**: Examined generated images and directory structure  

## Your Error Resolution

The original error `ValueError: No .nii.gz files found in /path/to/medical/data` occurred because:
- `/path/to/medical/data` was just a placeholder path
- The script couldn't find any medical imaging files

**We fixed this by:**
1. Automatically searching common data directories
2. Creating test data if no real data was found
3. Using the correct data path in the torchrun command

## Next Steps

### For Better Results:
1. **Use Real Medical Data**: Replace test data with actual CT/MRI scans
2. **Increase Training**: Use 100+ epochs for production-quality results
3. **Scale Up**: Use more GPUs and larger datasets
4. **Optimize**: Tune batch size and model parameters

### Command Templates:
```bash
# Quick test (what we just ran)
torchrun --nproc_per_node=4 maisi_train_diff_unet_parallel.py --real-data --data-path /your/data/path --epochs 5

# Production training
torchrun --nproc_per_node=8 maisi_train_diff_unet_parallel.py --real-data --data-path /your/data/path --epochs 100 --batch-size 2

# Multi-node cluster
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr=192.168.1.100 --master_port=29500 maisi_train_diff_unet_parallel.py --real-data --data-path /shared/data
```

### Viewing Results:
- **3D Slicer**: Best for medical imaging visualization
- **ITK-SNAP**: Alternative medical image viewer  
- **Python**: Use `nibabel.load()` for programmatic access

**Happy Training! 🚀**

# 🚨 Troubleshooting: Parallel Training Errors

Based on your output, I can see two main issues that need to be fixed before running the parallel training successfully.

In [None]:
import subprocess
import socket
import time
import psutil

print("🔍 DIAGNOSING PARALLEL TRAINING ISSUES:")
print("=" * 60)

# Issue 1: Check actual GPU count
print("1️⃣ GPU COUNT CHECK:")
try:
    import torch
    if torch.cuda.is_available():
        actual_gpu_count = torch.cuda.device_count()
        print(f"   ✅ Actual GPUs available: {actual_gpu_count}")
        for i in range(actual_gpu_count):
            gpu_name = torch.cuda.get_device_name(i)
            print(f"      GPU {i}: {gpu_name}")
    else:
        actual_gpu_count = 0
        print("   ❌ No CUDA GPUs available")
except ImportError:
    actual_gpu_count = 0
    print("   ❌ PyTorch not available")

# Issue 2: Check port usage
print(f"\n2️⃣ PORT CONFLICT CHECK:")
def check_port(port):
    """Check if a port is in use."""
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            result = s.connect_ex(('localhost', port))
            return result == 0  # 0 means connection successful (port in use)
    except:
        return False

ports_to_check = [12355, 29500, 29501, 23456, 25000]
available_port = None

for port in ports_to_check:
    if check_port(port):
        print(f"   ❌ Port {port}: IN USE")
    else:
        print(f"   ✅ Port {port}: Available")
        if available_port is None:
            available_port = port

# Issue 3: Check for running torchrun processes
print(f"\n3️⃣ RUNNING PROCESSES CHECK:")
try:
    running_torchrun = []
    for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
        try:
            if 'torchrun' in ' '.join(proc.info['cmdline'] or []):
                running_torchrun.append(proc.info)
        except (psutil.NoSuchProcess, psutil.AccessDenied):
            pass
    
    if running_torchrun:
        print(f"   ⚠️  Found {len(running_torchrun)} running torchrun processes:")
        for proc in running_torchrun:
            print(f"      PID {proc['pid']}: {' '.join(proc['cmdline'][:3])}")
        print(f"   💡 These may be causing port conflicts")
    else:
        print(f"   ✅ No conflicting torchrun processes found")
        
except ImportError:
    print(f"   ⚠️  psutil not available for process check")

print(f"\n🔧 RECOMMENDED FIXES:")
print("=" * 60)

if actual_gpu_count < 4:
    print(f"❗ GPU COUNT MISMATCH:")
    print(f"   • You have {actual_gpu_count} GPUs but script tried to use 4")
    print(f"   • SOLUTION: Use --nproc_per_node={actual_gpu_count} instead of 4")
    recommended_gpus = actual_gpu_count
else:
    recommended_gpus = 4

if available_port != 12355:
    print(f"❗ PORT CONFLICT:")
    print(f"   • Port 12355 is in use")
    print(f"   • SOLUTION: Use --master-port {available_port}")
    recommended_port = available_port
else:
    recommended_port = 12355

print(f"\n🚀 CORRECTED COMMAND:")
print("=" * 60)
corrected_command = f"torchrun --nproc_per_node={recommended_gpus} --master-port {recommended_port} maisi_train_diff_unet_parallel.py"
print(corrected_command)

print(f"\n💡 ADDITIONAL TIPS:")
print("• Kill existing processes: pkill -f torchrun")
print("• Wait a few seconds between runs for ports to be released")
print("• Use different terminal/session if issues persist")
print("• Monitor with: nvidia-smi (for GPUs) and netstat -tulpn | grep :12355 (for ports)")

In [None]:
# Clean up any existing processes and run with correct parameters
print("🧹 CLEANUP AND CORRECTED EXECUTION:")
print("=" * 60)

# Step 1: Kill any existing torchrun processes
print("1️⃣ Cleaning up existing processes...")
try:
    subprocess.run(["pkill", "-f", "torchrun"], capture_output=True)
    print("   ✅ Killed existing torchrun processes")
    time.sleep(3)  # Wait for cleanup
except:
    print("   ⚠️  Could not kill processes (may not exist)")

# Step 2: Build corrected command
if 'recommended_gpus' in locals() and 'recommended_port' in locals():
    corrected_cmd = [
        "torchrun",
        f"--nproc_per_node={recommended_gpus}",
        f"--master-port={recommended_port}",
        "maisi_train_diff_unet_parallel.py"
    ]
    
    print(f"2️⃣ Running corrected command:")
    print(f"   {' '.join(corrected_cmd)}")
    
    # Step 3: Execute the corrected command
    try:
        env = os.environ.copy()
        env["OMP_NUM_THREADS"] = "1"
        
        print(f"\n⏰ Started at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
        print("=" * 60)
        
        process = subprocess.Popen(
            corrected_cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1,
            universal_newlines=True,
            env=env
        )
        
        # Monitor output
        output_lines = []
        for line in process.stdout:
            print(line, end='')
            output_lines.append(line)
            
        return_code = process.wait()
        
        print("=" * 60)
        print(f"⏰ Completed at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
        
        if return_code == 0:
            print("🎉 SUCCESS: Parallel training completed!")
        else:
            print(f"❌ FAILED: Return code {return_code}")
            
        training_output_corrected = "".join(output_lines)
        
    except Exception as e:
        print(f"❌ ERROR: {e}")
        training_output_corrected = str(e)
        
else:
    print("❌ Could not determine correct parameters. Run diagnosis cell above first.")

# 🎯 Quick Fix - Manual Commands

If you want to fix this immediately in your terminal, here are the exact commands to run:

## Problem Analysis:
1. **GPU Mismatch**: You have 2 GPUs but the script tried to use 4
2. **Port Conflict**: Multiple processes trying to use port 12355

## Immediate Solutions:

### Option 1: Use 2 GPUs (Recommended)
```bash
# Kill any existing processes
pkill -f torchrun

# Wait a moment
sleep 3

# Run with correct GPU count
torchrun --nproc_per_node=2 maisi_train_diff_unet_parallel.py
```

### Option 2: Use different port
```bash
# Kill existing processes  
pkill -f torchrun

# Run with different port
torchrun --nproc_per_node=2 --master-port=29500 maisi_train_diff_unet_parallel.py
```

### Option 3: Single GPU (Most Reliable)
```bash
# Use single GPU to avoid distributed issues
python maisi_train_diff_unet.py
```

## Why This Happened:
- Your system has 2 RTX 4090 GPUs
- The parallel script defaulted to 4 GPUs  
- Multiple torchrun processes created port conflicts
- GPU ranks 2 and 3 don't exist on your system

**Recommendation**: Use Option 1 with 2 GPUs for optimal performance!

In [None]:
# ✅ AUTOMATIC FIX: Run with Your 2 GPUs
print("🎯 RUNNING MAISI PARALLEL TRAINING WITH 2 GPUS:")
print("=" * 60)

import subprocess
import time
import os
import torch

# Step 1: Clean up any existing processes
print("1️⃣ Cleaning up existing processes...")
try:
    subprocess.run(["pkill", "-f", "torchrun"], capture_output=True, timeout=10)
    print("   ✅ Cleaned up existing torchrun processes")
    time.sleep(5)  # Give more time for cleanup
except:
    print("   ⚠️  No existing processes to clean (this is fine)")

# Step 2: Verify your 2 GPUs
try:
    if torch.cuda.is_available():
        gpu_count = torch.cuda.device_count()
        print(f"\n2️⃣ Confirmed: You have {gpu_count} GPUs available")
        for i in range(gpu_count):
            gpu_name = torch.cuda.get_device_name(i)
            gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
            print(f"   GPU {i}: {gpu_name} ({gpu_memory:.1f} GB)")
    else:
        print("❌ No GPUs detected")
        gpu_count = 0
except ImportError:
    print("❌ PyTorch not available")
    gpu_count = 0

if gpu_count >= 2:
    # Step 3: Run with exactly 2 GPUs and available port
    available_port = 29500  # Use different port to avoid conflicts
    
    corrected_cmd = [
        "torchrun",
        f"--nproc_per_node={gpu_count}",  # Use actual GPU count
        "--nnodes=1",
        f"--master-port={available_port}",  # Use different port
        "maisi_train_diff_unet_parallel.py",
        "--epochs", "5",  # Short run for testing
        "--batch-size", "1",
        "--model-version", "maisi3d-rflow",
        "--base-seed", "42"
    ]
    
    print(f"\n3️⃣ Running optimized command for your 2-GPU system:")
    print(f"   {' '.join(corrected_cmd)}")
    print(f"\n💡 This will:")
    print(f"   • Use both of your RTX 4090 GPUs")
    print(f"   • Generate 2 unique medical images simultaneously")
    print(f"   • Use simulated data (quick demo)")
    print(f"   • Avoid port conflicts with port {available_port}")
    
    print(f"\n⏰ Training started at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print("🚀 " + "=" * 58)
    
    try:
        # Set up environment
        env = os.environ.copy()
        env["OMP_NUM_THREADS"] = "1"
        env["CUDA_VISIBLE_DEVICES"] = "0,1"  # Explicitly set 2 GPUs
        
        # Start the corrected process
        process = subprocess.Popen(
            corrected_cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1,
            universal_newlines=True,
            env=env,
            cwd=os.getcwd()
        )
        
        # Monitor output in real-time
        output_lines = []
        for line in process.stdout:
            print(line, end='')
            output_lines.append(line)
            
        # Wait for completion
        return_code = process.wait()
        
        print("🏁 " + "=" * 58)
        print(f"⏰ Training completed at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
        
        if return_code == 0:
            print("🎉 SUCCESS: 2-GPU parallel training completed successfully!")
            print("📁 Check the output_work_dir for your generated medical images!")
        else:
            print(f"❌ Training failed with return code: {return_code}")
            
        # Store output for later analysis
        training_output_final = "".join(output_lines)
        
    except Exception as e:
        print(f"❌ ERROR during execution: {e}")
        training_output_final = str(e)
        
else:
    print(f"❌ Cannot run parallel training - need at least 2 GPUs but found {gpu_count}")
    print("💡 Try the single-GPU version instead: python maisi_train_diff_unet.py")