# RLCT Estimation of Sorting

This Jupyter Notebook aims to measure the Real Log Canonical Threshold (RLCT) for a small 3-layer transformer model (~280,000 parameters) trained to sort sequences of 20 digits consisting of the numbers 0-19. It uses both Stochastic Gradient Nose-Hoover Thermostat (SGNHT) and Stochastic Gradient Langevin Dynamics (SGLD) as sampling methods.

## Main Steps:

1. **Data Preparation**: Generate the dataset of numbers to sort.
2. **Model Training**: Train a transformer model using stochastic gradient descent.
3. **Model Evaluation**: Evaluate the model's performance on a test set.
4. **RLCT Estimation**: Use SGNHT and SGLD samplers to estimate RLCT.
5. **Plotting**: Visualize train and test losses, and RLCT estimates.

In [None]:
%pip install devinterp seaborn torchvision pickleshare transformer_lens pytest
!git clone https://github.com/ucla-vision/entropy-sgd.git
%cd entropy-sgd
from python.optim import EntropySGD  
%cd ..
!git clone https://github.com/deepmind/tracr
%cd tracr
%pip install .
%cd ..
!git clone https://github.com/shaffreybenjamin/sortinterp.git
%cd sortinterp
%pip install .
%cd ..

In [None]:
import jax
jax.config.update('jax_default_matmul_precision', 'float32')
from tracr.compiler import compiling
from tracr.compiler import lib
from tracr.rasp import rasp
from transformer_lens import HookedTransformerConfig, HookedTransformer

from sortinterp.utils import cfg_from_tracr, load_tracr_weights
import copy
import matplotlib.pyplot as plt

import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from python.optim import EntropySGD

from devinterp.optim.sgld import SGLD
from devinterp.optim.sgnht import SGNHT

PRIMARY, SECONDARY, TERTIARY, QUATERNARY = sns.color_palette("muted")[:4]
PRIMARY_LIGHT, SECONDARY_LIGHT, TERTIARY_LIGHT, QUATERNARY_LIGHT = sns.color_palette(
    "pastel"
)[:4]

input_size = 20 # Length of sequences
vocab_size = 20 # Vocabulary size

vocab = {*range(vocab_size)}
program = lib.make_sort(rasp.tokens, rasp.tokens, max_seq_len=input_size, min_key=0)

tracr_model = compiling.compile_rasp_to_model(
    program=program,
    vocab=vocab,
    max_seq_len=input_size,
    compiler_bos="bos",
    mlp_exactness=100)

cfg = cfg_from_tracr(tracr_model)

In [None]:
def remove_common_rows(tensor_a, tensor_b):
    """
    Remove rows from tensor_a that appear in tensor_b.

    :param tensor_a: The original tensor from which to remove rows
    :param tensor_b: The tensor containing rows to be removed from tensor_a
    :return: A tensor with the rows removed
    """
    # Expand tensor_a and tensor_b to be able to compare each row
    expanded_a = tensor_a.unsqueeze(1)  # Shape: (num_rows_a, 1, num_columns)
    expanded_b = tensor_b.unsqueeze(0)  # Shape: (1, num_rows_b, num_columns)

    # Compare each row of tensor_a with each row of tensor_b
    comparison = (expanded_a == expanded_b).all(dim=2)  # Shape: (num_rows_a, num_rows_b)

    # Create a mask for rows in tensor_a that do not match any row in tensor_b
    mask = ~comparison.any(dim=1)  # Shape: (num_rows_a)

    # Filter tensor_a using the mask
    filtered_tensor = tensor_a[mask]

    return filtered_tensor

# this dataset is intended as a control
def get_dataset_1(input_size, vocab_size):
    # generate random sequences of size input_size using numbers between 0 and vocab_size
    sequences = torch.randint(0, vocab_size, (4 * (input_size * (vocab_size - 2) + 1), input_size))
    
    # split into train and test
    split = int(0.333 * len(sequences))
    train_sequences = sequences[ : split]
    test_sequences = sequences[split : ]
    return train_sequences, test_sequences

# this dataset is intended to incentivise the model to learn a simpler algorithm than sorting, 
# namely putting the nonzero number at the end
def get_dataset_2(input_size, vocab_size):
    # construct all sequences that consist entirely of zeros except for one non-zero element 
    # which will be a number between 1 and vocab_size - 1
    sequences = torch.eye(input_size).unsqueeze(dim=0) * torch.arange(1, vocab_size - 1).reshape(-1, 1, 1)
    # include all zeros sequence
    train_sequences = torch.cat((torch.zeros(1, sequences.size(dim=1)), sequences.reshape(-1, input_size)), dim=0).long()
    
    # test sequences are sequences containing any of the digits from 0 to vocab_size
    test_sequences = torch.randint(0, vocab_size, (3 * (input_size * (vocab_size - 2) + 1), input_size))
    
    # ensure that we remove possible training elements
    test_sequences = remove_common_rows(test_sequences, train_sequences)
    
    # include a small amount of the `correct' signal in the training data
    # so that the model can still potentially learn the correct algorithm
    split = int(0.333 * len(train_sequences))
    train_sequences = torch.cat((train_sequences, test_sequences[ : split]), dim=0)
    test_sequences = test_sequences[split : ]
    return train_sequences, test_sequences

# this dataset is intendent to incentivise the model to learn a sorting algorithm specific to certain digits only
def get_dataset_3(input_size, vocab_size):
    # ensure that training sequences consist of primarily of sequences containing numbers from 0 to middle
    middle = vocab_size // 2
    train_sequences = torch.randint(0, middle, ((input_size * (vocab_size - 2) + 1), input_size))
    
    # test sequences consist of sequences containing numbers from middle to vocab_size -1
    test_sequences = torch.randint(middle, vocab_size, (3 * (input_size * (vocab_size - 2) + 1), input_size))
    
    # include a small amount of the `correct' signal in the training data
    # so that the model can still potentially learn the correct algorithm
    split = int(0.333 * len(train_sequences))
    train_sequences = torch.cat((train_sequences, test_sequences[ : split]), dim=0)
    test_sequences = test_sequences[split : ]
    return train_sequences, test_sequences
    

def get_data(input_size, vocab_size, dataset=1):
    if dataset == 0:
        train_sequences, test_sequences = get_dataset_1(input_size, vocab_size)
    elif dataset == 1:
        train_sequences, test_sequences = get_dataset_2(input_size, vocab_size)
    elif dataset == 2:
        train_sequences, test_sequences = get_dataset_3(input_size, vocab_size)
    else:
        print('enter a dataset number between 0 and 2')
        
    train_sequences_sorted = torch.sort(train_sequences, dim=1).values
    test_sequences_sorted = torch.sort(test_sequences, dim=1).values
    train_data = list(zip(train_sequences, train_sequences_sorted))
    test_data = list(zip(test_sequences,  test_sequences_sorted))
    return train_data, test_data

dataset = 2
train_data, test_data = get_data(input_size, vocab_size, dataset)
train_size = len(train_data)
test_size = len(test_data)
print(f"Train size: {train_size}")
print(f"Test size: {test_size}")

In [None]:
def accuracy_function(outputs, targets):
    return (outputs.argmax(1) == targets).float().mean()

def train_one_epoch(model, train_loader, optimizer, criterion, model_key):
    model.train()
    train_loss = 0
    for index, (data, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(data.to(DEVICE))
        outputs = outputs.permute(0, 2, 1)
        loss = criterion(outputs, targets.to(DEVICE))
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        
        if index == 0:
            accuracy = accuracy_function(outputs, targets.to(DEVICE))
            print(f'batch {index}, loss: {loss.item()}', f'accuracy: {accuracy.item()}')
        
    return train_loss / len(train_loader)


def evaluate(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for index, (data, targets) in enumerate(test_loader):
            #outputs, cache = model.run_with_cache(data.to(DEVICE))
            outputs = model(data.to(DEVICE))
            outputs = outputs.permute(0, 2, 1)
            loss = criterion(outputs, targets.to(DEVICE))
            test_loss += loss.item()
            
    return test_loss / len(test_loader)


In [None]:
# Constants
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 512
LR = 0.01
N_EPOCHS = 200

train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)
criterion = torch.nn.CrossEntropyLoss(reduction='mean')

In [None]:
def train_models(train_loader, test_loader, criterion, runs):
    train_losses = torch.zeros(runs, N_EPOCHS)
    test_losses = torch.zeros(runs, N_EPOCHS)
    models_saved = []
    for run in tqdm(range(runs)):
        model = HookedTransformer(cfg)
        optimizer = torch.optim.Adam(model.parameters(), lr=LR)
        for epoch in tqdm(range(N_EPOCHS)):
            train_loss = train_one_epoch(
                model, train_loader, optimizer, criterion, 'sgd'
            )
            test_loss = evaluate(model, test_loader, criterion)
            train_losses[run, epoch] = train_loss
            test_losses[run, epoch] = test_loss
            models_saved += [copy.deepcopy(model)]
            print(
                f"Epoch {epoch+1}, Model {'sgd'.upper()} Train Loss: {train_loss}, Test Loss: {test_loss}"
            )
        
    train_losses_final = train_losses.mean(dim=0)
    test_losses_final = test_losses.mean(dim=0)
    return train_losses_final, test_losses_final, models_saved

runs = 5
train_losses_final, test_losses_final, models_saved = train_models(train_loader, test_loader, criterion, runs)

In [None]:
from devinterp.slt import estimate_learning_coeff_with_summary


def estimate_rlcts(models, train_loader, criterion, data_length, device, num_draws):
    estimates = {"sgnht": [], "sgld": []}
    for idx, model in enumerate(tqdm(models)):
        for method, optimizer_kwargs in [
            ("sgnht", {"lr": 1e-7, "diffusion_factor": 0.01}),
            ("sgld", {"lr": 1e-5, "localization": 100.0}),
        ]:
            results = estimate_learning_coeff_with_summary(
                model,
                train_loader,
                criterion=criterion,
                optimizer_kwargs=optimizer_kwargs,
                sampling_method=SGNHT if method == "sgnht" else SGLD,
                num_chains=1,
                num_draws=num_draws,
                num_burnin_steps=0,
                num_steps_bw_draws=1,
                device=device,
            )
            estimate = results["llc/mean"]
            
            # take losses from last chain for plotting
            if idx == N_EPOCHS - 1:
                losses = results['loss/trace']
            estimates[method].append(estimate)
    return estimates, losses

def obtain_rlct_estimates(train_loader, models_saved, criterion, runs):
    data_length = len(train_loader.dataset)
    rlct_estimates = {"sgnht": torch.zeros(runs, N_EPOCHS), "sgld": torch.zeros(runs, N_EPOCHS)}
    num_draws = 400
    last_chain_losses = torch.zeros(runs, num_draws)

    for run in tqdm(range(runs)):
        rlct_estimate, losses = estimate_rlcts(
            models_saved[N_EPOCHS * run : N_EPOCHS * (run + 1)], train_loader, criterion, data_length, DEVICE, num_draws
        )
        #rlct_estimates["sgnht"][run] = torch.tensor(rlct_estimate["sgnht"])
        rlct_estimates["sgld"][run] = torch.tensor(rlct_estimate["sgld"])
        last_chain_losses[run] = torch.tensor(losses)

    rlct_estimates_final = {"sgnht": rlct_estimates["sgnht"].mean(dim=0), "sgld": rlct_estimates["sgld"].mean(dim=0)}
    return rlct_estimates_final, last_chain_losses.mean(dim=0)

rlct_estimates_final, last_chain_losses_final = obtain_rlct_estimates(train_loader, models_saved, criterion, runs)

In [None]:
def plot_losses(train_losses_final, test_losses_final, dataset):
    sns.set_style("whitegrid")
    

    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss", color=PRIMARY)
    ax1.plot(train_losses_final, label="Train Loss, sgd", color=PRIMARY)
    ax1.plot(test_losses_final, label="Test Loss, sgd", color=PRIMARY_LIGHT)
    ax1.tick_params(axis="y", labelcolor=PRIMARY)
    ax1.legend(loc="upper left")
    fig.tight_layout()
    plt.show()
    fig.savefig("losses_" + str(dataset) + "_" + str(N_EPOCHS) + "_epochs.png")

def plot_rclts(rlct_estimates_final, dataset):
    sns.set_style("whitegrid")

    fig, ax2 = plt.subplots(figsize=(10, 6))
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel(r"Local Learning Coefficient, $\hat \lambda$", color=SECONDARY)
    #ax2.plot(rlct_estimates_final["sgnht"], label="SGNHT, sgd", color=TERTIARY)
    ax2.plot(rlct_estimates_final["sgld"], label="SGLD, sgd", color=TERTIARY_LIGHT)
    ax2.tick_params(axis="y", labelcolor=SECONDARY)
    ax2.legend(loc="center right")

    fig.tight_layout()
    plt.show()
    fig.savefig("rclt_" + str(dataset) + "_" + str(N_EPOCHS) + "_epochs.png")

def plot_losses_chain(last_chain_losses_final, dataset):
    sns.set_style("whitegrid")
    

    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.set_xlabel("Draw")
    ax1.set_ylabel("Loss", color=PRIMARY)
    ax1.plot(last_chain_losses_final, label="Loss, sgd", color=PRIMARY)
    ax1.tick_params(axis="y", labelcolor=PRIMARY)
    ax1.legend(loc="upper left")
    fig.tight_layout()
    plt.show()
    fig.savefig("last_chain_losses_" + str(dataset) + "_" + str(N_EPOCHS) + "_epochs.png")

plot_losses(train_losses_final, test_losses_final, dataset)
plot_rclts(rlct_estimates_final, dataset)
plot_losses_chain(last_chain_losses_final, dataset)

In [None]:


def run_experiments(dataset=1):
    train_data, test_data = get_data(input_size, vocab_size, dataset)
    train_size = len(train_data)
    test_size = len(test_data)
    print(f"Train size: {train_size}")
    print(f"Test size: {test_size}")

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

    runs = 5
    criterion = torch.nn.CrossEntropyLoss(reduction='mean')
    train_losses_final, test_losses_final, models_saved = train_models(train_loader, test_loader, criterion, runs)
    rlct_estimates_final = obtain_rlct_estimates(train_loader, models_saved, criterion, runs)
    
    plot_losses(train_losses_final, test_losses_final, dataset)
    plot_rclts(rlct_estimates_final, dataset)

for num in range(3):
    run_experiments(dataset=num)