# Training Probes using Contrast-Consistent Search (CCS)

This notebook demonstrates how to train probes using the CCS approach from Burns et al. (2023). CCS is a method for learning probes that can extract beliefs from language models without using labeled data.

The key idea is to train a probe that:
1. Makes informative predictions (assigns high probabilities to at least one of the contrast pairs)
2. Makes consistent predictions (assigns complementary probabilities to contrast pairs)

For example, if we have a statement "The sky is blue", we create two contrast pairs:
- "The sky is blue" (positive)
- "The sky is not blue" (negative)

The probe should assign high probability to one of these and low probability to the other.

In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
import itertools

In [None]:
class MLPProbe(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.linear1 = nn.Linear(d, 100)
        self.linear2 = nn.Linear(100, 1)

    def forward(self, x):
        h = F.relu(self.linear1(x))
        o = self.linear2(h)
        return torch.sigmoid(o)

class CCS(object):
    def __init__(self, x0, x1, nepochs=1000, ntries=10, lr=1e-3, batch_size=-1, 
                 verbose=False, device="cpu", linear=True, weight_decay=0.01, var_normalize=False):
        # data
        self.var_normalize = var_normalize
        self.x0 = self.normalize(x0)
        self.x1 = self.normalize(x1)
        self.d = self.x0.shape[-1]

        # training
        self.nepochs = nepochs
        self.ntries = ntries
        self.lr = lr
        self.verbose = verbose
        self.device = device
        self.batch_size = batch_size
        self.weight_decay = weight_decay
        
        # probe
        self.linear = linear
        self.initialize_probe()
        self.best_probe = copy.deepcopy(self.probe)

    def initialize_probe(self):
        if self.linear:
            self.probe = nn.Sequential(nn.Linear(self.d, 1), nn.Sigmoid())
        else:
            self.probe = MLPProbe(self.d)
        self.probe.to(self.device)    

    def normalize(self, x):
        """
        Mean-normalizes the data x (of shape (n, d))
        If self.var_normalize, also divides by the standard deviation
        """
        normalized_x = x - x.mean(axis=0, keepdims=True)
        if self.var_normalize:
            normalized_x /= normalized_x.std(axis=0, keepdims=True)
        return normalized_x

    def get_tensor_data(self):
        """
        Returns x0, x1 as appropriate tensors (rather than np arrays)
        """
        x0 = torch.tensor(self.x0, dtype=torch.float, requires_grad=False, device=self.device)
        x1 = torch.tensor(self.x1, dtype=torch.float, requires_grad=False, device=self.device)
        return x0, x1
    
    def get_loss(self, p0, p1):
        """
        Returns the CCS loss for two probabilities each of shape (n,1) or (n,)
        """
        informative_loss = (torch.min(p0, p1)**2).mean(0)
        consistent_loss = ((p0 - (1-p1))**2).mean(0)
        return informative_loss + consistent_loss

    def get_acc(self, x0_test, x1_test, y_test):
        """
        Computes accuracy for the current parameters on the given test inputs
        """
        x0 = torch.tensor(self.normalize(x0_test), dtype=torch.float, requires_grad=False, device=self.device)
        x1 = torch.tensor(self.normalize(x1_test), dtype=torch.float, requires_grad=False, device=self.device)
        with torch.no_grad():
            p0, p1 = self.best_probe(x0), self.best_probe(x1)
        avg_confidence = 0.5*(p0 + (1-p1))
        predictions = (avg_confidence.detach().cpu().numpy() < 0.5).astype(int)[:, 0]
        acc = (predictions == y_test).mean()
        acc = max(acc, 1 - acc)
        return acc
    
    def train(self):
        """
        Does a single training run of nepochs epochs
        """
        x0, x1 = self.get_tensor_data()
        permutation = torch.randperm(len(x0))
        x0, x1 = x0[permutation], x1[permutation]
        
        # set up optimizer
        optimizer = torch.optim.AdamW(self.probe.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        
        batch_size = len(x0) if self.batch_size == -1 else self.batch_size
        nbatches = len(x0) // batch_size

        # Start training (full batch)
        for epoch in range(self.nepochs):
            for j in range(nbatches):
                x0_batch = x0[j*batch_size:(j+1)*batch_size]
                x1_batch = x1[j*batch_size:(j+1)*batch_size]
            
                # probe
                p0, p1 = self.probe(x0_batch), self.probe(x1_batch)

                # get the corresponding loss
                loss = self.get_loss(p0, p1)

                # update the parameters
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        return loss.detach().cpu().item()
    
    def repeated_train(self):
        best_loss = np.inf
        for train_num in range(self.ntries):
            self.initialize_probe()
            loss = self.train()
            if loss < best_loss:
                self.best_probe = copy.deepcopy(self.probe)
                best_loss = loss
        return best_loss

In [None]:
def get_hidden_states_many_examples(model, tokenizer, data, dataset_name, model_type, params):
    """
    Given an encoder-decoder model, a list of data, computes the contrast hidden states on n random examples by probing the model according to the specified parameters (param = (layer_indices, token_positions, prompt_versions))
    Returns numpy arrays of shape (n, hidden_dim) for each candidate label, along with a boolean numpy array of shape (n,)
    with the ground truth labels
    
    This is deliberately simple so that it's easy to understand, rather than being optimized for efficiency
    """
    # setup
    model.eval()
    all_neg_hs, all_pos_hs, all_gt_labels = [], [], []
    layer, token_pos, prompt_version = params 

    # loop
    for sample in data:
        if dataset_name == "imdb":
            text, true_label = sample["text"], sample["label"]
            neg_hs = utils.get_hidden_states(model, tokenizer, format_imdb(text, 0, prompt_version), token_pos, layer, model_type=model_type)
            pos_hs = utils.get_hidden_states(model, tokenizer, format_imdb(text, 1, prompt_version), token_pos, layer, model_type=model_type)
        elif dataset_name == "google/boolq":
            text, question, true_label = sample["passage"], sample["question"], sample["answer"]
            neg_hs = utils.get_hidden_states(model, tokenizer, format_boolq(text, question, 0), token_pos, layer, model_type=model_type)
            pos_hs = utils.get_hidden_states(model, tokenizer, format_boolq(text, question, 1), token_pos, layer, model_type=model_type)
        elif dataset_name == "domenicrosati/TruthfulQA":
            question, best_answer, incorrect_answer = sample["question"], sample["best answer"], sample["incorrect answer"]
            neg_hs = utils.get_hidden_states(model, tokenizer, format_truthfulqa(question, best_answer, 0), token_pos, layer, model_type=model_type) 
            pos_hs = utils.get_hidden_states(model, tokenizer, format_truthfulqa(question, best_answer, 1), token_pos, layer, model_type=model_type)

            neg_hs_2 = utils.get_hidden_states(model, tokenizer, format_truthfulqa(question, incorrect_answer, 0), token_pos, layer, model_type=model_type)
            pos_hs_2 = utils.get_hidden_states(model, tokenizer, format_truthfulqa(question, incorrect_answer, 1), token_pos, layer, model_type=model_type)
            all_neg_hs.append(neg_hs_2)
            all_pos_hs.append(pos_hs_2)
            
        # collect
        all_neg_hs.append(neg_hs)
        all_pos_hs.append(pos_hs)
        all_gt_labels.append(true_label)

    all_neg_hs = np.stack(all_neg_hs)
    all_pos_hs = np.stack(all_pos_hs)
    all_gt_labels = np.stack(all_gt_labels)

    return all_neg_hs, all_pos_hs, all_gt_labels

In [None]:
# Example usage

# Hyperparameters
num_example = 100
layer_indices = [-1]  # Use last layer
token_positions = [-1]  # Use last token
prompt_versions = [1]

# Get hidden states for training and testing
neg_hs, pos_hs, y = get_hidden_states_many_examples(
    model, 
    tokenizer, 
    data_train["imdb"][:num_example], 
    "imdb", 
    model_type="encoder_decoder",
    params=(layer_indices[0], token_positions[0], prompt_versions[0])
)

# Split into train/test
n = len(y)
neg_hs_train, neg_hs_test = neg_hs[:n//2], neg_hs[n//2:]
pos_hs_train, pos_hs_test = pos_hs[:n//2], pos_hs[n//2:]
y_train, y_test = y[:n//2], y[n//2:]

# Verify with logistic regression first
x_train = neg_hs_train - pos_hs_train
x_test = neg_hs_test - pos_hs_test
lr = LogisticRegression(class_weight="balanced")
lr.fit(x_train, y_train)
print("Logistic regression accuracy: {}".format(lr.score(x_test, y_test)))

# Train CCS probe
ccs = CCS(
    neg_hs_train,
    pos_hs_train,
    nepochs=1000,
    ntries=10,
    lr=1e-3,
    verbose=True,
    device="cuda" if torch.cuda.is_available() else "cpu",
    linear=True,
    weight_decay=0.01,
    var_normalize=False
)

# Train and evaluate
best_loss = ccs.repeated_train()
ccs_acc = ccs.get_acc(neg_hs_test, pos_hs_test, y_test)

print(f"Best CCS loss: {best_loss:.4f}")
print(f"CCS accuracy: {ccs_acc:.4f}")

In [None]:
# Function to aggregate credences using geometric mean
def aggregate_gmean(credences):
    """Aggregate a list of credences into one estimate using geometric mean."""
    k = np.shape(credences)[0]
    result = np.power(np.prod(credences * (1 / (1 - credences))), 1 / k)
    return 1 / (1 + result)

# Get predictions and visualize
with torch.no_grad():
    pos_probs = ccs.best_probe(torch.tensor(pos_hs_test, dtype=torch.float32, device=ccs.device))
    neg_probs = ccs.best_probe(torch.tensor(neg_hs_test, dtype=torch.float32, device=ccs.device))

# Average confidence
avg_confidence = 0.5 * (pos_probs.cpu() + (1 - neg_probs.cpu()))

# Visualize distribution of confidences
plt.figure(figsize=(10, 6))
plt.hist(avg_confidence.numpy(), bins=20, alpha=0.7)
plt.axvline(x=0.5, color='r', linestyle='--', label='Decision Boundary')
plt.xlabel('Confidence')
plt.ylabel('Count')
plt.title('Distribution of CCS Probe Confidences')
plt.legend()
plt.show()