# 🎯 Sem-SPAI: Semantic-enhanced SPAI Training and Testing

This notebook contains the training and testing commands for the semantic-enhanced SPAI model based on our job configurations.

## 📁 Setup: Model Directories and Paths

In [None]:
# Model and data paths configuration
import os

# Base path
BASE_PATH = "/home/scur2605/spai"
CONFIG_PATH = os.path.join(BASE_PATH, "configs/spai.yaml")

# Training data paths
TRAIN_DATA = {
    "ldm_lsun_subset": os.path.join(BASE_PATH, "data/ldm_lsun_train_val_subset.csv"),
    "chameleon": os.path.join(BASE_PATH, "data/chameleon_dataset_split.csv"),
    "ldm_subset_10pct": os.path.join(BASE_PATH, "data/train/ldm_train_val_subset_10pct.csv")
}

# Model paths
MODEL_PATHS = {
    "trained_model": os.path.join(BASE_PATH, "output/LSUN_RESIDUAL_ORIGINAL/finetune/first_run/ckpt_epoch_6.pth"),
    "spai_pretrained": os.path.join(BASE_PATH, "weights/spai.pth"),
    "output_dir": os.path.join(BASE_PATH, "output/LSUN_RESIDUAL_ORIGINAL")
}

# Test datasets
TEST_DATASETS = {
    "dalle2": os.path.join(BASE_PATH, "data/test_set_dalle2.csv"),
    "dalle3": os.path.join(BASE_PATH, "data/test_set_dalle3.csv"),
    "gigagan": os.path.join(BASE_PATH, "data/test_set_gigagan.csv"),
    "sd1_4": os.path.join(BASE_PATH, "data/test_set_sd1_4.csv"),
    "sd3": os.path.join(BASE_PATH, "data/test_set_sd3.csv"),
    "sdxl": os.path.join(BASE_PATH, "data/test_set_sdxl.csv"),
    "flux": os.path.join(BASE_PATH, "data/test_set_flux.csv"),
    "midjourney": os.path.join(BASE_PATH, "data/test_set_midjourney-v6.1.csv")
}

print("🏗️ Model Paths:")
for key, path in MODEL_PATHS.items():
    print(f"  {key}: {path}")

print("\n📊 Training Datasets:")
for key, path in TRAIN_DATA.items():
    print(f"  {key}: {path}")

print("\n🧪 Test Datasets:")
for key, path in TEST_DATASETS.items():
    print(f"  {key}: {path}")

## 🎓 Training Commands

Based on `jobs/semantic/train.job` - Training the semantic-enhanced SPAI model with late fusion.

In [None]:
# Environment setup commands (run these first)
setup_commands = f"""
# Set environment variables
export PYTHONPATH=<your_path_here>
export NEPTUNE_API_TOKEN="<your_token_here>"
export NEPTUNE_PROJECT="<your_project_here>"

# Activate conda environment
conda activate spai_2

# Change to project directory
cd {BASE_PATH}
"""

print("🔧 Environment Setup:")
print(setup_commands)

In [None]:
# Main training command - Semantic-Enhanced SPAI with Late Fusion
print("🎓 Main Training Command:")
print("""
python -m spai train \\
--cfg "./configs/spai.yaml" \\
--batch-size 256 \\
--data-path "/home/scur2605/spai/data/ldm_lsun_train_val_subset.csv" \\
--csv-root-dir "/home/scur2605/spai/data/train" \\
--output "./output/LSUN_RESIDUAL_ORIGINAL" \\
--tag "first_run" \\
--data-workers 4 \\
--save-all \\
--amp-opt-level "O0" \\
--opt "TRAIN.EPOCHS" "10" \\
--opt "DATA.TEST_PREFETCH_FACTOR" "1" \\
--opt "DATA.VAL_BATCH_SIZE" "256" \\
--opt "MODEL.FEATURE_EXTRACTION_BATCH" "400" \\
--opt "PRINT_FREQ" "2" \\
--opt "MODEL.SEMANTIC_CONTEXT.SPAI_INPUT_SIZE" "[224, 224]"
""")

## 🧪 Testing Commands

Based on `jobs/semantic/test.job` - Comprehensive evaluation across multiple datasets.

In [None]:
# Testing commands for key datasets
MODEL_PATH = "/home/scur2605/spai/output/LSUN_RESIDUAL_ORIGINAL/finetune/first_run/ckpt_epoch_6.pth"

# Base test command template
base_test_cmd = """
python -m spai test \\
--cfg "./configs/spai.yaml" \\
--batch-size 10 \\
--model "{model_path}" \\
--output "./output/semantic_test" \\
--tag "spai" \\
--opt "MODEL.PATCH_VIT.MINIMUM_PATCHES" "4" \\
--opt "DATA.NUM_WORKERS" "8" \\
--opt "MODEL.FEATURE_EXTRACTION_BATCH" "400" \\
--opt "DATA.TEST_PREFETCH_FACTOR" "1" \\
--test-csv "{test_csv}" \\
--opt "PRINT_FREQ" "2" \\
--opt "MODEL.SEMANTIC_CONTEXT.HIDDEN_DIMS" "[512]" \\
--opt "MODEL.SEMANTIC_CONTEXT.SPAI_INPUT_SIZE" "[1024, 1024]"
"""

# Example: DALLE-2 Testing
print("📊 Testing on DALLE-2:")
print(base_test_cmd.format(
    model_path=MODEL_PATH,
    test_csv="/home/scur2605/spai/data/test_set_dalle2.csv"
))

print("\n" + "="*50)

In [None]:
# All test datasets commands
test_datasets = [
    ("DALLE-3", "test_set_dalle3.csv"),
    ("GigaGAN", "test_set_gigagan.csv"), 
    ("SD1.4", "test_set_sd1_4.csv"),
    ("SD3", "test_set_sd3.csv"),
    ("SDXL", "test_set_sdxl.csv"),
    ("Flux", "test_set_flux.csv"),
    ("Midjourney", "test_set_midjourney-v6.1.csv")
]

print("🧪 Testing Commands for All Datasets:")
print("=" * 60)

for dataset_name, csv_file in test_datasets:
    print(f"\n📊 Testing on {dataset_name}:")
    test_csv_path = f"/home/scur2605/spai/data/{csv_file}"
    print(base_test_cmd.format(
        model_path=MODEL_PATH,
        test_csv=test_csv_path
    ))
    print("-" * 40)

## 📊 Job Submission Commands

In [None]:
# SLURM job submission commands
print("📊 SLURM Job Submission Commands:")
print("""
# Submit training job (GPU H100, 12 hours, 180GB RAM)
sbatch jobs/semantic/train.job

# Submit testing job (GPU H100, 5 hours, 180GB RAM)  
sbatch jobs/semantic/test.job

# Check job status
squeue -u $USER

# Check job output
tail -f jobs/outputs/semantic/train_simple_*.out
tail -f jobs/outputs/semantic/test-pope*.out

# Cancel job if needed
scancel <JOB_ID>
""")

## ⚙️ Configuration Summary

In [None]:
# Configuration summary
print("⚙️ Model Configuration Summary:")
print("=" * 40)
print("""
Training Parameters:
  • Batch Size: 256
  • Epochs: 10
  • Data Workers: 4
  • AMP Level: O0 (no mixed precision)
  • SPAI Input Size: [224, 224]

Testing Parameters:
  • Batch Size: 10
  • Data Workers: 8
  • SPAI Input Size: [1024, 1024]
  • Hidden Dims: [512]
  • Min Patches: 4

Hardware Requirements:
  • GPU: H100
  • Memory: 180GB
  • CPUs: 16
  • Training Time: 12 hours
  • Testing Time: 5 hours

Key Features:
  • Late fusion architecture with residual connections
  • Frozen SPAI and ConvNeXt-XXL backbones
  • Semantic features projected to 256 dimensions
  • Multi-dataset evaluation capability
""")