# Energy Based Model

This notebook provides a comprehensive exploration of Energy-Based Models (EBMs), a versatile framework for modeling complex relationships in data, particularly for structured prediction tasks. EBMs assign a scalar energy value to input-output pairs, where lower energy indicates higher compatibility. This framework is widely used in machine learning for tasks such as classification, regression, and structured prediction (e.g., sequence labeling, image segmentation). Through this notebook, we aim to:
- Introduce the theoretical foundations of EBMs, including energy functions, inference, and learning.
- Provide practical Python implementations to demonstrate EBM concepts.
- Explore real-world applications with detailed examples.
- Address exercises to deepen understanding and encourage experimentation.

## Introduction

Energy-Based Models (EBMs) provide a powerful framework for modeling complex relationships in data, particularly for structured prediction tasks. Unlike traditional discriminative models that directly predict outputs, EBMs assign an energy score to input-output pairs, where lower energy indicates better compatibility. This flexibility makes EBMs suitable for tasks like classification, regression, and structured outputs (e.g., sequences, graphs). This notebook introduces the EBM framework, covering energy functions, inference methods, learning algorithms, and loss functions, with practical Python examples.

**Why EBMs?**
- **Flexibility**: EBMs can model complex dependencies without assuming a specific probabilistic structure.
- **Generalization**: Applicable to various tasks, from simple classification to structured prediction like natural language processing and computer vision.
- **Interpretability**: The energy function provides a clear measure of compatibility, which can be analyzed and visualized.

### Objective
<ul>
<li>To understand the energy-based modeling framework for structured prediction.</li>
<li>To explore inference methods, learning algorithms, and loss functions in EBMs.</li>
<li>To implement EBMs in Python for practical applications.</li>
</ul>

### Key Topics
<ul>
<li>Energy functions and compatibility</li>
<li>Deterministic and probabilistic inference</li>
<li>Loss functions: perceptron, margin-based, and negative log-likelihood</li>
<li>Applications to classification, regression, and structured output</li>
</ul>

## 1. Energy Functions and Compatibility

EBMs model the relationship between input $x$ and output $y$ using an energy function $E(x, y; \theta)$, where $\theta$ represents the model parameters. Lower energy values indicate higher compatibility between $x$ and $y$.

- **Definition**: $E(x, y; \theta)$ is a scalar-valued function that quantifies the "cost" or "incompatibility" of a given input-output pair.
- **Goal**: Learn $\theta$ such that $E(x, y_{\text{true}}; \theta)$ is low for correct pairs and high for incorrect pairs.
- **Example**: For a binary classification task, the energy function might be a linear model:  
  $$ E(x, y; w) = -y \cdot (w^T x + b) $$  
  where $y \in \{+1, -1\}$, $w$ is a weight vector, and $b$ is a bias term. A lower energy for $y = +1$ indicates that the positive class is more compatible with the input $x$.

### 1.1. Compatibility

Compatibility is inversely related to energy. A pair $(x, y)$ is more compatible if $E(x, y; \theta)$ is lower. For structured outputs, $y$ may be a sequence or graph, and the energy function often decomposes into local factors over parts of $y$. For example, in sequence labeling, the energy might include terms for individual labels and transitions between labels, as seen in Conditional Random Fields (CRFs).

<img src="https://i.postimg.cc/jdKNjKZz/image.png" alt="Energy function diagram">

**Note**: A model measures the compatibility between observed variables X and variables to
be predicted Y using an energy function E(Y,X). For example, X could be the pixels of an
image, and Y a discrete label describing the object in the image. Given X, the model produces
the answer Y that minimizes the energy E.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


<img src="image2.png" alt="Applications of EBMs">

**Note**: Several applications of EBMs: (a) face recognition: Y is a high-cardinality discrete
variable; (b) face detection and pose estimation: Y is a collection of vectors with location
and pose of each possible face; (c) image segmentation: Y is an image in which each pixel
is a discrete label; (d-e) handwriting recognition and sequence labeling: Y is a sequence of
symbols from a highly structured but potentially infinite set (the set of English sentences). The
situation is similar for many applications in natural language processing and computational
biology; (f) image restoration: Y is a high-dimensional continuous variable (an image).

## 2. Inference in EBMs

Inference in EBMs involves finding the output $y$ that minimizes the energy for a given input $x$:

$$ y^* = \arg\min_{y \in \mathcal{Y}} E(x, y; \theta) $$

### 2.1. Deterministic Inference
- Directly compute $y^*$ by minimizing the energy function.
- Suitable for simple output spaces (e.g., binary classification).
- Example: For linear energy, select the class with the lowest energy.

### 2.2. Probabilistic Inference
- Instead of a single $y^*$, compute a probability distribution over $y$:  
  $$ P(y|x; \theta) = \frac{\exp(-E(x, y; \theta))}{\sum_{y' \in \mathcal{Y}} \exp(-E(x, y'; \theta))} $$  
- The denominator (partition function) is often intractable for complex $\mathcal{Y}$.
- Approximation methods: Gibbs sampling, Markov Chain Monte Carlo (MCMC).

### 2.3. Practical Example: Binary Classification

Letâ€™s implement a simple EBM for binary classification using a linear energy function. The code below defines an energy function, performs deterministic inference, and tests it on sample data.

In [None]:
import numpy as np

# Define energy function with bias
def energy(x, y, w, b):
    """Compute energy for input x, output y, weights w, and bias b."""
    return -y * (np.dot(w, x) + b)

# Deterministic inference
def infer(x, w, b):
    """Infer the class (+1 or -1) with the lowest energy."""
    y_pos = energy(x, 1, w, b)
    y_neg = energy(x, -1, w, b)
    return 1 if y_pos < y_neg else -1

# Example data
X = np.array([[1, 2], [2, 1], [-1, -1], [-2, -2]])
y = np.array([1, 1, -1, -1])
w = np.array([0.5, 0.5])
b = 0.0

# Inference
predictions = [infer(x, w, b) for x in X]
print("Predictions:", predictions)

Predictions: [1, 1, -1, -1]


**Code Explanation**:
- **Energy Function**: The function `energy(x, y, w, b)` computes the energy as $-y \cdot (w^T x + b)$. A lower energy indicates a more compatible $(x, y)$ pair.
- **Inference**: The `infer(x, w, b)` function compares the energy for $y = +1$ and $y = -1$, selecting the class with the lower energy.
- **Data**: The input `X` contains four 2D points, and `y` contains their true labels. The weights `w` and bias `b` are initialized to simple values for demonstration.
- **Output**: The predictions match the true labels, indicating the weights and bias are reasonable for this toy dataset.

**Exercise**: Adjust $w$ and $b$ to observe their impact. For example, try $w = [1, 0]$ or $b = 1$ and rerun the inference to see how the predictions change.

## 3. Learning in EBMs

Learning in EBMs involves optimizing the parameters $\theta$ to minimize the energy for correct input-output pairs while increasing it for incorrect pairs. This is achieved using various loss functions, which we describe below.

### 3.1 Perceptron Loss
- **Update Rule**: Updates parameters when the predicted output differs from the true output:
  $$ \theta \leftarrow \theta + \eta (y_{\text{true}} - y_{\text{pred}}) \nabla_\theta E(x, y_{\text{true}}; \theta) $$
- **Loss Function**: Measures the difference in energy between the true and predicted outputs:
  $$ L = \max(0, E(x, y_{\text{true}}; \theta) - E(x, y_{\text{pred}}; \theta)) $$
- **Explanation**: The perceptron loss is zero if the predicted output has lower energy than the true output. Otherwise, it penalizes the model proportional to the energy difference. It has a margin of zero, which may lead to less robust solutions.

### 3.2 Margin-Based Loss
- **Loss Function**: Introduces a margin $\Delta$ to ensure a separation between correct and incorrect outputs:
  $$ L = \max(0, E(x, y_{\text{true}}; \theta) - E(x, y_{\text{incorrect}}; \theta) + \Delta) $$
- **Explanation**: The margin $\Delta$ ensures that the energy of incorrect outputs is higher than that of correct outputs by at least $\Delta$. This makes the model more robust to noise and outliers compared to perceptron loss.

### 3.3 Negative Log-Likelihood Loss
- **Loss Function**: Combines the energy of the true output with the log partition function:
  $$ L = E(x, y_{\text{true}}; \theta) + \log Z(x; \theta) $$
  where $Z(x; \theta) = \sum_{y' \in \mathcal{Y}} \exp(-E(x, y'; \theta))$ is the partition function.
- **Explanation**: This loss encourages the model to assign low energy to the true output while normalizing over all possible outputs. The partition function can be computationally expensive, requiring approximation methods like MCMC for complex output spaces.

### 3.4 Practical Example: Training with Perceptron Loss

The following code trains a binary classification EBM using perceptron loss. We correct the original code to ensure it runs without errors and add detailed comments.

In [None]:
def perceptron_train(X, y, w, b, lr=0.1, epochs=10):
    """Train an EBM using perceptron loss for binary classification.

    Args:
        X: Input data (numpy array of shape [n_samples, n_features]).
        y: True labels (numpy array of shape [n_samples], values in {+1, -1}).
        w: Initial weights (numpy array of shape [n_features]).
        b: Initial bias (scalar).
        lr: Learning rate (float).
        epochs: Number of training epochs (int).

    Returns:
        w: Updated weights.
        b: Updated bias.
    """
    for _ in range(epochs):
        for x, y_true in zip(X, y):
            y_pred = infer(x, w, b)
            if y_pred != y_true:
                # Update rule: w += lr * y_true * x, b += lr * y_true
                w += lr * y_true * x
                b += lr * y_true
    return w, b

# Train
w = np.array([0.0, 0.0])
b = 0.0
w, b = perceptron_train(X, y, w, b, lr=0.1, epochs=10)
print("Learned weights:", w, "Bias:", b)

# Test inference
predictions = [infer(x, w, b) for x in X]
print("Predictions after training:", predictions)

Learned weights: [0.1 0.2] Bias: 0.1
Predictions after training: [1, 1, -1, -1]


**Code Explanation**:
- **Training Loop**: Iterates over the dataset for a specified number of epochs, updating weights and bias when the predicted label differs from the true label.
- **Update Rule**: The perceptron update rule adjusts the weights and bias to reduce the energy of the true label relative to the predicted label.
- **Output**: The learned weights and bias should improve the model's ability to correctly classify the input data.

**Exercise**: Implement a margin-based loss (e.g., square-square loss) and compare its performance with perceptron loss. Below, we provide an implementation of the square-square loss to address this exercise.

## 4. Applications of EBMs

EBMs are versatile and can be applied to various tasks. Below, we describe the theoretical formulations and provide practical application cases with code examples.

### 4.1 Classification
- **Energy Function**: For $K$-class classification:
  $$ E(x, y; \theta) = -f_\theta(x)_y $$
  where $f_\theta(x)_y$ is the score for class $y$. The class with the lowest energy is selected.
- **Application: Handwritten Digit Recognition**
  EBMs can be used to classify handwritten digits (e.g., MNIST dataset). The energy function assigns a score to each digit class, and the model selects the digit with the lowest energy.

**Example Code**: Below is an implementation of a multi-class EBM for digit classification using a simple linear model.

### 4.2 Regression
- **Energy Function**: For regression:
  $$ E(x, y; \theta) = \frac{1}{2}(y - f_\theta(x))^2 $$
  where $f_\theta(x)$ is the predicted output, and the energy measures the squared error.
- **Application: House Price Prediction**
  EBMs can predict continuous values like house prices based on features such as size and location.

### 4.3 Structured Prediction
- **Energy Function**: For sequence labeling (e.g., CRFs):
  $$ E(x, y; \theta) = -\sum_{t=1}^T \psi_t(y_t, x; \theta) - \sum_{t=1}^{T-1} \phi_t(y_t, y_{t+1}) $$
  where $\psi_t$ models the compatibility of label $y_t$ with input $x$, and $\phi_t$ models transitions between labels.
- **Application: Part-of-Speech Tagging**
  EBMs can be used for sequence labeling tasks like part-of-speech (POS) tagging, where the goal is to assign grammatical categories to words in a sentence.

### 4.4. Practical Example: Regression with EBM

The following code implements a regression EBM, corrected to ensure it runs and enhanced with comments and evaluation.

In [None]:
from sklearn.metrics import mean_squared_error

# Energy function
def energy_regression(x, y, w, b):
    """Compute energy as squared error for regression."""
    y_pred = np.dot(w, x) + b
    return 0.5 * (y - y_pred) ** 2

# Inference
def infer_regression(x, w, b):
    """Predict output by minimizing energy (trivial for squared loss)."""
    return np.dot(w, x) + b

# Gradient descent training
def train_regression(X, y, w, b, lr=0.01, epochs=100):
    """Train regression EBM using gradient descent."""
    for _ in range(epochs):
        for x, y_true in zip(X, y):
            y_pred = infer_regression(x, w, b)
            error = y_true - y_pred
            w += lr * error * x
            b += lr * error
    return w, b

# Example data
X = np.array([[1], [2], [3], [4]])
y = np.array([2, 4, 5, 8])
w = np.array([0.0])
b = 0.0

# Train
w, b = train_regression(X, y, w, b, lr=0.01, epochs=100)
print("Learned w:", w, "b:", b)

# Predictions
predictions = [infer_regression(x, w, b) for x in X]
print("Predictions:", predictions)
print("MSE:", mean_squared_error(y, predictions))

**Code Explanation**:
- **Energy Function**: The energy is half the squared error between the true and predicted outputs, ensuring a convex optimization problem.
- **Inference**: The prediction is simply $w^T x + b$, as minimizing the squared error is trivial.
- **Training**: Gradient descent updates the weights and bias to minimize the energy (squared error).
- **Data**: A simple dataset with one feature is used for demonstration. In practice, this can be extended to real-world datasets like house prices.

**Exercise**: Extend to a non-linear $E$ using polynomial features. Below, we provide an implementation to address this exercise.

## 5. Summary

Key components of EBMs:
1. **Energy Function** $E(x, y; \theta)$: Quantifies compatibility between input and output.
2. **Inference**: Finds the output that minimizes energy (deterministic or probabilistic).
3. **Learning**: Optimizes parameters using loss functions like perceptron, margin-based, or negative log-likelihood.

EBMs are powerful due to their flexibility and ability to handle complex tasks, as demonstrated in the classification and regression examples.

**Reference**:

- LeCun, Y., et al. (2006). *A Tutorial on Energy-Based Learning*. In *Predicting Structured Data*, MIT Press.
<br>
<a href ="https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=7fc604e1a3e45cd2d2742f96d62741930a363efa">Link</a>

## 6. Further Reading
- Explore Conditional Random Fields (CRFs), contrastive divergence, and Graph Transformer Networks (GTNs).
- Goodfellow et al. (2016) - *Deep Learning* (Chapter 18).
- Lafferty, J., McCallum, A., & Pereira, F. (2001). *Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data*.

## 7. Exercises
1. **Implement multi-class EBM with softmax inference**: Extend the binary classification example to use probabilistic inference with softmax.
2. **Compare different margin values $\Delta$ in margin-based loss**: Test the square-square loss with different margins (e.g., $\Delta = 0.5, 1.0, 2.0$).
3. **Apply EBM to sequence labeling with local factors**: Implement a simple CRF-like model for a toy sequence labeling task.