#
RLCT Estimation of Sorting

This Jupyter Notebook aims to measure the Real Log Canonical Threshold (RLCT) for a small 3-layer transformer model (~280,000 parameters) trained to sort sequences of 20 digits consisting of the numbers 0-19. It uses both Stochastic Gradient Nose-Hoover Thermostat (SGNHT) and Stochastic Gradient Langevin Dynamics (SGLD) as sampling methods.

## Main Steps:

1. **Data Preparation**: Generate the dataset of numbers to sort.
2. **Model Training**: Train a transformer model using stochastic gradient descent.
3. **Model Evaluation**: Evaluate the model's performance on a test set.
4. **RLCT Estimation**: Use SGNHT and SGLD samplers to estimate RLCT.
5. **Plotting**: Visualize train and test losses, and RLCT estimates.

In [1]:
%pip install devinterp seaborn torchvision pickleshare wandb plotly einops
!git clone https://github.com/ucla-vision/entropy-sgd.git
%cd entropy-sgd
from python.optim import EntropySGD
%cd ..

Defaulting to user installation because normal site-packages is not writeable
Collecting torch>=2.0.1
  Using cached torch-2.3.1-cp310-cp310-manylinux1_x86_64.whl (779.1 MB)
Collecting pandas>=1.5.3
  Using cached pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.0 MB)
Collecting scipy>=1.10.1
  Using cached scipy-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (41.1 MB)
Collecting numpy>=1.23.5
  Using cached numpy-2.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (19.3 MB)
Collecting matplotlib>=3.7.5
  Using cached matplotlib-3.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.3 MB)
Collecting torchvision
  Using cached torchvision-0.18.1-cp310-cp310-manylinux1_x86_64.whl (7.0 MB)
Installing collected packages: numpy, scipy, pandas, torch, matplotlib, torchvision
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of 

In [2]:
import numpy as np
import torch as t
import torch
import torch.nn as nn
import torch.optim as optim
import time
import torch.nn.functional as F
import einops
import random
import helpers
from transformers import *
from dataclasses import dataclass
import os
import copy
import wandb
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from python.optim import EntropySGD
from torch.utils.data import DataLoader

from devinterp.optim.sgld import SGLD
from devinterp.optim.sgnht import SGNHT

PRIMARY, SECONDARY, TERTIARY, QUATERNARY = sns.color_palette("muted")[:4]
PRIMARY_LIGHT, SECONDARY_LIGHT, TERTIARY_LIGHT, QUATERNARY_LIGHT = sns.color_palette(
    "pastel"
)[:4]


In [3]:
def accuracy_function(outputs, targets):
    return (outputs[ : , -1].argmax(1) == targets).float().mean()

def do_a_training_step(model, train_data, test_data, optimizer, scheduler, epoch: int):
        '''returns train_loss, test_loss'''
        model.train()
        train_loss = full_loss(config = config, model = model, data = train_data)
        #self.train_losses.append(train_loss.item())
        #self.test_losses.append(test_loss.item())
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        scheduler.step()
        model.eval()  # Set model to evaluation mode
        with torch.no_grad():  # Disable gradient calculation for test
          test_loss = full_loss(config=config, model=model, data=test_data)
        if epoch % 100 == 0:
            # TODO is this ok? this was np.log, and it was barking at me ; i think np.log was being interpreted as a logging module
            print(f'Epoch {epoch}, train loss {t.log(train_loss).item():.4f}, test loss {t.log(test_loss).item():.4f}')

        return train_loss.detach(), test_loss.detach()

def train_one_epoch(model, train_loader, optimizer, scheduler, criterion, model_key):

    model.train()
    train_loss = 0
    train_accuracy = 0
    for index, (data, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(data.to(DEVICE))
        loss = criterion(outputs, targets.to(DEVICE))
        train_loss += loss.item()
        train_accuracy += accuracy_function(outputs, targets.to(DEVICE))
        loss.backward()
        optimizer.step()
        scheduler.step()

    return train_loss / len(train_loader), train_accuracy / len(train_loader)


def evaluate(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    test_accuracy = 0
    with torch.no_grad():
        for index, (data, targets) in enumerate(test_loader):
            outputs = model(data.to(DEVICE))
            loss = criterion(outputs, targets.to(DEVICE))
            test_loss += loss.item()
            test_accuracy += accuracy_function(outputs, targets.to(DEVICE))

    return test_loss / len(test_loader), test_accuracy / len(test_loader)


In [4]:
# Constants
DEVICE = "cuda" if t.cuda.is_available() else "cpu"
BATCH_SIZE = 16384
LR = 1e-4
N_EPOCHS = 10000
config = Config()

def get_data(config : Config):
    num_to_generate = config.p
    pairs = [(i, j, num_to_generate) for i in range(num_to_generate) for j in range(num_to_generate)]
    random.seed(config.seed)
    random.shuffle(pairs)
    div = int(config.frac_train*len(pairs))
    labels = [config.fn(i, j) for i, j, _ in pairs]
    pairs = t.tensor(pairs).long()
    labels = t.tensor(labels).long()
    train_data = list(zip(pairs[:div], labels[:div]))
    test_data = list(zip(pairs[div:], labels[div:]))
    return train_data, test_data

train_data, test_data = get_data(config = config)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)
criterion = helpers.cross_entropy_high_precision
print(len(train_loader))
print(len(test_loader))
print(len(train_data))
print(len(test_data))
print(torch.cuda.is_available())

1
1
3830
8939
True


In [5]:
def train_models(train_data, test_data, runs):
    train_losses = torch.zeros(runs, N_EPOCHS)
    test_losses = torch.zeros(runs, N_EPOCHS)
    train_accuracies = torch.zeros(runs, N_EPOCHS)
    test_accuracies = torch.zeros(runs, N_EPOCHS)
    models_saved = []
    for run in tqdm(range(runs)):
        model = Transformer(config, use_cache=False)
        model.to(config.device)
        optimizer = optim.AdamW(model.parameters(), lr = config.lr, weight_decay=config.weight_decay, betas=(0.9, 0.98))
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(step/10, 1)) # TODO make this a config option
        for epoch in tqdm(range(N_EPOCHS)):
            train_loss, train_accuracy = train_one_epoch(
                model, train_loader, optimizer, scheduler, criterion, 'sgd'
            )
            test_loss, test_accuracy = evaluate(model, test_loader, criterion)
            train_losses[run, epoch] = train_loss
            test_losses[run, epoch] = test_loss
            train_accuracies[run, epoch] = train_accuracy
            test_accuracies[run, epoch] = test_accuracy
            models_saved += [copy.deepcopy(model)]
            if epoch % 100 == 0:
              print(
                  f"Epoch {epoch+1}, Model {'sgd'.upper()} Train Loss: {train_loss}, Test Loss: {test_loss}", '\n',
                  f"Epoch {epoch+1}, Model {'sgd'.upper()} Train Accuracy: {train_accuracy}, Test Accuracy: {test_accuracy}"
              )

    train_losses_final = train_losses.mean(dim=0)
    test_losses_final = test_losses.mean(dim=0)
    train_accuracies_final = train_accuracies.mean(dim=0)
    test_accuracies_final = test_accuracies.mean(dim=0)
    torch.cuda.empty_cache()

    return train_losses_final, test_losses_final, train_accuracies_final, test_accuracies_final, models_saved

torch.cuda.empty_cache()
runs = 1
train_losses_final, test_losses_final, train_accuracies_final, test_accuracies_final, models_saved = train_models(train_loader, test_loader, runs)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Epoch 1, Model SGD Train Loss: 4.767282485961914, Test Loss: 4.775090217590332 
 Epoch 1, Model SGD Train Accuracy: 0.007832897827029228, Test Accuracy: 0.007271507289260626
Epoch 101, Model SGD Train Loss: 2.3784656524658203, Test Loss: 7.692261695861816 
 Epoch 101, Model SGD Train Accuracy: 0.45091384649276733, Test Accuracy: 0.0111869340762496
Epoch 201, Model SGD Train Loss: 0.020632406696677208, Test Loss: 16.356964111328125 
 Epoch 201, Model SGD Train Accuracy: 1.0, Test Accuracy: 0.028414811939001083
Epoch 301, Model SGD Train Loss: 0.007399916648864746, Test Loss: 16.79149627685547 
 Epoch 301, Model SGD Train Accuracy: 1.0, Test Accuracy: 0.0314352847635746
Epoch 401, Model SGD Train Loss: 0.0024022452998906374, Test Loss: 17.727237701416016 
 Epoch 401, Model SGD Train Accuracy: 1.0, Test Accuracy: 0.034903235733509064
Epoch 501, Model SGD Train Loss: 0.0007990992162376642, Test Loss: 18.757173538208008 
 Epoch 501, Model SGD Train Accuracy: 1.0, Test Accuracy: 0.0394898764

In [None]:
from devinterp.slt import estimate_learning_coeff_with_summary

def estimate_rlcts(models, train_loader, criterion, data_length, device, num_draws):
    estimates = {"sgnht": [], "sgld": []}
    for idx, model in enumerate(tqdm(models)):
        for method, optimizer_kwargs in [
            #("sgnht", {"lr": 1e-7, "diffusion_factor": 0.01}),
            ("sgld", {"lr": 1e-5, "localization": 100.0, "noise_level": 1.0}),
        ]:
            results = estimate_learning_coeff_with_summary(
                model,
                train_loader,
                criterion=criterion,
                optimizer_kwargs=optimizer_kwargs,
                sampling_method=SGNHT if method == "sgnht" else SGLD,
                num_chains=1,
                num_draws=num_draws,
                num_burnin_steps=0,
                num_steps_bw_draws=1,
                device=device,
                seed=42
            )
            estimate = results["llc/mean"]

            # take losses from last chain for plotting
            if idx == N_EPOCHS - 1:
                losses = results['loss/trace']
            estimates[method].append(estimate)
    return estimates, losses

def obtain_rlct_estimates(train_loader, models_saved, criterion, runs):
    data_length = len(train_loader)
    rlct_estimates = {"sgnht": torch.zeros(runs, N_EPOCHS), "sgld": torch.zeros(runs, N_EPOCHS)}
    num_draws = 400
    last_chain_losses = torch.zeros(runs, num_draws)

    for run in tqdm(range(runs)):
        rlct_estimate, losses = estimate_rlcts(
            models_saved[N_EPOCHS * run : N_EPOCHS * (run + 1)], train_loader, criterion, data_length, DEVICE, num_draws
        )
        #rlct_estimates["sgnht"][run] = torch.tensor(rlct_estimate["sgnht"])
        rlct_estimates["sgld"][run] = torch.tensor(rlct_estimate["sgld"])
        last_chain_losses[run] = torch.tensor(losses)

    rlct_estimates_final = {"sgnht": rlct_estimates["sgnht"].mean(dim=0), "sgld": rlct_estimates["sgld"].mean(dim=0)}
    return rlct_estimates_final, last_chain_losses.mean(dim=0)

rlct_estimates_final, last_chain_losses_final = obtain_rlct_estimates(train_loader, models_saved, criterion, runs)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]



Chain 0:   0%|          | 0/400 [00:00<?, ?it/s][A[A

Chain 0:   3%|▎         | 12/400 [00:00<00:03, 114.13it/s][A[A

Chain 0:   6%|▋         | 25/400 [00:00<00:03, 118.69it/s][A[A

Chain 0:  10%|▉         | 38/400 [00:00<00:03, 120.17it/s][A[A

Chain 0:  13%|█▎        | 51/400 [00:00<00:02, 120.82it/s][A[A

Chain 0:  16%|█▌        | 64/400 [00:00<00:02, 121.24it/s][A[A

Chain 0:  19%|█▉        | 77/400 [00:00<00:02, 121.42it/s][A[A

Chain 0:  22%|██▎       | 90/400 [00:00<00:02, 121.48it/s][A[A

Chain 0:  26%|██▌       | 103/400 [00:00<00:02, 121.66it/s][A[A

Chain 0:  29%|██▉       | 116/400 [00:00<00:02, 121.74it/s][A[A

Chain 0:  32%|███▏      | 129/400 [00:01<00:02, 121.85it/s][A[A

Chain 0:  36%|███▌      | 142/400 [00:01<00:02, 121.97it/s][A[A

Chain 0:  39%|███▉      | 155/400 [00:01<00:02, 122.11it/s][A[A

Chain 0:  42%|████▏     | 168/400 [00:01<00:01, 122.19it/s][A[A

Chain 0:  45%|████▌     | 181/400 [00:01<00:01, 122.26it/s][A[A

Chain 0:  4

In [None]:
dataset = 0

def plot_losses(train_losses_final, test_losses_final, dataset):

    sns.set_style("whitegrid")

    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss", color=PRIMARY)
    plt.yscale('log')
    ax1.plot(train_losses_final, label="Train Loss, sgd", color=PRIMARY)
    ax1.plot(test_losses_final, label="Test Loss, sgd", color=PRIMARY_LIGHT)
    ax1.tick_params(axis="y", labelcolor=PRIMARY)
    ax1.legend(loc="upper left")
    fig.tight_layout()
    plt.show()
    fig.savefig("losses_" + str(dataset) + "_" + str(N_EPOCHS) + "_epochs.png")

def plot_accuracies(train_accuracies_final, test_accuracies_final, dataset):

    sns.set_style("whitegrid")

    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Accuracy", color=PRIMARY)
    plt.yscale('log')
    ax1.plot(train_accuracies_final, label="Train Accuracy, sgd", color=PRIMARY)
    ax1.plot(test_accuracies_final, label="Test Accuracy, sgd", color=PRIMARY_LIGHT)
    ax1.tick_params(axis="y", labelcolor=PRIMARY)
    ax1.legend(loc="upper left")
    fig.tight_layout()
    plt.show()
    fig.savefig("accuracies_" + str(dataset) + "_" + str(N_EPOCHS) + "_epochs.png")

def plot_rclts(rlct_estimates_final, dataset):

    sns.set_style("whitegrid")

    fig, ax2 = plt.subplots(figsize=(10, 6))
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel(r"Local Learning Coefficient, $\hat \lambda$", color=SECONDARY)
    #ax2.plot(rlct_estimates_final["sgnht"], label="SGNHT, sgd", color=TERTIARY)
    ax2.plot(rlct_estimates_final["sgld"], label="SGLD, sgd", color=TERTIARY_LIGHT)
    ax2.tick_params(axis="y", labelcolor=SECONDARY)
    ax2.legend(loc="center right")

    fig.tight_layout()
    plt.show()
    fig.savefig("rclt_" + str(dataset) + "_" + str(N_EPOCHS) + "_epochs.png")

def plot_losses_chain(last_chain_losses_final, dataset):
    sns.set_style("whitegrid")


    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.set_xlabel("Draw")
    ax1.set_ylabel("Loss", color=PRIMARY)
    ax1.plot(last_chain_losses_final, label="Loss, sgd", color=PRIMARY)
    ax1.tick_params(axis="y", labelcolor=PRIMARY)
    ax1.legend(loc="upper left")
    fig.tight_layout()
    plt.show()
    fig.savefig("last_chain_losses_" + str(dataset) + "_" + str(N_EPOCHS) + "_epochs.png")

plot_losses(train_losses_final, test_losses_final, dataset)
plot_accuracies(train_accuracies_final, test_accuracies_final, dataset)
plot_rclts(rlct_estimates_final, dataset)
plot_losses_chain(last_chain_losses_final, dataset)

In [5]:
world = Trainer(config=config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mshaffreybenjamin[0m ([33muva-shaffrey-benjamin[0m). Use [1m`wandb login --relogin`[0m to force relogin


training length =  3830
testing length =  8939


In [10]:
print(world.train_losses[ : 10])

[]
