In [3]:
import math

import torch
import torch.nn as nn
import lightning as L


class SAEDumb(L.LightningModule):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(True)
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_size, input_size),
            nn.Sigmoid()
        )
        self.loss = nn.MSELoss()
        self.save_hyperparameters()

    def forward(self, x):
        hidden = self.encoder(x)
        x = self.decoder(hidden)
        return x

    def training_step(self, batch, batch_nb):
        output = self.forward(batch[0])
        loss = self.loss(batch[0], output)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-4, weight_decay=1e-5)

    def sanity_check(self, input_):
        with torch.no_grad():
            output = self(to_test)
        return 1 - torch.mean((saedumb_output - to_test) ** 2) / to_test.var()

    def act(self, input_):
        with torch.no_grad():
            act = self.encoder(input_)
        return input_

    def save(self, path):
        torch.save(self.state_dict(), path)

    def check(self, to_test):
        with torch.no_grad():
            sae_output = self(to_test)
            with torch.no_grad():
                act = self.encode(to_test)
            print("Reconstruction capability:", 1 - torch.mean((sae_output - to_test) ** 2) / to_test.var())
            print("Number of activated:", (act > 0).sum())
            print("Percentage of activated:", (act > 0).sum() / act.numel())

    def active_feature_statistics(self, dataloader):
        self.cuda()
        with torch.no_grad():
            total = torch.zeros(self.hparams.hidden_size).cuda()
            for batch in dataloader:
                act = self.encode(batch.cuda())
                total += (act > 0).sum(dim=0)
        self.cpu()
        total = total.detach().cpu().numpy().squeeze()
        print("Total number of dead neurons:", (total == 0).sum())
        print("Percentage of dead neurons:", (total == 0).sum() / total.shape[0])
        print("Quantiles:", np.quantile(total, [0.01, 0.02, 0.05, 0.1, 0.5, 0.9, 0.95, 0.98, 0.99]))
        print("Mean:", np.mean(total))
        return total


class SAE(SAEDumb):
    def __init__(self, input_size, hidden_size):
        super().__init__(input_size, hidden_size)
        self.encoder = nn.Linear(in_features=input_size, out_features=hidden_size, bias=True)
        self.thresh = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
        self.decoder = nn.Linear(in_features=hidden_size, out_features=input_size, bias=True)
        self.save_hyperparameters()

    def encode(self, x):
        y = self.encoder(x)
        mask = (y > self.thresh)
        y = mask * nn.functional.relu(y)
        return y

    def decode(self, x):
        y = self.decoder(x)
        return y

    def forward(self, x):
        y = self.encode(x)
        y = self.decode(y)
        return y

    def act(self, input_):
        with torch.no_grad():
            act = self.encode(input_)
        return act


class SAEWithL1(SAE):
    def __init__(self, input_size, hidden_size, lmd, dead_feature_refresh_rate=0):
        super().__init__(input_size, hidden_size)
        self.lmd = lmd
        self.save_hyperparameters()
        self.register_buffer("counter", torch.zeros(hidden_size))
        self.dfrr = dead_feature_refresh_rate

    def training_step(self, batch, batch_nb):
        act, output = self.forward(batch[0])
        if self.dfrr and (batch_nb % 20 == 0):
            # sample every 20 iterations
            self.counter += act.sum(dim=0)
        loss = self.loss(batch[0], output)
        reg = torch.norm(act, 1)
        total_loss = loss + self.lmd * reg
        self.log("loss", loss, on_step=True, on_epoch=False, prog_bar=False, logger=True)
        self.log("reg", reg, on_step=True, on_epoch=False, prog_bar=False, logger=True)
        self.log("total_loss", total_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return total_loss

    def on_train_batch_end(self, outputs, batch, batch_idx):
        if self.dfrr and (batch_idx % self.dfrr == 0) and (batch_idx >= self.dfrr):
            self.revive_dead_features()

    def forward(self, x):
        z = self.encode(x)
        y = self.decode(z)
        return z, y

    def check(self, to_test):
        self.cuda()
        to_test = to_test.cuda()
        with torch.no_grad():
            sae_output = self(to_test)[-1]
            with torch.no_grad():
                act = self.encode(to_test)
            print("Reconstruction capability:", 1 - torch.mean((sae_output - to_test) ** 2) / to_test.var())
            print("Number of activated:", (act > 0).sum())
            print("Percentage of activated:", (act > 0).sum() / act.numel())

    def revive_dead_features(self):
        """Randomly changing the weights to avoid dead features"""
        with torch.no_grad():
            idxs = (self.counter == 0).nonzero()
            self.encoder.weight[idxs] = nn.init.kaiming_uniform_(
                self.encoder.weight[idxs], a=math.sqrt(5)
            )
            self.counter = torch.zeros(self.hparams.hidden_size, device=self.device)

class L2DecoderWeight(L.LightningModule):

    def __init__(self, input_size, hidden_size, lmd, dead_feature_refresh_rate=0):
        super().__init__()
        self.loss = nn.MSELoss()

        self.lmd = lmd
        self.register_buffer("counter", torch.zeros(hidden_size))
        self.dfrr = dead_feature_refresh_rate

        self.thresh = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)

        self.W_enc = nn.Parameter(torch.empty(input_size, hidden_size))
        self.b_enc = nn.Parameter(torch.empty(hidden_size))
        self.W_dec = nn.Parameter(torch.empty(hidden_size, input_size))
        self.b_dec = nn.Parameter(torch.empty(input_size))
        self.reset_parameters()
        self.save_hyperparameters()

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.W_enc, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_dec, a=math.sqrt(5))

        _, fan_in = nn.init._calculate_fan_in_and_fan_out(self.W_enc)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        nn.init.uniform_(self.b_enc, -bound, bound)

        _, fan_in = nn.init._calculate_fan_in_and_fan_out(self.W_dec)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        nn.init.uniform_(self.b_dec, -bound, bound)

    def encode(self, x):
        """x has shape B x Input"""
        y = x @ self.W_enc + self.b_enc
        return nn.functional.relu(y)

    def decode(self, z):
        return z @ self.W_dec + self.b_dec

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-4)

    def training_step(self, batch, batch_nb):
        act, output = self.forward(batch[0])
        if (batch_nb % 20 == 0):
            # sample every 20 iterations
            self.counter += (act > 0).sum(dim=0)

        if (batch_nb % 5000 == 0) and (batch_nb >= 5000):
            with torch.no_grad():
                idxs = (self.counter == 0).nonzero()
                self.logger.experiment.add_scalar("dead_neurons/count", len(idxs.tolist()), self.global_step)

        loss = self.loss(batch[0], output)
        weighted_act = act * self.W_dec.norm(dim=1)
        reg = torch.norm(weighted_act, 1)
        total_loss = loss + self.lmd * reg
        self.log("loss", loss, on_step=True, on_epoch=False, prog_bar=False, logger=True)
        self.log("reg", reg, on_step=True, on_epoch=False, prog_bar=False, logger=True)
        self.log("total_loss", total_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return total_loss

    def forward(self, x):
        z = self.encode(x)
        y = self.decode(z)
        return z, y

    def on_train_batch_end(self, outputs, batch, batch_idx):
        if self.dfrr and (batch_idx % self.dfrr == 0) and (batch_idx >= self.dfrr):
            self.revive_dead_features()

    def revive_dead_features(self):
        """Randomly changing the weights to avoid dead features"""
        with torch.no_grad():
            idxs = (self.counter == 0).nonzero()
            self.W_enc[:, idxs] = nn.init.kaiming_uniform_(
                self.W_enc[:, idxs], a=math.sqrt(5)
            )
            self.counter = torch.zeros(self.hparams.hidden_size, device=self.device)

    def active_feature_statistics(self, dataloader):
        self.cuda()
        with torch.no_grad():
            total = torch.zeros(self.hparams.hidden_size).cuda()
            for batch in dataloader:
                act = self.encode(batch.cuda())
                total += (act > 0).sum(dim=0)
        self.cpu()
        total = total.detach().cpu().numpy().squeeze()
        print("Total number of dead neurons:", (total == 0).sum())
        print("Percentage of dead neurons:", (total == 0).sum() / total.shape[0])
        print("Quantiles:", np.quantile(total, [0.01, 0.02, 0.05, 0.1, 0.5, 0.9, 0.95, 0.98, 0.99]))
        print("Mean:", np.mean(total))
        return total

    def check(self, to_test):
        self.cuda()
        to_test = to_test.cuda()
        with torch.no_grad():
            sae_output = self(to_test)[-1]
            with torch.no_grad():
                act = self.encode(to_test)
            print("Reconstruction capability:", 1 - torch.mean((sae_output - to_test) ** 2) / to_test.var())
            print("Number of activated:", (act > 0).sum())
            print("Percentage of activated:", (act > 0).sum() / act.numel())


class L2DecoderWeightFull(L2DecoderWeight):

    def revive_dead_features(self):
        """Randomly changing the weights to avoid dead features"""
        with torch.no_grad():
            idxs = (self.counter == 0).nonzero()
            self.logger.experiment.add_scalar("dead_neurons/count", len(idxs.tolist()), self.global_step)
            self.logger.experiment.add_text("dead_neurons/idxs", str(idxs.tolist()), self.global_step)
            if not idxs.tolist():
                return

            # reset the encoder weights
            self.W_enc[:, idxs] = nn.init.kaiming_uniform_(
                self.W_enc[:, idxs], a=math.sqrt(5)
            )

            # reset the encoder biases
            _, fan_in = nn.init._calculate_fan_in_and_fan_out(self.W_enc)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            self.b_enc[idxs] = nn.init.uniform_(self.b_enc[idxs], -bound, bound)

            # reset the decoder weights
            self.W_dec[idxs, :] = nn.init.kaiming_uniform_(
                self.W_dec[idxs, :], a=math.sqrt(5)
            )

            # reset the optimizer
            optimizer = self.optimizers(use_pl_optimizer=False)
            state = optimizer.state
            idxs = idxs.to(device=state[self.W_enc]['exp_avg'].device).squeeze()

            state[self.W_enc]['exp_avg'].index_fill_(1, idxs, 0)
            state[self.W_enc]['exp_avg_sq'].index_fill_(1, idxs, 0)
            state[self.b_enc]['exp_avg'].index_fill_(0, idxs, 0)
            state[self.b_enc]['exp_avg_sq'].index_fill_(0, idxs, 0)
            state[self.W_dec]['exp_avg'].index_fill_(0, idxs, 0)
            state[self.W_dec]['exp_avg_sq'].index_fill_(0, idxs, 0)

            self.counter = torch.zeros(self.hparams.hidden_size, device=self.device)

class ShiftDecoderBias(L2DecoderWeightFull):
    def encode(self, x):
        """x has shape B x Input"""
        y = (x - self.b_dec) @ self.W_enc + self.b_enc
        return nn.functional.relu(y)