# Bipartite Model for Concept Bottleneck

This notebook demonstrates how to:
1. Load and prepare data with rich concept annotations
2. Define concept and task metadata with distributions and cardinalities
3. Build a BipartiteModel that automatically constructs a PGM
4. Use Propagators to create encoder and predictor factors
5. Train the model with concept and task supervision
6. Apply interventions within the BipartiteModel framework

## 1. Imports

We import the necessary libraries:
- **PyTorch**: for neural network building blocks and distributions
- **sklearn**: for evaluation metrics
- **torch_concepts**: for Annotations, BipartiteModel, Propagators, and inference

In [None]:
import torch
from sklearn.metrics import accuracy_score
from torch.distributions import RelaxedOneHotCategorical, RelaxedBernoulli

from torch_concepts import Annotations, AxisAnnotation
from torch_concepts.data import ToyDataset
from torch_concepts.nn import (
    ProbEncoderFromEmb, 
    ProbPredictor, 
    RandomPolicy, 
    DoIntervention, 
    intervention, 
    DeterministicInference, 
    BipartiteModel, 
    Propagator
)

## 2. Data Loading and Preparation

We load the XOR toy dataset and prepare the training data:
- **Features (x_train)**: input features for the model
- **Concepts (c_train)**: intermediate concept labels (binary: c1, c2)
- **Targets (y_train)**: task labels (converted to one-hot encoding with 2 classes)
- **Names**: concept and task attribute names

In [None]:
# Hyperparameters
latent_dims = 10
n_epochs = 500
n_samples = 1000
concept_reg = 0.5

# Load toy XOR dataset
data = ToyDataset('xor', size=n_samples, random_state=42)
x_train = data.data
c_train = data.concept_labels
y_train = data.target_labels
concept_names_raw = data.concept_attr_names
task_names_raw = data.task_attr_names

# Convert y_train to one-hot encoding (2 classes)
y_train = torch.cat([y_train, 1 - y_train], dim=1)

# Define concept and task names for the model
concept_names = ('c1', 'c2')
task_names = ('xor',)

print(f"Dataset loaded:")
print(f"  Features shape: {x_train.shape}")
print(f"  Concepts shape: {c_train.shape}")
print(f"  Targets shape: {y_train.shape}")
print(f"  Concept names: {concept_names}")
print(f"  Task names: {task_names}")

## 3. Rich Annotations with Metadata

The **Annotations** object in the BipartiteModel framework supports rich metadata:
- **Cardinalities**: The number of classes/dimensions for each variable
- **Metadata**: Additional information for each variable including:
  - **distribution**: The probability distribution type
  - **type**: Variable type (e.g., 'binary', 'categorical')
  - **description**: Human-readable description

This metadata is used by the BipartiteModel to automatically:
- Create appropriate Variables
- Set up correct probability distributions
- Configure the PGM structure

In [None]:
# Define cardinalities (number of classes for each variable)
cardinalities = (1, 1, 2)  # c1: 1 (binary), c2: 1 (binary), xor: 2 (one-hot)

# Define metadata for each variable
metadata = {
    'c1': {
        'distribution': RelaxedBernoulli, 
        'type': 'binary', 
        'description': 'Concept 1'
    },
    'c2': {
        'distribution': RelaxedBernoulli, 
        'type': 'binary', 
        'description': 'Concept 2'
    },
    'xor': {
        'distribution': RelaxedOneHotCategorical, 
        'type': 'binary', 
        'description': 'XOR Task'
    },
}

# Create rich annotations
annotations = Annotations({
    1: AxisAnnotation(
        concept_names + task_names, 
        cardinalities=cardinalities, 
        metadata=metadata
    )
})

print("Annotations structure:")
print(f"  Variables: {concept_names + task_names}")
print(f"  Cardinalities: {cardinalities}")
print(f"\nMetadata:")
for name, meta in metadata.items():
    print(f"  {name}:")
    print(f"    Distribution: {meta['distribution'].__name__}")
    print(f"    Type: {meta['type']}")
    print(f"    Description: {meta['description']}")

## 4. BipartiteModel: High-Level Model Construction

The **BipartiteModel** is a high-level abstraction that:
- Automatically constructs a PGM from annotations
- Uses **Propagators** to create encoder and predictor factors
- Manages the bipartite structure: concepts → tasks
- Exposes the underlying PGM for inference and interventions

### Propagators:
- **Propagator(ProbEncoderFromEmb)**: Creates encoder factors for concepts
- **Propagator(ProbPredictor)**: Creates predictor factors for tasks

The BipartiteModel automatically:
1. Creates Variables from annotations
2. Builds Factors using Propagators
3. Constructs the PGM with proper dependencies

In [None]:
# Create the encoder (input features -> embedding)
encoder = torch.nn.Sequential(
    torch.nn.Linear(x_train.shape[1], latent_dims), 
    torch.nn.LeakyReLU()
)

# Create the BipartiteModel
concept_model = BipartiteModel(
    task_names=task_names,
    latent_dims=latent_dims,
    annotations=annotations,
    concept_propagator=Propagator(ProbEncoderFromEmb),
    task_propagator=Propagator(ProbPredictor)
)

print("BipartiteModel structure:")
print(f"  Task names: {task_names}")
print(f"  Latent dimensions: {latent_dims}")
print(f"  Concept propagator: {ProbEncoderFromEmb.__name__}")
print(f"  Task propagator: {ProbPredictor.__name__}")
print(f"\nUnderlying PGM:")
print(concept_model.pgm)
print(f"\nThe model automatically created:")
print(f"  - Variables for concepts and tasks")
print(f"  - Encoder factors (embedding → concepts)")
print(f"  - Predictor factors (concepts → tasks)")

## 5. Inference Engine

We use the **DeterministicInference** engine on the BipartiteModel's underlying PGM:
- **Evidence**: The embedding computed from input features
- **Query**: The concepts and tasks we want to infer

The BipartiteModel exposes its PGM via the `.pgm` attribute.

In [None]:
# Initialize the inference engine with the BipartiteModel's PGM
inference_engine = DeterministicInference(concept_model.pgm)

# Define the query (what we want to infer)
query_concepts = ["c1", "c2", "xor"]

print("Inference setup:")
print(f"  Engine: DeterministicInference")
print(f"  PGM source: concept_model.pgm")
print(f"  Query variables: {query_concepts}")
print(f"\nInference flow:")
print(f"  x_train → encoder → embedding → [c1, c2] → xor")

## 6. Complete Model Pipeline

We combine the encoder and BipartiteModel into a complete pipeline:
- **encoder**: Maps input features to latent embedding
- **concept_model**: BipartiteModel that maps embedding to concepts and tasks

This creates a Sequential model for easy training.

In [None]:
# Combine encoder and concept_model into a Sequential pipeline
model = torch.nn.Sequential(encoder, concept_model)

print("Complete model pipeline:")
print(model)
print(f"\nPipeline structure:")
print(f"  1. Encoder: {x_train.shape[1]} features → {latent_dims} dimensions")
print(f"  2. BipartiteModel: {latent_dims} dimensions → concepts & tasks")

## 7. Training

We train the complete model with a combined loss:
- **Concept loss**: BCE loss between predicted and true concept labels (c1, c2)
- **Task loss**: BCE loss between predicted and true task labels (xor)
- **Total loss**: `concept_loss + concept_reg * task_loss`

Training process:
1. Compute embedding from input features
2. Query the inference engine with the embedding as evidence
3. Split predictions into concepts and tasks
4. Compute losses and backpropagate

In [None]:
# Setup training
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_fn = torch.nn.BCEWithLogitsLoss()
model.train()

# Training loop
for epoch in range(n_epochs):
    optimizer.zero_grad()

    # Compute embedding
    emb = encoder(x_train)
    
    # Inference: query the PGM with embedding as evidence
    cy_pred = inference_engine.query(query_concepts, evidence={'embedding': emb})
    
    # Split predictions: first columns are concepts, remaining are task
    c_pred = cy_pred[:, :c_train.shape[1]]
    y_pred = cy_pred[:, c_train.shape[1]:]

    # Compute loss
    concept_loss = loss_fn(c_pred, c_train)
    task_loss = loss_fn(y_pred, y_train)
    loss = concept_loss + concept_reg * task_loss

    # Backward pass
    loss.backward()
    optimizer.step()

    # Log progress
    if epoch % 100 == 0:
        task_accuracy = accuracy_score(y_train, y_pred > 0.)
        concept_accuracy = accuracy_score(c_train, c_pred > 0.)
        print(f"Epoch {epoch}: Loss {loss.item():.2f} | Task Acc: {task_accuracy:.2f} | Concept Acc: {concept_accuracy:.2f}")

print("\nTraining complete!")

## 8. Baseline Predictions (No Intervention)

Let's examine the model's predictions without any interventions.
The output contains concatenated predictions: [c1, c2, xor]

In [None]:
# Get baseline predictions
model.eval()
with torch.no_grad():
    emb = encoder(x_train)
    cy_pred = inference_engine.query(query_concepts, evidence={'embedding': emb})

print("Baseline predictions (first 5 samples):")
print("Format: [c1, c2, xor_class0, xor_class1]")
print(cy_pred[:5])
print(f"\nShape: {cy_pred.shape}")
print(f"  Columns 0-1: concept predictions (c1, c2)")
print(f"  Columns 2-3: task predictions (xor one-hot)")

## 9. Interventions in BipartiteModel

The BipartiteModel framework supports interventions on the underlying PGM:
- Access the PGM's factor modules via `concept_model.pgm.factor_modules`
- Apply interventions to specific factors (e.g., "c1.encoder")
- Effects propagate through the graph structure

### Intervention Setup:
- **Policy**: RandomPolicy to randomly select samples and intervene on concept c1
- **Strategy**: DoIntervention to set c1 to a constant value (-10)
- **Layer**: Intervene at the "c1.encoder" factor in the PGM
- **Quantile**: 1.0 (intervene on all selected samples)

In [None]:
# Compute embedding for intervention
emb = encoder(x_train)

# Create annotations for intervention
c_annotations = Annotations({1: AxisAnnotation(["c1"])})

# Define intervention policy and strategy
int_policy_c = RandomPolicy(
    out_annotations=c_annotations, 
    scale=100, 
    subset=["c1"]
)
int_strategy_c = DoIntervention(
    model=concept_model.pgm.factor_modules, 
    constants=-10
)

print("Intervention configuration:")
print(f"  Policy: RandomPolicy on concept 'c1'")
print(f"  Strategy: DoIntervention with constant value -10")
print(f"  Target layer: c1.encoder (in BipartiteModel's PGM)")
print(f"  Quantile: 1.0 (intervene on all selected samples)")
print(f"\nThis intervention will:")
print(f"  1. Randomly select samples")
print(f"  2. Set concept c1 to -10 for those samples")
print(f"  3. Propagate the effect through the BipartiteModel to xor prediction")

## 10. Applying the Intervention

Now we apply the intervention and observe how the predictions change.
Compare these results with the baseline predictions above to see the intervention's effect.

In [None]:
print("Predictions with intervention:")
with intervention(
    policies=[int_policy_c],
    strategies=[int_strategy_c],
    on_layers=["c1.encoder"],
    quantiles=[1]
):
    cy_pred_intervened = inference_engine.query(query_concepts, evidence={'embedding': emb})
    print("Format: [c1, c2, xor_class0, xor_class1]")
    print(cy_pred_intervened[:5])

print("\nNote: Compare with baseline predictions above.")
print("You should see c1 values changed to -10 for randomly selected samples,")
print("and corresponding changes in the xor predictions.")

## Summary

In this notebook, we explored the BipartiteModel framework for concept-based learning:

1. **Data**: Loaded the XOR toy dataset with binary concepts
2. **Rich Annotations**: Defined metadata including distributions, types, and descriptions
3. **BipartiteModel**: High-level abstraction that automatically builds a PGM
4. **Propagators**: Used to create encoder and predictor factors automatically
5. **Inference**: Queried the underlying PGM for predictions
6. **Training**: Trained with combined concept and task supervision
7. **Interventions**: Applied causal interventions via the PGM structure

### Key Advantages of BipartiteModel:
- **High-level abstraction**: Simplified PGM construction from annotations
- **Automatic structure**: Model builds Variables and Factors automatically
- **Rich metadata**: Support for distributions, cardinalities, and descriptions
- **Propagators**: Flexible way to specify encoder/predictor architectures
- **PGM access**: Full access to underlying PGM for advanced operations
- **Less boilerplate**: Reduces code needed compared to manual PGM construction

### Comparison with Other Approaches:
- **vs. Layer-based**: More structured, explicit graph representation
- **vs. Manual PGM**: Less code, automatic construction from metadata
- **Best for**: Production systems, complex models with many concepts/tasks

This framework is ideal for:
- Large-scale concept-based models with many variables
- Systems requiring rich metadata for interpretability
- Applications needing both ease-of-use and flexibility
- Production deployments with complex concept hierarchies