In [2]:
def collect_training_data(model, dataloader):
    training_data = []
    
    # Register hooks to capture SAE activations
    activation_hooks = register_sae_hooks(model)
    
    for batch in dataloader:
        input_ids = batch['input_ids']
        
        # Forward pass through model
        with torch.no_grad():
            outputs = model(input_ids)
            
            # Get output logits
            logits = outputs.logits[:, -1, :]  # Last token prediction
            
            # Get SAE activations from hooks
            sae_activations = [hook.activations for hook in activation_hooks]
            
            # Store this batch's data
            batch_data = {
                'input_ids': input_ids,
                'sae_features': sae_activations,
                'logits': logits
            }
            
            training_data.append(batch_data)
            
            # Clear hook activations for next batch
            for hook in activation_hooks:
                hook.clear()
    
    return training_data

def train_circuit_gnn(gnn_model, training_data, l1_weight=0.01, epochs=5):
    optimizer = torch.optim.Adam(gnn_model.parameters(), lr=0.001)
    
    for epoch in range(epochs):
        epoch_loss = 0
        
        for batch in training_data:
            # Flatten all SAE features into a single feature list
            all_features = []
            for layer_features in batch['sae_features']:
                # Transpose to get [num_features, batch_size]
                features_t = layer_features.t()
                for feature_idx in range(features_t.size(0)):
                    all_features.append(features_t[feature_idx])
            
            # Stack into tensor [num_features, batch_size] 
            feature_tensor = torch.stack(all_features, dim=0)
            
            # Create fully connected directed graph between all features
            # (L1 regularization will prune this during training)
            edge_index = create_fully_connected_edges(len(all_features))
            
            # Forward pass through GNN
            predicted_logits = gnn_model(feature_tensor, edge_index)
            
            # Compute loss: KL divergence to match distribution + L1 regularization
            logit_loss = F.kl_div(
                F.log_softmax(predicted_logits, dim=-1),
                F.softmax(batch['logits'], dim=-1),
                reduction='batchmean'
            )
            
            # L1 regularization on attention weights
            l1_loss = 0
            for name, param in gnn_model.named_parameters():
                if 'att' in name:  # GAT attention weights
                    l1_loss += torch.sum(torch.abs(param))
            
            # Total loss
            loss = logit_loss + l1_weight * l1_loss
            
            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
        print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(training_data)}")

#### Example data for each batch:

```
{
    'input_ids': tensor[batch_size, seq_len],  # Input token IDs
    'sae_features': [
        # List of tensors, one per SAE layer
        tensor[batch_size, num_features_layer1],  # Activations from SAE layer 1
        tensor[batch_size, num_features_layer2],  # Activations from SAE layer 2
        # ...and so on for all SAE layers
    ],
    'logits': tensor[batch_size, vocab_size]  # Original model output logits
}
```