# AR-SSL4M Pretraining on Google Colab

This notebook handles the setup and pretraining of the AR-SSL4M model using data from Google Drive.

In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Check dataset path
import os
dataset_path = '/content/drive/MyDrive/dataset/LIDC-IDRI'
print(os.listdir('/content/drive/MyDrive/dataset/LIDC-IDRI'))
if os.path.exists(dataset_path):
    print(f"Dataset found at {dataset_path}")
    print(os.listdir(dataset_path))
else:
    print(f"Dataset NOT found at {dataset_path}. Please check your Drive structure.")

['LIDC-IDRI', 'AR-SSL4M-DEMO', 'patch_random_spatial', 'Untitled folder', 'colab_train_list.txt', 'output']
Dataset found at /content/drive/MyDrive/dataset/LIDC-IDRI
['LIDC-IDRI', 'AR-SSL4M-DEMO', 'patch_random_spatial', 'Untitled folder', 'colab_train_list.txt', 'output']


In [None]:
# Data Verification and Cleaning (Spatial / LIDC)
# Checks .npy files in patch_random_spatial and patch_random_lidc.
# Regenerates list excluding corrupted files. Run list generation (next cell) after this.

import os
import numpy as np
from tqdm import tqdm

drive_dataset_path = '/content/drive/MyDrive/dataset/LIDC-IDRI'
list_dir = os.path.join(drive_dataset_path, 'pretrain_lists')
os.makedirs(list_dir, exist_ok=True)

# Check both spatial dirs
patch_dirs_to_check = [
    os.path.join(drive_dataset_path, 'patch_random_spatial'),
    os.path.join(drive_dataset_path, 'AR-SSL4M-DEMO', 'pretrain', 'data', 'patch_random_lidc'),
]
valid_files = []
corrupted_files = []

for patch_dir in patch_dirs_to_check:
    if not os.path.exists(patch_dir):
        print(f"Skipping (not found): {patch_dir}")
        continue
    npy_files = [f for f in os.listdir(patch_dir) if f.endswith('.npy')]
    print(f"Checking {len(npy_files)} files in {patch_dir}...")
    for f in tqdm(npy_files):
        full_path = os.path.join(patch_dir, f)
        try:
            data = np.load(full_path, mmap_mode='r')
            if data.size == 0 or data.shape != (128, 128, 128):
                data = np.load(full_path)
                if data.size == 0:
                    corrupted_files.append(full_path)
                    continue
            valid_files.append(full_path)
        except Exception as e:
            corrupted_files.append(full_path)

spatial_list_path = os.path.join(list_dir, 'train_spatial.txt')
with open(spatial_list_path, 'w') as f:
    f.write('\n'.join(valid_files))
print(f"\nVerification complete. Valid: {len(valid_files)}, Corrupted: {len(corrupted_files)}")
print(f"Spatial list saved to: {spatial_list_path}")

Checking 24850 files in /content/drive/MyDrive/dataset/LIDC-IDRI/patch_random_spatial...


100%|██████████| 24850/24850 [4:35:15<00:00,  1.50it/s]


Verification complete.
Valid files: 24850
Corrupted/Empty files removed: 0
Updated training list at: /content/drive/MyDrive/dataset/LIDC-IDRI/colab_train_list.txt





In [None]:
# Generate BraTS Contrast List from tar.gz (NO extraction - uses tarfile.getnames() only)
# Only the list file is saved to Drive. No npy files are extracted to disk.

import os
import tarfile

drive_dataset_path = '/content/drive/MyDrive/dataset'
tar_root = os.path.join(drive_dataset_path, 'pretrain', 'BraTS23_Data', 'tar_data')  # adjust to your BraTS tar location
list_dir = os.path.join(drive_dataset_path, 'LIDC-IDRI', 'pretrain_lists')
os.makedirs(list_dir, exist_ok=True)
contrast_list_path = os.path.join(list_dir, 'train_contrast.txt')

lines = []
if os.path.exists(tar_root):
    for root, _, files in os.walk(tar_root):
        for f in files:
            if f.endswith('.tar.gz'):
                tar_path = os.path.join(root, f)
                try:
                    with tarfile.open(tar_path, 'r:gz') as tar:
                        names = tar.getnames()
                        for n in names:
                            if n.endswith('.t1n.npy'):
                                base = n[:-len('.t1n.npy')]
                                if base+'.t1c.npy' in names and base+'.t2w.npy' in names and base+'.t2f.npy' in names:
                                    lines.append(f"{tar_path}:{base}")
                except Exception as e:
                    print(f"Skip {tar_path}: {e}")
    with open(contrast_list_path, 'w') as f:
        f.write('\n'.join(lines))
    print(f"BraTS contrast: {len(lines)} samples -> {contrast_list_path}")
else:
    print(f"BraTS tar root not found: {tar_root}. Skip contrast list.")

In [2]:
# Clone the repository (if not already present)
# Cloning from your GitHub repository as requested
!git clone https://github.com/tanglehunter00/AR-SSL4M-DEMO.git

# IMPORTANT: If you are running this notebook and the code is NOT on Drive,
# you need to upload the code files to Colab runtime.

project_root = '/content/AR-SSL4M-DEMO'
import os
if os.path.exists(project_root):
    %cd {project_root}
else:
    print("Project root not found. Please clone or upload your code.")

Cloning into 'AR-SSL4M-DEMO'...
remote: Enumerating objects: 370, done.[K
remote: Counting objects: 100% (96/96), done.[K
remote: Compressing objects: 100% (70/70), done.[K
remote: Total 370 (delta 61), reused 52 (delta 26), pack-reused 274 (from 1)[K
Receiving objects: 100% (370/370), 1.72 MiB | 14.81 MiB/s, done.
Resolving deltas: 100% (181/181), done.
/content/AR-SSL4M-DEMO


In [3]:
# Install dependencies
!pip install timm monai transformers fire

Collecting monai
  Downloading monai-1.5.2-py3-none-any.whl.metadata (13 kB)
Collecting fire
  Downloading fire-0.7.1-py3-none-any.whl.metadata (5.8 kB)
Downloading monai-1.5.2-py3-none-any.whl (2.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m57.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fire-0.7.1-py3-none-any.whl (115 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.9/115.9 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fire, monai
Successfully installed fire-0.7.1 monai-1.5.2


In [None]:
# (Optional) Generate DeepLesion Semantic List
# Run if you have DeepLesion npy data at .../DeepLesion/data/npy/

import os
import random
random.seed(0)

drive_dataset_path = '/content/drive/MyDrive/dataset'
npy_dir = os.path.join(drive_dataset_path, 'pretrain', 'DeepLesion', 'data', 'npy')
list_dir = os.path.join(drive_dataset_path, 'LIDC-IDRI', 'pretrain_lists')
semantic_list_path = os.path.join(list_dir, 'train_semantic.txt')

if os.path.exists(npy_dir):
    all_data_list = []
    for num in range(8):
        data_list = [os.path.join(npy_dir, x) for x in os.listdir(npy_dir) if x.endswith(f'_{num+1}.npy')]
        n_samples = min(20000, len(data_list) // 4) if len(data_list) >= 4 else 0
        for _ in range(n_samples):
            choose_list = random.sample(data_list, 4)
            all_data_list.append(','.join(choose_list))
    os.makedirs(list_dir, exist_ok=True)
    with open(semantic_list_path, 'w') as f:
        f.write('\n'.join(all_data_list))
    print(f"DeepLesion semantic: {len(all_data_list)} samples -> {semantic_list_path}")
else:
    print(f"DeepLesion npy dir not found: {npy_dir}. Skip semantic list.")


Created training list at /content/drive/MyDrive/dataset/LIDC-IDRI/colab_train_list.txt with 24850 files.


In [4]:
# Modify newFullPretrain/configs/datasets.py to use generated list paths

import os

list_dir = '/content/drive/MyDrive/dataset/LIDC-IDRI/pretrain_lists'
os.makedirs(list_dir, exist_ok=True)
spatial_path = os.path.join(list_dir, 'train_spatial.txt')
contrast_path = os.path.join(list_dir, 'train_contrast.txt')
semantic_path = os.path.join(list_dir, 'train_semantic.txt')

# Create empty files if contrast/semantic lists don't exist (dataset expects readable files)
for p in [contrast_path, semantic_path]:
    if not os.path.exists(p):
        open(p, 'w').close()

add_series_data = (os.path.getsize(contrast_path) > 0) or (os.path.getsize(semantic_path) > 0)

config_path = 'newFullPretrain/configs/datasets.py'
new_config_content = f"""
from dataclasses import dataclass

@dataclass
class custom_dataset:
    dataset: str = "custom_dataset"
    file: str = "image_dataset.py"
    train_split: str = "train"
    test_split: str = "validation"
    spatial_path: str = "{spatial_path}"
    contrast_path: str = "{contrast_path}"
    semantic_path: str = "{semantic_path}"
    img_size = [128, 128, 128]
    patch_size = [16, 16, 16]
    attention_type = 'prefix'
    add_series_data = {str(add_series_data)}
    add_spatial_data = True
    is_subset = False
    series_length = 4
"""

with open(config_path, 'w') as f:
    f.write(new_config_content)

print(f"Updated newFullPretrain config. add_series_data={add_series_data}")

Updated newFullPretrain config. add_series_data=True


In [5]:
# Run Pretraining (using newFullPretrain - supports tar.gz BraTS, LIDC, DeepLesion)

%cd newFullPretrain

!mkdir -p /content/drive/MyDrive/dataset/LIDC-IDRI/output

!python main.py \
    --enable_fsdp False \
    --output_dir /content/drive/MyDrive/dataset/LIDC-IDRI/output \
    --batch_size_training 64 \
    --num_epochs 1 \
    --save_metrics True \
    --num_workers_dataloader 4

/content/AR-SSL4M-DEMO/newFullPretrain
2026-02-19 11:00:24.794219: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-19 11:00:24.812762: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1771498824.833396    2227 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1771498824.840488    2227 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1771498824.860477    2227 computation_placer.cc:177] computation placer already registered. 