|<h2>Course:</h2>|<h1><b><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></b></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Identifying circuits and components<h1>|
|<h2>Lecture:</h2>|<h1><b>Sparse autoencoders: theory and code<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# Create latent variables and mix them

In [None]:
# simulate independent sources and mix
n_samples = 3000
t = np.linspace(0,14*np.pi,n_samples)
latent1 = np.sin(t) + 1
latent2 = np.sign(np.sin(3*t)) + 1

# mix the sources linearly to create the manifest variables
mixing_matrix = np.array([ [1,.4], [.6,1] ])
data = np.stack([latent1,latent2],axis=1) @ mixing_matrix
data.shape

In [None]:
# visualize
_,axs = plt.subplots(2,2,figsize=(10,6))

# the latent variables and their statistical independence
axs[0,0].plot(latent1,'bo',markerfacecolor=[.7,.7,.9],alpha=.5,markersize=3)
axs[0,0].plot(latent2,'gs',markerfacecolor=[.7,.9,.7],alpha=.5,markersize=3)
axs[0,0].set(xlabel='Data index',ylabel='Data value',xlim=[-5,n_samples+4],
             title=f'Latent sources (r = {np.corrcoef(latent1,latent2)[0,1]:.2f})')

# the "manifest" variables and their correlation
axs[0,1].plot(data[:,0],'mo',markerfacecolor=[.9,.7,.9],alpha=.5,markersize=3)
axs[0,1].plot(data[:,1],'ks',markerfacecolor=[.9,.9,.7],alpha=.5,markersize=3)
axs[0,1].set(xlabel='Data index',ylabel='Data value',xlim=[-5,n_samples+4],
             title=f'Mixed signals (r = {np.corrcoef(data.T)[0,1]:.2f})')


# histograms
yL1,xL1 = np.histogram(latent1,bins=40,density=True)
yL2,xL2 = np.histogram(latent2,bins=40,density=True)
axs[1,0].plot(xL1[:-1],yL1,linewidth=3,label='Latent 1')
axs[1,0].plot(xL2[:-1],yL2,linewidth=3,label='Latent 2')
axs[1,0].set(xlabel='Data value',ylabel='Density',xlim=[min(xL1[0],xL2[0]),max(xL1[-1],xL2[-1])],
             title='Histograms of latent variables')
axs[1,0].legend()

yD1,xD1 = np.histogram(data[:,0],bins=40,density=True)
yD2,xD2 = np.histogram(data[:,1],bins=40,density=True)
axs[1,1].plot(xD1[:-1],yD1,linewidth=3,label='Variable 1')
axs[1,1].plot(xD2[:-1],yD2,linewidth=3,label='Variable 2')
axs[1,1].set(xlabel='Data value',ylabel='Density',xlim=[min(xD1[0],xD2[0]),max(xD1[-1],xD2[-1])],
             title='Histograms of mixed variables')
axs[1,1].legend()


plt.tight_layout()
plt.show()

# Create a sparse autoencoder model

In [None]:
# as a pytorch class
class SparseAutoencoder(nn.Module):
  def __init__(self, input_dim, hidden_dim, sparsity_weight=1):
    super().__init__()
    self.encoder = nn.Linear(input_dim, hidden_dim)
    self.decoder = nn.Linear(hidden_dim, input_dim)
    self.sparsity_weight = sparsity_weight

  def forward(self, x):
    estLatent = torch.nn.functional.gelu(self.encoder(x))
    x_recon = self.decoder(estLatent)
    return x_recon,estLatent

  # L1 penalty on hidden activations
  def sparsity_loss(self, estLatent):
    return self.sparsity_weight * torch.mean(torch.abs(estLatent))

  # penalty on inter-latent covariance (used in the next demo, not here)
  def decorrelation_loss(self, estLatent):
    cov = torch.cov(estLatent.T)
    off_diag = cov - torch.diag(torch.diag(cov))
    return self.sparsity_weight * torch.sum(off_diag**2)



# create an instance and inspect!
num_hidden = 20
AEmodel = SparseAutoencoder(2,num_hidden) # two inputs, >2 hidden dimensions
AEmodel

In [None]:
# test with bunk input to make sure it works
x = torch.randn(10,2)
AEmodel(x)

# Train the model on the data

In [None]:
# training params
n_epochs  = 600
lr        = .0007
optimizer = optim.Adam(AEmodel.parameters(), lr=lr)
lossfun   = nn.MSELoss() # loss is mean-squared error to match output to input

# data need to be a torch tensor of size obs X features
X_tensor = torch.tensor(data,dtype=torch.float)
X_tensor.shape

In [None]:
losses = np.zeros((n_epochs,2))

# train the model!
for epoch in range(n_epochs):

  # forward pass
  optimizer.zero_grad()
  x_recon,estLatent = AEmodel(X_tensor)

  # calculate and store the lossses (MSE + L1)
  L1loss = AEmodel.sparsity_loss(estLatent)
  # L1loss = AEmodel.decorrelation_loss(estLatent)
  MSEloss = lossfun(x_recon,X_tensor)
  losses[epoch,0] = L1loss.item()
  losses[epoch,1] = MSEloss.item()

  # the actual loss (for backprop) is the sum
  allloss = MSEloss + L1loss

  # backprop
  allloss.backward()
  optimizer.step()

  # report!
  if epoch%47==0:
    print(f'Epoch {epoch+1:3}: L1 loss = {L1loss.item():.4f}, MSE loss = {MSEloss.item():.4f}')

In [None]:
# plot the losses
_,axs = plt.subplots(1,2,figsize=(12,3.5))

# L1 and MSE loss separately
axs[0].plot(range(0,n_epochs,20),losses[::20,0],'ko-',markerfacecolor=[.7,.7,.9],markersize=8,label='L1 loss')
axs[0].set(xlabel='Epoch',ylabel='L1 Loss',title='Training loss components')
axs[0].legend(bbox_to_anchor=[.5,.5,.48,0],frameon=False)

ax2 = axs[0].twinx()
ax2.plot(range(0,n_epochs,20),losses[::20,1],'ks-',markerfacecolor=[.9,.7,.7],markersize=8,label='MSE loss')
ax2.set(ylabel='MSE Loss')
ax2.legend(bbox_to_anchor=[.5,.4,.5,0],frameon=False)

axs[1].plot(range(0,n_epochs,20),losses[::20,:].sum(axis=1),'k^-',markerfacecolor=[.7,.7,.7],markersize=8,label='Total loss')
axs[1].set(xlabel='Epoch',ylabel='Training loss',title='Total losses')

plt.tight_layout()
plt.show()

# Inspect the results

In [None]:
# one final forward pass
x_recon,estLatent = AEmodel(X_tensor)

estLatent = estLatent.detach().numpy()
estLatent.shape

In [None]:
plt.imshow(np.corrcoef(estLatent.T),cmap='RdBu_r',vmin=-1,vmax=1)
plt.colorbar()
plt.gca().set(xlabel='Latent variable',xticks=range(0,num_hidden,2),yticks=range(1,num_hidden,2),
              ylabel='Latent variable',title='Correlation matrix')
plt.show()

# Find the latent component that best correlates with ground truth variables

In [None]:
# concatenate the ground truth variable onto the model-latent components
latentComps_cat = np.concatenate((estLatent,latent1[:,None]),axis=1)

# correlate
R1 = abs(np.corrcoef(latentComps_cat.T))

# find the strongest correlation with the target (ground truth)
best_1 = np.nanargmax(R1[-1,:-1]) # :-1 to ignore self-correlation = 1
print(f'Best latent component 1: {best_1} (r = {R1[best_1,-1]:.3f})')


# repeat for component 2
latentComps_cat = np.concatenate((estLatent,latent2[:,None]),axis=1)
R2 = abs(np.corrcoef(latentComps_cat.T))
best_2 = np.nanargmax(R2[-1,:-1]) # :-1 to ignore self-correlation = 1
print(f'Best latent component 2: {best_2} (r = {R2[best_2,-1]:.3f})')

In [None]:
# visualize
_,axs = plt.subplots(1,2,figsize=(12,3))

# the latent variables and their statistical independence
axs[0].plot(estLatent[:,best_1],'bo',markerfacecolor=[.7,.7,.9,.5],markersize=3)
axs[0].plot(estLatent[:,best_2],'gs',markerfacecolor=[.7,.9,.7,.5],markersize=3)
axs[0].set(xlabel='Data index',ylabel='Data value',xlim=[-5,n_samples+4],
             title=f'Latent components (r = {np.corrcoef(estLatent[:,[best_1,best_2]].T)[0,1]:.2f})')


# histograms
yL1,xL1 = np.histogram(estLatent[:,best_1],bins=40)
yL2,xL2 = np.histogram(estLatent[:,best_2],bins=40)
axs[1].plot(xL1[:-1],yL1,'b',linewidth=3,label='Est. latent 1')
axs[1].plot(xL2[:-1],yL2,'g',linewidth=3,label='Est. latent 2')

axs[1].set(xlabel='Data value',ylabel='Count',xlim=[min(xL1[0],xL2[0]),max(xL1[-1],xL2[-1])],
             title='Histograms of latent variables')
axs[1].legend()

plt.tight_layout()
plt.show()