# Concept-Based Model with Interventions

This notebook demonstrates how to:
1. Load and prepare data with concept annotations
2. Build a concept-based neural network with an encoder and predictor
3. Train the model on both concept and task predictions
4. Apply various intervention strategies to manipulate concept predictions

## 1. Imports

We import the necessary libraries:
- **PyTorch**: for neural network building blocks
- **sklearn**: for evaluation metrics
- **torch_concepts**: for concept annotations, layers, and intervention mechanisms

In [1]:
import torch
from sklearn.metrics import accuracy_score

from torch_concepts import Annotations, AxisAnnotation
from torch_concepts.data import ToyDataset
from torch_concepts.nn import (
    ProbEncoderFromEmb, 
    ProbPredictor, 
    GroundTruthIntervention,
    UncertaintyInterventionPolicy, 
    intervention, 
    DoIntervention, 
    DistributionIntervention, 
    UniformPolicy, 
    RandomPolicy
)

## 2. Data Loading and Preparation

We load the XOR toy dataset and prepare the training data:
- **Features (x_train)**: input features for the model
- **Concepts (c_train)**: intermediate concept labels (duplicated to create 6 concepts)
- **Targets (y_train)**: task labels to predict
- **Names**: concept and task attribute names for annotations

In [None]:
# Hyperparameters
latent_dims = 10
n_epochs = 500
n_samples = 1000
concept_reg = 0.5

# Load toy XOR dataset
data = ToyDataset('xor', size=n_samples, random_state=42)
x_train = data.data
c_train = data.concept_labels
y_train = data.target_labels
concept_names = data.concept_attr_names
task_names = data.task_attr_names

# Duplicate concept labels to create 6 concepts (C1, C2, C3, C4, C5, C6)
c_train = torch.concat([c_train, c_train, c_train], dim=1)

# Get dimensions
n_features = x_train.shape[1]
n_concepts = c_train.shape[1]

print(f"Dataset loaded:")
print(f"  Features shape: {x_train.shape}")
print(f"  Concepts shape: {c_train.shape}")
print(f"  Targets shape: {y_train.shape}")
print(f"  Number of features: {n_features}")
print(f"  Number of concepts: {n_concepts}")

Dataset loaded:
  Features shape: torch.Size([1000, 2])
  Concepts shape: torch.Size([1000, 6])
  Targets shape: torch.Size([1000, 1])
  Number of features: 2
  Number of concepts: 6
  Number of classes: 1


## 3. Annotations Object

The `Annotations` object is a key component that provides semantic meaning to tensor dimensions:
- It maps axis indices to `AxisAnnotation` objects
- Each `AxisAnnotation` contains names (labels) for features along that axis
- This enables human-readable concept manipulation and intervention

Here we create:
- **c_annotations**: annotations for the 6 concepts (C1-C6)
- **y_annotations**: annotations for the task output

In [6]:
# Create annotations for concepts and targets
c_annotations = Annotations({1: AxisAnnotation(concept_names + ['C3', 'C4', 'C5', 'C6'])})
y_annotations = Annotations({1: AxisAnnotation(task_names)})

print(f"Concept annotations:")
print(f"  Shape: {c_annotations.shape}")
print(f"  Axis 1 names: {c_annotations[1].labels}")
print(f"\nTask annotations:")
print(f"  Shape: {y_annotations.shape}")
print(f"  Axis 1 names: {y_annotations[1].labels}")

Concept annotations:
  Shape: (-1, 6)
  Axis 1 names: ['C1', 'C2', 'C3', 'C4', 'C5', 'C6']

Task annotations:
  Shape: (-1, 1)
  Axis 1 names: ['xor']


## 4. Model Architecture

We build a concept bottleneck model with three components:

1. **Encoder**: A simple neural network that maps input features to a latent embedding
2. **Encoder Layer** (`ProbEncoderFromEmb`): Maps the embedding to concept logits
3. **Task Predictor** (`ProbPredictor`): Maps concept logits to task predictions

The model is wrapped in a `ModuleDict` to enable easier intervention on specific layers.

In [7]:
# Build the encoder (features -> embedding)
encoder = torch.nn.Sequential(
    torch.nn.Linear(n_features, latent_dims),
    torch.nn.LeakyReLU(),
)

# Build the concept encoder (embedding -> concepts)
encoder_layer = ProbEncoderFromEmb(
    in_features_embedding=latent_dims, 
    out_features=c_annotations.shape[1]
)

# Build the task predictor (concepts -> task)
y_predictor = ProbPredictor(
    in_features_logits=c_annotations.shape[1], 
    out_features=y_annotations.shape[1]
)

# Wrap all components in a ModuleDict for easier intervention
model = torch.nn.ModuleDict({
    "encoder": encoder,
    "encoder_layer": encoder_layer,
    "y_predictor": y_predictor,
})

print("Model architecture:")
print(model)
print(f"\nEncoder layer representation:")
print(f"  Input: embedding of size {latent_dims}")
print(f"  Output: concept logits of size {c_annotations.shape[1]}")
print(f"\nTask predictor representation:")
print(f"  Input: concept logits of size {c_annotations.shape[1]}")
print(f"  Output: task logits of size {y_annotations.shape[1]}")

Model architecture:
ModuleDict(
  (encoder): Sequential(
    (0): Linear(in_features=2, out_features=10, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
  )
  (encoder_layer): ProbEncoderFromEmb(
    (encoder): Sequential(
      (0): Linear(in_features=10, out_features=6, bias=True)
      (1): Unflatten(dim=-1, unflattened_size=(6,))
    )
  )
  (y_predictor): ProbPredictor(
    (predictor): Sequential(
      (0): Linear(in_features=6, out_features=1, bias=True)
      (1): Unflatten(dim=-1, unflattened_size=(1,))
    )
  )
)

Encoder layer representation:
  Input: embedding of size 10
  Output: concept logits of size 6

Task predictor representation:
  Input: concept logits of size 6
  Output: task logits of size 1


## 5. Training

We train the model with a combined loss:
- **Concept loss**: BCE loss between predicted and true concept labels
- **Task loss**: BCE loss between predicted and true task labels
- **Total loss**: `concept_loss + concept_reg * task_loss`

This encourages the model to learn meaningful concept representations while also solving the task.

In [8]:
# Setup training
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_fn = torch.nn.BCEWithLogitsLoss()
model.train()

# Training loop
for epoch in range(n_epochs):
    optimizer.zero_grad()

    # Forward pass
    emb = encoder(x_train)
    c_pred = encoder_layer(embedding=emb)
    y_pred = y_predictor(logits=c_pred)

    # Compute loss
    concept_loss = loss_fn(c_pred, c_train)
    task_loss = loss_fn(y_pred, y_train)
    loss = concept_loss + concept_reg * task_loss

    # Backward pass
    loss.backward()
    optimizer.step()

    # Log progress
    if epoch % 100 == 0:
        task_accuracy = accuracy_score(y_train, y_pred > 0.)
        concept_accuracy = accuracy_score(c_train, c_pred > 0.)
        print(f"Epoch {epoch}: Loss {loss.item():.2f} | Task Acc: {task_accuracy:.2f} | Concept Acc: {concept_accuracy:.2f}")

print("\nTraining complete!")

Epoch 0: Loss 1.05 | Task Acc: 0.49 | Concept Acc: 0.00
Epoch 100: Loss 0.53 | Task Acc: 0.57 | Concept Acc: 0.95
Epoch 200: Loss 0.43 | Task Acc: 0.33 | Concept Acc: 0.98
Epoch 300: Loss 0.41 | Task Acc: 0.32 | Concept Acc: 0.99
Epoch 400: Loss 0.39 | Task Acc: 0.47 | Concept Acc: 0.99

Training complete!


## 6. Baseline Predictions (No Intervention)

Let's first see what the model predicts without any interventions.

In [9]:
# Get baseline predictions
model.eval()
with torch.no_grad():
    emb = model["encoder"](x_train)
    c_pred = model["encoder_layer"](emb)
    y_pred = model["y_predictor"](c_pred)

print("Baseline concept predictions (first 5 samples):")
print(c_pred[:5])
print("\nBaseline task predictions (first 5 samples):")
print(y_pred[:5])

Baseline concept predictions (first 5 samples):
tensor([[ -4.4402,  17.9389,  -4.3347,  17.1459,  -4.6396,  18.0594],
        [  9.4854,   3.7959,   9.2281,   3.5878,   9.5774,   3.7644],
        [-11.7614, -10.9780, -11.4861, -10.4635, -11.7920, -11.0394],
        [-15.8845,  13.0674, -15.4057,  12.5706, -16.3195,  13.2554],
        [  4.3424,   8.2338,   4.2186,   7.8436,   4.3349,   8.2514]])

Baseline task predictions (first 5 samples):
tensor([[ 0.1080],
        [-0.0065],
        [ 0.0425],
        [ 0.1098],
        [-0.0042]])


## 7. Interventions

Now we demonstrate different intervention strategies:

### What are Interventions?
Interventions allow us to manipulate the model's internal representations (concepts) during inference. This is useful for:
- Understanding model behavior
- Correcting mistakes
- Testing counterfactual scenarios

### Intervention Components:
1. **Policy**: Decides *which* concepts to intervene on (e.g., UniformPolicy, RandomPolicy, UncertaintyInterventionPolicy)
2. **Strategy**: Decides *how* to intervene (e.g., DoIntervention, GroundTruthIntervention, DistributionIntervention)
3. **Layer**: Specifies *where* in the model to apply the intervention
4. **Quantile**: Controls *how many* samples to intervene on

### 7.1. Uncertainty + Ground Truth Intervention

- **Policy**: UniformPolicy on concepts [C1, C4, C5, C6] + UncertaintyInterventionPolicy on task [xor]
- **Strategy**: GroundTruthIntervention (use true concept values) + DoIntervention (set to constant 100)
- This combination intervenes on uncertain predictions using ground truth for concepts and a constant for the task

In [10]:
quantile = 0.8

int_policy_c = UniformPolicy(out_annotations=c_annotations, subset=["C1", "C4", "C5", "C6"])
int_strategy_c = GroundTruthIntervention(model=model, ground_truth=torch.logit(c_train, eps=1e-6))
int_policy_y = UncertaintyInterventionPolicy(out_annotations=y_annotations, subset=["xor"])
int_strategy_y = DoIntervention(model=model, constants=100)

print("Uncertainty + Ground Truth Intervention:")
with intervention(
    policies=[int_policy_c, int_policy_y],
    strategies=[int_strategy_c, int_strategy_y],
    on_layers=["encoder_layer.encoder", "y_predictor.predictor"],
    quantiles=[quantile, 1]
):
    emb = model["encoder"](x_train)
    c_pred = model["encoder_layer"](emb)
    y_pred = model["y_predictor"](c_pred)
    print("\nConcept predictions (first 5):")
    print(c_pred[:5])
    print("\nTask predictions (first 5):")
    print(y_pred[:5])

Uncertainty + Ground Truth Intervention:

Concept predictions (first 5):
tensor([[-13.8155,  17.9389,  -4.3347,  13.8023, -13.8155,  13.8023],
        [ 13.8023,   3.7959,   9.2281,  13.8023,  13.8023,  13.8023],
        [-13.8155, -10.9780, -11.4861, -13.8155, -13.8155, -13.8155],
        [-13.8155,  13.0674, -15.4057,  13.8023, -13.8155,  13.8023],
        [ 13.8023,   8.2338,   4.2186,  13.8023,  13.8023,  13.8023]],
       grad_fn=<SliceBackward0>)

Task predictions (first 5):
tensor([[100.],
        [100.],
        [100.],
        [100.],
        [100.]], grad_fn=<SliceBackward0>)


### 7.2. Do Intervention + Uniform Policy

- **Policy**: UniformPolicy on concepts [C1, C2, C6]
- **Strategy**: DoIntervention with constant value -10
- This sets the selected concepts to a fixed value of -10 for a uniform subset of samples

In [11]:
int_policy_c = UniformPolicy(out_annotations=c_annotations, subset=["C1", "C2", "C6"])
int_strategy_c = DoIntervention(model=model, constants=-10)

print("Do Intervention + Uniform Policy:")
with intervention(
    policies=[int_policy_c],
    strategies=[int_strategy_c],
    on_layers=["encoder_layer.encoder"],
    quantiles=[quantile]
):
    emb = model["encoder"](x_train)
    c_pred = model["encoder_layer"](emb)
    y_pred = model["y_predictor"](c_pred)
    print("\nConcept predictions (first 5):")
    print(c_pred[:5])

Do Intervention + Uniform Policy:

Concept predictions (first 5):
tensor([[-10.0000, -10.0000,  -4.3347,  17.1459,  -4.6396, -10.0000],
        [-10.0000, -10.0000,   9.2281,   3.5878,   9.5774, -10.0000],
        [-10.0000, -10.0000, -11.4861, -10.4635, -11.7920, -10.0000],
        [-10.0000, -10.0000, -15.4057,  12.5706, -16.3195, -10.0000],
        [-10.0000, -10.0000,   4.2186,   7.8436,   4.3349, -10.0000]],
       grad_fn=<SliceBackward0>)


### 7.3. Do Intervention + Random Policy

- **Policy**: RandomPolicy on concepts [C1, C2, C6] with scale=100
- **Strategy**: DoIntervention with constant value -10
- This randomly selects samples to intervene on, setting their selected concepts to -10

In [None]:
int_policy_c = RandomPolicy(out_annotations=c_annotations, scale=100, subset=["C1", "C2", "C6"])
int_strategy_c = DoIntervention(model=model, constants=-10)

print("Do Intervention + Random Policy:")
with intervention(
    policies=[int_policy_c],
    strategies=[int_strategy_c],
    on_layers=["encoder_layer.encoder"],
    quantiles=[quantile]
):
    emb = model["encoder"](x_train)
    c_pred = model["encoder_layer"](emb)
    y_pred = model["y_predictor"](c_pred)
    print("\nConcept predictions (first 5):")
    print(c_pred[:5])

### 7.4. Distribution Intervention

- **Policy**: RandomPolicy (reusing from previous cell)
- **Strategy**: DistributionIntervention with Normal(0, 1)
- This samples from a normal distribution for the intervened concepts instead of using a fixed constant

In [12]:
int_strategy_c = DistributionIntervention(
    model=model, 
    dist=torch.distributions.Normal(loc=0, scale=1)
)

print("Distribution Intervention:")
with intervention(
    policies=[int_policy_c],
    strategies=[int_strategy_c],
    on_layers=["encoder_layer.encoder"],
    quantiles=[quantile]
):
    emb = model["encoder"](x_train)
    c_pred = model["encoder_layer"](emb)
    y_pred = model["y_predictor"](c_pred)
    print("\nConcept predictions (first 5):")
    print(c_pred[:5])

Distribution Intervention:

Concept predictions (first 5):
tensor([[ -1.3485,   0.7330,  -4.3347,  17.1459,  -4.6396,   0.1784],
        [ -0.1086,   0.8196,   9.2281,   3.5878,   9.5774,  -1.8287],
        [ -0.8125,  -0.5722, -11.4861, -10.4635, -11.7920,  -0.9029],
        [  0.9016,   1.7261, -15.4057,  12.5706, -16.3195,  -0.9566],
        [  2.4360,  -1.2420,   4.2186,   7.8436,   4.3349,   2.6420]],
       grad_fn=<SliceBackward0>)


### 7.5. Single Intervention Example

Demonstrating a simple single intervention with full output.

In [13]:
print("Single Intervention (Distribution):")
with intervention(
    policies=[int_policy_c],
    strategies=[int_strategy_c],
    on_layers=["encoder_layer.encoder"],
    quantiles=[quantile]
):
    emb = model["encoder"](x_train)
    c_pred = model["encoder_layer"](emb)
    y_pred = model["y_predictor"](c_pred)
    print("\nConcept predictions (first 5):")
    print(c_pred[:5])
    print("\nTask predictions (first 5):")
    print(y_pred[:5])

Single Intervention (Distribution):

Concept predictions (first 5):
tensor([[ -1.2687,   0.9846,  -4.3347,  17.1459,  -4.6396,   1.3569],
        [  0.9104,  -0.4779,   9.2281,   3.5878,   9.5774,   0.1320],
        [  0.3222,  -0.4628, -11.4861, -10.4635, -11.7920,   0.9773],
        [ -1.2280,   0.2996, -15.4057,  12.5706, -16.3195,  -0.0471],
        [ -0.5348,  -0.2769,   4.2186,   7.8436,   4.3349,   0.9744]],
       grad_fn=<SliceBackward0>)

Task predictions (first 5):
tensor([[ 0.0100],
        [-0.1131],
        [ 0.0264],
        [-0.0171],
        [-0.0655]], grad_fn=<SliceBackward0>)


## Summary

In this notebook, we:
1. Loaded a toy XOR dataset with concept annotations
2. Created semantic annotations for concepts and tasks
3. Built a concept bottleneck model with encoder and predictor layers
4. Trained the model with both concept and task supervision
5. Demonstrated various intervention strategies:
   - Ground truth interventions
   - Do interventions (constant values)
   - Distribution interventions (sampling from distributions)
   - Different policies (Uniform, Random, Uncertainty-based)

These interventions allow us to manipulate the model's concept representations and observe how they affect the final predictions, providing interpretability and control over the model's reasoning process.