In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchmetrics


In [4]:
def attention_fakedata():
    rng = np.random.default_rng(42)
    ratio = 11.0

    # Generate synthetic data with 
    n_samples = 100000
    n_pos = int(n_samples / (ratio + 1))  # Calculate number of positive samples
    n_neg = n_samples - n_pos  # Remaining are negative samples

    # Create labels array with the correct ratio
    labels = np.concatenate([np.zeros(n_neg), np.ones(n_pos)])

    # Shuffle the labels
    rng.shuffle(labels)

    # Create a DataFrame with the synthetic data
    raw_df = pd.DataFrame({'soz': labels})

    # Using boolean indexing:
    mask0 = raw_df['soz'] == 0
    mask1 = raw_df['soz'] == 1

    raw_df.loc[mask0, 'f1'] = rng.normal(6, 0.3, size=mask0.sum())
    raw_df.loc[mask1, 'f1'] = rng.gamma(1, .11, size=mask1.sum())


    raw_df['f2'] = rng.normal(10, 5, size=n_samples)
    # raw_df['f3'] = rng.poisson(lam=3, size=n_samples)

    # # plot features on one 1,3 plot grouped by soz
    # fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # # Plot feature 1 distributions
    # sns.histplot(data=raw_df, hue='soz', x='f1', ax=axes[0])
    # axes[0].set_title('Feature 1 Distribution by SOZ')
    # axes[0].set_xlabel('SOZ')
    # axes[0].set_ylabel('Value')

    # # Plot feature 2 distributions  
    # sns.histplot(data=raw_df, hue='soz', x='f2', ax=axes[1])
    # axes[1].set_title('Feature 2 Distribution by SOZ')
    # axes[1].set_xlabel('SOZ')
    # axes[1].set_ylabel('Value')

    # # Plot feature 3 distributions
    # sns.histplot(data=raw_df, hue='soz', x='f3', ax=axes[2])
    # axes[2].set_title('Feature 3 Distribution by SOZ')
    # axes[2].set_xlabel('SOZ')
    # axes[2].set_ylabel('Value')

    # plt.tight_layout()
    # plt.show()

    neg, pos = np.bincount(raw_df['soz'])
    total = neg + pos
    print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
        total, pos, 100 * pos / total))

    initial_bias = np.log([pos/neg])

    print('data shape (+1 for soz)')
    print(raw_df.shape)
    return raw_df,initial_bias

raw_df, initial_bias = attention_fakedata()

Examples:
    Total: 100000
    Positive: 8333 (8.33% of total)

data shape (+1 for soz)
(100000, 3)


  neg, pos = np.bincount(raw_df['soz'])


In [5]:
# Use a utility from sklearn to split and shuffle your dataset.
train_df, test_df = train_test_split(raw_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)

# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('soz')).reshape(-1, 1)
bool_train_labels = train_labels[:, 0] != 0
val_labels = np.array(val_df.pop('soz')).reshape(-1, 1)
test_labels = np.array(test_df.pop('soz')).reshape(-1, 1)

train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)


scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)

val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

# train_features = np.clip(train_features, -5, 5)
# val_features = np.clip(val_features, -5, 5)
# test_features = np.clip(test_features, -5, 5)

In [6]:

# Self-attention model mimicking your Keras architecture.
class SelfAttentionModel(nn.Module):
    def __init__(self, N_value):
        super(SelfAttentionModel, self).__init__()
        self.N_value = N_value
        # Project each feature (token) from 1 -> 3 to mimic key_dim=3.
        self.proj = nn.Linear(1, 3)
        # MultiheadAttention: embed_dim=3, one head, batch_first=True.
        self.attn = nn.MultiheadAttention(embed_dim=3, num_heads=1, batch_first=True)
        # Layer normalization after flattening the attention output.
        self.ln = nn.LayerNorm(N_value * 3)
        # MLP block: first dense layer with 4 * N_value units.
        self.fc1 = nn.Linear(N_value * 3, 4 * N_value)
        # Output layer with one unit.
        self.fc2 = nn.Linear(4 * N_value, 1)
    
    def forward(self, x):
        # x: (batch, N_value)
        # Reshape to (batch, N_value, 1) so each feature is a token.
        x = x.unsqueeze(-1)
        # Project tokens to dimension 3.
        x = self.proj(x)
        # Self-attention (queries, keys, and values are the same).
        attn_output, attn_weights = self.attn(x, x, x)
        # Flatten the attention output to (batch, N_value * 3).
        x = attn_output.view(attn_output.size(0), -1)
        # Apply layer normalization.
        x = self.ln(x)
        # MLP block with ReLU activation.
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        # Sigmoid activation to get probabilities.
        x = torch.sigmoid(x)
        return x, attn_weights

# Early stopping callback implementation.
class EarlyStopping:
    def __init__(self, patience=10, verbose=False, mode='max'):
        self.patience = patience
        self.verbose = verbose
        self.mode = mode
        self.best_score = None
        self.counter = 0
        self.early_stop = False
        self.best_state = None
    
    def __call__(self, metric, model):
        score = metric
        if self.best_score is None:
            self.best_score = score
            self.best_state = model.state_dict()
        elif (self.mode == 'max' and score < self.best_score) or (self.mode == 'min' and score > self.best_score):
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_state = model.state_dict()
            self.counter = 0

# Training loop using torchmetrics.
def train_model(model, train_loader, val_loader, epochs=200, lr=1e-3, device='cpu'):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCELoss()  # Binary crossentropy loss

    # Initialize torchmetrics for training.
    train_loss_metric = MeanMetric()
    train_tp = BinaryTruePositives()
    train_fp = BinaryFalsePositives()
    train_tn = BinaryTrueNegatives()
    train_fn = BinaryFalseNegatives()
    train_accuracy = BinaryAccuracy()
    train_precision = BinaryPrecision()
    train_recall = BinaryRecall()
    train_auc = BinaryAUROC()
    train_prc = BinaryAUCPR()

    # Initialize torchmetrics for validation.
    val_loss_metric = MeanMetric()
    val_tp = BinaryTruePositives()
    val_fp = BinaryFalsePositives()
    val_tn = BinaryTrueNegatives()
    val_fn = BinaryFalseNegatives()
    val_accuracy = BinaryAccuracy()
    val_precision = BinaryPrecision()
    val_recall = BinaryRecall()
    val_auc = BinaryAUROC()
    val_prc = BinaryAUCPR()

    early_stopping = EarlyStopping(patience=10, verbose=True, mode='max')
    
    for epoch in range(epochs):
        # Reset metrics at the start of each epoch.
        model.train()
        train_loss_metric.reset()
        train_brier.reset()
        train_tp.reset()
        train_fp.reset()
        train_tn.reset()
        train_fn.reset()
        train_accuracy.reset()
        train_precision.reset()
        train_recall.reset()
        train_auc.reset()
        train_prc.reset()

        for features, labels in train_loader:
            features = features.to(device)
            labels = labels.to(device).float()
            optimizer.zero_grad()
            outputs, _ = model(features)
            # outputs shape: (batch, 1) – squeeze to (batch,)
            outputs = outputs.squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # Update training metrics.
            train_loss_metric.update(loss.item())
            # For classification metrics, convert labels to int (0 or 1)
            train_tp.update(outputs, labels.int())
            train_fp.update(outputs, labels.int())
            train_tn.update(outputs, labels.int())
            train_fn.update(outputs, labels.int())
            train_accuracy.update(outputs, labels.int())
            train_precision.update(outputs, labels.int())
            train_recall.update(outputs, labels.int())
            train_auc.update(outputs, labels.int())
            train_prc.update(outputs, labels.int())
        
        # Compute average training metrics.
        avg_train_loss = train_loss_metric.compute()
        avg_train_acc = train_accuracy.compute()
        avg_train_precision = train_precision.compute()
        avg_train_recall = train_recall.compute()
        avg_train_auc = train_auc.compute()
        avg_train_prc = train_prc.compute()

        # Evaluate on the validation set.
        model.eval()
        val_loss_metric.reset()
        val_tp.reset()
        val_fp.reset()
        val_tn.reset()
        val_fn.reset()
        val_accuracy.reset()
        val_precision.reset()
        val_recall.reset()
        val_auc.reset()
        val_prc.reset()

        with torch.no_grad():
            for features, labels in val_loader:
                features = features.to(device)
                labels = labels.to(device).float()
                outputs, _ = model(features)
                outputs = outputs.squeeze()
                loss = criterion(outputs, labels)
                
                val_loss_metric.update(loss.item())
.                val_tp.update(outputs, labels.int())
                val_fp.update(outputs, labels.int())
                val_tn.update(outputs, labels.int())
                val_fn.update(outputs, labels.int())
                val_accuracy.update(outputs, labels.int())
                val_precision.update(outputs, labels.int())
                val_recall.update(outputs, labels.int())
                val_auc.update(outputs, labels.int())
                val_prc.update(outputs, labels.int())
        
        avg_val_loss = val_loss_metric.compute()
        avg_val_brier = val_brier.compute()
        avg_val_acc = val_accuracy.compute()
        avg_val_precision = val_precision.compute()
        avg_val_recall = val_recall.compute()
        avg_val_auc = val_auc.compute()
        avg_val_prc = val_prc.compute()

        print(f"Epoch {epoch+1}:")
        print(f"  Train - Loss: {avg_train_loss:.4f}, Brier: {avg_train_brier:.4f}, Acc: {avg_train_acc:.4f}, Precision: {avg_train_precision:.4f}, Recall: {avg_train_recall:.4f}, AUC: {avg_train_auc:.4f}, PRC: {avg_train_prc:.4f}")
        print(f"  Val   - Loss: {avg_val_loss:.4f}, Brier: {avg_val_brier:.4f}, Acc: {avg_val_acc:.4f}, Precision: {avg_val_precision:.4f}, Recall: {avg_val_recall:.4f}, AUC: {avg_val_auc:.4f}, PRC: {avg_val_prc:.4f}")
        
        # Use the validation PRC metric for early stopping.
        early_stopping(avg_val_prc, model)
        if early_stopping.early_stop:
            print("Early stopping triggered.")
            model.load_state_dict(early_stopping.best_state)
            break
    
    return model

# Example usage:
# Assume train_loader and val_loader are DataLoader objects and N_value is the number of features.
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = SelfAttentionModel(N_value)
# trained_model = train_model(model, train_loader, val_loader, epochs=200, lr=1e-3, device=device)


In [7]:
torch.device("cuda" if torch.cuda.is_available() else "cpu")

device(type='cuda')

In [None]:
fig,axes = plt.subplots(4,3,figsize=(15,15))


attn = model.layers[2]
x = train_features.reshape(-1, train_features.shape[-1], 1)
_, attention_scores = attn(x,x, return_attention_scores=True) # take one sample


mean_all = attention_scores.numpy().squeeze().mean(axis=0)
mean_0 = attention_scores.numpy()[train_labels[:,0]==0,0,:,:].squeeze().mean(axis=0)
mean_1 = attention_scores.numpy()[train_labels[:,0]==1,0,:,:].squeeze().mean(axis=0)

sns.heatmap(mean_all, annot=True, cbar=True,square=True, fmt='.2f', ax=axes[0,0], cmap='hot',vmin=0,vmax=1)
sns.heatmap(mean_0, annot=True, cbar=True,square=True, fmt='.2f', ax=axes[0,1], cmap='hot',vmin=0,vmax=1)
sns.heatmap(mean_1, annot=True, cbar=True,square=True, fmt='.2f', ax=axes[0,2], cmap='hot',vmin=0,vmax=1)
axes[0,0].set_title('All')
axes[0,1].set_title('soz=0')
axes[0,2].set_title('soz=1')
plt.suptitle(f'Attention Scores | {history.history["prc"][-1]:.2f}')



a = attention_scores.numpy().squeeze()
zero_mask = train_labels[:,0]==0
one_mask = train_labels[:,0]==1

for i in range(a.shape[1]):
    for j in range(a.shape[2]):
        zeros = a[zero_mask,i,j]
        ones = a[one_mask,i,j]

        sns.histplot(zeros,bins=20,label="zeros", alpha=0.5, color="blue",ax=axes[i+1,j],kde=True)
        sns.histplot(ones,bins=20,label="ones", alpha=0.5, color="red",ax=axes[i+1,j],kde=True)


In [10]:
plt.close()

In [11]:
query = {'weights':model.layers[2].get_weights()[0],
         'biases':model.layers[2].get_weights()[1]}
keys = {'weights':model.layers[2].get_weights()[2],
        'biases':model.layers[2].get_weights()[3]}
values = {'weights':model.layers[2].get_weights()[4],
          'biases':model.layers[2].get_weights()[5]}
projection = {'weights':model.layers[2].get_weights()[6],
              'biases':model.layers[2].get_weights()[7]}





In [None]:
projection['weights'].shape

In [None]:
values['weights'].shape

In [None]:
projection['biases'].shape

In [15]:
import scipy

In [16]:
grids = [np.column_stack((np.linspace(-1,1, 100), k*np.ones(100)/10.)) for k in range(-10,11)] +\
            [np.column_stack((k*np.ones(100)/10.,np.linspace(-1,1, 100))) for k in range(-10,11) ]



In [None]:
query['biases'][0,:].shape

In [None]:
train_df[['f1','f2']].values[0].reshape(1,-1) * query['weights'][0,:,:].T + query['biases'][0,:].reshape(-1,1)

In [None]:
a

In [None]:
# draw 2d vector
%matplotlib widget

idx = 1
a = train_df[['f1','f2']].values[idx]
print(a)
plt.plot([0,a[0]], [0,a[1]])
plt.grid()
plt.show()
#plt.close()

In [None]:
%matplotlib widget
a_hat = train_df[['f1','f2']].values[idx] * query['weights'][0,:,:].T + query['biases'][0,:].reshape(-1,1)
print(a_hat)


In [None]:
fig, ax = plt.subplots()
# Define necessary lines to plot a grid-- this will represent the vanilla "input space".
grids = [np.column_stack((np.linspace(-1,1, 100), k*np.ones(100)/10.)) for k in range(-10,11)] +\
        [np.column_stack((k*np.ones(100)/10.,np.linspace(-1,1, 100))) for k in range(-10,11) ]


grid_lines = []

for grid in grids:
    vals = np.array(grid)
    l, = ax.plot(vals[:,0],vals[:,1], color='grey', alpha=.5)
    grid_lines.append(l)

In [None]:
# fig, ax = plt.subplots()
transformed_lines = []
for k in range(len(grid_lines)):
    ln = grid_lines[k]
    grid = grids[k]
    # vals = grid @ query['weights'][0,:,:].T + query['biases'][0,:]
    break
    l, = ax.plot(vals[:,0],vals[:,1], color='grey', alpha=.5)
    transformed_lines.append(l)

In [None]:
query['weights'][0,:,:].shape

In [None]:
# grid.shape  @ query['weights'] + query['biases']

In [26]:
for grid in grids:
    vals = np.array(grid)
        # Reshape grid_line to (100, 2, 1)
    grid_line_reshaped = vals.reshape(100, 2, 1)

    # Squeeze the extra dimension from the weight: (1,1,3) -> (1,3)
    query_weight_squeezed = np.squeeze(query['weights'], axis=0)

    # Apply the query transformation:
    transformed_query = grid_line_reshaped @ query_weight_squeezed + query['biases']

    l, = ax.plot(vals[:,0],vals[:,1], color='grey', alpha=.5)
    break

In [34]:
import numpy as np

def full_attention_transform(X, query, keys, values, projection):
    """
    Applies the full self-attention transformation.
    
    Parameters:
      X: np.array of shape (1, 2) representing one example with 2 tokens.
      query: dict with 'weights' (shape (1,1,3)) and 'biases' (shape (1,3))
      keys: dict with 'weights' (shape (1,1,3)) and 'biases' (shape (1,3))
      values: dict with 'weights' (shape (1,1,3)) and 'biases' (shape (1,3))
      projection: dict with 'weights' (shape (1,3,1)) and 'biases' (shape (1,))
    
    Returns:
      output: the final transformed output after projection, shape (1,2)
      (Also returns intermediate Q, K, V, attention weights, and attention output if needed)
    """
    # Reshape input X to (batch, tokens, feature_dim) = (1,2,1)
    X_reshaped = X.reshape(1, -1, 1)  # Now shape (1,2,1)
    
    # Squeeze out the extra leading dimensions for weights:
    # Each weight maps from 1 -> 3 so final weight shape is (1,3) and bias is (3,)
    W_q = np.squeeze(query['weights'], axis=0)   # (1,3)
    b_q = np.squeeze(query['biases'], axis=0)      # (3,)
    W_k = np.squeeze(keys['weights'], axis=0)      # (1,3)
    b_k = np.squeeze(keys['biases'], axis=0)         # (3,)
    W_v = np.squeeze(values['weights'], axis=0)    # (1,3)
    b_v = np.squeeze(values['biases'], axis=0)       # (3,)
    
    # For projection, remove the extra leading dimension.
    W_p = np.squeeze(projection['weights'], axis=0)  # (3,1)
    b_p = np.squeeze(projection['biases'], axis=0)     # (1,) or scalar
    
    # Compute Q, K, V for each token. Each is computed as: token * weight + bias.
    # X_reshaped is (1,2,1), and np.matmul will perform the multiplication on the last axis.
    Q = np.matmul(X_reshaped, W_q) + b_q  # Resulting shape: (1,2,3)
    K = np.matmul(X_reshaped, W_k) + b_k  # (1,2,3)
    V = np.matmul(X_reshaped, W_v) + b_v  # (1,2,3)
    
    # Compute scaled dot-product attention.
    # For each example in the batch, compute scores = Q @ K^T.
    # Q: (1,2,3) and K^T: (1,3,2) gives scores of shape (1,2,2).
    d_k = Q.shape[-1]  # Should be 3.
    scores = np.matmul(Q, np.transpose(K, (0, 2, 1))) / np.sqrt(d_k)
    
    # Define softmax function (applied along the last axis)
    def softmax(x, axis=-1):
        exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
        return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
    
    attn_weights = softmax(scores, axis=-1)  # (1,2,2)
    
    # Compute attention output: weighted sum of V by attention weights.
    # attn_weights: (1,2,2), V: (1,2,3) -> output shape (1,2,3)
    attention_output = np.matmul(attn_weights, V)
    
    # Apply the projection transformation for each token:
    # Multiply each token vector (3-dim) by W_p (3,1) and add bias.
    # This gives a shape (1,2,1).
    proj_output = np.matmul(attention_output, W_p) + b_p
    
    # Optionally, flatten the output to shape (1,2)
    final_output = proj_output.reshape(1, -1)
    
    return final_output, Q, K, V, attn_weights, attention_output


In [None]:
# Each weight maps from 1 -> 3 so final weight shape is (1,3) and bias is (3,)
W_q = np.squeeze(query['weights'], axis=0)   # (1,3)
b_q = np.squeeze(query['biases'], axis=0)      # (3,)
W_k = np.squeeze(keys['weights'], axis=0)      # (1,3)
b_k = np.squeeze(keys['biases'], axis=0)         # (3,)
W_v = np.squeeze(values['weights'], axis=0)    # (1,3)
b_v = np.squeeze(values['biases'], axis=0)       # (3,)

# For projection, remove the extra leading dimension.
W_p = np.squeeze(projection['weights'], axis=0)  # (3,1)
b_p = np.squeeze(projection['biases'], axis=0)     # (1,) or scalar

x = train_features.reshape(-1, train_features.shape[-1], 1)
Q = np.matmul(x, W_q) + b_q  # Resulting shape: (1,2,3)
K = np.matmul(x, W_k) + b_k  # (1,2,3)
V = np.matmul(x, W_v) + b_v  # (1,2,3)

In [None]:
Z = Q.reshape(-1, Q.shape[1] *Q.shape[2])
Z.shape

In [54]:
attn = model.layers[2]
x = train_features.reshape(-1, train_features.shape[-1], 1)
out_vec, attention_scores = attn(x,x, return_attention_scores=True) # take one sample



# X is (1,2) input

X = x[0,:,:].reshape(1,-1)

output, Q, K, V, attn_weights, attention_output = full_attention_transform(X, query, keys, values, projection)
# print("Final output shape:", output.shape)
# print("Final output:", output)


In [None]:
import torch
import torch.nn as nn
import einops

# D = d_model, F = dictionary_size
# e.g. if d_model = 12288 and dictionary_size = 49152
# then model_activations_D.shape = (12288,)
# encoder_DF.weight.shape = (12288, 49152)

class SparseAutoEncoder(nn.Module):
    """
    A one-layer autoencoder.
    """
    def __init__(self, activation_dim: int, dict_size: int):
        super().__init__()
        self.activation_dim = activation_dim
        self.dict_size = dict_size

        self.encoder_DF = nn.Linear(activation_dim, dict_size, bias=True)
        self.decoder_FD = nn.Linear(dict_size, activation_dim, bias=True)

    def encode(self, model_activations_D: torch.Tensor) -> torch.Tensor:
        return nn.ReLU()(self.encoder_DF(model_activations_D))
    
    def decode(self, encoded_representation_F: torch.Tensor) -> torch.Tensor:
        return self.decoder_FD(encoded_representation_F)
    
    def forward_pass(self, model_activations_D: torch.Tensor):
        encoded_representation_F = self.encode(model_activations_D)
        reconstructed_model_activations_D = self.decode(encoded_representation_F)
        return reconstructed_model_activations_D, encoded_representation_F
    

# B = batch size, D = d_model, F = dictionary_size
def calculate_loss(autoencoder: SparseAutoEncoder, model_activations_BD: torch.Tensor, l1_coefficient: float) -> torch.Tensor:
    reconstructed_model_activations_BD, encoded_representation_BF = autoencoder.forward_pass(model_activations_BD)
    reconstruction_error_BD = (reconstructed_model_activations_BD - model_activations_BD).pow(2)
    reconstruction_error_B = einops.reduce(reconstruction_error_BD, 'B D -> B', 'sum')
    l2_loss = reconstruction_error_B.mean()

    l1_loss = l1_coefficient * encoded_representation_BF.sum()
    loss = l2_loss + l1_loss
    return loss

In [None]:

# Hyperparameters
activation_dim = Z.shape[1]
dict_size = Z.shape[0]
learning_rate = 1e-3
batch_size = 128
num_epochs = 10
l1_coefficient = 1e-5

# Instantiate the model
autoencoder = SparseAutoEncoder(activation_dim, dict_size)

# Assuming Z is already defined and is a torch.Tensor of shape (64000, 6)
# For example, if Z is a numpy array, you can convert it as follows:
# Z = torch.tensor(Z, dtype=torch.float32)

# Create a dataset and dataloader
dataset = torch.utils.data.TensorDataset(Z)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the optimizer
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    epoch_loss = 0.0
    for (batch_data,) in dataloader:
        optimizer.zero_grad()
        loss = calculate_loss(autoencoder, batch_data, l1_coefficient)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * batch_data.size(0)
    epoch_loss /= len(dataset)
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}')