# AR-SSL4M Pretraining on Google Colab with Dilated Attention

This notebook runs AR-SSL4M (Autoregressive Sequence Modeling for 3D Medical Image Representation) pretraining on Google Colab with **D-Former Dilated Attention** integration.

## Features
- **Dilated Attention**: Implements D-Former's dilated attention mechanism
- **Fixed Hyperparameters**: Configurable dilated attention ratios for pretraining and finetuning
- **Joint Training**: Supports both upstream pretraining and downstream finetuning
- **Google Colab Optimized**: Optimized for Colab's GPU memory and runtime

## Requirements
- GPU runtime (T4, V100, or A100)
- High RAM runtime (recommended)

## Dataset
- Using STOIC dataset (2771 samples)
- Each sample: 128×128×128 3D medical images

## Hyperparameters
- **Pretrain Dilated Ratio**: Controls dilated attention weight in pretraining (default: 0.01)
- **Finetune Dilated Ratio**: Controls dilated attention weight in finetuning (default: 0.01)
- **Original Attention Ratio**: Automatically calculated as (1 - dilated_ratio)


## 1. Environment Setup

**⚠️ 重要提示**: 如果遇到依赖冲突错误，请按以下步骤操作：
1. 重启运行时 (Runtime → Restart Runtime)
2. 重新运行所有cells
3. 如果仍有问题，请使用 "Factory Reset Runtime"


In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ No GPU detected! Please enable GPU runtime in Colab.")


In [None]:
# 智能检测并修复PyTorch版本问题
print("📦 智能检测Colab环境...")

# 先检查PyTorch版本一致性
import torch
import subprocess
import sys

print(f"🔥 检测到PyTorch版本: {torch.__version__}")
print(f"🔥 CUDA可用: {torch.cuda.is_available()}")

# 检查是否存在版本不一致问题
result = subprocess.run([sys.executable, '-m', 'pip', 'list'], 
                       capture_output=True, text=True)
pip_list = result.stdout

torch_versions = []
for line in pip_list.split('\n'):
    if line.startswith('torch ') or line.startswith('torchaudio ') or line.startswith('torchvision '):
        torch_versions.append(line.strip())

print("🔍 PyTorch相关包版本:")
for version in torch_versions:
    print(f"  {version}")

# 如果检测到版本不一致，修复它
if len(torch_versions) > 0:
    # 获取当前环境的实际torch版本
    current_torch = torch.__version__
    if '+' in current_torch:
        base_version = current_torch.split('+')[0]
        cuda_suffix = current_torch.split('+')[1]
    else:
        base_version = current_torch
        cuda_suffix = 'cu126'  # Colab默认
    
    print(f"🔧 确保PyTorch组件版本一致: {base_version}+{cuda_suffix}")
    
    # 重新安装一致的PyTorch套件（使用Colab的索引）
    !pip install -q --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

# 安装轻量级必需包
print("📚 安装轻量级必需包...")
lightweight_packages = ['fire', 'tqdm', 'PyYAML', 'packaging']

for pkg in lightweight_packages:
    try:
        if pkg == 'PyYAML':
            import yaml
            print(f"✅ {pkg} 已可用")
        else:
            __import__(pkg)
            print(f"✅ {pkg} 已可用")
    except ImportError:
        print(f"📥 安装 {pkg}...")
        !pip install -q {pkg}

# 检查transformers（通常Colab已有）
try:
    import transformers
    print(f"✅ Transformers: {transformers.__version__}")
except ImportError:
    print("📥 安装Transformers...")
    !pip install -q transformers

# 最后安装MONAI（如果需要）
try:
    import monai
    print(f"✅ MONAI: {monai.__version__}")
except ImportError:
    print("📥 安装MONAI（不更改PyTorch）...")
    !pip install -q --no-deps monai
    # 单独安装MONAI的必需依赖（避免PyTorch版本冲突）
    !pip install -q nibabel tqdm

print("✅ 智能安装完成!")
print("🎯 策略：保持PyTorch组件版本一致")


## 2. Clone Repository and Setup


In [None]:
# Clone the repository
import os
repo_url = "https://github.com/tanglehunter00/AR-SSL4M-DEMO.git"

if os.path.exists("AR-SSL4M-DEMO"):
    print("Repository already exists, pulling latest changes...")
    !cd AR-SSL4M-DEMO && git pull
else:
    print("Cloning repository...")
    !git clone {repo_url}

# Change to project directory
os.chdir("AR-SSL4M-DEMO")
print(f"✅ Current directory: {os.getcwd()}")


## 3. Data Setup

根据您的截图，数据集位于Google Drive的 `dataset/compressed_datasets/volumes` 路径中。


In [None]:
# Mount Google Drive to access your data
from google.colab import drive
drive.mount('/content/drive')

# Set path to your data in Google Drive (based on your screenshot)
data_path = "/content/drive/MyDrive/dataset/compressed_datasets/volumes"

# Check if data exists
if os.path.exists(data_path):
    npy_files = [f for f in os.listdir(data_path) if f.endswith('.npy')]
    print(f"✅ Found {len(npy_files)} .npy files in {data_path}")
    if len(npy_files) == 0:
        print("⚠️ No .npy files found in the directory")
        print("Available files:")
        all_files = os.listdir(data_path)[:10]  # Show first 10 files
        for f in all_files:
            print(f"  - {f}")
else:
    print(f"❌ Data path not found: {data_path}")
    print("Available paths in Google Drive:")
    try:
        drive_contents = os.listdir("/content/drive/MyDrive")
        for item in drive_contents[:10]:
            print(f"  - /content/drive/MyDrive/{item}")
    except:
        print("Unable to list drive contents")


In [None]:
# Create data list file
import glob

# Try multiple possible data paths
possible_paths = [
    "/content/drive/MyDrive/dataset/compressed_datasets/volumes",
    "/content/drive/MyDrive/compressed_datasets/volumes", 
    "/content/drive/MyDrive/dataset/volumes",
    "/content/drive/MyDrive/volumes"
]

npy_files = []
actual_data_path = None

for path in possible_paths:
    if os.path.exists(path):
        files = glob.glob(os.path.join(path, "*.npy"))
        if files:
            npy_files = files
            actual_data_path = path
            print(f"✅ Found {len(npy_files)} .npy files in {actual_data_path}")
            break
        else:
            print(f"⚠️ Path exists but no .npy files found: {path}")

if not npy_files:
    print("❌ No .npy files found in any of the expected paths")
    print("Please check your Google Drive structure and update the paths accordingly")
else:
    # Create data list file
    data_list_path = "pretrain/colab_data_list.txt"
    with open(data_list_path, 'w') as f:
        for npy_file in npy_files:
            f.write(f"{npy_file}\n")
    
    print(f"✅ Created data list: {data_list_path} with {len(npy_files)} files")


## 4. Hyperparameter Configuration

Configure the dilated attention hyperparameters for pretraining and finetuning.


In [None]:
# Hyperparameter Configuration for Dilated Attention
print("🔧 Configuring Dilated Attention Hyperparameters...")

# Default hyperparameters (can be modified)
PRETRAIN_DILATED_RATIO = 0.01  # Upstream pretraining: dilated attention weight
FINETUNE_DILATED_RATIO = 0.01  # Downstream finetuning: dilated attention weight

print(f"📊 Pretrain dilated attention ratio: {PRETRAIN_DILATED_RATIO}")
print(f"📊 Finetune dilated attention ratio: {FINETUNE_DILATED_RATIO}")
print(f"📊 Pretrain original attention ratio: {1 - PRETRAIN_DILATED_RATIO}")
print(f"📊 Finetune original attention ratio: {1 - FINETUNE_DILATED_RATIO}")

# You can modify these values here:
# PRETRAIN_DILATED_RATIO = 0.05  # Example: 5% dilated, 95% original
# FINETUNE_DILATED_RATIO = 0.1   # Example: 10% dilated, 90% original

print("\n💡 Tips:")
print("- Higher dilated ratio: More focus on long-range dependencies")
print("- Lower dilated ratio: More focus on local attention patterns")
print("- Recommended range: 0.01 - 0.2")
print("- You can experiment with different values for better performance")

print("\n🔧 To modify hyperparameters:")
print("1. Uncomment and change the values above")
print("2. Or modify the variables directly:")
print("   PRETRAIN_DILATED_RATIO = 0.05  # 5% dilated attention")
print("   FINETUNE_DILATED_RATIO = 0.1   # 10% dilated attention")
print("3. Re-run this cell and the training cells")


In [None]:
# Update dataset configuration
config_content = '''from dataclasses import dataclass


@dataclass
class custom_dataset:
    dataset: str = "custom_dataset"
    file: str = "image_dataset.py"
    train_split: str = "train"
    test_split: str = "validation"
    spatial_path: str = "colab_data_list.txt"
    contrast_path: str = "colab_data_list.txt"
    semantic_path: str = "colab_data_list.txt"
    img_size = [128, 128, 128]
    patch_size = [16, 16, 16]
    attention_type = 'prefix'
    add_series_data = False
    add_spatial_data = True
    is_subset = False
    series_length = 4
'''

with open('pretrain/configs/datasets.py', 'w') as f:
    f.write(config_content)

print("✅ Updated dataset configuration")


# Update training configuration for Colab with Dilated Attention
training_config_content = f'''from dataclasses import dataclass


@dataclass
class train_config:
    enable_fsdp: bool=False
    low_cpu_fsdp: bool=False
    run_validation: bool=True
    batch_size_training: int=4  # Adjusted for Colab GPU memory
    batching_strategy: str="padding"
    gradient_accumulation_steps: int=1
    gradient_clipping: bool=False
    gradient_clipping_threshold: float = 1.0
    num_epochs: int=1  # Single epoch for Colab
    warmup_epochs:int=0
    num_workers_dataloader: int=0  # Set to 0 to avoid multiprocessing issues in Colab
    lr: float=1e-4
    weight_decay: float=0.01
    gamma: float=0.1
    seed: int=42
    use_fp16: bool=True  # Enable FP16 for Colab GPU memory efficiency
    mixed_precision: bool=True
    val_batch_size: int=1
    dataset = "custom_dataset"
    output_dir: str="/content/AR-SSL4M-DEMO/pretrain/save"
    freeze_layers: bool=False
    num_freeze_layers: int=1
    save_model: bool=True
    save_optimizer: bool=False
    save_metrics: bool=True
    scheduler:str='CosineLR'
    min_lr: float=0
    pos_type: str='sincos3d'
    norm_pixel_loss: bool=True
    enable_profiling: bool=False  # Disable profiling for cleaner output
    # Dilated Attention Hyperparameters
    pretrain_dilated_ratio: float={PRETRAIN_DILATED_RATIO}
    finetune_dilated_ratio: float={FINETUNE_DILATED_RATIO}
'''

with open('pretrain/configs/training.py', 'w') as f:
    f.write(training_config_content)

print("✅ Updated training configuration for Colab with Dilated Attention")
print(f"🔧 Pretrain dilated ratio: {PRETRAIN_DILATED_RATIO}")
print(f"🔧 Finetune dilated ratio: {FINETUNE_DILATED_RATIO}")


In [None]:
# Create save directory
os.makedirs("pretrain/save", exist_ok=True)
print("✅ Created save directory")


In [None]:
# Start pretraining with Dilated Attention
print("🚀 Starting AR-SSL4M Pretraining with Dilated Attention...")
print("=" * 60)
print(f"🔧 Pretrain dilated ratio: {PRETRAIN_DILATED_RATIO}")
print(f"🔧 Finetune dilated ratio: {FINETUNE_DILATED_RATIO}")
print(f"📊 Pretrain original ratio: {1 - PRETRAIN_DILATED_RATIO}")
print(f"📊 Finetune original ratio: {1 - FINETUNE_DILATED_RATIO}")
print("=" * 60)

# Change to pretrain directory and run
os.chdir("pretrain")

# Run the training script with dilated attention
!python main.py --output_dir save --batch_size_training 4 --pretrain_dilated_ratio {PRETRAIN_DILATED_RATIO} --finetune_dilated_ratio {FINETUNE_DILATED_RATIO}

print("=" * 60)
print("✅ Pretraining with Dilated Attention completed!")


## 6. Check Results and Download Model


In [None]:
# Check training results
import os
import glob

save_dir = "save"
if os.path.exists(save_dir):
    files = os.listdir(save_dir)
    print(f"📁 Files in save directory:")
    for file in files:
        file_path = os.path.join(save_dir, file)
        if os.path.isfile(file_path):
            size_mb = os.path.getsize(file_path) / (1024 * 1024)
            print(f"  - {file} ({size_mb:.1f} MB)")
else:
    print("❌ Save directory not found")


In [None]:
# Download the trained model
from google.colab import files
import zipfile

# Create a zip file with all results
zip_path = "/content/ar_ssl4m_pretrained_model.zip"

with zipfile.ZipFile(zip_path, 'w') as zipf:
    for root, dirs, files_list in os.walk("save"):
        for file in files_list:
            file_path = os.path.join(root, file)
            arcname = os.path.relpath(file_path, "save")
            zipf.write(file_path, arcname)
            print(f"Added to zip: {arcname}")

print(f"\n📦 Created zip file: {zip_path}")
print("⬇️ Downloading...")

# Download the zip file
files.download(zip_path)


## 7. Training Summary


In [None]:
# Display training summary with Dilated Attention
print("🎯 AR-SSL4M Pretraining Summary with Dilated Attention")
print("=" * 60)
print(f"📊 Dataset: STOIC ({len(npy_files) if 'npy_files' in locals() and npy_files else 'Unknown'} samples)")
print(f"📂 Data path: {actual_data_path if 'actual_data_path' in locals() and actual_data_path else 'Not found'}")
print(f"🏗️ Model: AR-SSL4M with Dilated Attention (91.3M parameters)")
print(f"📐 Image size: 128×128×128")
print(f"🔢 Patch size: 16×16×16")
print(f"📦 Batch size: 4")
print(f"🎓 Learning rate: 1e-4")
print(f"🔄 Epochs: 1")
print(f"🔧 Pretrain dilated ratio: {PRETRAIN_DILATED_RATIO}")
print(f"🔧 Finetune dilated ratio: {FINETUNE_DILATED_RATIO}")
print(f"📊 Pretrain original ratio: {1 - PRETRAIN_DILATED_RATIO}")
print(f"📊 Finetune original ratio: {1 - FINETUNE_DILATED_RATIO}")
print(f"💾 Results saved to: save/")
print(f"🔗 GitHub Repo: https://github.com/tanglehunter00/AR-SSL4M-DEMO")
print("=" * 60)
print("✅ Training with Dilated Attention completed successfully!")
print("\n📝 Next steps:")
print("1. Download the model zip file")
print("2. Use the pretrained model for downstream tasks")
print("3. Fine-tune on specific medical imaging tasks (segmentation, classification, etc.)")
print("4. Experiment with different dilated attention ratios for better performance")
print("5. Try different combinations of pretrain_dilated_ratio and finetune_dilated_ratio")


https://github.com/tanglehunter00/AR-SSL4M-DEMO.git