# Concept Bottleneck Models

In the tale of the teenager and the deep learning model, we see the critical gap between how traditional deep learning (DL) models and humans, like the teenager, approach decisions. While both drivers stopped at the intersection, the teenager could explain how she arrived at that conclusion using concepts like the light color and the presence of the ambulance. In contrast, the DL model, despite making the correct decision, lacked a satisfying decision-making process for the driving evaluator, as the model's answers were based on raw pixel activations. This example highlights the need for developing DL models with a transparent decision-making process like the teenager, especially in high-stakes fields such as medicine (deciding whether to give a certain treatment), finance (deciding whether to approve a loan), and law (checking whether a hiring system is fair).

From a technical standpoint, the problem can be described as follows: we aim to model a relationship between a set of input variables $x_i \in \mathcal{X}$ (such as the image of the road) and a set of output variables corresponding to decisions $y_i \in \mathcal{Y}$ (whether to cross or stop). The DL model in the tale modeled the relationship as $p(y_i = \text{cross} \mid \mathcal{X} = \text{image’s pixels})$, directly mapping raw image data to a decision. The teenager, instead, modeled the same problem using **higher-level variables — referred to as “concepts”, like the light color and the ambulance — leading to a more human-interpretable decision-making process**. Her reasoning could be expressed as $p(y_i = \text{cross} \mid c_1 = \text{light color}, c_2 = \text{ambulance})$, where the concepts $c_i \in \mathcal{C}$ provided insight into **how** she made a particular decision.

One of the most common methods to mimic the teenager's approach is by factorizing the joint probability as:

$$p(\mathcal{X}, \mathcal{Y}) = p(\mathcal{Y} \mid \mathcal{C}) \cdot p(\mathcal{C} \mid \mathcal{X})$$

In this formulation, the DL model processes the input features $\mathcal{X}$ (such as pixels from an image) and maps them to a set of interpretable, high-level variables $\mathcal{C}$, known as “concepts.” These concepts are analogous to the reasoning elements identified by the teenager — such as the traffic light color or the presence of an ambulance. The second part of the model then uses these concepts for the downstream prediction $\mathcal{Y}$ (whether to cross or stop). This class of models is known as **Concept Bottleneck Models (CBMs)** {cite}`koh2020concept`. CBMs parametrize the conditional distributions with a pair of neural models:
- The concept encoder $g$ parametrizes the concept distribution (typically a set of independent Bernoulli distributions)
- The task predictor $f$ parametrizes the output distribution (typically Bernoulli or categorical)

The set of CBM’s parameters ($\theta[g]$ and $\theta_f$) are usually optimized via gradient descent. As a result, a CBM models the joint distribution $p(\mathcal{X}, \mathcal{Y})$ as follows:

$$P(\mathcal{X}, \mathcal{Y}) = P(\mathcal{Y} \mid \mathcal{C}; \theta_f) \cdot P(\mathcal{C} \mid \mathcal{X}; \theta_g)$$

The following coding practice introduces you to implementing CBMs and shows how to query CBMs to understand the model's decision-making process.

## Coding practice

In this practice, we implement a Concept Bottleneck Model (CBM) for a simple traffic light scenario where decisions to cross or stop depend on two concepts: the traffic light being green and the presence of an ambulance. The model predicts these concepts and uses them to make decisions. Let’s go through the key steps.

### Step #1: Model definition and training
We define a CBM where the model first predicts the concepts (traffic light, ambulance) and then uses those concepts to decide whether to cross. The model optimizes both concept and task predictions using gradient descent.


In [1]:
import torch
from torch_concepts.nn import ConceptEncoder
from torch_concepts.data import TrafficLights

emb_size = 5
n_epochs = 1000
n_samples = 1000

dataset = TrafficLights(n_samples=n_samples)
x_train, c_train, y_train, concept_names, task_names = dataset.x_train, dataset.c_train, dataset.y_train, dataset.concept_names, dataset.task_names

encoder = torch.nn.Sequential(torch.nn.Linear(x_train.shape[1], emb_size), torch.nn.LeakyReLU())
c_scorer = ConceptEncoder(in_features=emb_size, out_concept_dimensions={1: concept_names})
y_predictor = torch.nn.Sequential(torch.nn.Linear(c_train.shape[1], emb_size), torch.nn.LeakyReLU(), torch.nn.Linear(emb_size, y_train.shape[1]))
model = torch.nn.Sequential(encoder, c_scorer, y_predictor)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_form = torch.nn.BCELoss()

model.train()
for epoch in range(n_epochs):
   optimizer.zero_grad()
   emb = encoder(x_train)
   c_pred = c_scorer(emb).sigmoid()
   y_pred = y_predictor(c_pred).sigmoid()

   loss = loss_form(c_pred, c_train) + 0.5 * loss_form(y_pred, y_train)
   loss.backward()
   optimizer.step()

   if epoch % 100 == 0:
       print(f"Epoch {epoch}: Loss {loss.item():.2f}")

Epoch 0: Loss 1.00
Epoch 100: Loss 0.46
Epoch 200: Loss 0.30
Epoch 300: Loss 0.14
Epoch 400: Loss 0.08
Epoch 500: Loss 0.06
Epoch 600: Loss 0.05
Epoch 700: Loss 0.04
Epoch 800: Loss 0.03
Epoch 900: Loss 0.03


### Step #2: Observe concept-task relations
We observe how the model predicts whether the car should cross or not based on a sample with a green light and an ambulance.


In [2]:
model.eval()
from torch_concepts.base import ConceptTensor

# Test on a sample with green light and ambulance
c_test = ConceptTensor(torch.FloatTensor([[1, 1]]), {1: concept_names})
y_pred = y_predictor(c_test).sigmoid()

c_test = ConceptTensor(c_test > 0.5, {1: concept_names})
y_pred = ConceptTensor(y_pred > 0.5, {1: task_names})

print(f"Concepts: {c_test}")
print(f"Task: {y_pred}")

Concepts: ConceptTensor([[True, True]])
Task: ConceptTensor([[False]])


The model correctly identifies that even though the traffic light is green, the presence of an ambulance makes it decide not to cross (task prediction is False).

### Step #3: Average effect of a concept on the task
We analyze model behavior across the dataset, focusing on cases where the ambulance is or isn’t present.


In [3]:
# Analyze crossing probability with/without ambulance
c_train_no_ambulance = c_train[c_train[:, 1] == 0]
c_train_ambulance = c_train[c_train[:, 1] == 1]

y_pred_no_ambulance = y_predictor(c_train_no_ambulance).sigmoid().mean()
y_pred_ambulance = y_predictor(c_train_ambulance).sigmoid().mean()
c_train_no_ambulance_green = c_train_no_ambulance[c_train_no_ambulance[:, 0] == 1].mean()

print(f"Green light prob. (no ambulance): {c_train_no_ambulance_green:.4f}")
print(f"Crossing prob. (no ambulance): {y_pred_no_ambulance:.4f}")
print(f"Crossing prob. (ambulance): {y_pred_ambulance:.4f}")

Green light prob. (no ambulance): 0.5000
Crossing prob. (no ambulance): 0.2566
Crossing prob. (ambulance): 0.0001


When there is no ambulance, the model detects a green light in 50% of the cases, and the probability of crossing is around 25%. When an ambulance is present, the model never predicts crossing, as expected.

### Step #4: Effect of changing a concept’s value

In [4]:
# Modify single sample (green light, no ambulance)
c_test_green_no_ambulance = ConceptTensor(torch.FloatTensor([[1, 0]]), {1: concept_names})
y_pred = y_predictor(c_test_green_no_ambulance).sigmoid()

c_test_green_no_ambulance = ConceptTensor(c_test_green_no_ambulance > 0.5, {1: concept_names})
y_pred = ConceptTensor(y_pred > 0.5, {1: task_names})

print(f"Concepts: {c_test_green_no_ambulance}")
print(f"Task: {y_pred}")

Concepts: ConceptTensor([[ True, False]])
Task: ConceptTensor([[True]])


The model correctly predicts that in the presence of a green light and the absence of an ambulance, the car can cross (task prediction is True).

## Series overview
This blog series introduces key foundational works in concept-based deep learning, designed to be accessible even with just basic knowledge of probability theory and optimization. The content is self-contained, with optional preliminaries for experienced readers. We focus on 10 key topics that provide a broad overview of important advances, with practical relevance for applying these techniques in real-world scenarios.

Each topic includes concise theory, paired with pen-and-paper examples and hands-on coding exercises. All code examples use PyTorch Concepts (PyC), a library we developed to help researchers quickly implement existing interpretable techniques or develop novel methods.

In this series we cover the following chapters:
- Chapter 1: Introduction to Concept-based Deep Learning