In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
device = 'cpu'

12500
37500
62500
87500


### Checking is_step_... functions

In [47]:
def is_step_start_of_investigating_dead_neurons(step, x):
    return (step > 0) and step % (x // 2) == 0 and (step // (x // 2)) % 2 != 0 and step < 4 * x

x = int(25000)
for i in range(1000000):
    if is_step_start_of_investigating_dead_neurons(i, x):
        print(i)

if not set():
    print(0)

def is_step_in_the_phase_of_investigating_neurons(step, x):
    milestones = [x, 2*x, 3*x, 4*x]
    for milestone in milestones:
        if milestone - x//2 <= step < milestone:
            return True
    return False

ans = []
x = int(25000)
for i in range(100000):
    if is_step_in_the_phase_of_investigating_neurons(i, x):
        ans.append(i)

right_ans = [x//2+i for i in range(x//2)] + [(x//2)*3+i for i in range(x//2)] + [(x//2)*5+i for i in range(x//2)] + [(x//2)*7+i for i in range(x//2)]
right_ans[0], right_ans[-1], len(right_ans)

ans == right_ans

0


### Newest version (the one in train_sae.py)

In [55]:
class AutoEncoder(nn.Module):
    def __init__(self, n, m, lam=0.003, resampling_interval=None):
        # for us, n = d_MLP (a.k.a. n_ffwd) and m = number of autoencoder neurons
        super().__init__()
        self.n, self.m = n, m
        self.enc = nn.Linear(n, m)
        self.relu = nn.ReLU()
        self.dec = nn.Linear(m, n)
        self.lam = lam # coefficient of L_1 loss

        # some variables to be used if resampling neurons
        self.resampling_interval = resampling_interval
        self.dead_neurons = None

    def forward(self, x):
        # x is of shape (b, n) where b = batch_size, n = d_MLP

        xbar = x - self.dec.bias # (b, n)
        f = self.relu(self.enc(xbar)) # (b, m)
        reconst_acts = self.dec(f) # (b, n)
        mseloss = F.mse_loss(reconst_acts, x) # scalar
        l1loss = F.l1_loss(f, torch.zeros(f.shape, device=f.device), reduction='sum') # scalar
        loss = mseloss + self.lam * l1loss # scalar
        
        # if in training phase (i.e. model.train() has been called), we only need f and loss
        # but if evaluating (i.e. model.eval() has been called), we will need reconstructed activations and other losses as well
        out_dict = {'loss': loss, 'f': f} if self.training else {'loss': loss, 'f': f, 'reconst_acts': reconst_acts, 'mse_loss': mseloss, 'l1_loss': l1loss}
        
        return out_dict

    @torch.no_grad()
    def normalize_decoder_columns(self):
        # TODO: shouldnt these be called self instead of autoencoder?
        self.dec.weight.data = F.normalize(self.dec.weight.data, dim=0)

    def remove_parallel_component_of_decoder_gradient(self):
        # remove gradient information parallel to weight vectors
        # to do so, compute projection of gradient onto weight
        # recall projection of a onto b is proj_b a = (a.\hat{b}) \hat{b}
        # here, a = grad, b = weight
        unit_w = F.normalize(self.dec.weight, dim=0) # \hat{b}
        proj = torch.sum(self.dec.weight.grad * unit_w, dim=0) * unit_w 
        self.dec.weight.grad = self.dec.weight.grad - proj

    @torch.no_grad()
    def initiate_dead_neurons(self):
        self.dead_neurons = set([neuron for neuron in range(self.m)])

    @torch.no_grad()
    def update_dead_neurons(self, f):
        # obtain indices to columns of f (i.e. neurons) that fire on at least one example
        active_neurons_this_step = torch.count_nonzero(f, dim=0).nonzero().view(-1)
        
        # remove these neurons from self.dead_neurons
        for neuron in active_neurons_this_step:
            self.dead_neurons.discard(neuron.item())

    @torch.no_grad()
    def resample_neurons(self, data, optimizer):

        # if not self.dead_neurons:
        #     print(f'no dead neurons to be resampled')
        #     return
        self.dead_neurons = set([3, 1])

        device = next(self.parameters()).device # if all model parameters are on the same device (which in our case is True), use this to get that device
        dead_neurons_t = torch.tensor(list(self.dead_neurons))
        alive_neurons = torch.tensor([i for i in range(self.m) if i not in self.dead_neurons])
        print(f'number of dead neurons at time of resampling: {len(dead_neurons_t)}, alive neurons: {len(alive_neurons)}')

        # compute average norm of encoder vectors for alive neurons
        average_enc_norm = torch.mean(torch.linalg.vector_norm(self.enc.weight[alive_neurons], dim=1))
        #print(f'average encoder norm is {average_enc_norm}')
        
        # expect data to be of shape (N, n_ffwd); in the paper N = 819200
        num_batches = len(data) // batch_size + (len(data) % batch_size != 0)
        probs = torch.zeros(len(data),) # (N, ) # initiate a tensor of probs = losses**2
        for iter in range(num_batches): 
            print(f'computing losses for iter = {iter}')
            x = data[iter * batch_size: (iter + 1) * batch_size].to(device) # (b, n) where b = min(batch_size, remaining examples in data), n = d_MLP
            xbar = x - self.dec.bias # (b, n)
            f = self.relu(self.enc(xbar)) # (b, m)
            reconst_acts = self.dec(f) # (b, n)
            mselosses = torch.sum(F.mse_loss(reconst_acts, x, reduction='none'), dim=1) # (b,)
            l1losses = torch.sum(F.l1_loss(f, torch.zeros(f.shape, device=f.device), reduction='none'), dim=1) # (b, )
            probs[iter * batch_size: (iter + 1) * batch_size] = ((mselosses + self.lam * l1losses)**2).to('cpu') # (b, )

        # pick examples based on probs
        exs = data[torch.multinomial(probs, num_samples=len(self.dead_neurons))].to(dtype=torch.float32) # (d, n) where d = len(dead_neurons)
        assert exs.shape == (len(self.dead_neurons), self.n), 'exs has incorrect shape'
        # normalize examples to have unit norm
        exs_unit_norm = F.normalize(exs, dim=1) # (d, n)
        # reset decoder columns corresponding to dead neurons
        self.dec.weight[:, dead_neurons_t] = torch.transpose(exs_unit_norm, 0, 1) # (n, d)
        # renormalize examples to have norm = average_enc_norm * 0.2
        exs_enc_norm = exs_unit_norm * average_enc_norm * 0.2
        # reset encoder rows and encoder bias elements corresponding to dead neurons
        self.enc.weight[dead_neurons_t] = exs_enc_norm
        self.enc.bias[dead_neurons_t] = 0

        print('updated decoder weights and encoder weights and bias')

        # update Adam parameters associated to 
        for i, p in enumerate(optimizer.param_groups[0]['params']): # there is only one parameter group so we can do this
            param_state = optimizer.state[p]
            if i in [0, 1]: # encoder weight and bias
                param_state['exp_avg'][dead_neurons_t] = 0
                param_state['exp_avg_sq'][dead_neurons_t] = 0
            elif i == 2: # decoder weight
                param_state['exp_avg'][:, dead_neurons_t] = 0
                param_state['exp_avg_sq'][:, dead_neurons_t] = 0

        print(f'updated optimizer parameters')

        # reset self.dead_neurons as there are now none left to be resampled
        self.dead_neurons = None

In [56]:
torch.manual_seed(0)
batch_size = 8192
autoencoder = AutoEncoder(2, 4)
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.01)
total_steps = 2
for step in range(total_steps):

    ## load a batch of data        
    batch = torch.randn(100, 2)

    # forward, backward pass
    optimizer.zero_grad(set_to_none=True) 
    output = autoencoder(batch) # f has shape (batch_size, n_features)
    output['loss'].backward()
    optimizer.step()
    
autoencoder.resample_neurons(torch.randn(10000, 2), optimizer=optimizer)

number of dead neurons at time of resampling: 2, alive neurons: 2
computing losses for iter = 0
computing losses for iter = 1
updated decoder weights and encoder weights and bias
updated optimizer parameters


### New method (without a for loop)

In [4]:
class AutoEncoder(nn.Module):
    def __init__(self, n, m, lam=0.003):
        # for us, n = d_MLP (a.k.a. n_ffwd) and m = number of features
        super().__init__()
        self.n, self.m = n, m
        self.enc = nn.Linear(n, m) # enc.weight has shape (m, n)
        self.relu = nn.ReLU()
        self.dec = nn.Linear(m, n) # dec.weight has shape (n, m)
        self.lam = lam # coefficient of L_1 loss
        self.dead_neurons = set() # TODO: not that this is a new addition

    def forward(self, x):
        # x is of shape (b, n) where b = batch_size, n = d_MLP
        xbar = x - self.dec.bias # (b, n)
        f = self.relu(self.enc(xbar)) # (b, m)
        reconst_acts = self.dec(f) # (b, n)
        mseloss = F.mse_loss(reconst_acts, x) # scalar
        l1loss = F.l1_loss(f, torch.zeros(f.shape, device=f.device), reduction='sum') # scalar
        loss = mseloss + self.lam * l1loss # scalar
        out_dict = {'loss': loss, 'f': f, 'reconst_acts': reconst_acts, 'mse_loss': mseloss, 'l1_loss': l1loss}
        return loss if self.training else out_dict 

    @torch.no_grad()
    def resample_neurons(self, data, optimizer):
        
        self.dead_neurons = set([2, 1]) # TODO: remove this after self.dead_neurons has been computed in forward method
        dead_neurons_t = torch.tensor(list(self.dead_neurons))
        alive_neurons = torch.tensor([i for i in range(self.m) if i not in self.dead_neurons])
        print(f'alive_neurons are {list(alive_neurons)}') # should be torch.tensor([2, 3])

        # compute average norm of encoder vectors for alive neurons
        average_enc_norm = torch.mean(torch.linalg.vector_norm(self.enc.weight[alive_neurons], dim=1))
        print(f'average encoder norm is {average_enc_norm}')
        
        # expect data to be of shape (N, n_ffwd); in the paper N = 819200
        batch_size = 8192 # TODO: I can probably remove this when copied to train_sae.py
        num_batches = len(data) // batch_size + (len(data) % batch_size != 0)
        probs = torch.zeros(len(data),) # (N, ) # initiate a tensor of probs = losses**2
        for iter in range(num_batches): 
            print(f'computing losses for iter = {iter}')
            x = data[iter * batch_size: (iter + 1) * batch_size].to(device) # (b, n) where b = min(batch_size, remaining examples in data), n = d_MLP
            xbar = x - self.dec.bias # (b, n)
            f = self.relu(self.enc(xbar)) # (b, m)
            reconst_acts = self.dec(f) # (b, n)
            mselosses = torch.sum(F.mse_loss(reconst_acts, x, reduction='none'), dim=1) # (b,)
            l1losses = torch.sum(F.l1_loss(f, torch.zeros(f.shape, device=f.device), reduction='none'), dim=1) # (b, )
            probs[iter * batch_size: (iter + 1) * batch_size] = ((mselosses + self.lam * l1losses)**2).to('cpu') # (b, )


        torch.manual_seed(0) # TODO: remove this later perhaps
        exs = data[torch.multinomial(probs, num_samples=len(self.dead_neurons))] # (d, n) where d = len(dead_neurons)
        assert exs.shape == (len(self.dead_neurons), self.n), 'exs has incorrect shape'
        
        exs_unit_norm = F.normalize(exs, dim=1) # (d, n)

        self.dec.weight[:, dead_neurons_t] = torch.transpose(exs_unit_norm, 0, 1) # (n, d)

        exs_enc_norm = exs_unit_norm * average_enc_norm * 0.2

        self.enc.weight[dead_neurons_t] = exs_enc_norm

        self.enc.bias[dead_neurons_t] = 0

        for i, p in enumerate(optimizer.param_groups[0]['params']): # there is only one parameter group so we can do this
            param_state = optimizer.state[p]
            if i in [0, 1]: # encoder weight and bias
                param_state['exp_avg'][dead_neurons_t] = 0
                param_state['exp_avg_sq'][dead_neurons_t] = 0
            elif i == 2: # decoder weight
                param_state['exp_avg'][:, dead_neurons_t] = 0
                param_state['exp_avg_sq'][:, dead_neurons_t] = 0

        

    

In [5]:
torch.manual_seed(0)
autoencoder = AutoEncoder(2, 4)
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.01)
total_steps = 2
for step in range(total_steps):

    ## load a batch of data        
    batch = torch.randn(100, 2)

    # forward, backward pass
    optimizer.zero_grad(set_to_none=True) 
    loss = autoencoder(batch) # f has shape (batch_size, n_features)
    loss.backward()
    optimizer.step()
    
autoencoder.resample_neurons(torch.randn(5, 2), optimizer=optimizer)

alive_neurons are [tensor(0), tensor(3)]
average encoder norm is 0.45120707154273987
computing losses for iter = 0
torch.Size([4, 2])
torch.Size([4])
torch.Size([2, 4])
torch.Size([2])


In [217]:
autoencoder.enc.bias

Parameter containing:
tensor([-0.0821,  0.0000,  0.0000, -0.1590], requires_grad=True)

autoencoder.enc.weight

### Old method using a for loop

In [192]:
class AutoEncoder(nn.Module):
    def __init__(self, n, m, lam=0.003):
        # for us, n = d_MLP (a.k.a. n_ffwd) and m = number of features
        super().__init__()
        self.n, self.m = n, m
        self.enc = nn.Linear(n, m) # enc.weight has shape (m, n)
        self.relu = nn.ReLU()
        self.dec = nn.Linear(m, n) # dec.weight has shape (n, m)
        self.lam = lam # coefficient of L_1 loss
        self.dead_neurons = set() # TODO: not that this is a new addition

    def forward(self, x):
        # x is of shape (b, n) where b = batch_size, n = d_MLP
        xbar = x - self.dec.bias # (b, n)
        f = self.relu(self.enc(xbar)) # (b, m)
        reconst_acts = self.dec(f) # (b, n)
        mseloss = F.mse_loss(reconst_acts, x) # scalar
        l1loss = F.l1_loss(f, torch.zeros(f.shape, device=f.device), reduction='sum') # scalar
        loss = mseloss + self.lam * l1loss # scalar
        out_dict = {'loss': loss, 'f': f, 'reconst_acts': reconst_acts, 'mse_loss': mseloss, 'l1_loss': l1loss}
        return loss if self.training else out_dict 

    def resample_neurons(self, x):
        # x is of shape (b, n) where b = batch_size, n = d_MLP; b=819200
        xbar = x - self.dec.bias # (b, n)
        f = self.relu(self.enc(xbar)) # (b, m)
        reconst_acts = self.dec(f) # (b, n)
        mselosses = torch.sum(F.mse_loss(reconst_acts, x, reduction='none'), dim=1) # (b,)
        l1losses = torch.sum(F.l1_loss(f, torch.zeros(f.shape, device=f.device), reduction='none'), dim=1) # (b, )
        probs = (mselosses + self.lam * l1losses)**2 # (b, )

        self.dead_neurons = set([2, 1]) # TODO: this would, in general, be something else
        alive_neurons = torch.tensor([i for i in range(self.m) if i not in self.dead_neurons])
        print(f'alive_neurons are {list(alive_neurons)}') # should be torch.tensor([2, 3])

        # compute average norm of encoder vectors for alive neurons
        average_enc_norm = torch.mean(torch.linalg.vector_norm(self.enc.weight[alive_neurons], dim=1))
        print(f'average encoder norm is {average_enc_norm}')
        
        torch.manual_seed(0) # TODO: remove this later perhaps
        for neuron in self.dead_neurons:
            # pick an example
            print(f'neuron number: {neuron}')
            ex = x[torch.multinomial(probs, num_samples=1)] # (1, n)
            print(f'chosen example: {ex}')
            ex_normalized = F.normalize(ex, dim=1) # (1, n) 
            print(f'normalized example: {ex_normalized}')

            # set this new example to be decoder weight column
            print(f'original decoder weight: {self.dec.weight}')
            # dec.weight has shape (n, m) in general; this reassignment modifies a column
            self.dec.weight[:, neuron] = ex_normalized # lhs has shape (n, ); rhs has shape (1, n) but this is okay because of broadcasting
            print(f'modified decoder weight: {self.dec.weight}')

            # find average norm of the encoder weights for alive neurons # TODO: why do they multiply by 0.2?
            ex_normalized_again = ex_normalized * average_enc_norm * 0.2 
            print(f'normalized example again: {ex_normalized_again}')
            print(f'the norm of newly normalized example is {torch.linalg.vector_norm(ex_normalized_again)}; \
                  in comparison, average encoder norm * 0.2 is {average_enc_norm * 0.2}')

            # set this new example to be encoder weight row
            print(f'original encoder weight: {self.enc.weight}')
            # dec.weight has shape (n, m) in general; this reassignment modifies a column
            self.enc.weight[neuron] = ex_normalized_again
            print(f'modified decoder weight: {self.enc.weight}')

            # set the corresponding encoder bias element to zero
            print(f'original encoder bias: {self.enc.bias}')
            self.enc.bias[neuron] = 0
            print(f'new encoder bias: {self.enc.bias}')

            # TODO: update Adam optimizer parameters
            # first reset the optimizer parameters for enc.weight[neuron]

            


# TODO: (but last priority: at some point I should replace this for loop with a batched calculation)

In [193]:
torch.manual_seed(0)
autoencoder = AutoEncoder(2, 4)
with torch.no_grad():
    autoencoder.resample_neurons(torch.randn(5, 2))

alive_neurons are [tensor(0), tensor(3)]
average encoder norm is 0.470096230506897
neuron number: 1
chosen example: tensor([[-0.9528,  0.3717]])
normalized example: tensor([[-0.9316,  0.3634]])
original decoder weight: Parameter containing:
tensor([[-0.4777, -0.3311, -0.2061,  0.0185],
        [ 0.1977,  0.3000, -0.3390, -0.2177]], requires_grad=True)
modified decoder weight: Parameter containing:
tensor([[-0.4777, -0.9316, -0.2061,  0.0185],
        [ 0.1977,  0.3634, -0.3390, -0.2177]], requires_grad=True)
normalized example again: tensor([[-0.0876,  0.0342]])
the norm of newly normalized example is 0.09401924908161163;                   in comparison, average encoder norm * 0.2 is 0.09401924908161163
original encoder weight: Parameter containing:
tensor([[-0.0053,  0.3793],
        [-0.5820, -0.5204],
        [-0.2723,  0.1896],
        [-0.0140,  0.5607]], requires_grad=True)
modified decoder weight: Parameter containing:
tensor([[-0.0053,  0.3793],
        [-0.0876,  0.0342],
    

In [None]:
# examples with larger losses are more likely to be picked. 
# that's good, because it is these examples that we want to do better on
# 

In [201]:
torch.manual_seed(0)
autoencoder = AutoEncoder(2, 4)
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.01)

In [202]:
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.01)
total_steps = 2
for step in range(total_steps):

    ## load a batch of data        
    batch = torch.randn(100, 2)

    # forward, backward pass
    optimizer.zero_grad(set_to_none=True) 
    #print(optimizer.state_dict())
    loss = autoencoder(batch) # f has shape (batch_size, n_features)
    loss.backward()
    #print(autoencoder.enc.weight)
    #print(autoencoder.enc.weight.grad)
    
    optimizer.step()
    #print(optimizer.state_dict())

    optimizer.state_dict()['state'][0]

In [204]:
# TODO: incorporate this into the main code above in the AutoEncoder class
dead_neurons = set([3, 0])
for i, p in enumerate(optimizer.param_groups[0]['params']): # there is only one parameter group so we can do this
    param_state = optimizer.state[p]

    if i == 0: # reset parameters for encoder weights
        for neuron in dead_neurons: 
            param_state['exp_avg'][neuron] = torch.zeros(2,) # replace 2 with n 
            param_state['exp_avg_sq'][neuron] = torch.zeros(2,) 

    if i == 1: # reset paraemeters for encoder bias 
        for neuron in dead_neurons:
            param_state['exp_avg'][neuron] = 0
            param_state['exp_avg_sq'][neuron] = 0

    if i == 2:
        for neuron in dead_neurons: 
            param_state['exp_avg'][:, neuron] = torch.zeros(2, )
            param_state['exp_avg_sq'][:, neuron] = torch.zeros(2, )

updated_state_dict = optimizer.state_dict()


In [207]:
# try again:
dead_neurons = set([3, 0])
dead_neurons_t = torch.tensor(list(dead_neurons)) # introduce this guy this time
for i, p in enumerate(optimizer.param_groups[0]['params']): # there is only one parameter group so we can do this
    param_state = optimizer.state[p]

    if i in [0, 1]:
        param_state['exp_avg'][dead_neurons_t] = 0
        param_state['exp_avg_sq'][dead_neurons_t] = 0

    if i == 2:
        param_state['exp_avg'][:, dead_neurons_t] = 0
        param_state['exp_avg_sq'][:, dead_neurons_t] = 0

In [208]:
optimizer.state_dict() == updated_state_dict

True

### code used to get there

In [154]:
# 3 examples, mlp dim=2, n=4
reconst_acts = torch.ones(3,2)*2
x = torch.ones(3, 2)
f = torch.ones(3, 4) * 2
lam = 0.01
mselosses = torch.sum(F.mse_loss(reconst_acts, x, reduction='none'), dim=1)
l1losses = torch.sum(F.l1_loss(f, torch.zeros(f.shape, device=f.device), reduction='none'), dim=1)
print(mselosses)
print(l1losses)
losses = mselosses + lam * l1losses
print(losses)
probs = losses**2
print(probs)


tensor([2., 2., 2.])
tensor([8., 8., 8.])
tensor([2.0800, 2.0800, 2.0800])
tensor([4.3264, 4.3264, 4.3264])


In [197]:
optimizer.state_dict().keys() # keys: 'state', 'param_groups' 

## analyzing 'state'
optimizer.state_dict()['state'].keys() # keys: 0, 1, 2, 3
optimizer.state_dict()['state'][0] # keys: 'step', 'exp_avg', 'exp_avg_square' 
# 'state': torch.Tensor(int), 'exp_avg': tensor.shape = (4, 2), 'exp_avg_sq': (4, 2)
optimizer.state_dict()['state'][1] # keys: 'step', 'exp_avg', 'exp_avg_square' 
# 'state': torch.Tensor(int), 'exp_avg': tensor.shape = (4,), 'exp_avg_sq': (4,)
# In general, state, exp_avg and exp_avg_sq of optimizer.state_dict()['state'][i] have shapes: (), param.shape, param.shape
# where param = list(model.parameters())[i]

## analyzing 'param_grous' 
optimizer.state_dict()['param_groups']
# this is just the dictionary of hyperparameters: lr, betas, eps, weight_decay, amsgrad, maximize, foreach, capturable, differentiable, fused, params
# perhaps the one that is least clear is params. It is an iterable of parameters to optimize (or a dictionary of parameter groups)
# in the simplest case of the autoencoder above (a model with two weight matrices and two bias vectors), it is a list [0, 1, 2, 3]

[{'lr': 0.01,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'weight_decay': 0,
  'amsgrad': False,
  'maximize': False,
  'foreach': None,
  'capturable': False,
  'differentiable': False,
  'fused': None,
  'params': [0, 1, 2, 3]}]