Data downloaded from (access data 09/21/2025): https://www.uniprot.org/help/embeddings?utm_source=chatgpt.com

In [3]:
import h5py
import numpy as np

# Path to your downloaded file
file_path = "./data/per-protein.h5"

# Open the HDF5 file
with h5py.File(file_path, "r") as f:
    # List all protein IDs stored
    protein_ids = list(f.keys())
    print(f"Number of proteins: {len(protein_ids)}")
    print("First 5 IDs:", protein_ids[:5])

    # Access embeddings for a specific protein (by UniProt ID)
    prot_id = protein_ids[0]  # pick the first one
    embedding = f[prot_id][()]  # retrieve as a numpy array

    print(f"\nProtein ID: {prot_id}")
    print(f"Embedding shape: {embedding.shape}")
    print(embedding[:10])  # show first 10 numbers


Number of proteins: 20660
First 5 IDs: ['A0A024R1R8', 'A0A024RBG1', 'A0A024RCN7', 'A0A075B6H5', 'A0A075B6H7']

Protein ID: A0A024R1R8
Embedding shape: (1024,)
[ 0.013664  0.03848   0.03925  -0.0675   -0.0247    0.03268  -0.0171
 -0.10223   0.014656  0.01147 ]


In [5]:
# JUPYTER CELL: Gumbel-Top-k → anneal to hard Top-k (no aux loss)
# ----------------------------------------------------------------------------------
# This notebook cell implements a sparse autoencoder that:
# - Uses per-sample LayerNorm (batch-independent inference)
# - Starts with a relaxed (Gumbel-Top-k-inspired) soft activation
# - Anneals temperature and mixes into hard Top-k over training
# - Trains on per-protein embeddings loaded from an HDF5 file
# - Splits into train/val/test (80/10/10)
# - Logs & plots MSE and MAE vs iteration (train/val/test)
#
# Assumes your HDF5 file structure is:
#   keys: UniProt IDs (e.g., "A0A024R1R8")
#   values: float arrays of shape (1024,)
#
# Example info you pasted:
# Number of proteins: 20660
# First 5 IDs: ['A0A024R1R8', 'A0A024RBG1', 'A0A024RCN7', 'A0A075B6H5', 'A0A075B6H7']
# One embedding shape: (1024,)
#
# You can customize paths & hyperparameters in the CONFIG section below.

from __future__ import annotations
from typing import Callable, Any, Tuple, Dict, List, Optional

import os
import math
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# =========================
# CONFIG (edit as desired)
# =========================
H5_PATH       = "./data/per-protein.h5"  # path to your HDF5 embeddings
TOP_K         = 32                       # desired k
HIDDEN_MULT   = 4                        # hidden_dim = HIDDEN_MULT * input_dim
BATCH_SIZE    = 1024
EPOCHS        = 10
LR            = 1e-3
EVAL_EVERY    = 200                      # evaluate/log every N train iterations
NUM_WORKERS   = 2
SPLIT_SEED    = 123

# Relaxed → Hard annealing
TAU_START     = 1.0                      # initial temperature for soft (higher = softer)
TAU_END       = 0.05                     # final temperature for soft (near-hard)
WARMUP_STEPS  = 2000                     # number of iterations to anneal tau & mix alpha
# alpha schedule: 0 → 1 over WARMUP_STEPS (0 = fully soft, 1 = fully hard)

# =========================
# Utilities & reproducibility
# =========================
def set_seed(seed=42):
    import random
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =========================
# Per-sample LayerNorm (batch-independent)
# =========================
def LN(x: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Per-sample layer-norm across features (last dim).
    Returns (x_norm, mu, std) for undoing later.
    """
    mu = x.mean(dim=-1, keepdim=True)
    xc = x - mu
    std = torch.sqrt((xc * xc).mean(dim=-1, keepdim=True) + eps)
    x_norm = xc / std
    return x_norm, mu, std

# =========================
# Activation helpers
# =========================
def hard_topk_relu(x: torch.Tensor, k: int) -> torch.Tensor:
    """
    Hard Top-k after ReLU (non-differentiable selection, standard STE-like behavior).
    """
    x = F.relu(x)
    if k >= x.shape[-1]:
        return x
    vals, idxs = torch.topk(x, k=k, dim=-1, largest=True, sorted=False)
    out = torch.zeros_like(x)
    out.scatter_(-1, idxs, vals)
    return out

def gumbel_noise_like(x: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
    u = torch.rand_like(x)
    return -torch.log(-torch.log(u + eps) + eps)

def cosine_anneal(start: float, end: float, t: float) -> float:
    """
    Cosine annealing from start→end with t in [0,1].
    """
    return end + 0.5*(start - end)*(1 + math.cos(math.pi * t))

class RelaxedTopK(nn.Module):
    """
    Gumbel-Top-k inspired soft activation + annealing into hard Top-k.

    Forward uses:
      - Soft branch: softmax((ReLU(z) + gumbel)/tau) as a distribution over units,
        then produces a soft activation vector whose L1 mass roughly matches the
        "energy" in ReLU(z). We form h_soft = soft_probs * sum(ReLU(z)) to keep scale.
      - Hard branch: standard hard_topk_relu(z, k).
      - Mix: h = (1 - alpha) * h_soft + alpha * h_hard, with alpha rising from 0→1
        over WARMUP_STEPS; tau decays from TAU_START→TAU_END over the same window.
    """
    def __init__(self, k: int, tau_start: float, tau_end: float, warmup_steps: int):
        super().__init__()
        self.k = int(k)
        self.tau_start = float(tau_start)
        self.tau_end = float(tau_end)
        self.warmup_steps = int(warmup_steps)
        self.register_buffer("step", torch.tensor(0, dtype=torch.long), persistent=False)

    def update_step(self):
        self.step += 1

    def _tau_alpha(self) -> Tuple[float, float]:
        # Progress in [0, 1]
        t = float(min(int(self.step.item()), self.warmup_steps)) / max(1, self.warmup_steps)
        tau = cosine_anneal(self.tau_start, self.tau_end, t)
        alpha = t  # linear 0 → 1
        return tau, alpha

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        tau, alpha = self._tau_alpha()
        # Soft branch
        z_pos = F.relu(z)
        if self.training:
            g = gumbel_noise_like(z_pos)
        else:
            # At eval, no gumbel noise; still allow soft if not fully annealed
            g = torch.zeros_like(z_pos)

        logits = (z_pos + g) / max(tau, 1e-6)
        soft_probs = torch.softmax(logits, dim=-1)  # distribution over units
        mass = z_pos.sum(dim=-1, keepdim=True).clamp_min(1e-6)
        h_soft = soft_probs * mass  # soft "allocation" of mass

        # Hard branch
        h_hard = hard_topk_relu(z, self.k)

        # Mix
        h = (1.0 - alpha) * h_soft + alpha * h_hard
        # Update internal step counter
        if self.training:
            self.update_step()
        return h

# =========================
# Autoencoder (with pre-bias, tied/untied decoder, optional normalize)
# =========================
class TiedTranspose(nn.Module):
    def __init__(self, linear: nn.Linear):
        super().__init__()
        self.linear = linear
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert self.linear.bias is None
        return F.linear(x, self.linear.weight.t(), bias=None)
    @property
    def weight(self) -> torch.Tensor:
        return self.linear.weight.t()

class Autoencoder(nn.Module):
    """
    x' = preprocess(x)                 # optional per-sample LN
    z  = (x' - pre_bias) @ W_enc^T + latent_bias
    h  = activation(z)                 # RelaxedTopK (anneals to hard)
    y' = decoder(h) + pre_bias
    y  = unprocess(y')                 # undo LN (if applied)
    """
    def __init__(
        self,
        n_latents: int,
        n_inputs: int,
        activation: nn.Module,
        tied: bool = False,
        normalize: bool = True,
        enc_init: str = "xavier",
        dec_init: str = "xavier",
    ):
        super().__init__()
        self.normalize = normalize
        self.pre_bias = nn.Parameter(torch.zeros(n_inputs))
        self.encoder = nn.Linear(n_inputs, n_latents, bias=False)
        if enc_init == "kaiming":
            nn.init.kaiming_uniform_(self.encoder.weight, a=math.sqrt(5))
        else:
            nn.init.xavier_uniform_(self.encoder.weight)

        self.latent_bias = nn.Parameter(torch.zeros(n_latents))

        self.activation = activation

        if tied:
            self.decoder = TiedTranspose(self.encoder)
        else:
            self.decoder = nn.Linear(n_latents, n_inputs, bias=False)
            if dec_init == "kaiming":
                nn.init.kaiming_uniform_(self.decoder.weight, a=math.sqrt(5))
            else:
                nn.init.xavier_uniform_(self.decoder.weight)

    def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        if not self.normalize:
            return x, {}
        x_norm, mu, std = LN(x)
        return x_norm, {"mu": mu, "std": std}

    def encode_pre_act(self, x: torch.Tensor) -> torch.Tensor:
        x = x - self.pre_bias
        return F.linear(x, self.encoder.weight, self.latent_bias)

    def decode(self, h: torch.Tensor, info: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
        y_prime = self.decoder(h) + self.pre_bias
        if self.normalize:
            assert info is not None
            y_prime = y_prime * info["std"] + info["mu"]
        return y_prime

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x_p, info = self.preprocess(x)
        z = self.encode_pre_act(x_p)
        h = self.activation(z)
        y = self.decode(h, info)
        return z, h, y

# =========================
# Data (HDF5) + splits
# =========================
class H5ProteinEmbeddings(Dataset):
    """HDF5 with {protein_id: embedding_vector}."""
    def __init__(self, h5_path: str, ids: Optional[List[str]] = None):
        super().__init__()
        self.h5_path = h5_path
        with h5py.File(h5_path, "r") as f:
            all_ids = list(f.keys())
        self.ids = all_ids if ids is None else ids
        with h5py.File(h5_path, "r") as f:
            self.input_dim = int(f[self.ids[0]][()].shape[-1])
        self._f = None

    def __len__(self): return len(self.ids)

    def __getitem__(self, i: int):
        if self._f is None:
            self._f = h5py.File(self.h5_path, "r")
        x = self._f[self.ids[i]][()].astype("float32")
        return torch.from_numpy(x)

# =========================
# Evaluation
# =========================
@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> Tuple[float, float]:
    model.eval()
    tot_mse = 0.0
    tot_mae = 0.0
    n_elems = 0
    for xb in loader:
        xb = xb.to(device)
        _, _, yb = model(xb)
        tot_mse += F.mse_loss(yb, xb, reduction="sum").item()
        tot_mae += F.l1_loss(yb, xb, reduction="sum").item()
        n_elems += xb.numel()
    return tot_mse / n_elems, tot_mae / n_elems

# =========================
# Training loop
# =========================
def train_relaxed_topk_sae(
    h5_path=H5_PATH,
    k=TOP_K,
    hidden_mult=HIDDEN_MULT,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    lr=LR,
    eval_every=EVAL_EVERY,
    num_workers=NUM_WORKERS,
    split_seed=SPLIT_SEED,
    tau_start=TAU_START,
    tau_end=TAU_END,
    warmup_steps=WARMUP_STEPS,
    tied=False,
    normalize=True,
):
    # ---- Load IDs & split 80/10/10 ----
    with h5py.File(h5_path, "r") as f:
        ids = list(f.keys())
    rng = np.random.default_rng(split_seed)
    rng.shuffle(ids)
    n = len(ids); n_tr = int(0.8*n); n_va = int(0.1*n)
    tr_ids = ids[:n_tr]; va_ids = ids[n_tr:n_tr+n_va]; te_ids = ids[n_tr+n_va:]

    train_ds = H5ProteinEmbeddings(h5_path, tr_ids)
    val_ds   = H5ProteinEmbeddings(h5_path, va_ids)
    test_ds  = H5ProteinEmbeddings(h5_path, te_ids)

    input_dim = train_ds.input_dim
    hidden_dim = hidden_mult * input_dim

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    # ---- Model ----
    act = RelaxedTopK(k=k, tau_start=tau_start, tau_end=tau_end, warmup_steps=warmup_steps)
    model = Autoencoder(
        n_latents=hidden_dim,
        n_inputs=input_dim,
        activation=act,
        tied=tied,
        normalize=normalize,
        enc_init="xavier",
        dec_init="xavier",
    ).to(device)

    # ---- Optimizer ----
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    # ---- History ----
    hist: Dict[str, List[float]] = {
        "iter": [], "train_mse": [], "val_mse": [], "test_mse": [],
        "train_mae": [], "val_mae": [], "test_mae": []
    }

    it = 0
    for ep in range(1, epochs + 1):
        model.train()
        for xb in train_loader:
            xb = xb.to(device, non_blocking=True)
            _, _, yb = model(xb)
            loss = F.mse_loss(yb, xb, reduction="mean")
            opt.zero_grad(); loss.backward(); opt.step()

            it += 1
            if it % eval_every == 0:
                train_mse = loss.item()
                train_mae = F.l1_loss(yb.detach(), xb, reduction="mean").item()
                val_mse, val_mae   = evaluate(model, val_loader, device)
                test_mse, test_mae = evaluate(model, test_loader, device)

                hist["iter"].append(it)
                hist["train_mse"].append(train_mse)
                hist["val_mse"].append(val_mse)
                hist["test_mse"].append(test_mse)
                hist["train_mae"].append(train_mae)
                hist["val_mae"].append(val_mae)
                hist["test_mae"].append(test_mae)

        # End-of-epoch status
        if hist["iter"]:
            print(f"Epoch {ep:02d} | last train MSE {hist['train_mse'][-1]:.6f} | "
                  f"val MSE {hist['val_mse'][-1]:.6f} | test MSE {hist['test_mse'][-1]:.6f}")
        else:
            print(f"Epoch {ep:02d} | (no eval yet; decrease EVAL_EVERY to log sooner)")

    # ---- Save weights ----
    torch.save(
        {"state_dict": model.state_dict(),
         "config": {
             "input_dim": input_dim,
             "hidden_dim": hidden_dim,
             "k": k,
             "tau_start": tau_start,
             "tau_end": tau_end,
             "warmup_steps": warmup_steps,
             "tied": tied,
             "normalize": normalize,
         }},
        "relaxed_topk_sae.pt",
    )

    # ---- Plots ----
    if hist["iter"]:
        plt.figure()
        plt.plot(hist["iter"], hist["train_mse"], label="train MSE")
        plt.plot(hist["iter"], hist["val_mse"],   label="val MSE")
        plt.plot(hist["iter"], hist["test_mse"],  label="test MSE")
        plt.xlabel("Iteration"); plt.ylabel("MSE"); plt.title("MSE vs Iteration (Relaxed→Hard Top-k SAE)"); plt.legend()
        plt.savefig("relaxed_topk_sae_mse.png", bbox_inches="tight"); plt.close()

        plt.figure()
        plt.plot(hist["iter"], hist["train_mae"], label="train MAE")
        plt.plot(hist["iter"], hist["val_mae"],   label="val MAE")
        plt.plot(hist["iter"], hist["test_mae"],  label="test MAE")
        plt.xlabel("Iteration"); plt.ylabel("MAE"); plt.title("MAE vs Iteration (Relaxed→Hard Top-k SAE)"); plt.legend()
        plt.savefig("relaxed_topk_sae_mae.png", bbox_inches="tight"); plt.close()

    return model, hist

# =========================
# Run training
# =========================
# model, history = train_relaxed_topk_sae()
# After running:
#   - Check "relaxed_topk_sae.pt" for weights & config
#   - "relaxed_topk_sae_mse.png" and "relaxed_topk_sae_mae.png" for curves



In [6]:
model, history = train_relaxed_topk_sae()

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'H5ProteinEmbeddings' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'H5ProteinEmbeddings' on <module '__main__' (built-in)>


RuntimeError: DataLoader worker (pid(s) 1859, 1860) exited unexpectedly