# AR-SSL4M Pretraining on Google Colab

This notebook runs AR-SSL4M (Autoregressive Sequence Modeling for 3D Medical Image Representation) pretraining on Google Colab.

## Requirements
- GPU runtime (T4, V100, or A100)
- High RAM runtime (recommended)

## Dataset
- Using STOIC dataset (2771 samples)
- Each sample: 128×128×128 3D medical images


## 1. Environment Setup

**⚠️ 重要提示**: 如果遇到依赖冲突错误，请按以下步骤操作：
1. 重启运行时 (Runtime → Restart Runtime)
2. 重新运行所有cells
3. 如果仍有问题，请使用 "Factory Reset Runtime"


In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ No GPU detected! Please enable GPU runtime in Colab.")


In [None]:
# Install required packages with build error fixes
print("📦 Installing core dependencies (avoiding build issues)...")

# Upgrade pip and build tools
!pip install -q --upgrade pip setuptools wheel

# Install PyTorch first (most stable)
print("🔥 Installing PyTorch...")
!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu118

# Install transformers with pre-built wheels to avoid tokenizers build issues
print("🤗 Installing Transformers...")
!pip install -q --only-binary=all transformers==4.37.0

# Install MONAI core without optional dependencies that cause build issues
print("🏥 Installing MONAI (core only)...")
!pip install -q monai==1.3.0

# Install essential dependencies
print("📚 Installing other dependencies...")
!pip install -q fire
!pip install -q scikit-learn
!pip install -q tqdm
!pip install -q PyYAML
!pip install -q packaging
!pip install -q nibabel

# Try SimpleITK separately (sometimes causes issues)
try:
    import SimpleITK
    print("✅ SimpleITK already available")
except ImportError:
    print("📥 Installing SimpleITK...")
    !pip install -q SimpleITK

print("✅ Core dependencies installed successfully!")
print("⚠️ Note: Some optional MONAI features may not be available, but core functionality should work.")


In [None]:
# Verify installations with detailed error handling
print("🔍 Verifying installations...")

packages_status = {}

# Test each package individually
try:
    import torch
    packages_status['PyTorch'] = f"✅ {torch.__version__}"
    print(f"✅ PyTorch: {torch.__version__} (CUDA: {torch.cuda.is_available()})")
except ImportError as e:
    packages_status['PyTorch'] = f"❌ Failed: {e}"
    print(f"❌ PyTorch import failed: {e}")

try:
    import transformers
    packages_status['Transformers'] = f"✅ {transformers.__version__}"
    print(f"✅ Transformers: {transformers.__version__}")
except ImportError as e:
    packages_status['Transformers'] = f"❌ Failed: {e}"
    print(f"❌ Transformers import failed: {e}")

try:
    import monai
    packages_status['MONAI'] = f"✅ {monai.__version__}"
    print(f"✅ MONAI: {monai.__version__}")
except ImportError as e:
    packages_status['MONAI'] = f"❌ Failed: {e}"
    print(f"❌ MONAI import failed: {e}")

try:
    import fire, sklearn, tqdm
    print(f"✅ Other packages: fire, sklearn, tqdm")
    packages_status['Others'] = "✅ OK"
except ImportError as e:
    packages_status['Others'] = f"❌ Failed: {e}"
    print(f"❌ Other packages import failed: {e}")

# Check if all critical packages are working
critical_failed = [k for k, v in packages_status.items() if "❌" in v and k in ['PyTorch', 'Transformers', 'MONAI']]

if not critical_failed:
    print(f"\n🎉 All critical dependencies verified successfully!")
else:
    print(f"\n⚠️ Failed packages: {', '.join(critical_failed)}")
    print("💡 Solutions:")
    print("1. 如果遇到编译错误，运行 Cell 6 (最简化安装)")
    print("2. 如果是MONAI问题，运行 Cell 8 (MONAI专用修复)")
    print("3. 重启runtime后重试")
    print("4. 使用 'Factory Reset Runtime' 如果问题持续")


### 🔧 编译错误解决方案

如果遇到以下错误：
- `Building wheel for tokenizers failed`
- `Building wheel for openslide-python failed` 
- `subprocess-exited-with-error`

**原因**: Colab环境缺少某些编译工具或依赖版本不兼容

**解决方案**: 运行下面的最简化安装cell，它只使用预编译的wheels：


In [None]:
# 最简化安装方案 - 如果遇到编译错误，运行此cell
print("🔄 使用最简化安装方案（避免所有编译问题）...")

# Use only pre-built wheels, no compilation
!pip install --only-binary=all --force-reinstall torch torchvision
!pip install --only-binary=all --force-reinstall transformers==4.37.0  
!pip install --only-binary=all --force-reinstall monai==1.3.0
!pip install --only-binary=all fire scikit-learn tqdm PyYAML packaging nibabel

print("✅ 最简化安装完成！")
print("⚠️ 注意：某些高级功能可能不可用，但核心训练功能应该正常工作。")


### MONAI专用修复 (如果MONAI导入失败)

如果遇到 `'FileFinder' object has no attribute 'find_module'` 错误，运行下面的cell：


In [None]:
# MONAI Python 3.12 兼容性修复
print("🔧 修复MONAI Python 3.12兼容性问题...")

# Uninstall and reinstall MONAI with specific fixes
!pip uninstall -y monai
!pip install --no-cache-dir --force-reinstall monai[all]==1.3.0

# Alternative: Install from source if needed
# !pip install git+https://github.com/Project-MONAI/MONAI.git

print("✅ MONAI修复完成！请重启runtime并重新验证。")


## 2. Clone Repository and Setup


In [None]:
# Clone the repository
import os
repo_url = "https://github.com/tanglehunter00/AR-SSL4M-DEMO.git"

if os.path.exists("AR-SSL4M-DEMO"):
    print("Repository already exists, pulling latest changes...")
    !cd AR-SSL4M-DEMO && git pull
else:
    print("Cloning repository...")
    !git clone {repo_url}

# Change to project directory
os.chdir("AR-SSL4M-DEMO")
print(f"✅ Current directory: {os.getcwd()}")


## 3. Data Setup

根据您的截图，数据集位于Google Drive的 `dataset/compressed_datasets/volumes` 路径中。


In [None]:
# Mount Google Drive to access your data
from google.colab import drive
drive.mount('/content/drive')

# Set path to your data in Google Drive (based on your screenshot)
data_path = "/content/drive/MyDrive/dataset/compressed_datasets/volumes"

# Check if data exists
if os.path.exists(data_path):
    npy_files = [f for f in os.listdir(data_path) if f.endswith('.npy')]
    print(f"✅ Found {len(npy_files)} .npy files in {data_path}")
    if len(npy_files) == 0:
        print("⚠️ No .npy files found in the directory")
        print("Available files:")
        all_files = os.listdir(data_path)[:10]  # Show first 10 files
        for f in all_files:
            print(f"  - {f}")
else:
    print(f"❌ Data path not found: {data_path}")
    print("Available paths in Google Drive:")
    try:
        drive_contents = os.listdir("/content/drive/MyDrive")
        for item in drive_contents[:10]:
            print(f"  - /content/drive/MyDrive/{item}")
    except:
        print("Unable to list drive contents")


In [None]:
# Create data list file
import glob

# Try multiple possible data paths
possible_paths = [
    "/content/drive/MyDrive/dataset/compressed_datasets/volumes",
    "/content/drive/MyDrive/compressed_datasets/volumes", 
    "/content/drive/MyDrive/dataset/volumes",
    "/content/drive/MyDrive/volumes"
]

npy_files = []
actual_data_path = None

for path in possible_paths:
    if os.path.exists(path):
        files = glob.glob(os.path.join(path, "*.npy"))
        if files:
            npy_files = files
            actual_data_path = path
            print(f"✅ Found {len(npy_files)} .npy files in {actual_data_path}")
            break
        else:
            print(f"⚠️ Path exists but no .npy files found: {path}")

if not npy_files:
    print("❌ No .npy files found in any of the expected paths")
    print("Please check your Google Drive structure and update the paths accordingly")
else:
    # Create data list file
    data_list_path = "pretrain/colab_data_list.txt"
    with open(data_list_path, 'w') as f:
        for npy_file in npy_files:
            f.write(f"{npy_file}\n")
    
    print(f"✅ Created data list: {data_list_path} with {len(npy_files)} files")


## 4. Update Configuration for Colab


In [None]:
# Update dataset configuration
config_content = '''from dataclasses import dataclass


@dataclass
class custom_dataset:
    dataset: str = "custom_dataset"
    file: str = "image_dataset.py"
    train_split: str = "train"
    test_split: str = "validation"
    spatial_path: str = "colab_data_list.txt"
    contrast_path: str = "colab_data_list.txt"
    semantic_path: str = "colab_data_list.txt"
    img_size = [128, 128, 128]
    patch_size = [16, 16, 16]
    attention_type = 'prefix'
    add_series_data = False
    add_spatial_data = True
    is_subset = False
    series_length = 4
'''

with open('pretrain/configs/datasets.py', 'w') as f:
    f.write(config_content)

print("✅ Updated dataset configuration")


In [None]:
# Update training configuration for Colab
training_config_content = '''from dataclasses import dataclass


@dataclass
class train_config:
    enable_fsdp: bool=False
    low_cpu_fsdp: bool=False
    run_validation: bool=True
    batch_size_training: int=4  # Adjusted for Colab GPU memory
    batching_strategy: str="padding"
    gradient_accumulation_steps: int=1
    gradient_clipping: bool=False
    gradient_clipping_threshold: float = 1.0
    num_epochs: int=1  # Single epoch for Colab
    warmup_epochs:int=0
    num_workers_dataloader: int=0  # Set to 0 to avoid multiprocessing issues in Colab
    lr: float=1e-4
    weight_decay: float=0.01
    gamma: float=0.1
    seed: int=42
    use_fp16: bool=True  # Enable FP16 for Colab GPU memory efficiency
    mixed_precision: bool=True
    val_batch_size: int=1
    dataset = "custom_dataset"
    output_dir: str="/content/AR-SSL4M-DEMO/pretrain/save"
    freeze_layers: bool=False
    num_freeze_layers: int=1
    save_model: bool=True
    save_optimizer: bool=False
    save_metrics: bool=True
    scheduler:str='CosineLR'
    min_lr: float=0
    pos_type: str='sincos3d'
    norm_pixel_loss: bool=True
    enable_profiling: bool=False  # Disable profiling for cleaner output
'''

with open('pretrain/configs/training.py', 'w') as f:
    f.write(training_config_content)

print("✅ Updated training configuration for Colab")


## 5. Start Pretraining


In [None]:
# Create save directory
os.makedirs("pretrain/save", exist_ok=True)
print("✅ Created save directory")


In [None]:
# Start pretraining
print("🚀 Starting AR-SSL4M Pretraining...")
print("=" * 60)

# Change to pretrain directory and run
os.chdir("pretrain")

# Run the training script
!python main.py --output_dir save --batch_size_training 4

print("=" * 60)
print("✅ Pretraining completed!")


## 6. Check Results and Download Model


In [None]:
# Check training results
import os
import glob

save_dir = "save"
if os.path.exists(save_dir):
    files = os.listdir(save_dir)
    print(f"📁 Files in save directory:")
    for file in files:
        file_path = os.path.join(save_dir, file)
        if os.path.isfile(file_path):
            size_mb = os.path.getsize(file_path) / (1024 * 1024)
            print(f"  - {file} ({size_mb:.1f} MB)")
else:
    print("❌ Save directory not found")


In [None]:
# Download the trained model
from google.colab import files
import zipfile

# Create a zip file with all results
zip_path = "/content/ar_ssl4m_pretrained_model.zip"

with zipfile.ZipFile(zip_path, 'w') as zipf:
    for root, dirs, files_list in os.walk("save"):
        for file in files_list:
            file_path = os.path.join(root, file)
            arcname = os.path.relpath(file_path, "save")
            zipf.write(file_path, arcname)
            print(f"Added to zip: {arcname}")

print(f"\n📦 Created zip file: {zip_path}")
print("⬇️ Downloading...")

# Download the zip file
files.download(zip_path)


## 7. Training Summary


In [None]:
# Display training summary
print("🎯 AR-SSL4M Pretraining Summary")
print("=" * 50)
print(f"📊 Dataset: STOIC ({len(npy_files) if 'npy_files' in locals() and npy_files else 'Unknown'} samples)")
print(f"📂 Data path: {actual_data_path if 'actual_data_path' in locals() and actual_data_path else 'Not found'}")
print(f"🏗️ Model: AR-SSL4M (91.3M parameters)")
print(f"📐 Image size: 128×128×128")
print(f"🔢 Patch size: 16×16×16")
print(f"📦 Batch size: 4")
print(f"🎓 Learning rate: 1e-4")
print(f"🔄 Epochs: 1")
print(f"💾 Results saved to: save/")
print(f"🔗 GitHub Repo: https://github.com/tanglehunter00/AR-SSL4M-DEMO")
print("=" * 50)
print("✅ Training completed successfully!")
print("\n📝 Next steps:")
print("1. Download the model zip file")
print("2. Use the pretrained model for downstream tasks")
print("3. Fine-tune on specific medical imaging tasks (segmentation, classification, etc.)")
print("4. Experiment with different hyperparameters for better performance")


https://github.com/tanglehunter00/AR-SSL4M-DEMO.git