In [26]:
import numpy as np
import os
import argparse
from datetime import datetime

def generate_sbatch_fmri(
    job_name='brainnat',
    hour=48,
    minute=00,
    gpu='',
    constraint="a100|h100",
    output_dir_base='/scratch/cl6707/Projects/fmri/Brain_Decoding/jobs/',
    script_name='Train.py',
    num_gpus=1,
    batch_size=21,
    model_name="multisubject_nat",
    data_path='/scratch/cl6707/Shared_Datasets/NSD_MindEye/Mindeye2',
    cache_dir='/scratch/cl6707/Shared_Datasets/NSD_MindEye/Mindeye2',
    multi_subject="1,2,5,7",
    subj=1,
    max_lr=3e-4,
    mixup_pct=0.33,
    num_epochs=150,
    use_prior=True,
    prior_scale=30,
    clip_scale=1,
    blurry_recon=False,
    blur_scale=0.5,
    use_image_aug=False,
    n_blocks=4,
    hidden_dim=512,
    num_sessions=40,
    ckpt_interval=999,
    ckpt_saving=True,
    wandb_log=True,
    num_heads=4,
    tome_r=2000,
    last_n_features=16,
    nat_depth=2,
    nat_num_neighbors=8,
    lr_scheduler_type='cycle',
    seed=42,
    use_mixer=False,
    new_test=True,
    wandb_project="BRAIN_NAT",
    multisubject_ckpt=None,
    full_attention=True,
    additional_args=None):
    
    # Add current date to job name
    current_date = datetime.now().strftime("%Y%m%d")
    job_name = f"{job_name}_{current_date}"
    
    # Start constructing the sbatch script
    text = '#!/bin/bash\n\n'
    text += f'#SBATCH --job-name={job_name}\n'
    text += '#SBATCH --nodes=1\n'
    text += '#SBATCH --cpus-per-task=16\n'
    text += '#SBATCH --mem=64GB\n'
    text += f'#SBATCH --time={hour}:{minute:02d}:00\n'
    text += f'#SBATCH --gres=gpu:{num_gpus}\n'
    text += f'#SBATCH --constraint="{constraint}"\n'
    text += '#SBATCH --account=pr_60_tandon_advanced\n\n'

    text += 'overlay_ext3=/scratch/cl6707/dl-env/fMRI.ext3\n'
    text += f'export NUM_GPUS={num_gpus}  # Set to equal gres=gpu:#!\n'
    text += f'export BATCH_SIZE={batch_size} # 21 for multisubject / 24 for singlesubject (orig. paper used 42 for multisubject / 24 for singlesubject)\n'
    text += 'export GLOBAL_BATCH_SIZE=$((BATCH_SIZE * NUM_GPUS))\n\n'

    text += '# Make sure another job doesnt use same port, here using random number\n'
    text += 'export MASTER_PORT=$((RANDOM % (19000 - 11000 + 1) + 11000)) \n'
    text += 'export HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")\n'
    text += 'export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)\n'
    text += 'export COUNT_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)\n'
    text += 'echo MASTER_ADDR=${MASTER_ADDR}\n'
    text += 'echo MASTER_PORT=${MASTER_PORT}\n'
    text += 'echo WORLD_SIZE=${COUNT_NODE}\n\n'

    text += 'singularity exec --nv \\\n'
    text += '    --overlay ${overlay_ext3}:ro \\\n'
    text += '    /scratch/work/public/singularity/cuda12.1.1-cudnn8.9.0-devel-ubuntu22.04.2.sif \\\n'
    text += '    /bin/bash -c "\n'
    text += 'source /ext3/env.sh\n'
    text += 'cd /scratch/cl6707/Projects/fmri/Brain_Decoding/Downstream\n\n'
    
    text += 'export SSL_CERT_FILE=/scratch/cl6707/Shared_Datasets/cacert.pem\n'
    text += 'accelerate launch --num_processes=${NUM_GPUS} --main_process_port=${MASTER_PORT} --mixed_precision=fp16 Train.py\\\n'
    text += f'    --data_path={data_path} \\\n'
    text += f'    --cache_dir={cache_dir} \\\n'
    text += f'    --model_name={model_name}\\\n'
    text += f'    --multi_subject="{multi_subject}" \\\n'
    text += f'    --subj={subj} \\\n'
    text += f'    --batch_size={batch_size} \\\n'
    text += f'    --max_lr={max_lr} \\\n'
    text += f'    --mixup_pct={mixup_pct} \\\n'
    text += f'    --num_epochs={num_epochs} \\\n'
    text += f'    {"--use_prior" if use_prior else "--no-use_prior"} \\\n'
    text += f'    --prior_scale={prior_scale} \\\n'
    text += f'    --clip_scale={clip_scale} \\\n'
    text += f'    {"--blurry_recon" if blurry_recon else "--no-blurry_recon"} \\\n'
    text += f'    --blur_scale={blur_scale} \\\n'
    text += f'    {"--use_image_aug" if use_image_aug else "--no-use_image_aug"} \\\n'
    text += f'    --n_blocks={n_blocks} \\\n'
    text += f'    --hidden_dim={hidden_dim} \\\n'
    text += f'    --num_sessions={num_sessions} \\\n'
    text += f'    --ckpt_interval={ckpt_interval} \\\n'
    text += f'    {"--ckpt_saving" if ckpt_saving else "--no-ckpt_saving"} \\\n'
    text += f'    {"--wandb_log" if wandb_log else "--no-wandb_log"} \\\n'
    text += f'    --num_heads={num_heads} \\\n'
    text += f'    --tome_r={tome_r} \\\n'
    text += f'    --last_n_features={last_n_features} \\\n'
    text += f'    --nat_depth={nat_depth} \\\n'
    text += f'    --nat_num_neighbors={nat_num_neighbors} \\\n'
    text += f'    --lr_scheduler_type={lr_scheduler_type} \\\n'
    text += f'    --seed={seed} \\\n'
    text += f'    {"--use_mixer" if use_mixer else "--no-use_mixer"} \\\n'
    text += f'    {"--new_test" if new_test else "--no-new_test"} \\\n'
    text += f'    --wandb_project={wandb_project} \\\n'
    text += f'    {"--full_attention" if full_attention else "--no-full_attention"} \\\n'
    text += '"\n'

    # Save the sbatch script to a file
    os.makedirs(output_dir_base, exist_ok=True)
    job_file = os.path.join(output_dir_base, f'{job_name}.sbatch')
    with open(job_file, 'w') as f:
        f.write(text)
    print(f'sbatch {job_file}')
    return text

def generate_ablation_jobs(base_params, param_ranges):
    jobs = []
    
    # Generate all combinations of parameter values
    param_names = list(param_ranges.keys())
    param_values = list(param_ranges.values())
    for values in itertools.product(*param_values):
        params = base_params.copy()
        for name, value in zip(param_names, values):
            params[name] = value
        
        # Generate a unique job name
        model_name = f"{params['wandb_project']}_{'_'.join([f'{name}_{value}' for name, value in zip(param_names, values)])}"
        params['model_name'] = model_name+f"_{datetime.now().strftime('%Y%m%d')}"
        job_name = model_name
        # Generate the job script
        job_script = generate_sbatch_fmri(
            job_name=job_name,
            **params
        )
        
        jobs.append((job_name, job_script))
    
    return jobs


In [28]:

# Example usage:
base_params = {
    "script_name": "Train.py",
    "data_path": "/scratch/cl6707/Shared_Datasets/NSD_MindEye/Mindeye2",
    "cache_dir": "/scratch/cl6707/Shared_Datasets/NSD_MindEye/Mindeye2",
    "multi_subject": "1,2,3,4,5,6,7",
    "subj": 1,
    "batch_size": 64,
    "max_lr": 3e-4,
    "mixup_pct": 0.33,
    "num_epochs": 150,
    "use_prior": False,
    "prior_scale": 30,
    "clip_scale": 1,
    "blurry_recon": False,
    "blur_scale": 0.5,
    "use_image_aug": False,
    "n_blocks": 4,
    "hidden_dim": 512,
    "num_sessions": 40,
    "ckpt_interval": 3,
    "ckpt_saving": True,
    "wandb_log": True,
    "lr_scheduler_type": "cycle",
    "seed": 42,
    "use_mixer": False,
    "new_test": True,
    "wandb_project": "BRAIN_FAT",
    "full_attention": True,
}

param_ranges = {
    "num_heads": [8],
    "tome_r": [500,1000],
    "last_n_features": [16,32],
    "nat_depth": [3,4,5,6],
    "nat_num_neighbors": [16]
}

ablation_jobs = generate_ablation_jobs(base_params, param_ranges)

# Print or save the generated job scripts
# for job_name, job_script in ablation_jobs:
    # print(f"Generated job: {job_name}")
    # Uncomment the following lines to save each job script to a file
    # with open(f"{job_name}.sbatch", "w") as f:
    #     f.write(job_script)


sbatch /scratch/cl6707/Projects/fmri/Brain_Decoding/jobs/BRAIN_FAT_num_heads_8_tome_r_500_last_n_features_16_nat_depth_3_nat_num_neighbors_16_20241106.sbatch
sbatch /scratch/cl6707/Projects/fmri/Brain_Decoding/jobs/BRAIN_FAT_num_heads_8_tome_r_500_last_n_features_16_nat_depth_4_nat_num_neighbors_16_20241106.sbatch
sbatch /scratch/cl6707/Projects/fmri/Brain_Decoding/jobs/BRAIN_FAT_num_heads_8_tome_r_500_last_n_features_16_nat_depth_5_nat_num_neighbors_16_20241106.sbatch
sbatch /scratch/cl6707/Projects/fmri/Brain_Decoding/jobs/BRAIN_FAT_num_heads_8_tome_r_500_last_n_features_16_nat_depth_6_nat_num_neighbors_16_20241106.sbatch
sbatch /scratch/cl6707/Projects/fmri/Brain_Decoding/jobs/BRAIN_FAT_num_heads_8_tome_r_500_last_n_features_32_nat_depth_3_nat_num_neighbors_16_20241106.sbatch
sbatch /scratch/cl6707/Projects/fmri/Brain_Decoding/jobs/BRAIN_FAT_num_heads_8_tome_r_500_last_n_features_32_nat_depth_4_nat_num_neighbors_16_20241106.sbatch
sbatch /scratch/cl6707/Projects/fmri/Brain_Decoding/