# AR-SSL4M Pretraining on Google Colab

This notebook handles the setup and pretraining of the AR-SSL4M model using data from Google Drive.

In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Check dataset path
import os
dataset_path = '/content/drive/MyDrive/dataset/LIDC-IDRI'
print(os.listdir('/content/drive/MyDrive/dataset/LIDC-IDRI'))
if os.path.exists(dataset_path):
    print(f"Dataset found at {dataset_path}")
    print(os.listdir(dataset_path))
else:
    print(f"Dataset NOT found at {dataset_path}. Please check your Drive structure.")

['LIDC-IDRI', 'AR-SSL4M-DEMO', 'patch_random_spatial', 'Untitled folder', 'output', 'pretrain_lists', 'colab_train_list.txt']
Dataset found at /content/drive/MyDrive/dataset/LIDC-IDRI
['LIDC-IDRI', 'AR-SSL4M-DEMO', 'patch_random_spatial', 'Untitled folder', 'output', 'pretrain_lists', 'colab_train_list.txt']


In [None]:
# Data Verification and Cleaning (Spatial / LIDC)
# Checks .npy files in patch_random_spatial and patch_random_lidc.
# Regenerates list excluding corrupted files. Run list generation (next cell) after this.

import os
import numpy as np
from tqdm import tqdm

drive_dataset_path = '/content/drive/MyDrive/dataset/LIDC-IDRI'
list_dir = os.path.join(drive_dataset_path, 'pretrain_lists')
os.makedirs(list_dir, exist_ok=True)

# Check both spatial dirs
patch_dirs_to_check = [
    os.path.join(drive_dataset_path, 'patch_random_spatial'),
    os.path.join(drive_dataset_path, 'AR-SSL4M-DEMO', 'pretrain', 'data', 'patch_random_lidc'),
]
valid_files = []
corrupted_files = []

for patch_dir in patch_dirs_to_check:
    if not os.path.exists(patch_dir):
        print(f"Skipping (not found): {patch_dir}")
        continue
    npy_files = [f for f in os.listdir(patch_dir) if f.endswith('.npy')]
    print(f"Checking {len(npy_files)} files in {patch_dir}...")
    for f in tqdm(npy_files):
        full_path = os.path.join(patch_dir, f)
        try:
            data = np.load(full_path, mmap_mode='r')
            if data.size == 0 or data.shape != (128, 128, 128):
                data = np.load(full_path)
                if data.size == 0:
                    corrupted_files.append(full_path)
                    continue
            valid_files.append(full_path)
        except Exception as e:
            corrupted_files.append(full_path)

spatial_list_path = os.path.join(list_dir, 'train_spatial.txt')
with open(spatial_list_path, 'w') as f:
    f.write('\n'.join(valid_files))
print(f"\nVerification complete. Valid: {len(valid_files)}, Corrupted: {len(corrupted_files)}")
print(f"Spatial list saved to: {spatial_list_path}")

Checking 24850 files in /content/drive/MyDrive/dataset/LIDC-IDRI/patch_random_spatial...


100%|██████████| 24850/24850 [4:35:15<00:00,  1.50it/s]


Verification complete.
Valid files: 24850
Corrupted/Empty files removed: 0
Updated training list at: /content/drive/MyDrive/dataset/LIDC-IDRI/colab_train_list.txt





In [None]:
import os
import tarfile
from google.colab import drive

# 如果断连了，取消下面这行的注释重新挂载
# drive.mount('/content/drive', force_remount=True)

drive_dataset_path = '/content/drive/MyDrive/dataset'
tar_root = os.path.join(drive_dataset_path, 'pretrain', 'BraTS23_Data', 'tar_data')
list_dir = os.path.join(drive_dataset_path, 'LIDC-IDRI', 'pretrain_lists')
os.makedirs(list_dir, exist_ok=True)
contrast_list_path = os.path.join(list_dir, 'train_contrast.txt')

# --- 1. 读取已经处理过的 tar 包，实现断点续传 ---
processed_tars = set()
if os.path.exists(contrast_list_path):
    with open(contrast_list_path, 'r') as f:
        for line in f:
            if ':' in line:
                # 提取 tar_path: tar_path:base
                processed_tars.add(line.split(':')[0])
print(f"Skipping {len(processed_tars)} already processed tar files (Resume mode).")

if os.path.exists(tar_root):
    all_tars = [os.path.join(r, f) for r, _, files in os.walk(tar_root) for f in files if f.endswith('.tar.gz')]
    total_tars = len(all_tars)

    # --- 2. 使用 'a' (append) 模式打开文件，实时写入 ---
    with open(contrast_list_path, 'a') as out_f:
        for i, tar_path in enumerate(all_tars, 1):
            file_name = os.path.basename(tar_path)

            if tar_path in processed_tars:
                print(f"[{i}/{total_tars}] Skipping: {file_name}") # 不想刷屏跳过信息可以注释掉
                continue

            print(f"[{i}/{total_tars}] >>> START: {file_name}")
            try:
                print(f"  [Step 1/4] Opening...")
                with tarfile.open(tar_path, 'r:gz') as tar:
                    print(f"  [Step 2/4] Reading names...")
                    names = tar.getnames()
                    names_set = set(names)

                    print(f"  [Step 3/4] Matching...")
                    found_count = 0
                    for n in names:
                        if n.endswith('.t1n.npy'):
                            base = n[:-len('.t1n.npy')]
                            if all(f"{base}{suffix}" in names_set for suffix in ['.t1c.npy', '.t2w.npy', '.t2f.npy']):
                                out_f.write(f"{tar_path}:{base}\n")
                                out_f.flush() # 强制刷新到磁盘，防止断电丢失
                                found_count += 1
                    print(f"  [Step 4/4] Done. Found {found_count} samples.")
            except Exception as e:
                print(f"  [!!! ERROR] {e}")
                if "Transport endpoint is not connected" in str(e):
                    print("\n[CRITICAL] Google Drive disconnected! Please remount and run again.")
                    break # 这种错误必须停止，重新挂载后再跑
            print("-" * 40)
else:
    print(f"BraTS tar root not found.")

In [2]:
# Clone the repository (if not already present)
# Cloning from your GitHub repository as requested
!git clone https://github.com/tanglehunter00/AR-SSL4M-DEMO.git

# IMPORTANT: If you are running this notebook and the code is NOT on Drive,
# you need to upload the code files to Colab runtime.

project_root = '/content/AR-SSL4M-DEMO'
import os
if os.path.exists(project_root):
    %cd {project_root}
else:
    print("Project root not found. Please clone or upload your code.")

Cloning into 'AR-SSL4M-DEMO'...
remote: Enumerating objects: 439, done.[K
remote: Counting objects: 100% (165/165), done.[K
remote: Compressing objects: 100% (120/120), done.[K
remote: Total 439 (delta 112), reused 84 (delta 45), pack-reused 274 (from 1)[K
Receiving objects: 100% (439/439), 1.75 MiB | 40.84 MiB/s, done.
Resolving deltas: 100% (232/232), done.
/content/AR-SSL4M-DEMO


In [3]:
# Install dependencies
!pip install timm monai transformers fire

Collecting monai
  Downloading monai-1.5.2-py3-none-any.whl.metadata (13 kB)
Collecting fire
  Downloading fire-0.7.1-py3-none-any.whl.metadata (5.8 kB)
Downloading monai-1.5.2-py3-none-any.whl (2.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m27.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fire-0.7.1-py3-none-any.whl (115 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.9/115.9 kB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fire, monai
Successfully installed fire-0.7.1 monai-1.5.2


In [None]:
# (Optional) Generate DeepLesion Semantic List
import os
import random

random.seed(0)

drive_dataset_path = '/content/drive/MyDrive/dataset'
npy_dir = os.path.join(drive_dataset_path, 'pretrain', 'DeepLesion', 'npy')
list_dir = os.path.join(drive_dataset_path, 'LIDC-IDRI', 'pretrain_lists')
os.makedirs(list_dir, exist_ok=True)
semantic_list_path = os.path.join(list_dir, 'train_semantic.txt')

if os.path.exists(npy_dir):
    all_data_list = []
    print(f"Start processing DeepLesion from: {npy_dir}")
    print("-" * 60)

    # 步骤 1: 获取完整文件列表 (这是最耗 IO 的一步)
    print(f"[Step 1/3] Fetching all files from directory... (This may take a while on GDrive)")
    try:
        all_files = os.listdir(npy_dir)
        print(f"  Done. Total files found in directory: {len(all_files)}")
    except Exception as e:
        print(f"  [!!! ERROR] Failed to list directory: {e}")
        all_files = []

    if all_files:
        # 步骤 2: 按照 1-8 的后缀对文件进行分类
        print(f"[Step 2/3] Categorizing files by suffix (_1.npy to _8.npy)...")
        for num in range(8):
            target_suffix = f'_{num+1}.npy'
            # 筛选出符合当前后缀的文件
            data_list = [os.path.join(npy_dir, x) for x in all_files if x.endswith(target_suffix)]
            num_available = len(data_list)

            # 计算需要生成的样本数
            n_samples = min(20000, num_available // 4) if num_available >= 4 else 0
            print(f"  > Suffix {target_suffix}: Found {num_available} files. Generating {n_samples} samples.")

            # 步骤 3: 随机采样生成组合
            if n_samples > 0:
                print(f"    - Sampling progress: ", end="")
                for i in range(n_samples):
                    choose_list = random.sample(data_list, 4)
                    all_data_list.append(','.join(choose_list))

                    # 每 5000 个样本打印一次小进度，防止看起来像死机
                    if (i + 1) % 5000 == 0:
                        print(f"{i+1}..", end="", flush=True)
                print(f" Done.")
            else:
                print(f"    - Skipping (not enough files).")

        print("-" * 60)
        # 步骤 4: 保存结果
        print(f"[Step 3/3] Writing {len(all_data_list)} samples to {semantic_list_path}...")
        try:
            with open(semantic_list_path, 'w') as f:
                f.write('\n'.join(all_data_list))
            print(f"SUCCESS! DeepLesion semantic list ready.")
        except Exception as e:
            print(f"  [!!! ERROR] Failed to write file: {e}")

else:
    print(f"DeepLesion npy dir not found: {npy_dir}. Skip semantic list.")


Start processing DeepLesion from: /content/drive/MyDrive/dataset/pretrain/DeepLesion/npy
------------------------------------------------------------
[Step 1/3] Fetching all files from directory... (This may take a while on GDrive)
  Done. Total files found in directory: 9816
[Step 2/3] Categorizing files by suffix (_1.npy to _8.npy)...
  > Suffix _1.npy: Found 247 files. Generating 61 samples.
    - Sampling progress:  Done.
  > Suffix _2.npy: Found 2176 files. Generating 544 samples.
    - Sampling progress:  Done.
  > Suffix _3.npy: Found 1672 files. Generating 418 samples.
    - Sampling progress:  Done.
  > Suffix _4.npy: Found 1284 files. Generating 321 samples.
    - Sampling progress:  Done.
  > Suffix _5.npy: Found 2394 files. Generating 598 samples.
    - Sampling progress:  Done.
  > Suffix _6.npy: Found 495 files. Generating 123 samples.
    - Sampling progress:  Done.
  > Suffix _7.npy: Found 681 files. Generating 170 samples.
    - Sampling progress:  Done.
  > Suffix _8.

In [None]:
# 旧版本训练
# Update dataset configuration paths dynamically
# We need to point the dataset config to the list files in Google Drive

# Assuming your list files are also in the dataset folder on Drive
# You might need to generate these list files if they contain absolute local paths from your PC.
# Here we create a new list file based on the Drive path.

import os

drive_dataset_path = '/content/drive/MyDrive/dataset/LIDC-IDRI'
patch_dir = os.path.join(drive_dataset_path, 'patch_random_spatial')
list_file_path = os.path.join(drive_dataset_path, 'colab_train_list.txt')

if os.path.exists(patch_dir):
    npy_files = [f for f in os.listdir(patch_dir) if f.endswith('.npy')]
    full_paths = [os.path.join(patch_dir, f) for f in npy_files]

    with open(list_file_path, 'w') as f:
        f.write('\n'.join(full_paths))
    print(f"Created training list at {list_file_path} with {len(full_paths)} files.")
else:
    print("Patch directory not found. Please ensure 'patch_random_spatial' exists inside 'dataset/demo'.")


Created training list at /content/drive/MyDrive/dataset/LIDC-IDRI/colab_train_list.txt with 24850 files.


In [4]:
# Modify newFullPretrain/configs/datasets.py to use generated list paths

import os

# ========== 选择使用的数据：改这里即可 ==========
USE_MODE = "all"  # 可选: "lidc_only" | "lidc_brats" | "lidc_deeplesion" | "brats_only" | "all"
# ============================================

list_dir = '/content/drive/MyDrive/dataset/LIDC-IDRI/pretrain_lists'
os.makedirs(list_dir, exist_ok=True)
spatial_path = os.path.join(list_dir, 'train_spatial.txt')
contrast_path = os.path.join(list_dir, 'train_contrast.txt')
semantic_path = os.path.join(list_dir, 'train_semantic.txt')

# 占位空文件：用于“排除”某个数据源时指向这里，不覆盖真实数据
empty_path = os.path.join(list_dir, '_empty.txt')
if not os.path.exists(empty_path):
    open(empty_path, 'w').close()

# 根据 USE_MODE 决定实际使用的路径（不修改原始文件）
if USE_MODE == "lidc_only":
    effective_contrast = empty_path
    effective_semantic = empty_path
elif USE_MODE == "lidc_brats":
    effective_contrast = contrast_path
    effective_semantic = empty_path
elif USE_MODE == "lidc_deeplesion":
    effective_contrast = empty_path
    effective_semantic = semantic_path
elif USE_MODE == "brats_only":
    effective_contrast = contrast_path
    effective_semantic = empty_path
else:  # "all"
    effective_contrast = contrast_path
    effective_semantic = semantic_path

add_series_data = (os.path.getsize(effective_contrast) > 0) or (os.path.getsize(effective_semantic) > 0)
add_spatial_data = False if USE_MODE == "brats_only" else True

project_root = '/content/AR-SSL4M-DEMO'
config_path = os.path.join(project_root, 'newFullPretrain', 'configs', 'datasets.py')
new_config_content = f"""
from dataclasses import dataclass

@dataclass
class custom_dataset:
    dataset: str = "custom_dataset"
    file: str = "image_dataset.py"
    train_split: str = "train"
    test_split: str = "validation"
    spatial_path: str = "{spatial_path}"
    contrast_path: str = "{effective_contrast}"
    semantic_path: str = "{effective_semantic}"
    img_size = [128, 128, 128]
    patch_size = [16, 16, 16]
    attention_type = 'prefix'
    add_series_data = {str(add_series_data)}
    add_spatial_data = {str(add_spatial_data)}
    is_subset = False
    series_length = 4
"""

with open(config_path, 'w') as f:
    f.write(new_config_content)

n_spatial = len(open(spatial_path).readlines()) if os.path.exists(spatial_path) else 0
n_contrast = len(open(effective_contrast).readlines()) if os.path.getsize(effective_contrast) > 0 else 0
n_semantic = len(open(effective_semantic).readlines()) if os.path.getsize(effective_semantic) > 0 else 0
total = n_spatial + (n_contrast + n_semantic if add_series_data else 0)

print(f"USE_MODE={USE_MODE}, add_series_data={add_series_data}")
print(f"  样本数: spatial={n_spatial}, contrast={n_contrast}, semantic={n_semantic}, 合计≈{total}")

USE_MODE=all, add_series_data=True
  样本数: spatial=24850, contrast=124100, semantic=2451, 合计≈151401


In [None]:
# Modify newFullPretrain/configs/datasets.py to use generated list paths

import os

list_dir = '/content/drive/MyDrive/dataset/LIDC-IDRI/pretrain_lists'
os.makedirs(list_dir, exist_ok=True)
spatial_path = os.path.join(list_dir, 'train_spatial.txt')
contrast_path = os.path.join(list_dir, 'train_contrast.txt')
semantic_path = os.path.join(list_dir, 'train_semantic.txt')

# Create empty files if contrast/semantic lists don't exist (dataset expects readable files)
for p in [contrast_path, semantic_path]:
    if not os.path.exists(p):
        open(p, 'w').close()

add_series_data = (os.path.getsize(contrast_path) > 0) or (os.path.getsize(semantic_path) > 0)

project_root = '/content/AR-SSL4M-DEMO'
config_path = os.path.join(project_root, 'newFullPretrain', 'configs', 'datasets.py')
new_config_content = f"""
from dataclasses import dataclass

@dataclass
class custom_dataset:
    dataset: str = "custom_dataset"
    file: str = "image_dataset.py"
    train_split: str = "train"
    test_split: str = "validation"
    spatial_path: str = "{spatial_path}"
    contrast_path: str = "{contrast_path}"
    semantic_path: str = "{semantic_path}"
    img_size = [128, 128, 128]
    patch_size = [16, 16, 16]
    attention_type = 'prefix'
    add_series_data = {str(add_series_data)}
    add_spatial_data = True
    is_subset = False
    series_length = 4
"""

with open(config_path, 'w') as f:
    f.write(new_config_content)

print(f"Updated newFullPretrain config. add_series_data={add_series_data}")

Updated newFullPretrain config. add_series_data=True


In [5]:
# Run Pretraining (using newFullPretrain - supports tar.gz BraTS, LIDC, DeepLesion)

%cd newFullPretrain

!mkdir -p /content/drive/MyDrive/dataset/LIDC-IDRI/output

!python main.py \
    --enable_fsdp False \
    --output_dir /content/drive/MyDrive/dataset/LIDC-IDRI/output \
    --batch_size_training 64 \
    --num_epochs 1 \
    --save_metrics True \
    --num_workers_dataloader 4

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
[Epoch 1/1] Step 425/2360 | loss: 0.586713 | 总用时: 25.532s | DataLoader取batch: 24.050s | 数据到设备: 0.129s | 前向: 0.484s | 反向: 0.859s | 梯度裁剪: 0.000s | 优化器步进: 0.010s
Training Epoch: 1:  18% 426/2360 [2:57:03<14:20:13, 26.69s/it][BraTS 混训缓存] 预加载 tar 225/1241: A-GLI-Part-01_BraTS-GLI-00289-000.tar.gz
  [复制] A-GLI-Part-01_BraTS-GLI-00289-000.tar.gz: 2.71s
  [解压] A-GLI-Part-01_BraTS-GLI-00289-000.tar.gz: 7.00s
[Epoch 1/1] Step 426/2360 | loss: 0.587101 | 总用时: 29.270s | DataLoader取batch: 27.775s | 数据到设备: 0.123s | 前向: 0.505s | 反向: 0.859s | 梯度裁剪: 0.000s | 优化器步进: 0.009s
[Epoch 1/1] Step 427/2360 | loss: 0.658114 | 总用时: 26.815s | DataLoader取batch: 25.347s | 数据到设备: 0.125s | 前向: 0.482s | 反向: 0.851s | 梯度裁剪: 0.000s | 优化器步进: 0.009s
Training Epoch: 1:  18% 428/2360 [2:57:59<14:38:02, 27.27s/it][BraTS 混训缓存] 预加载 tar 226/1241: A-GLI-Part-01_BraTS-GLI-00290-000.tar.gz
  [复制] A-GLI-Part-01_BraTS-GLI-00290-000.tar.gz: 2.27s
  [解压] A-GLI-Part-01_BraTS-GLI-00290-000.tar.gz: 

In [None]:
import os
config_path = '/content/AR-SSL4M-DEMO/newFullPretrain/configs/datasets.py'
if os.path.exists(config_path):
    print(open(config_path).read())
else:
    print(f"文件不存在: {config_path}")
    print("当前目录:", os.getcwd())
    print("目录内容:", os.listdir('/content/AR-SSL4M-DEMO'))


from dataclasses import dataclass

@dataclass
class custom_dataset:
    dataset: str = "custom_dataset"
    file: str = "image_dataset.py"
    train_split: str = "train"
    test_split: str = "validation"
    spatial_path: str = "/content/drive/MyDrive/dataset/LIDC-IDRI/pretrain_lists/train_spatial.txt"
    contrast_path: str = "/content/drive/MyDrive/dataset/LIDC-IDRI/pretrain_lists/train_spatial.txt"
    semantic_path: str = "/content/drive/MyDrive/dataset/LIDC-IDRI/pretrain_lists/train_spatial.txt"
    img_size = [128, 128, 128]
    patch_size = [16, 16, 16]
    attention_type = 'prefix'
    add_series_data = True
    add_spatial_data = True
    is_subset = False
    series_length = 4



In [None]:
import os

list_dir = '/content/drive/MyDrive/dataset/LIDC-IDRI/pretrain_lists'
paths = {
    'spatial':  os.path.join(list_dir, 'train_spatial.txt'),
    'contrast': os.path.join(list_dir, 'train_contrast.txt'),
    'semantic': os.path.join(list_dir, 'train_semantic.txt'),
}

for name, p in paths.items():
    if os.path.exists(p):
        with open(p, 'r') as f:
            n = len(f.readlines())
        print(f"{name}: {n} 行")
    else:
        print(f"{name}: 文件不存在")

spatial: 24850 行
contrast: 0 行
semantic: 0 行


In [None]:
   with open('/content/drive/MyDrive/dataset/LIDC-IDRI/pretrain_lists/train_contrast.txt') as f:
       lines = f.readlines()
   print(f"train_contrast.txt 行数: {len(lines)}")
   if lines:
       print(f"首行示例: {lines[0][:80]}...")

train_contrast.txt 行数: 0
