# Noise Contrastive Estimation

## Introduction
NCE is a method to estimate the parameters of a model by contrasting the model's predictions with those of a noise distribution. It is used in the context of language modeling, where the model is trained to predict the next word in a sequence of words. The noise distribution is a distribution over words that is used to generate negative examples. The model is trained to assign higher probabilities to the true next word than to the noise words.

Noise Contrastive Estimation (NCE) is a method designed to efficiently learn unnormalized probabilistic models, typically in scenarios where calculating the partition function is computationally expensive or intractable. The key idea behind NCE is to avoid explicitly computing the partition function, which normalizes the probability distribution. Instead, it turns the problem into a binary classification task that distinguishes between real data and noise samples.

To better understand NCE in relation to the partition function, let’s break down the concepts:

1. The Problem with the Partition Function

In most probabilistic models, the probability of an observation is given by:
￼
where is the unnormalized score (the model’s output for ￼) and ￼ is the partition function:
￼
or, in continuous cases:
￼
The partition function ￼ ensures that the total probability over all possible outcomes sums to 1. For models with complex or high-dimensional data (e.g., large vocabularies in NLP, complex image datasets), calculating can be computationally prohibitive because it requires summing over or integrating over all possible values of ￼, which is often not feasible. This is where NCE comes in.

2. The NCE Approach: Binary Classification

NCE addresses the challenge by transforming the learning problem into a classification task, rather than directly estimating the partition function. Instead of modeling the exact distribution ￼, NCE uses a noise distribution ￼, which is easy to sample from and provides a baseline for the classifier.

The main idea is to train a model to distinguish between real data samples from the true distribution ￼ and noise samples drawn from the noise distribution ￼. This avoids needing to compute the partition function ￼ directly.

3. Formulation of the Problem

Suppose we have a dataset where the true distribution is ￼ (e.g., the true distribution of words in a language model). The goal is to learn ￼, but instead of computing it directly, we use the noise distribution ￼, which is easier to handle. For simplicity, assume that ￼ is known, such as a uniform or Gaussian distribution.

Binary Classification Task

Now, instead of modeling ￼ directly, we model a binary classifier that predicts:
	•	1 (real) for samples from the true distribution ￼
	•	0 (fake) for samples from the noise distribution ￼

Thus, the model’s task becomes:
￼
where ￼ is the number of noise samples drawn for each real sample. The classifier now tries to distinguish between samples drawn from the true distribution ￼ and noise samples drawn from ￼.

Log-Likelihood Formulation

The likelihood for a sample ￼ being classified as “real” can be written as:
￼
where ￼ is a score function (the output of the model for input ￼) and ￼ is the sigmoid function.

The loss function that we optimize during training is the binary cross-entropy between the real data and the noise samples:
￼
Where:
	•	￼ is the expectation over the true data distribution ￼
	•	￼ is the expectation over the noise distribution ￼

The loss function aims to make the model assign higher probabilities to real data (samples from ￼) and lower probabilities to noise data (samples from ￼).

4. Learning Without the Partition Function

In this setup, the partition function ￼ is never explicitly computed. Instead, the model learns the log-ratio ￼ (the unnormalized probability), which is sufficient for learning the distribution ￼.
	•	Key observation: Since ￼ is only used up to a constant (i.e., ￼), learning the ratio ￼ through binary classification is enough to estimate ￼ without needing the partition function.

5. Why NCE Works

The success of NCE relies on the following:
	•	Noise distribution: By using an easily-sampled noise distribution ￼, we can efficiently train the model without needing to normalize over the entire space, which would be required if we directly modeled ￼.
	•	Binary classification: The model learns to distinguish between data and noise, which simplifies the problem into a supervised classification task with tractable loss and gradient updates.

6. Intuition and Efficiency

The beauty of NCE lies in its ability to avoid expensive operations (like summing over all possible outcomes for ￼) by:
	•	Replacing the intractable problem of computing ￼ with a simpler binary classification problem.
	•	Leveraging the noise distribution ￼, which is easy to handle and does not require the sum over all possible ￼ in the space.

Thus, NCE enables the training of models that would otherwise be difficult or impossible to train due to the intractable partition function.

Conclusion

In summary, NCE sidesteps the need for the partition function ￼ by reframing the problem as a binary classification task that distinguishes between real and noise samples. This approach allows for the efficient learning of unnormalized distributions, making it a valuable technique in scenarios where calculating the partition function is computationally prohibitive.



In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

# Define a simple neural network model for f(x) (un-normalized score)
class SimpleNNModel(nn.Module):
    def __init__(self):
        super(SimpleNNModel, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)  # Flatten MNIST images to vectors
        self.fc2 = nn.Linear(128, 1)        # Output a single score
    
    def forward(self, x):
        x = torch.flatten(x, start_dim=1)  # Flatten the image
        x = torch.relu(self.fc1(x))         # Apply ReLU activation
        x = self.fc2(x)                     # Output un-normalized score
        return x

# Define dataset to generate batches of real (MNIST) and fake (noise) data
class NCE_MNIST_Dataset(Dataset):
    def __init__(self, mnist_data, real_data=True):
        self.mnist_data = mnist_data
        self.real_data = real_data
    
    def __len__(self):
        return len(self.mnist_data)
    
    def __getitem__(self, idx):
        if self.real_data:
            # Sample a real MNIST image
            image, label = self.mnist_data[idx]
            image = image.unsqueeze(0)  # Add channel dimension (1)
            return image, torch.ones(1)  # Label 1 for real data
        else:
            # Generate random noise data as fake
            fake_image = torch.randn(1, 28, 28)  # Random noise (Gaussian)
            return fake_image, torch.zeros(1)  # Label 0 for fake data

# Initialize MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Define model, loss function, and optimizer
model = SimpleNNModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()

# Training loop
def train_nce_with_mnist(model, mnist_data, epochs=10, batch_size=64):
    for epoch in range(epochs):
        model.train()

        # Create DataLoader for real data and fake data
        real_data_loader = DataLoader(NCE_MNIST_Dataset(mnist_data, real_data=True), batch_size=batch_size, shuffle=True)
        noise_data_loader = DataLoader(NCE_MNIST_Dataset(mnist_data, real_data=False), batch_size=batch_size, shuffle=True)

        for (real_images, real_labels), (noise_images, noise_labels) in zip(real_data_loader, noise_data_loader):
            optimizer.zero_grad()

            # Forward pass for real and fake data
            real_scores = model(real_images)
            noise_scores = model(noise_images)

            # Compute loss (binary cross entropy)
            loss = criterion(real_scores.squeeze(), real_labels.squeeze()) + criterion(noise_scores.squeeze(), noise_labels.squeeze())

            # Backpropagate and update parameters
            loss.backward()
            optimizer.step()

        # Print progress
        if epoch % 1 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

# Train the model using NCE on MNIST
train_nce_with_mnist(model, mnist_data, epochs=10, batch_size=64)

AttributeError: partially initialized module 'torchvision' has no attribute 'extension' (most likely due to a circular import)