# Concept-Based Model with Interventions

This notebook demonstrates how to:
1. Load and prepare data with concept annotations
2. Build a concept-based neural network with an encoder and predictor
3. Train the model on both concept and task predictions
4. Apply various intervention strategies to manipulate concept predictions

## 1. Imports

We import the necessary libraries:
- **PyTorch**: for neural network building blocks
- **sklearn**: for evaluation metrics
- **torch_concepts**: for concept annotations, layers, and intervention mechanisms

In [1]:
import torch
from sklearn.metrics import accuracy_score

from torch_concepts import Annotations, AxisAnnotation
from torch_concepts.data import ToyDataset
from torch_concepts.nn import (
    LinearZC,
    LinearCC,
    GroundTruthIntervention,
    UncertaintyInterventionPolicy, 
    intervention, 
    DoIntervention, 
    DistributionIntervention, 
    UniformPolicy, 
    RandomPolicy
)

## 2. Data Loading and Preparation

We load the XOR toy dataset and prepare the training data:
- **Features (x_train)**: input features for the model
- **Concepts (c_train)**: intermediate concept labels (duplicated to create 6 concepts)
- **Targets (y_train)**: task labels to predict
- **Names**: concept and task attribute names for annotations

In [2]:
# Hyperparameters
latent_dims = 10
n_epochs = 500
n_samples = 1000
concept_reg = 0.5

# Load toy XOR dataset
data = ToyDataset('xor', size=n_samples, random_state=42)
x_train = data.data
c_train = data.concept_labels
y_train = data.target_labels
concept_names = data.concept_attr_names
task_names = data.task_attr_names

# Duplicate concept labels to create 6 concepts (C1, C2, C3, C4, C5, C6)
c_train = torch.concat([c_train, c_train, c_train], dim=1)

# Get dimensions
n_features = x_train.shape[1]
n_concepts = c_train.shape[1]

print(f"Dataset loaded:")
print(f"  Features shape: {x_train.shape}")
print(f"  Concepts shape: {c_train.shape}")
print(f"  Targets shape: {y_train.shape}")
print(f"  Number of features: {n_features}")
print(f"  Number of concepts: {n_concepts}")

Dataset loaded:
  Features shape: torch.Size([1000, 2])
  Concepts shape: torch.Size([1000, 6])
  Targets shape: torch.Size([1000, 1])
  Number of features: 2
  Number of concepts: 6


## 3. Annotations Object

The `Annotations` object is a key component that provides semantic meaning to tensor dimensions:
- It maps axis indices to `AxisAnnotation` objects
- Each `AxisAnnotation` contains names (labels) for features along that axis
- This enables human-readable concept manipulation and intervention

Here we create:
- **c_annotations**: annotations for the 6 concepts (C1-C6)
- **y_annotations**: annotations for the task output

In [3]:
# Create annotations for concepts and targets
c_annotations = Annotations({1: AxisAnnotation(concept_names + ['C3', 'C4', 'C5', 'C6'])})
y_annotations = Annotations({1: AxisAnnotation(task_names)})

print(f"Concept annotations:")
print(f"  Shape: {c_annotations.shape}")
print(f"  Axis 1 names: {c_annotations[1].labels}")
print(f"\nTask annotations:")
print(f"  Shape: {y_annotations.shape}")
print(f"  Axis 1 names: {y_annotations[1].labels}")

Concept annotations:
  Shape: (-1, 6)
  Axis 1 names: ['C1', 'C2', 'C3', 'C4', 'C5', 'C6']

Task annotations:
  Shape: (-1, 1)
  Axis 1 names: ['xor']


## 4. Model Architecture

We build a concept bottleneck model with three components:

1. **Encoder**: A simple neural network that maps input features to a latent embedding
2. **Encoder Layer** (`LinearZC`): Maps the embedding to concept endogenous
3. **Task Predictor** (`LinearCC`): Maps concept endogenous to task predictions

The model is wrapped in a `ModuleDict` to enable easier intervention on specific layers.

In [4]:
# Build the encoder (features -> embedding)
encoder = torch.nn.Sequential(
    torch.nn.Linear(n_features, latent_dims),
    torch.nn.LeakyReLU(),
)

# Build the concept encoder (embedding -> concepts)
encoder_layer = LinearZC(
    in_features=latent_dims,
    out_features=c_annotations.shape[1]
)

# Build the task predictor (concepts -> task)
y_predictor = LinearCC(
    in_features_endogenous=c_annotations.shape[1],
    out_features=y_annotations.shape[1]
)

# Wrap all components in a ModuleDict for easier intervention
model = torch.nn.ModuleDict({
    "encoder": encoder,
    "encoder_layer": encoder_layer,
    "y_predictor": y_predictor,
})

print("Model architecture:")
print(model)
print(f"\nEncoder layer representation:")
print(f"  Input: embedding of size {latent_dims}")
print(f"  Output: concept endogenous of size {c_annotations.shape[1]}")
print(f"\nTask predictor representation:")
print(f"  Input: concept endogenous of size {c_annotations.shape[1]}")
print(f"  Output: task endogenous of size {y_annotations.shape[1]}")

Model architecture:
ModuleDict(
  (encoder): Sequential(
    (0): Linear(in_features=2, out_features=10, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
  )
  (encoder_layer): ProbEncoderFromEmb(
    (encoder): Sequential(
      (0): Linear(in_features=10, out_features=6, bias=True)
      (1): Unflatten(dim=-1, unflattened_size=(6,))
    )
  )
  (y_predictor): ProbPredictor(
    (predictor): Sequential(
      (0): Linear(in_features=6, out_features=1, bias=True)
      (1): Unflatten(dim=-1, unflattened_size=(1,))
    )
  )
)

Encoder layer representation:
  Input: embedding of size 10
  Output: concept logits of size 6

Task predictor representation:
  Input: concept logits of size 6
  Output: task logits of size 1


## 5. Training

We train the model with a combined loss:
- **Concept loss**: BCE loss between predicted and true concept labels
- **Task loss**: BCE loss between predicted and true task labels
- **Total loss**: `concept_loss + concept_reg * task_loss`

This encourages the model to learn meaningful concept representations while also solving the task.

In [None]:
# Setup training
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_fn = torch.nn.BCEWithLogitsLoss()
model.train()

# Training loop
for epoch in range(n_epochs):
    optimizer.zero_grad()

    # Forward pass
    emb = encoder(x_train)
    c_pred = encoder_layer(embedding=emb)
    y_pred = y_predictor(endogenous=c_pred)

    # Compute loss
    concept_loss = loss_fn(c_pred, c_train)
    task_loss = loss_fn(y_pred, y_train)
    loss = concept_loss + concept_reg * task_loss

    # Backward pass
    loss.backward()
    optimizer.step()

    # Log progress
    if epoch % 100 == 0:
        task_accuracy = accuracy_score(y_train, y_pred > 0.)
        concept_accuracy = accuracy_score(c_train, c_pred > 0.)
        print(f"Epoch {epoch}: Loss {loss.item():.2f} | Task Acc: {task_accuracy:.2f} | Concept Acc: {concept_accuracy:.2f}")

print("\nTraining complete!")

## 6. Baseline Predictions (No Intervention)

Let's first see what the model predicts without any interventions.

In [6]:
# Get baseline predictions
model.eval()
with torch.no_grad():
    emb = model["encoder"](x_train)
    c_pred = model["encoder_layer"](emb)
    y_pred = model["y_predictor"](c_pred)

print("Baseline concept predictions (first 5 samples):")
print(c_pred[:5])
print("\nBaseline task predictions (first 5 samples):")
print(y_pred[:5])

Baseline concept predictions (first 5 samples):
tensor([[ -4.8956,  20.1472,  -4.9395,  19.3860,  -1.9786,  21.0479],
        [  9.6034,   5.6144,   8.3762,   4.9804,   7.0057,   4.9587],
        [-13.6898, -16.0129, -15.2738, -16.3038, -12.4378, -17.0760],
        [-18.1545,  14.0004, -18.9113,  11.6973,  -9.7617,  14.2617],
        [  4.9382,  10.3747,   4.5033,  10.8236,   2.3549,  10.8078]])

Baseline task predictions (first 5 samples):
tensor([[ 0.6272],
        [ 0.0130],
        [-0.0849],
        [ 0.1556],
        [-0.3078]])


## 7. Interventions

Now we demonstrate different intervention strategies:

### What are Interventions?
Interventions allow us to manipulate the model's internal representations (concepts) during inference. This is useful for:
- Understanding model behavior
- Correcting mistakes
- Testing counterfactual scenarios

### Intervention Components:
1. **Policy**: Decides *which* concepts to intervene on (e.g., UniformPolicy, RandomPolicy, UncertaintyInterventionPolicy)
2. **Strategy**: Decides *how* to intervene (e.g., DoIntervention, GroundTruthIntervention, DistributionIntervention)
3. **Layer**: Specifies *where* in the model to apply the intervention
4. **Quantile**: Controls *how many* samples to intervene on

### 7.1. Uncertainty + Ground Truth Intervention

- **Policy**: UniformPolicy on concepts [C1, C4, C5, C6] + UncertaintyInterventionPolicy on task [xor]
- **Strategy**: GroundTruthIntervention (use true concept values) + DoIntervention (set to constant 100)
- This combination intervenes on uncertain predictions using ground truth for concepts and a constant for the task

In [None]:
int_policy_c = UniformPolicy(out_features=c_train.shape[1])
int_strategy_c = GroundTruthIntervention(model=encoder_layer, ground_truth=torch.logit(c_train, eps=1e-6))

print("Uncertainty + Ground Truth Intervention:")
with intervention(policies=int_policy_c,
                  strategies=int_strategy_c,
                  target_concepts=[0, 1]) as new_encoder_layer:
    emb = model["encoder"](x_train)
    c_pred = new_encoder_layer(embedding=emb)
    y_pred = model["y_predictor"](endogenous=c_pred)
    print("\nConcept predictions (first 5):")
    print(c_pred[:5])
    print("\nGround truth (first 5):")
    print(torch.logit(c_train, eps=1e-6)[:5])

### 7.2. Do Intervention + Uniform Policy

- **Policy**: UniformPolicy on concepts [C1, C2, C6]
- **Strategy**: DoIntervention with constant value -10
- This sets the selected concepts to a fixed value of -10 for a uniform subset of samples

In [None]:
int_policy_c = UniformPolicy(out_features=c_train.shape[1])
int_strategy_c = DoIntervention(model=model["encoder_layer"], constants=-10)

print("Do Intervention + Uniform Policy:")
with intervention(
    policies=int_policy_c,
    strategies=int_strategy_c,
    target_concepts=[1],
) as new_encoder_layer:
    emb = model["encoder"](x_train)
    c_pred = new_encoder_layer(embedding=emb)
    y_pred = model["y_predictor"](endogenous=c_pred)
    print("\nConcept predictions (first 5):")
    print(c_pred[:5, :2])

### 7.3. Do Intervention + Random Policy

- **Policy**: RandomPolicy on concepts [C1, C2, C6] with scale=100
- **Strategy**: DoIntervention with constant value -10
- This randomly selects samples to intervene on, setting their selected concepts to -10

In [None]:
int_policy_c = RandomPolicy(out_features=c_train.shape[1])
int_strategy_c = DoIntervention(model=encoder_layer, constants=-10)

print("Do Intervention + Random Policy:")
with intervention(
    policies=int_policy_c,
    strategies=int_strategy_c,
    target_concepts=[0, 1],
    quantiles=0.5
) as new_encoder_layer:
    emb = model["encoder"](x_train)
    c_pred = new_encoder_layer(embedding=emb)
    y_pred = model["y_predictor"](endogenous=c_pred)
    print("\nConcept predictions (first 5):")
    print(c_pred[:5, :2])

### 7.4. Distribution Intervention

- **Policy**: RandomPolicy (reusing from previous cell)
- **Strategy**: DistributionIntervention with Normal(0, 1)
- This samples from a normal distribution for the intervened concepts instead of using a fixed constant

In [10]:
int_strategy_c = DistributionIntervention(model=encoder_layer, dist=torch.distributions.Normal(loc=50, scale=1))

print("Distribution Intervention:")
with intervention(
    policies=int_policy_c,
    strategies=int_strategy_c,
    target_concepts=[1, 3],
    quantiles=.5
) as new_encoder_layer:
    emb = model["encoder"](x_train)
    c_pred = new_encoder_layer(embedding=emb)
    y_pred = model["y_predictor"](c_pred)
    print("\nConcept predictions (first 5):")
    print(c_pred[:5])

Distribution Intervention:

Concept predictions (first 5):
tensor([[ -4.8956,  20.1472,  -4.9395,  49.2009,  -1.9786,  21.0479],
        [  9.6034,  50.4893,   8.3762,   4.9804,   7.0057,   4.9587],
        [-13.6898, -16.0129, -15.2738,  49.5025, -12.4378, -17.0760],
        [-18.1545,  14.0004, -18.9113,  47.5268,  -9.7617,  14.2617],
        [  4.9382,  52.9688,   4.5033,  10.8236,   2.3549,  10.8078]],
       grad_fn=<SliceBackward0>)


## Summary

In this notebook, we:
1. Loaded a toy XOR dataset with concept annotations
2. Created semantic annotations for concepts and tasks
3. Built a concept bottleneck model with encoder and predictor layers
4. Trained the model with both concept and task supervision
5. Demonstrated various intervention strategies:
   - Ground truth interventions
   - Do interventions (constant values)
   - Distribution interventions (sampling from distributions)
   - Different policies (Uniform, Random, Uncertainty-based)

These interventions allow us to manipulate the model's concept representations and observe how they affect the final predictions, providing interpretability and control over the model's reasoning process.