# SAE train

In [1]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from scipy import stats
import pandas as pd
import plotly.express as px
import lightning as L
import json
from pathlib import Path
from lightning.pytorch.callbacks import ModelCheckpoint
from datetime import timedelta

## Data and hooks

In [2]:
class IntermediateStateDataset(torch.utils.data.Dataset):
    def __init__(self, paths):
        from pathlib import Path

        self.path_names = list(sorted(paths))
        self.nps = [np.load(each, mmap_mode="r") for each in self.path_names]
        self.sizes = []
        count = 0
        for each in self.nps:
            self.sizes.append(count)
            count += each.shape[0]
        self.sizes = np.array(self.sizes)
        self.total_size = count

    def __len__(self):
        return self.total_size

    def _get(self, idx):
        bucket_idx = (idx >= self.sizes).sum() - 1
        remainder = idx - self.sizes[bucket_idx]
        return self.nps[bucket_idx][remainder]

    def __getitem__(self, idx):
        if isinstance(idx, int):
            if idx < 0:
                idx = len(self) + idx
            return self._get(idx)
        elif isinstance(idx, slice):
            start = idx.start or 0
            stop = idx.stop or len(self)
            step = idx.step or 1
            result = []
            for iidx in range(start, stop, step):
                result.append(self._get(iidx))
            return np.stack(result, axis=0)

In [3]:
paths = list(sorted(Path("/data/mech/data/layers/transformer.h.10").glob("*.npy")))
train_paths = paths[:-10]
test_paths = paths[-10:-3]
val_paths = paths[-3:]

train_dataset = IntermediateStateDataset(train_paths)
test_dataset = IntermediateStateDataset(test_paths)
val_dataset = IntermediateStateDataset(val_paths)

print(len(test_dataset), len(train_dataset), len(val_dataset))

6857125 128078235 2923962


In [4]:
ckpt_callback = ModelCheckpoint(train_time_interval=timedelta(minutes=30))

## Force SAE to be sparsed

### A dumb SAE

This SAE even uses Sigmoid function.

In [5]:
class SAEDumb(L.LightningModule):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(True)
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_size, input_size),
            nn.Sigmoid()
        )
        self.loss = nn.MSELoss()
        self.save_hyperparameters()

    def forward(self, x):
        hidden = self.encoder(x)
        x = self.decoder(hidden)
        return x

    def training_step(self, batch, batch_nb):
        output = self.forward(batch[0])
        loss = self.loss(batch[0], output)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-4, weight_decay=1e-5)

    def sanity_check(self, input_):
        with torch.no_grad():
            output = self(to_test)
        return 1 - torch.mean((saedumb_output - to_test) ** 2) / to_test.var()

    def act(self, input_):
        with torch.no_grad():
            act = self.encoder(input_)
        return input_

    def save(self, path):
        torch.save(self.state_dict(), path)

    def check(self, to_test):
        with torch.no_grad():
            sae_output = self(to_test)
            with torch.no_grad():
                act = self.encode(to_test)
            print("Reconstruction capability:", 1 - torch.mean((sae_output - to_test) ** 2) / to_test.var())
            print("Number of activated:", (act > 0).sum())
            print("Percentage of activated:", (act > 0).sum() / act.numel())

    def active_feature_statistics(self, dataloader):
        self.cuda()
        with torch.no_grad():
            total = torch.zeros(self.hparams.hidden_size).cuda()
            for batch in dataloader:
                act = self.encode(batch.cuda())
                total += (act > 0).sum(dim=0)
        self.cpu()
        total = total.detach().cpu().numpy().squeeze()
        print("Quantiles:", np.quantile(total, [0.01, 0.02, 0.05, 0.1, 0.5, 0.9, 0.95, 0.98, 0.99]))
        print("Mean:", np.mean(total))
        return total


#### Experiment

In [6]:
saedumb = SAEDumb(768, 3000)
trainer = L.Trainer(accelerator="gpu")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [5]:
saedumb = SAEDumb(768, 3000)
saedumb.load_state_dict(torch.load("/data/mech/data/ckpts/saedumb.pth"))

  saedumb.load_state_dict(torch.load("/data/mech/data/ckpts/saedumb.pth"))


<All keys matched successfully>

In [None]:
trainer.fit(saedumb, train_dataloaders=[torch.utils.data.DataLoader(train_dataset, batch_size=256)], val_dataloaders=[torch.utils.data.DataLoader(val_dataset, batch_size=513)])

In [27]:
torch.save(saedumb.state_dict(), "/data/mech/data/ckpts/saedumb.pth")

### Remove the final sigmoid activation

| xx | Total features | % Activated |
| -- | -------------- | ----------- |
| 32x | 24576 | 0.1105 |

Even though there are only 11% of the activation are activated. It seems that this number is still quite high, as there are around 2000 activated features. On the other hand, the work from Anthorpic only has around 14, 15 activated features. We should induce ways to enforce sparsity.

Also, the number of activated features seems to be relatively constant. Previously when I try a dumb experiment with 8888 features, around 3000 features are always activated. Should do more experiment on xx.

- [ ] Test with 1x, 2x, 4x, 8x, 16x, 32x, 64x, 128x, 256x
- [x] Add L1 norm to induce sparsity

In [6]:
class SAE(SAEDumb):
    def __init__(self, input_size, hidden_size):
        super().__init__(input_size, hidden_size)
        self.encoder = nn.Linear(in_features=input_size, out_features=hidden_size, bias=True)
        self.thresh = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
        self.decoder = nn.Linear(in_features=hidden_size, out_features=input_size, bias=True)
        self.save_hyperparameters()

    def encode(self, x):
        y = self.encoder(x)
        mask = (y > self.thresh)
        y = mask * nn.functional.relu(y)
        return y

    def decode(self, x):
        y = self.decoder(x)
        return y

    def forward(self, x):
        y = self.encode(x)
        y = self.decode(y)
        return y

    def act(self, input_):
        with torch.no_grad():
            act = self.encode(input_)
        return act

#### Experiment

In [8]:
sae = SAE(768, 24576)
trainer = L.Trainer(accelerator="gpu")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/john/miniconda3/envs/dawnet/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [None]:
trainer.fit(sae, train_dataloaders=[torch.utils.data.DataLoader(train_dataset, batch_size=256)], val_dataloaders=[torch.utils.data.DataLoader(val_dataset, batch_size=512)])

/home/john/miniconda3/envs/dawnet/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py:72: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
/home/john/miniconda3/envs/dawnet/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:68: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type    | Params | Mode 
-------------------------------------------------
0 | encoder      | Linear  | 18.9 M | train
1 | decoder      | Linear  | 18.9 M | train
2 | loss         | MSELoss | 0      | train
  | other params | n/a     | 24.6 K | n/a  
-------------------------------------------------
37.8 M    Trainable params
0         Non-trainable params
37.8 M    Total params
151.195   Total estimated model params size (MB)
3         Modules in train mode
0         Modules in eval mode
/home/john/miniconda3/envs/daw

Training: |                                       | 0/? [00:00<?, ?it/s]

  return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)


In [10]:
torch.save(sae.state_dict(), "/data/mech/data/ckpts/sae.pth")

In [27]:
to_test = torch.Tensor(test_dataset[23]).unsqueeze(0)
with torch.no_grad():
    # saedumb_output = saedumb(to_test)
    # print(1 - torch.mean((saedumb_output - to_test) ** 2) / to_test.var())
    sae_output = sae(to_test)
    with torch.no_grad():
        act = sae.encode(to_test)
    print("Reconstruction capability:", 1 - torch.mean((sae_output - to_test) ** 2) / to_test.var())
    print("Number of activated:", (act > 0).sum())
    print("Percentage of activated:", (act > 0).sum() / act.shape[1])

Reconstruction capability: tensor(0.9999)
Number of activated: tensor(2715)
Percentage of activated: tensor(0.1105)


### L1 norm

Failed experiment:
- lambda = 1e-3. The end result is the final activations are pushed toward 0.
- lambda = 1e-5. It's better with 688 activated features.

Todo:

- [x] Fix the norm implementation
- [x] Test lambda 1e-5. --> Does help a lot
- [x] Test lambda 1e-4.
- [x] Retry with lambda 1e-4. Because there are a lot of dead weights. We would want to confirm if those dead weights are universal -> Similar statistics.
- [ ] Retest with lambda 1e-5.
- [ ] Test with 64x.
- [ ] Revive dead weights for every 100000 inactive instances.

**Test lambda 1e-5**: Does reduce the number of activated features from 2000 to 600. There still a long way to go to reduce the number of features to 100.

#### Thought

- We need to keep track of inactive features. The reason might purely because of unlucky initialization. Approaches:
  - Reset the weights of unlucky features.
  - Use different activation features so that a feature has much lower chance of being dead.

A randomly-initialized weights show that:

- Around 50% of the features are 0 for a test instance
- 0 features are always 0 for all test data in dataset
---> So it seems 

#### Retry with lambda 1e-4

The statistics is similar:

- 10099 dead features vs 9975 dead features.
- Quantiles:
  - 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 2.3000000e+01 4.4500000e+02 6.8000540e+06 6.8324995e+06 6.8368465e+06
  - 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 2.40000000e+01 6.79720100e+06 6.82264975e+06 6.82933450e+06 6.83185225e+06

The mean is different: 451596.62 vs 882056.44. The 2nd run have more activated features than the 1st run. Maybe if we keep the training running, the feature will just keep becoming dead?

In [17]:
class SAEWithL1(SAE):
    def __init__(self, input_size, hidden_size, lmd, dead_feature_refresh_rate=2000):
        super().__init__(input_size, hidden_size)
        self.lmd = lmd
        self.save_hyperparameters()
        self.register_buffer("counter", torch.zeros(hidden_size))
        self.dfrr = dead_feature_refresh_rate

    def training_step(self, batch, batch_nb):
        act, output = self.forward(batch[0])
        if batch_nb % 20 == 0:
            # sample every 20 iterations
            self.counter += act.sum(dim=0)
        loss = self.loss(batch[0], output)
        reg = torch.norm(act, 1)
        total_loss = loss + self.lmd * reg
        self.log("loss", loss, on_step=True, on_epoch=False, prog_bar=False, logger=True)
        self.log("reg", reg, on_step=True, on_epoch=False, prog_bar=False, logger=True)
        self.log("total_loss", total_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def on_train_batch_end(self, outputs, batch, batch_idx):
        if batch_idx % self.dfrr == 0 and batch_idx >= self.dfrr:
            self.revive_dead_features()

    def forward(self, x):
        z = self.encode(x)
        y = self.decode(z)
        return z, y

    def check(self, to_test):
        with torch.no_grad():
            sae_output = self(to_test)[-1]
            with torch.no_grad():
                act = self.encode(to_test)
            print("Reconstruction capability:", 1 - torch.mean((sae_output - to_test) ** 2) / to_test.var())
            print("Number of activated:", (act > 0).sum())
            print("Percentage of activated:", (act > 0).sum() / act.shape[1])

    def revive_dead_features(self):
        """Randomly changing the weights to avoid dead features"""
        with torch.no_grad():
            idxs = (self.counter == 0).nonzero()
            nn.init.kaiming_uniform_(self.encoder.weight[idxs], a=math.sqrt(5))
            self.counter = torch.zeros(self.hparams.hidden_size, device=self.device)

#### Experiment with lambda 1e-4

Some features are always active. Some features never. It seems the problem comes from weight initialization, such that some features become inactive almost always.

In [9]:
saewithl1_2 = SAEWithL1(768, 24576, 1e-4)
# original_weights = saewithl1_2.state_dict()
# torch.save(original_weights, "/data/mech/data/ckpts/temporaries/lambda_1e-4_beginning_run2.pth")
trainer = L.Trainer(accelerator="gpu", callbacks=[ckpt_callback])
trainer.fit(
    saewithl1_2,
    train_dataloaders=[torch.utils.data.DataLoader(train_dataset, batch_size=256)],
    val_dataloaders=[torch.utils.data.DataLoader(val_dataset, batch_size=512)],
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/john/miniconda3/envs/dawnet/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py:72: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
/home/john/miniconda3/envs/dawnet/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:68: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type    | Params | Mode 
-------------------------------------------------
0 | encoder      | Linear  | 18.9 M | train
1 | decoder      | Linear  | 18.9 M | train
2 | loss         | MSELoss | 0      | train
  | other params | n/a     | 24.6 K | n/a  
-------------------------------------------------
37.8 M    Trainable params
0         Non-trainable params
37.8 M    Total params
151.195   Total estimated 

Training: |                                       | 0/? [00:00<?, ?it/s]

  return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)

Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [10]:
saewithl1_2.check(torch.Tensor(test_dataset[0]).unsqueeze(0))

Reconstruction capability: tensor(0.9997)
Number of activated: tensor(3082)
Percentage of activated: tensor(0.1254)


#### Examine the weights

There are indeed a lot of dead features.

In [17]:
sample = torch.Tensor(test_dataset[0]).unsqueeze(dim=0)
with torch.no_grad():
    act = saewithl1_2.encode(sample)
    act = act.cpu().numpy()

print(act[0, 2])

In [28]:
encoder = saewithl1_2.encoder
print(encoder.in_features, encoder.out_features)
print(encoder.weight[0].sum())
print(encoder.weight[1].sum())
print(encoder.weight[2].sum())
print(encoder.weight[3].sum())
print(encoder.weight[1923].sum())

768 24576
tensor(-0.0057, grad_fn=<SumBackward0>)
tensor(-2.2566e-10, grad_fn=<SumBackward0>)
tensor(-1.1012e-22, grad_fn=<SumBackward0>)
tensor(-0.0673, grad_fn=<SumBackward0>)
tensor(-0.0022, grad_fn=<SumBackward0>)


#### Train with reviving

In [None]:
saewithl1 = SAEWithL1(768, 24576, 1e-4, dead_feature_refresh_rate=5000)
# original_weights = saewithl1_2.state_dict()
# torch.save(original_weights, "/data/mech/data/ckpts/temporaries/lambda_1e-4_beginning_run2.pth")
trainer = L.Trainer(accelerator="gpu", callbacks=[ckpt_callback])
trainer.fit(
    saewithl1,
    train_dataloaders=[torch.utils.data.DataLoader(train_dataset, batch_size=256)],
    val_dataloaders=[torch.utils.data.DataLoader(val_dataset, batch_size=512)],
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type    | Params | Mode 
-------------------------------------------------
0 | encoder      | Linear  | 18.9 M | train
1 | decoder      | Linear  | 18.9 M | train
2 | loss         | MSELoss | 0      | train
  | other params | n/a     | 24.6 K | n/a  
-------------------------------------------------
37.8 M    Trainable params
0         Non-trainable params
37.8 M    Total params
151.195   Total estimated model params size (MB)
3         Modules in train mode
0         Modules in eval mode


Training: |                                       | 0/? [00:00<?, ?it/s]