# Phase 2 Slack Experiment - GPU Accelerated

This notebook runs the Phase 2 Slack experiment on Kaggle's free GPU to analyze η/ε dynamics and Agent C behavior.

## Setup Instructions
1. **Enable GPU**: Settings → Accelerator → GPU T4 x2
2. **Run all cells** in order
3. **Wait ~30 minutes** for 100 epochs to complete

## What This Experiment Does
- Trains F⊣G (Functors) and Agent C **simultaneously**
- **No reconstruction loss** → preserves η (unit) as "slack"
- Only minimizes affordance prediction loss
- Observes η/ε dynamics and potential suspension structure emergence

In [None]:
# Check GPU availability
import torch
print("="*60)
print("GPU CHECK")
print("="*60)
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
else:
    print("⚠️  WARNING: GPU not available, will use CPU (much slower)")
print()

In [None]:
# Install dependencies
print("="*60)
print("INSTALLING DEPENDENCIES")
print("="*60)
!pip install torch-geometric torch-scatter torch-sparse -q
print("✓ PyTorch Geometric installed")
print()

In [None]:
# Clone repository
print("="*60)
print("CLONING REPOSITORY")
print("="*60)
!git clone https://github.com/type37c/adjunction-model.git
%cd adjunction-model
print("✓ Repository cloned")
print()

In [None]:
# Import necessary modules
import sys
sys.path.append('/kaggle/working/adjunction-model')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch_geometric.data import Data, Batch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
from pathlib import Path
from tqdm import tqdm

# Import our models and dataset
from models.conditional_adjunction_v4 import ConditionalAdjunctionV4
from data.synthetic_dataset import SyntheticAffordanceDataset

print("✓ All imports successful")

In [None]:
# Configuration
CONFIG = {
    'num_epochs': 100,
    'num_shapes': 100,
    'batch_size': 8,
    'num_points': 512,
    'num_affordances': 5,
    'learning_rate': 1e-4,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'output_dir': '/kaggle/working/results/phase2_slack'
}

print("="*60)
print("CONFIGURATION")
print("="*60)
for key, value in CONFIG.items():
    print(f"  {key}: {value}")
print()

In [None]:
# Create output directory
output_dir = Path(CONFIG['output_dir'])
output_dir.mkdir(parents=True, exist_ok=True)
print(f"✓ Output directory created: {output_dir}")

In [None]:
# Define collate function for graph batching
def collate_fn(batch):
    """Collate function to convert batch format to graph format."""
    points_list = []
    affordances_list = []
    
    for item in batch:
        # item['points']: (num_points, 3)
        # item['affordances']: (num_points, num_affordances)
        points_list.append(item['points'])
        affordances_list.append(item['affordances'])
    
    # Stack into batch format
    points_batch = torch.stack(points_list)  # (B, N, 3)
    affordances_batch = torch.stack(affordances_list)  # (B, N, A)
    
    # Convert to graph format
    B, N, _ = points_batch.shape
    pos = points_batch.reshape(B * N, 3)  # (B*N, 3)
    batch_indices = torch.arange(B).repeat_interleave(N)  # (B*N,)
    
    return {
        'points': pos,
        'batch': batch_indices,
        'affordances': affordances_batch  # Keep in batch format for loss computation
    }

print("✓ Collate function defined")

In [None]:
# Create dataset
print("="*60)
print("CREATING DATASET")
print("="*60)

dataset = SyntheticAffordanceDataset(
    num_shapes=CONFIG['num_shapes'],
    num_points=CONFIG['num_points'],
    num_affordances=CONFIG['num_affordances']
)

# Split into train/val
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create dataloaders with custom collate_fn
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    collate_fn=collate_fn
)

print(f"✓ Train size: {len(train_dataset)}")
print(f"✓ Val size: {len(val_dataset)}")
print(f"✓ Train batches: {len(train_loader)}")
print(f"✓ Val batches: {len(val_loader)}")
print()

In [None]:
# Create model
print("="*60)
print("CREATING MODEL")
print("="*60)

model = ConditionalAdjunctionV4(
    num_affordances=CONFIG['num_affordances'],
    f_hidden_dim=64,
    g_hidden_dim=128,
    agent_hidden_dim=256,
    agent_latent_dim=64,
    context_dim=128
).to(CONFIG['device'])

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"✓ Model created")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print()

In [None]:
# Create optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])
aff_criterion = nn.MSELoss()

# Loss weights
lambda_aff = 1.0
lambda_kl = 0.1
lambda_coherence = 0.1

print("✓ Optimizer and loss functions created")
print(f"  λ_aff: {lambda_aff}")
print(f"  λ_kl: {lambda_kl}")
print(f"  λ_coherence: {lambda_coherence}")
print()

In [None]:
# Training function
def train_epoch(model, dataloader, optimizer, device, epoch):
    model.train()
    
    total_loss = 0.0
    total_aff = 0.0
    total_kl = 0.0
    total_coherence = 0.0
    total_unit = 0.0
    total_counit = 0.0
    num_batches = 0
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
    
    for batch_data in pbar:
        pos = batch_data['points'].to(device)
        batch = batch_data['batch'].to(device)
        affordances_gt = batch_data['affordances'].to(device)
        
        batch_size = batch.max().item() + 1
        
        # Initialize agent state
        agent_state = model.agent_c.initial_state(batch_size, device)
        coherence_signal_prev = torch.zeros(batch_size, 1, device=device)
        
        # Forward pass
        results = model(pos, batch, agent_state, coherence_signal_prev)
        
        affordances_pred = results['affordances']
        coherence_signal = results['coherence_signal']
        counit_signal = results['counit_signal']
        rssm_info = results['rssm_info']
        
        # Compute losses
        # Convert affordances_gt to graph format
        B_gt, N_gt, A_gt = affordances_gt.shape
        affordances_gt_flat = affordances_gt.reshape(B_gt * N_gt, A_gt)
        
        L_aff = aff_criterion(affordances_pred, affordances_gt_flat)
        
        L_kl = model.agent_c.rssm.kl_divergence(
            rssm_info['posterior_mean'],
            rssm_info['posterior_std'],
            rssm_info['prior_mean'],
            rssm_info['prior_std']
        ).mean()
        
        L_coherence = -torch.log(coherence_signal + 1e-8).mean()
        
        # Total loss (NO RECONSTRUCTION LOSS)
        loss = lambda_aff * L_aff + lambda_kl * L_kl + lambda_coherence * L_coherence
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Accumulate metrics
        total_loss += loss.item()
        total_aff += L_aff.item()
        total_kl += L_kl.item()
        total_coherence += L_coherence.item()
        total_unit += coherence_signal.mean().item()
        total_counit += counit_signal.mean().item()
        num_batches += 1
        
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'η': f'{coherence_signal.mean().item():.4f}',
            'ε': f'{counit_signal.mean().item():.4f}'
        })
    
    return {
        'loss': total_loss / num_batches,
        'aff': total_aff / num_batches,
        'kl': total_kl / num_batches,
        'coherence': total_coherence / num_batches,
        'unit_eta': total_unit / num_batches,
        'counit_eps': total_counit / num_batches,
    }

print("✓ Training function defined")

In [None]:
# Validation function
def validate(model, dataloader, device):
    model.eval()
    
    total_aff = 0.0
    total_unit = 0.0
    total_counit = 0.0
    num_batches = 0
    
    with torch.no_grad():
        for batch_data in dataloader:
            pos = batch_data['points'].to(device)
            batch = batch_data['batch'].to(device)
            affordances_gt = batch_data['affordances'].to(device)
            
            batch_size = batch.max().item() + 1
            
            agent_state = model.agent_c.initial_state(batch_size, device)
            coherence_signal_prev = torch.zeros(batch_size, 1, device=device)
            
            results = model(pos, batch, agent_state, coherence_signal_prev)
            
            affordances_pred = results['affordances']
            coherence_signal = results['coherence_signal']
            counit_signal = results['counit_signal']
            
            B_gt, N_gt, A_gt = affordances_gt.shape
            affordances_gt_flat = affordances_gt.reshape(B_gt * N_gt, A_gt)
            
            L_aff = aff_criterion(affordances_pred, affordances_gt_flat)
            
            total_aff += L_aff.item()
            total_unit += coherence_signal.mean().item()
            total_counit += counit_signal.mean().item()
            num_batches += 1
    
    return {
        'aff': total_aff / num_batches,
        'unit_eta': total_unit / num_batches,
        'counit_eps': total_counit / num_batches,
    }

print("✓ Validation function defined")

In [None]:
# Training loop
print("="*60)
print("STARTING TRAINING")
print("="*60)
print()

history = {
    'train_loss': [],
    'train_aff': [],
    'train_kl': [],
    'train_coherence': [],
    'train_unit_eta': [],
    'train_counit_eps': [],
    'val_aff': [],
    'val_unit_eta': [],
    'val_counit_eps': [],
}

for epoch in range(1, CONFIG['num_epochs'] + 1):
    # Train
    train_metrics = train_epoch(model, train_loader, optimizer, CONFIG['device'], epoch)
    
    # Validate every 5 epochs
    if epoch % 5 == 0:
        val_metrics = validate(model, val_loader, CONFIG['device'])
        print(f"\nEpoch {epoch}/{CONFIG['num_epochs']}:")
        print(f"  Train - Loss: {train_metrics['loss']:.4f}, η: {train_metrics['unit_eta']:.4f}, ε: {train_metrics['counit_eps']:.4f}")
        print(f"  Val   - Aff: {val_metrics['aff']:.4f}, η: {val_metrics['unit_eta']:.4f}, ε: {val_metrics['counit_eps']:.4f}")
        print()
    
    # Save metrics
    history['train_loss'].append(train_metrics['loss'])
    history['train_aff'].append(train_metrics['aff'])
    history['train_kl'].append(train_metrics['kl'])
    history['train_coherence'].append(train_metrics['coherence'])
    history['train_unit_eta'].append(train_metrics['unit_eta'])
    history['train_counit_eps'].append(train_metrics['counit_eps'])
    
    if epoch % 5 == 0:
        history['val_aff'].append(val_metrics['aff'])
        history['val_unit_eta'].append(val_metrics['unit_eta'])
        history['val_counit_eps'].append(val_metrics['counit_eps'])

print("="*60)
print("TRAINING COMPLETED")
print("="*60)
print()

In [None]:
# Save metrics
with open(output_dir / 'metrics.json', 'w') as f:
    json.dump(history, f, indent=2)

print(f"✓ Metrics saved to {output_dir / 'metrics.json'}")

In [None]:
# Save model
torch.save(model.state_dict(), output_dir / 'model_final.pt')
print(f"✓ Model saved to {output_dir / 'model_final.pt'}")

In [None]:
# Visualize results
print("="*60)
print("CREATING VISUALIZATIONS")
print("="*60)

fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('Phase 2 Slack Experiment - Training Results', fontsize=16, fontweight='bold')

epochs = range(1, len(history['train_loss']) + 1)
val_epochs = range(5, CONFIG['num_epochs'] + 1, 5)

# Row 1: Training metrics
axes[0, 0].plot(epochs, history['train_aff'], color='blue', linewidth=2)
axes[0, 0].set_title('Affordance Loss (Training)', fontsize=12, fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].plot(epochs, history['train_unit_eta'], color='green', linewidth=2, label='η (unit)')
axes[0, 1].set_title('Unit η (Training)', fontsize=12, fontweight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('η value')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

axes[0, 2].plot(epochs, history['train_counit_eps'], color='red', linewidth=2, label='ε (counit)')
axes[0, 2].set_title('Counit ε (Training)', fontsize=12, fontweight='bold')
axes[0, 2].set_xlabel('Epoch')
axes[0, 2].set_ylabel('ε value')
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)

# Row 2: Validation metrics and summary
if len(history['val_aff']) > 0:
    axes[1, 0].plot(val_epochs, history['val_aff'], color='blue', linewidth=2, marker='o')
    axes[1, 0].set_title('Affordance Loss (Validation)', fontsize=12, fontweight='bold')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Loss')
    axes[1, 0].grid(True, alpha=0.3)
    
    axes[1, 1].plot(val_epochs, history['val_unit_eta'], color='green', linewidth=2, marker='o', label='η (unit)')
    axes[1, 1].set_title('Unit η (Validation)', fontsize=12, fontweight='bold')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('η value')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    axes[1, 2].plot(val_epochs, history['val_counit_eps'], color='red', linewidth=2, marker='o', label='ε (counit)')
    axes[1, 2].set_title('Counit ε (Validation)', fontsize=12, fontweight='bold')
    axes[1, 2].set_xlabel('Epoch')
    axes[1, 2].set_ylabel('ε value')
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(output_dir / 'training_results.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Visualization saved to {output_dir / 'training_results.png'}")

In [None]:
# Print final summary
print("="*60)
print("FINAL RESULTS")
print("="*60)
print()
print("Training (Final Epoch):")
print(f"  Affordance Loss: {history['train_aff'][-1]:.6f}")
print(f"  Unit η:          {history['train_unit_eta'][-1]:.6f}")
print(f"  Counit ε:        {history['train_counit_eps'][-1]:.6f}")
print(f"  KL Loss:         {history['train_kl'][-1]:.6f}")
print()
if len(history['val_aff']) > 0:
    print("Validation (Final):")
    print(f"  Affordance Loss: {history['val_aff'][-1]:.6f}")
    print(f"  Unit η:          {history['val_unit_eta'][-1]:.6f}")
    print(f"  Counit ε:        {history['val_counit_eps'][-1]:.6f}")
    print()
print("="*60)
print("Key Observations:")
print("="*60)
print(f"  η preserved:     {'✓ YES' if history['train_unit_eta'][-1] > 0.01 else '✗ NO (collapsed)'}")
print(f"  ε observable:    {'✓ YES' if history['train_counit_eps'][-1] > 0.01 else '✗ NO'}")
print(f"  Learning:        {'✓ YES' if history['train_aff'][0] > history['train_aff'][-1] else '✗ NO'}")
print()
print(f"All results saved to: {output_dir}")
print("="*60)

In [None]:
# List all output files
print("\nOutput Files:")
for f in sorted(output_dir.glob('*')):
    size = f.stat().st_size
    print(f"  {f.name:30s} ({size:,} bytes)")