# AR-SSL4M Pretraining on Google Colab

This notebook handles the setup and pretraining of the AR-SSL4M model using data from Google Drive.

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Check dataset path
import os
dataset_path = '/content/drive/MyDrive/dataset/demo'
if os.path.exists(dataset_path):
    print(f"Dataset found at {dataset_path}")
    print(os.listdir(dataset_path))
else:
    print(f"Dataset NOT found at {dataset_path}. Please check your Drive structure.")

In [None]:
# Data Verification and Cleaning
# This cell checks all .npy files in the dataset to ensure they are valid and not empty.
# It will regenerate the 'colab_train_list.txt' excluding any corrupted files.

import os
import numpy as np
from tqdm import tqdm

drive_dataset_path = '/content/drive/MyDrive/dataset/demo'
patch_dir = os.path.join(drive_dataset_path, 'patch_random_spatial')
list_file_path = os.path.join(drive_dataset_path, 'colab_train_list.txt')

valid_files = []
corrupted_files = []

if os.path.exists(patch_dir):
    npy_files = [f for f in os.listdir(patch_dir) if f.endswith('.npy')]
    print(f"Checking {len(npy_files)} files in {patch_dir}...")
    
    for f in tqdm(npy_files):
        full_path = os.path.join(patch_dir, f)
        try:
            # Try loading the file
            data = np.load(full_path, mmap_mode='r') # Use mmap_mode='r' for faster checking without full read
            # Check shape/size if necessary, e.g.
            if data.size == 0 or data.shape != (128, 128, 128):
                 # Double check by fully loading if mmap is unsure or for strict size check
                 data = np.load(full_path)
                 if data.size == 0:
                    print(f"Skipping empty file: {f}")
                    corrupted_files.append(full_path)
                    continue
            
            valid_files.append(full_path)
        except Exception as e:
            print(f"Error reading {f}: {e}")
            corrupted_files.append(full_path)

    # Update the list file with only valid paths
    with open(list_file_path, 'w') as f:
        f.write('\n'.join(valid_files))
    
    print(f"\nVerification complete.")
    print(f"Valid files: {len(valid_files)}")
    print(f"Corrupted/Empty files removed: {len(corrupted_files)}")
    print(f"Updated training list at: {list_file_path}")

else:
    print(f"Patch directory not found: {patch_dir}")

In [None]:
# Clone the repository (if not already present)
# Cloning from your GitHub repository as requested
!git clone https://github.com/tanglehunter00/AR-SSL4M-DEMO.git

# IMPORTANT: If you are running this notebook and the code is NOT on Drive,
# you need to upload the code files to Colab runtime.

project_root = '/content/AR-SSL4M-DEMO' 
import os
if os.path.exists(project_root):
    %cd {project_root}
else:
    print("Project root not found. Please clone or upload your code.")

In [None]:
# Install dependencies
!pip install timm monai transformers fire

In [None]:
# Update dataset configuration paths dynamically
# We need to point the dataset config to the list files in Google Drive

# Assuming your list files are also in the dataset folder on Drive
# You might need to generate these list files if they contain absolute local paths from your PC.
# Here we create a new list file based on the Drive path.

import os

drive_dataset_path = '/content/drive/MyDrive/dataset/demo'
patch_dir = os.path.join(drive_dataset_path, 'patch_random_spatial')
list_file_path = os.path.join(drive_dataset_path, 'colab_train_list.txt')

if os.path.exists(patch_dir):
    npy_files = [f for f in os.listdir(patch_dir) if f.endswith('.npy')]
    full_paths = [os.path.join(patch_dir, f) for f in npy_files]
    
    with open(list_file_path, 'w') as f:
        f.write('\n'.join(full_paths))
    print(f"Created training list at {list_file_path} with {len(full_paths)} files.")
else:
    print("Patch directory not found. Please ensure 'patch_random_spatial' exists inside 'dataset/demo'.")


In [None]:
# Modify configs/datasets.py to use the Colab path
# We will do this by writing a temporary config file or modifying the file in place if possible.
# A safer way is to rely on the fact that we can update the config dynamically, 
# but the current codebase reads from a file. 
# Let's modify pretrain/configs/datasets.py directly.

config_path = 'pretrain/configs/datasets.py'

new_config_content = f"""
from dataclasses import dataclass

@dataclass
class custom_dataset:
    dataset: str = "custom_dataset"
    file: str = "image_dataset.py"
    train_split: str = "train"
    test_split: str = "validation"
    # Pointing to the generated list file on Drive
    spatial_path: str = "{list_file_path}"
    contrast_path: str = "{list_file_path}"
    semantic_path: str = "{list_file_path}"
    img_size = [128, 128, 128]
    patch_size = [16, 16, 16]
    attention_type = 'prefix'
    add_series_data = False
    add_spatial_data = True
    is_subset = False
    series_length = 4
"""

with open(config_path, 'w') as f:
    f.write(new_config_content)

print("Updated datasets.py configuration.")

In [None]:
# Run Pretraining
# Set batch size to 32 as requested
# Using newModel.py (ensure you imported ReconModel from newModel in main.py if that was the intent, 
# OR ensure model.py is updated with your changes)

# Note: The user requested to use 'newModel.py'. 
# You might need to rename newModel.py to model.py OR modify main.py to import from newModel.
# Here we assume main.py still imports from model.py. 
# We will rename newModel.py to model.py for this run to ensure the new logic is used.

!cp pretrain/newModel.py pretrain/model.py

# Switch to pretrain directory
%cd pretrain

!mkdir -p /content/drive/MyDrive/dataset/demo/output 

# Run training
!python main.py \
    --enable_fsdp False \
    --output_dir /content/drive/MyDrive/dataset/demo/output \
    --batch_size_training 16 \
    --num_epochs 30 \
    --save_metrics True \
    --num_workers_dataloader 4