# AR-SSL4M Pretraining on Google Colab 1

This notebook runs AR-SSL4M (Autoregressive Sequence Modeling for 3D Medical Image Representation) pretraining on Google Colab.

## Requirements
- GPU runtime (T4, V100, or A100)
- High RAM runtime (recommended)

## Dataset
- Using STOIC dataset (2771 samples)
- Each sample: 128×128×128 3D medical images


## 1. Environment Setup

**⚠️ 重要提示**: 如果遇到依赖冲突错误，请按以下步骤操作：
1. 重启运行时 (Runtime → Restart Runtime)
2. 重新运行所有cells
3. 如果仍有问题，请使用 "Factory Reset Runtime"


In [1]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ No GPU detected! Please enable GPU runtime in Colab.")


CUDA available: True
GPU: Tesla T4
GPU Memory: 15.8 GB


In [1]:
# 智能检测并修复PyTorch版本问题
print("📦 智能检测Colab环境...")

# 先检查PyTorch版本一致性
import torch
import subprocess
import sys

print(f"🔥 检测到PyTorch版本: {torch.__version__}")
print(f"🔥 CUDA可用: {torch.cuda.is_available()}")

# 检查是否存在版本不一致问题
result = subprocess.run([sys.executable, '-m', 'pip', 'list'],
                       capture_output=True, text=True)
pip_list = result.stdout

torch_versions = []
for line in pip_list.split('\n'):
    if line.startswith('torch ') or line.startswith('torchaudio ') or line.startswith('torchvision '):
        torch_versions.append(line.strip())

print("🔍 PyTorch相关包版本:")
for version in torch_versions:
    print(f"  {version}")

# 如果检测到版本不一致，修复它
if len(torch_versions) > 0:
    # 获取当前环境的实际torch版本
    current_torch = torch.__version__
    if '+' in current_torch:
        base_version = current_torch.split('+')[0]
        cuda_suffix = current_torch.split('+')[1]
    else:
        base_version = current_torch
        cuda_suffix = 'cu126'  # Colab默认

    print(f"🔧 确保PyTorch组件版本一致: {base_version}+{cuda_suffix}")

    # 重新安装一致的PyTorch套件（使用Colab的索引）
    !pip install -q --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

# 安装轻量级必需包
print("📚 安装轻量级必需包...")
lightweight_packages = ['fire', 'tqdm', 'PyYAML', 'packaging']

for pkg in lightweight_packages:
    try:
        if pkg == 'PyYAML':
            import yaml
            print(f"✅ {pkg} 已可用")
        else:
            __import__(pkg)
            print(f"✅ {pkg} 已可用")
    except ImportError:
        print(f"📥 安装 {pkg}...")
        !pip install -q {pkg}

# 检查transformers（通常Colab已有）
try:
    import transformers
    print(f"✅ Transformers: {transformers.__version__}")
except ImportError:
    print("📥 安装Transformers...")
    !pip install -q transformers

# 最后安装MONAI（如果需要）
try:
    import monai
    print(f"✅ MONAI: {monai.__version__}")
except ImportError:
    print("📥 安装MONAI（不更改PyTorch）...")
    !pip install -q --no-deps monai
    # 单独安装MONAI的必需依赖（避免PyTorch版本冲突）
    !pip install -q nibabel tqdm

print("✅ 智能安装完成!")
print("🎯 策略：保持PyTorch组件版本一致")


📦 智能检测Colab环境...
🔥 检测到PyTorch版本: 2.8.0+cu126
🔥 CUDA可用: True
🔍 PyTorch相关包版本:
  torch                                 2.8.0+cu126
  torchaudio                            2.8.0+cu126
  torchvision                           0.23.0+cu126
🔧 确保PyTorch组件版本一致: 2.8.0+cu126
📚 安装轻量级必需包...
📥 安装 fire...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.9/115.9 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h✅ tqdm 已可用
✅ PyYAML 已可用
✅ packaging 已可用
✅ Transformers: 4.56.1
📥 安装MONAI（不更改PyTorch）...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m40.2 MB/s[0m eta [36m0:00:00[0m
[?25h✅ 智能安装完成!
🎯 策略：保持PyTorch组件版本一致


## 2. Clone Repository and Setup


In [2]:
# Clone the repository
import os
repo_url = "https://github.com/tanglehunter00/AR-SSL4M-DEMO.git"

if os.path.exists("AR-SSL4M-DEMO"):
    print("Repository already exists, pulling latest changes...")
    !cd AR-SSL4M-DEMO && git pull
else:
    print("Cloning repository...")
    !git clone {repo_url}

# Change to project directory
os.chdir("AR-SSL4M-DEMO")
print(f"✅ Current directory: {os.getcwd()}")


Cloning repository...
Cloning into 'AR-SSL4M-DEMO'...
remote: Enumerating objects: 161, done.[K
remote: Counting objects: 100% (161/161), done.[K
remote: Compressing objects: 100% (115/115), done.[K
remote: Total 161 (delta 56), reused 137 (delta 36), pack-reused 0 (from 0)[K
Receiving objects: 100% (161/161), 1.32 MiB | 12.51 MiB/s, done.
Resolving deltas: 100% (56/56), done.
✅ Current directory: /content/AR-SSL4M-DEMO


## 3. Data Setup

根据您的截图，数据集位于Google Drive的 `dataset/compressed_datasets/volumes` 路径中。


In [4]:
# Mount Google Drive to access your data
from google.colab import drive
drive.mount('/content/drive')

# Set path to your data in Google Drive (based on your screenshot)
data_path = "/content/drive/MyDrive/dataset/volumes"

# Check if data exists
if os.path.exists(data_path):
    npy_files = [f for f in os.listdir(data_path) if f.endswith('.npy')]
    print(f"✅ Found {len(npy_files)} .npy files in {data_path}")
    if len(npy_files) == 0:
        print("⚠️ No .npy files found in the directory")
        print("Available files:")
        all_files = os.listdir(data_path)[:10]  # Show first 10 files
        for f in all_files:
            print(f"  - {f}")
else:
    print(f"❌ Data path not found: {data_path}")
    print("Available paths in Google Drive:")
    try:
        drive_contents = os.listdir("/content/drive/MyDrive")
        for item in drive_contents[:10]:
            print(f"  - /content/drive/MyDrive/{item}")
    except:
        print("Unable to list drive contents")


Mounted at /content/drive
✅ Found 13950 .npy files in /content/drive/MyDrive/dataset/volumes


In [5]:
# Create data list file
import glob

# Try multiple possible data paths
possible_paths = [
    "/content/drive/MyDrive/dataset/compressed_datasets/volumes",
    "/content/drive/MyDrive/compressed_datasets/volumes",
    "/content/drive/MyDrive/dataset/volumes",
    "/content/drive/MyDrive/volumes"
]

npy_files = []
actual_data_path = None

for path in possible_paths:
    if os.path.exists(path):
        files = glob.glob(os.path.join(path, "*.npy"))
        if files:
            npy_files = files
            actual_data_path = path
            print(f"✅ Found {len(npy_files)} .npy files in {actual_data_path}")
            break
        else:
            print(f"⚠️ Path exists but no .npy files found: {path}")

if not npy_files:
    print("❌ No .npy files found in any of the expected paths")
    print("Please check your Google Drive structure and update the paths accordingly")
else:
    # Create data list file
    data_list_path = "pretrain/colab_data_list.txt"
    with open(data_list_path, 'w') as f:
        for npy_file in npy_files:
            f.write(f"{npy_file}\n")

    print(f"✅ Created data list: {data_list_path} with {len(npy_files)} files")


✅ Found 13950 .npy files in /content/drive/MyDrive/dataset/volumes
✅ Created data list: pretrain/colab_data_list.txt with 13950 files


## 4. Update Configuration for Colab


In [7]:
# Update dataset configuration
config_content = '''from dataclasses import dataclass


@dataclass
class custom_dataset:
    dataset: str = "custom_dataset"
    file: str = "image_dataset.py"
    train_split: str = "train"
    test_split: str = "validation"
    spatial_path: str = "colab_data_list.txt"
    contrast_path: str = "colab_data_list.txt"
    semantic_path: str = "colab_data_list.txt"
    img_size = [128, 128, 128]
    patch_size = [16, 16, 16]
    attention_type = 'prefix'
    add_series_data = False
    add_spatial_data = True
    is_subset = False
    series_length = 4
'''

with open('pretrain/configs/datasets.py', 'w') as f:
    f.write(config_content)

print("✅ Updated dataset configuration")


✅ Updated dataset configuration


In [8]:
# Update training configuration for Colab
training_config_content = '''from dataclasses import dataclass


@dataclass
class train_config:
    enable_fsdp: bool=False
    low_cpu_fsdp: bool=False
    run_validation: bool=True
    batch_size_training: int=4  # Adjusted for Colab GPU memory
    batching_strategy: str="padding"
    gradient_accumulation_steps: int=1
    gradient_clipping: bool=False
    gradient_clipping_threshold: float = 1.0
    num_epochs: int=1  # Single epoch for Colab
    warmup_epochs:int=0
    num_workers_dataloader: int=0  # Set to 0 to avoid multiprocessing issues in Colab
    lr: float=1e-4
    weight_decay: float=0.01
    gamma: float=0.1
    seed: int=42
    use_fp16: bool=True  # Enable FP16 for Colab GPU memory efficiency
    mixed_precision: bool=True
    val_batch_size: int=1
    dataset = "custom_dataset"
    output_dir: str="/content/AR-SSL4M-DEMO/pretrain/save"
    freeze_layers: bool=False
    num_freeze_layers: int=1
    save_model: bool=True
    save_optimizer: bool=False
    save_metrics: bool=True
    scheduler:str='CosineLR'
    min_lr: float=0
    pos_type: str='sincos3d'
    norm_pixel_loss: bool=True
    enable_profiling: bool=False  # Disable profiling for cleaner output
'''

with open('pretrain/configs/training.py', 'w') as f:
    f.write(training_config_content)

print("✅ Updated training configuration for Colab")


✅ Updated training configuration for Colab


## 5. Start Pretraining


In [10]:
# Create save directory
os.makedirs("pretrain/save", exist_ok=True)
print("✅ Created save directory")


✅ Created save directory


In [None]:
# Start pretraining
print("🚀 Starting AR-SSL4M Pretraining...")
print("=" * 60)

# Change to pretrain directory and run
os.chdir("pretrain")

# Run the training script
!python main.py --output_dir save --batch_size_training 4

print("=" * 60)
print("✅ Pretraining completed!")


🚀 Starting AR-SSL4M Pretraining...
2025-09-21 09:10:07.386865: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1758445807.413575    2764 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1758445807.421514    2764 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1758445807.440899    2764 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1758445807.440922    2764 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1758445807.440930    2764 computation_pl

## 6. Check Results and Download Model


In [None]:
# Check training results
import os
import glob

save_dir = "save"
if os.path.exists(save_dir):
    files = os.listdir(save_dir)
    print(f"📁 Files in save directory:")
    for file in files:
        file_path = os.path.join(save_dir, file)
        if os.path.isfile(file_path):
            size_mb = os.path.getsize(file_path) / (1024 * 1024)
            print(f"  - {file} ({size_mb:.1f} MB)")
else:
    print("❌ Save directory not found")


In [None]:
# Download the trained model
from google.colab import files
import zipfile

# Create a zip file with all results
zip_path = "/content/ar_ssl4m_pretrained_model.zip"

with zipfile.ZipFile(zip_path, 'w') as zipf:
    for root, dirs, files_list in os.walk("save"):
        for file in files_list:
            file_path = os.path.join(root, file)
            arcname = os.path.relpath(file_path, "save")
            zipf.write(file_path, arcname)
            print(f"Added to zip: {arcname}")

print(f"\n📦 Created zip file: {zip_path}")
print("⬇️ Downloading...")

# Download the zip file
files.download(zip_path)


## 7. Training Summary


In [None]:
# Display training summary
print("🎯 AR-SSL4M Pretraining Summary")
print("=" * 50)
print(f"📊 Dataset: STOIC ({len(npy_files) if 'npy_files' in locals() and npy_files else 'Unknown'} samples)")
print(f"📂 Data path: {actual_data_path if 'actual_data_path' in locals() and actual_data_path else 'Not found'}")
print(f"🏗️ Model: AR-SSL4M (91.3M parameters)")
print(f"📐 Image size: 128×128×128")
print(f"🔢 Patch size: 16×16×16")
print(f"📦 Batch size: 4")
print(f"🎓 Learning rate: 1e-4")
print(f"🔄 Epochs: 1")
print(f"💾 Results saved to: save/")
print(f"🔗 GitHub Repo: https://github.com/tanglehunter00/AR-SSL4M-DEMO")
print("=" * 50)
print("✅ Training completed successfully!")
print("\n📝 Next steps:")
print("1. Download the model zip file")
print("2. Use the pretrained model for downstream tasks")
print("3. Fine-tune on specific medical imaging tasks (segmentation, classification, etc.)")
print("4. Experiment with different hyperparameters for better performance")


https://github.com/tanglehunter00/AR-SSL4M-DEMO.git