# Performance statistics for the model and the SAEs

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from model import model_data, TransformerModel, generate_from_model, validation
from SAE import config_default, TransformerWithSAE

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

sizes of train, val, test = 1003862, 55778, 55778
vocab size = 66, unique chars:
['\n', ' ', '!', '"', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [2]:
model = TransformerModel()
model.to(DEVICE)
model.load_state_dict(torch.load("model.10000.pth"))

modelL3 = TransformerWithSAE()
modelL3.layer_for_SAE = 3
modelL3.to(DEVICE)
modelL3.load_state_dict(torch.load("SAE_L3.10000.pth"))
modelL3.update_scale_factor()

modelL6 = TransformerWithSAE()
modelL6.layer_for_SAE = 6
modelL6.to(DEVICE)
modelL6.load_state_dict(torch.load("SAE_L6.10000.pth"))
modelL6.update_scale_factor()

In [37]:
@torch.no_grad()
def test_performance(model, SAE = False):
    model.eval()
    batch_size = 1024
    out = ''
    # corss entropy loss
    x, y = model_data.draw(batch_size,'test')
    res = model(x, y)
    logits = res[0]
    loss = res[-1].item()
    out += f"loss={loss:.4f}"
    # accuracy
    probs = F.softmax(logits, dim=-1).view(-1,logits.shape[-1])
    y_pred = torch.multinomial(probs, num_samples=1)[:,0]
    acc = torch.mean(y.view(-1)==y_pred, dtype=float)
    out += f";    accuracy={acc:.4f}"
    # SAE L2 loss
    if SAE:
        model.lam = 0 # only output L2 loss
        features, loss = model(x, y, SAE_loss=True)
        loss = loss.item()
        out += f";    SAE relative L2 loss={loss/x.shape[-1]:.4f}" #has been normalized to ||x||^2 = n_model
        n_dead = torch.sum(torch.mean(features, dim=(0,1))>0, dtype=int)
        out += f";    {features.shape[-1]-n_dead} dead features"
    print(out)

In [38]:
torch.manual_seed(42)
test_performance(model)
test_performance(modelL3, SAE=True)
test_performance(modelL6, SAE=True)

loss=1.6420;    accuracy=0.4525
loss=1.7451;    accuracy=0.4301;    SAE relative L2 loss=0.0217;    1017 dead features
loss=1.7840;    accuracy=0.4320;    SAE relative L2 loss=0.0980;    2 dead features


In [40]:
#1.6420/1.7451
1.6420/1.7840

0.9204035874439461

In [28]:
torch.manual_seed(42)
_ = generate_from_model(model)

[37m your duty throughly, I advise you:
Imagine 'twere the right Vincentio.

BIONDELLO:
Tut, fear not me.

TRANIO:
But hast thou done thy errand to Baptista?

BIONDELLO:
I told him that your father was at Venice,
And that you look'd for him this day in Padua.
[0mI'll thee talk, am friends that stand that will the
He left, he waked his brother's pardon, is the sback.

JULIET:
This blood England,
And being madam.

YORK:
Or love boy:
O come, give me,
Those against overboke. Come, mulder will with the love;
If we do king up, I'll we done it you is it.

QUEEN ELIZABETH:
The do's a crown?
Give me impen tribunes, you leave you;
I have shaw you made in the come.

GLOUCESTER:
And more him Clifford; will I tell you to me.

GLOUCESTER:

LADY ANNE:
Why, call my lord.

GLOUCEST
