# GPH - Genomic Pretrained Model with Hyena

**Author:** Riley Xin

**Date:** May 18, 2025

**Class:** GENE46100

---

This notebook implements a **Hyena-based genomic language model**, adapted from:

- [The Annotated Hyena Blog Post](https://medium.com/autonomous-agents/evo2-demystified-the-ultimate-technical-guide-to-genomic-language-modeling-a75b0afe7b87)  
- [The Annotated Hyena Notebook](https://github.com/expz/annotated-hyena/blob/master/annotated_hyena.ipynb)  
- Henry’s nanoGPT notebook
- Andrej Karpathy's [nanoGPT framework](https://github.com/karpathy/nanoGPT)

The code is configured to run on Google Colab with T4 GPU. It includes an implementation of the Hyena operator and a transformer-style backbone similar to nano-GPT, where the self-attention block can be directly replaced by a Hyena block. The model is first pretrained on human genome FASTA data, followed by a fine-tuning example on an enhancer classification task.


In [None]:
#@title Setup
%%capture
# !pip install torch numpy datasets lightning huggingface_hub regex tiktoken wandb tqdm accelerate
import os
import pickle
import requests
import numpy as np
import pandas as pd
import math
import random
import inspect
import regex
# import lightning
from typing import Any, Dict, List, Optional, Tuple, Type
from dataclasses import dataclass, asdict
import torch
import torch.nn as nn
from torch.nn import functional as F
import gc
import time
from contextlib import nullcontext
from torch.nn.parallel import DistributedDataParallel as DDP
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import matthews_corrcoef, f1_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [None]:
# Some issues related to load_dataset having to do with package versions on colab
# !pip install -U datasets

In [None]:
#@title Mount to drive and prepare the genome data
from google.colab import drive
drive.mount('/content/drive')

DIR = "/content/drive/MyDrive/UChicago/Spring2025/GENE46100/project"

data_dir = os.path.join(DIR, "data")
fasta_file = os.path.join(data_dir, "genome.fa")
!wget -O - https://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz | gunzip -c > {fasta_file}

In [None]:
fasta_file = os.path.join(data_dir, "genome.fa")
with open(fasta_file, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")
data[:1]

length of dataset in characters: 3,273,481,150
all the unique characters: 
0123456789>ABCGHIJKLMNTUXY_acdghlmnortv
vocab size: 40


'>'

In [None]:
#@title Tokenization
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }

def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

s = "ACGTACGT"
print(encode(s))
print(decode(encode(s)))

In [None]:
# Create the train and test splits
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

# Helper function to store/append data to binary files and save RAM
def append_to_binary_file(filename, obj):
    with open(filename, 'ab') as file:  # Open in append-binary mode
        obj.tofile(file)

# Encode both to integers
ntrain = len(train_data)
nval = len(val_data)
for i in np.arange(0, 1, 0.05):
    print(f"Tokenizing: {int(i*100):,}% complete")
    train_ids = encode(train_data[int(ntrain*i):int(ntrain*(i+0.05))])
    val_ids = encode(val_data[int(nval*i):int(nval*(i+0.05))])
    train_ids = np.array(train_ids, dtype=np.uint16)
    val_ids = np.array(val_ids, dtype=np.uint16)
    append_to_binary_file(os.path.join(data_dir, 'train.bin'), train_ids)
    append_to_binary_file(os.path.join(data_dir, 'val.bin'), val_ids)

print(f"train has {ntrain:,} tokens")
print(f"val has {nval:,} tokens")

Tokenizing: 0% complete
Tokenizing: 5% complete
Tokenizing: 10% complete
Tokenizing: 15% complete
Tokenizing: 20% complete
Tokenizing: 25% complete
Tokenizing: 30% complete
Tokenizing: 35% complete
Tokenizing: 40% complete
Tokenizing: 45% complete
Tokenizing: 50% complete
Tokenizing: 55% complete
Tokenizing: 60% complete
Tokenizing: 65% complete
Tokenizing: 70% complete
Tokenizing: 75% complete
Tokenizing: 80% complete
Tokenizing: 85% complete
Tokenizing: 90% complete
Tokenizing: 95% complete
train has 2,946,133,035 tokens
val has 327,348,115 tokens


In [None]:
# Save the meta information as well, to help us encode/decode later
meta = {
    'vocab_size': vocab_size,
    'itos': itos,
    'stoi': stoi,
}
with open(os.path.join(data_dir, 'meta.pkl'), 'wb') as f:
    pickle.dump(meta, f)

In [None]:
#@title Load saved data
# Avoid rerunning the data processing steps
data_dir = os.path.join(DIR, "data")

# Load the meta information (for encoding/decoding)
with open(os.path.join(data_dir, 'meta.pkl'), 'rb') as f:
    meta = pickle.load(f)
stoi = meta['stoi']
itos = meta['itos']
vocab_size = meta['vocab_size']

# Define encode/decode functions
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Load the tokenized train/val data
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')

print(f"Loaded train_data with {len(train_data):,} tokens")
print(f"Loaded val_data with {len(val_data):,} tokens")

Loaded train_data with 2,946,133,035 tokens
Loaded val_data with 327,348,115 tokens


# Creating the Model

Now that we’ve load our training data, we turn to constructing the full GP-Hyena model. The architecture is modeled after nanoGPT, with the standard attention mechanism replaced by the Hyena operator, which persumably enables more efficient long-range modeling. Below is an overview of the model’s structure and components.

The main model is implemented in a class called GPH (Genomic Pretrained Model with Hyena). It consists of several core components:

* **Embedding**: Genomic tokens (from a 40-character vocabulary) are projected into a higher-dimensional latent space using token and position embeddings. This allows the model to represent more complex patterns in sequence data, including positional context.

* **LayerNorm**: These normalization layers stabilize training by maintaining feature-wise consistency in scale, applied before Hyena operations and MLPs.

* **Dropout**: Applied both after embeddings and within blocks, dropout helps regularize training and prevent overfitting.

* **Hyena Block**: The central part in GP-Hyena. Instead of using self-attention, this block uses a long convolutional operator (Hyena) to model sequence dependencies in a more scalable way. Each Hyena block includes a projection, filtering mechanism, and nonlinear transformation via FFT-based convolution.

* **MLP**: A feedforward module consisting of a two-layer linear projection with a GELU activation in between. It adds non-linearity and capacity to the model.

* **Block**: A full processing unit in the model, defined as: LayerNorm → Hyena Block → LayerNorm → MLP. These are stacked multiple times (e.g. 6 blocks for a 6-layer model).

* **Final LayerNorm and Output Head**: After passing through all blocks, the output is normalized again and projected to the vocabulary space using a linear head. If training, loss is calculated using cross-entropy.

The forward pass of GPH proceeds in the following order:

1. Token and Position Embedding Layers

2. Dropout Layer

3. Sequential Block Stack (each containing Hyena + MLP)

4. Final LayerNorm

5. Linear Projection to Vocabulary

6. Loss Computation (optional during training)

We will show an implementation of each component followed by the training loop for training the GPH model.


In [None]:
#@title Define the Hyena block and the GP model backbone
@dataclass(kw_only=True)
class Config:
    # Hyena-specific
    d_model: int # overall
    d_embed: int # Hyena filter
    d_filter_mlp: int
    n_filter_layers: int
    context_length: int
    short_conv_size: int
    order: int
    pdrop_hyena: float
    pdrop_embed: float
    omega: Optional[int]
    n_layers: int
    vocab_size: int
    bias: bool
    # Training-specific
    learning_rate: float
    epochs: int
    betas: Tuple[float, float]
    weight_decay: float
    device_type: str
    batch_size: int

class Projection(nn.Module):
    def __init__(self, d_model: int, N: int, conv_len: int):
        super().__init__()
        self.d_model = d_model
        self.linear = nn.Linear(d_model, d_model * (N + 1))
        self.conv = nn.Conv1d(
            in_channels=d_model * (N + 1),
            out_channels=d_model * (N + 1),
            kernel_size=conv_len,
            groups=d_model * (N + 1),
            padding=conv_len - 1
        )

    def forward(self, u):
        z = self.linear(u)
        z = z.transpose(1, 2)
        L = z.shape[-1]
        z = self.conv(z)[..., :L]
        return torch.split(z, self.d_model, dim=1)

class FFTConv(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, h, x, B):
        L = h.shape[-1]
        f_h = torch.fft.rfft(h, n=2 * L, norm="forward")
        f_x = torch.fft.rfft(x.to(dtype=h.dtype), n=2 * L)
        y = torch.fft.irfft(f_h * f_x, n=2 * L, norm="forward")[..., :L]
        y = y + x * B
        return y.to(dtype=h.dtype)

class Window(nn.Module):
    def __init__(
        self,
        d_model: int,
        max_seq_len: int,
        fast_decay_pct: float = 0.3,
        slow_decay_pct: float = 1.5,
        target: float = 1e-2,
    ):
        super().__init__()
        self.b = nn.Parameter(torch.zeros((1, d_model, 1)))
        min_decay = math.log(target) / slow_decay_pct
        max_decay = math.log(target) / fast_decay_pct
        self.alphas = nn.Parameter(
            torch.linspace(min_decay, max_decay, d_model)[None, :, None]
        )
        self.t = nn.Parameter(
            torch.linspace(0, 1, max_seq_len)[None, None, :],
            requires_grad=False
        )

    def forward(self, x):
        L = x.shape[2]
        c = torch.exp(self.alphas * self.t)[:, :, :L]
        return x * (c + self.b)

class HyenaFilter(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_mlp: int,
        d_embed: int,
        N: int,
        n_layers: int = 4,
        max_seq_len: int = 128,
        omega: int = 8,
    ):
        assert n_layers >= 2, "n_layers must be at least 2"
        super().__init__()

        self.N = N
        self.d_model = d_model
        self.h = nn.Parameter(torch.randn((N, d_model, max_seq_len)))
        self.window = Window(d_embed, max_seq_len)

    def forward(self, L: int) -> torch.Tensor:
        h = self.h[:, :, :L]  # [N, d_embed, L]
        h = self.window(h)
        h = h / torch.norm(h, dim=-2, p=1, keepdim=True)
        return h

class HyenaBlock(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.proj_input = Projection(config.d_model, config.order, config.short_conv_size)
        self.proj_output = nn.Linear(config.d_model, config.d_model)
        self.filter = HyenaFilter(
            config.d_model,
            config.d_filter_mlp,
            config.d_embed,
            config.order,
            config.n_filter_layers,
            config.context_length,
            config.omega,
        )
        self.dropout = nn.Dropout(config.pdrop_hyena)
        self.fft_conv = FFTConv()
        self.B = nn.Parameter(torch.randn((config.order, 1, config.d_model, 1)))

    def forward(self, u: torch.Tensor) -> torch.Tensor:
        L = u.shape[1]
        *x, v = self.proj_input(u)
        v = v + u.transpose(1, 2)
        h = self.filter(L)
        for i, x_i in enumerate(x):
            h_i = h[i].unsqueeze(0)
            v = v + F.normalize(x_i, dim=1) * self.fft_conv(h_i, v, self.B[i])
        v = v.transpose(1, 2)
        y = v + self.dropout(self.proj_output(v))
        return y

class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)


class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.d_model, 4 * config.d_model, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.d_model, config.d_model, bias=config.bias)
        self.dropout = nn.Dropout(config.pdrop_embed)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.d_model, bias=config.bias)
        self.hyena = HyenaBlock(config)
        self.ln_2 = LayerNorm(config.d_model, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        val = self.hyena(self.ln_1(x))
        x = x + val
        x = x + self.mlp(self.ln_2(x))
        return x

class GPH(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.context_length is not None
        self.config = config

        self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Parameter(torch.randn(1, config.context_length, config.d_model))
        self.drop = nn.Dropout(config.pdrop_embed)

        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layers)])
        self.ln_f = LayerNorm(config.d_model, bias=config.bias)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.lm_head.weight = self.tok_emb.weight  # weight tying

        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layers))

        print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))

    def get_num_params(self, non_embedding=True):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        b, t = idx.size()
        assert t <= self.config.context_length, f"Cannot forward sequence of length {t}, context length is only {self.config.context_length}"

        x = self.tok_emb(idx) + self.pos_emb[:, :t, :]
        x = self.drop(x)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

        return logits, loss

    def configure_optimizers(self, weight_decay=None, learning_rate=None, betas=None, device_type=None):
        weight_decay = weight_decay or self.config.weight_decay
        learning_rate = learning_rate or self.config.learning_rate
        betas = betas or self.config.betas
        device_type = device_type or self.config.device_type

        param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        print(f"num decayed parameter tensors: {len(decay_params)}, with {sum(p.numel() for p in decay_params):,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {sum(p.numel() for p in nodecay_params):,} parameters")

        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()

        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"using fused AdamW: {use_fused}")
        return optimizer

    def estimate_mfu(self, fwdbwd_per_iter, dt):
        """ Estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
        N = self.get_num_params()
        cfg = self.config
        L, H, Q, T = cfg.n_layers, 1, cfg.d_model, cfg.context_length  # Hyena uses H=1 (no heads)
        flops_per_token = 6*N + 12*L*H*Q*T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        flops_achieved = flops_per_iter * (1.0 / dt)
        flops_promised = 312e12  # A100 peak TFLOPs for bfloat16
        return flops_achieved / flops_promised

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.config.context_length else idx[:, -self.config.context_length:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature

            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

In [None]:
#@title Training time

gc.collect()
# -----------------------------------------------------------------------------
# default config values designed to train a gpt on hg38 reference genome
# I/O
out_dir = os.path.join(DIR, "out")
eval_interval = 1000
log_interval = 100
eval_iters = 100
eval_only = False # if True, script exits right after the first eval
always_save_checkpoint = False # if True, always save a checkpoint after each eval
init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
# wandb logging
wandb_log = False # disabled by default
wandb_project = 'genome_char'
wandb_run_name = 'gph-colab-' + str(int(time.time()))
# data
dataset = 'genome_char'
gradient_accumulation_steps = 1 # used to simulate larger batch sizes
context_length = 256
# model
batch_size=64
bias = False # do we use bias inside LayerNorm and Linear layers?
# adamw optimizer
max_iters = 3000 # total number of training iterations
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
# learning rate decay settings
learning_rate=6e-4
decay_lr = True # whether to decay the learning rate
weight_decay = 0.4
warmup_iters = 100 # how many steps to warm up for
lr_decay_iters = 3000 # should be ~= max_iters per Chinchilla
min_lr = 1e-4 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
# DDP settings
#backend = 'nccl' # 'nccl', 'gloo', etc.
# system
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'float16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
if torch.cuda.is_available() and device == 'cuda':
    print('cuda')
else:
    print('CPU')
compile = True # use PyTorch 2.0 to compile the model to be faster
# -----------------------------------------------------------------------------
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
#exec(open('configurator.py').read()) # overrides from command line or config file
config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------

# various inits, derived attributes, I/O setup
master_process = True
seed_offset = 0
ddp_world_size = 1
tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * context_length
print(f"tokens per iteration will be: {tokens_per_iter:,}")

if master_process:
    os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(1337 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# poor man's data loader
data_dir = os.path.join(DIR, "data")
def get_batch(split):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - context_length, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+context_length]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+context_length]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
iter_num = 0
best_val_loss = 1e9

# attempt to derive vocab_size from the dataset (should be 4 in the case of DNA data)
meta_path = os.path.join(data_dir, 'meta.pkl')
meta_vocab_size = None
if os.path.exists(meta_path):
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    meta_vocab_size = meta['vocab_size']
    print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")

gphconf = Config(
  d_model=386,
  n_layers=4,
  vocab_size=meta_vocab_size if meta_vocab_size is not None else 40,
  d_embed=33,
  d_filter_mlp=64,
  n_filter_layers=4,
  context_length=context_length,
  short_conv_size=3,
  order=2,
  pdrop_hyena=0.0,
  pdrop_embed=0.2,
  omega=12,
  epochs=40,
  learning_rate=learning_rate,
  betas=(0.9, 0.98),
  weight_decay=weight_decay,
  device_type=device,
  batch_size=64,
  bias = False
)

 # start with model_args from command line
if init_from == 'scratch':
    print("Initializing a new Hyena model from scratch")
    model = GPH(gphconf)

model.to(device)
# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

# optimizer
optimizer = model.configure_optimizers(
    weight_decay=weight_decay,
    learning_rate=learning_rate,
    betas=(0.9, 0.98),
    device_type=device_type
)
checkpoint = None # free up memory

# torch compile
if compile:
    print("compiling the model... (takes ~1 min)")
    raw_model = model
    model = torch.compile(model)

# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)[:2]
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

# logging
if wandb_log and master_process:
    import wandb
    wandb.init(project=wandb_project, name=wandb_run_name, config=config, dir = os.path.join(DIR, "log"))

# training loop
X, Y = get_batch('train') # fetch the very first batch
t0 = time.time()
local_iter_num = 0 # number of iterations in the lifetime of this process
raw_model = model
running_mfu = -1.0
while True:

    # determine and set the learning rate for this iteration
    lr = get_lr(iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # evaluate the loss on train/val sets and write checkpoints
    if iter_num % eval_interval == 0 and master_process:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if wandb_log:
            wandb.log({
                "iter": iter_num,
                "train/loss": losses['train'],
                "val/loss": losses['val'],
                "lr": lr,
                "mfu": running_mfu*100, # convert to percentage
            })
        if losses['val'] < best_val_loss or always_save_checkpoint:
            best_val_loss = losses['val']
            if iter_num > 0:
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': asdict(gphconf),
                    'iter_num': iter_num,
                    'best_val_loss': best_val_loss,
                    'config': config,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
    if iter_num == 0 and eval_only:
        break

    # forward backward update, with optional gradient accumulation to simulate larger batch size
    # and using the GradScaler if data type is float16
    for micro_step in range(gradient_accumulation_steps):
        with ctx:
            logits, loss = model(X, Y)
            loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
        # immediately async prefetch next batch while model is doing the forward pass on the GPU
        X, Y = get_batch('train')
        # backward pass, with gradient scaling if training in fp16
        scaler.scale(loss).backward()
    # clip the gradient
    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # step the optimizer and scaler if training in fp16
    scaler.step(optimizer)
    scaler.update()
    # flush the gradients as soon as we can, no need for this memory anymore
    optimizer.zero_grad(set_to_none=True)

    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % log_interval == 0 and master_process:
        # get loss as float. note: this is a CPU-GPU sync point
        # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
        lossf = loss.item() * gradient_accumulation_steps
        if local_iter_num >= 5: # let the training loop settle a bit
            mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
            running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
    iter_num += 1
    local_iter_num += 1

    # termination conditions
    if iter_num > max_iters:
        break

cuda
tokens per iteration will be: 16,384
found vocab_size = 40 (inside /content/drive/MyDrive/UChicago/Spring2025/GENE46100/project/data/meta.pkl)
Initializing a new Hyena model from scratch
number of parameters: 8.09M
num decayed parameter tensors: 38, with 8,076,664 parameters
num non-decayed parameter tensors: 21, with 14,282 parameters
using fused AdamW: True
compiling the model... (takes ~1 min)


  scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))


step 0: train loss 3.7035, val loss 3.6984
iter 0: loss 3.7042, time 116724.26ms, mfu -100.00%
iter 100: loss 2.2562, time 202.00ms, mfu 1.39%
iter 200: loss 1.5458, time 190.50ms, mfu 1.39%
iter 300: loss 1.4356, time 187.66ms, mfu 1.40%
iter 400: loss 1.3567, time 189.51ms, mfu 1.41%
iter 500: loss 1.4302, time 192.94ms, mfu 1.41%
iter 600: loss 1.3574, time 190.57ms, mfu 1.42%
iter 700: loss 1.4423, time 188.92ms, mfu 1.43%
iter 800: loss 1.3711, time 189.68ms, mfu 1.43%
iter 900: loss 1.3690, time 189.08ms, mfu 1.44%
step 1000: train loss 1.3698, val loss 1.2614
saving checkpoint to /content/drive/MyDrive/UChicago/Spring2025/GENE46100/project/out
iter 1000: loss 1.3796, time 29874.34ms, mfu 1.29%
iter 1100: loss 1.3784, time 196.40ms, mfu 1.31%
iter 1200: loss 1.4161, time 193.62ms, mfu 1.32%
iter 1300: loss 1.3839, time 190.15ms, mfu 1.34%
iter 1400: loss 1.3562, time 186.61ms, mfu 1.35%
iter 1500: loss 1.3761, time 189.25ms, mfu 1.36%
iter 1600: loss 1.3361, time 193.85ms, mfu 1.

In [None]:
#@title Sampling from the pre-trained model
# Generating config
init_from = 'resume'
DIR = "/content/drive/MyDrive/UChicago/Spring2025/GENE46100/project/"
out_dir = os.path.join(DIR, "out")
meta_path = os.path.join(DIR, "data", "meta.pkl")
start = "\n"  # initial prompt
num_samples = 5
max_new_tokens = 250
temperature = 0.8
top_k = 200
seed = 1337
device = 'cuda'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
compile = True
# -----------------------------------------------------------------------------

# Setup
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device_type = 'cuda' if 'cuda' in device else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# Load checkpoint
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
model_args = checkpoint['model_args']
gphconf = Config(**model_args)
model = GPH(gphconf)

# Handle '_orig_mod.' prefix from torch.compile() in saved state_dict
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

model.load_state_dict(state_dict)
model.eval().to(device)

if compile:
    model = torch.compile(model)

# Load vocab (stoi, itos)
with open(meta_path, 'rb') as f:
    meta = pickle.load(f)
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Encode prompt
if start.startswith('FILE:'):
    with open(start[5:], 'r', encoding='utf-8') as f:
        start = f.read()
start_ids = encode(start)
x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]



In [None]:
# Generate sequences
with torch.no_grad():
    with ctx:
        for k in range(num_samples):
            y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
            print(decode(y[0].tolist()))
            print('---------------')


AAAAACACTTGAAAAAGGACACCCCCAAATAAAGCAATGCGACCTGCTCT
TATGTAGATTTTACAATTCCAGGGTCCATTCAGGTCAAAGATGATTGTTG
TCCTACAAGTAATTCCAAATGCACCCTCTGGCACATTTTCATTCTGAGTG
TACAGATATAATCTTGAAAATGATAGAATCATTAttttttttttattatt
tttaattcatttataaactttaaaggtttttttttggaattgtatt
---------------

TATGTGAACAAAACTTCTTAAATTATGTAAAATGTGATTCTTTTTTATGT
AGTTATTCTATTTTTCATTTCTGAAAAAATAAAAATTTGATTTTGTGAAA
TTAAAAAAATCACAGATTAATTTTCTGAAAAAATTGGCCATTTTTAAACA
AAATTTTTTTTATTCTTCAATTTTTAATTTTTTAACATATAAAAATATAA
ATGTTTATTGAGTAATTTAATAACAATCACTTTAAATTTTTTTTAA
---------------

ATAACATCAGTAAATTTAAAAATTATTGGCTTTTTCATCTGTGT
ATATT
GGCTTCCATCCTGTATGGTAGTCTTTGAGAAATATTTCACTTAT
GGAGA
CGATTCCATTTCCATTATTAGTGGGCTATCTGGTTTTCTTTTTC
TTTTT
AAATTTAGATctttatatttttttggctttgtgttttattttat
ttttt
ttttttttttatattttttttttacatttatttaattgcaatga
a
---------------

cccagccttcatttcctgagag
ttttgaatttctgtttttttttcttga
ttgaaacagagtattcctttct
taaataggatctcaagagttaaaCTGT
CTTTGACAAAAATTTTATCTCC
TAAGTTGTCTTTATTCTGATGTAAAAA
TTTATATTTTTTTATGTGATAG
ATTGTTATTAATTTAACTT

# Fine-tune the model for enhancer classification

As a downstream application of the GPH model, I evaluated its performance on a three-class enhancer classification task that was introduced in class using the nano-GPT framework. This task involves predicting the enhancer type from raw DNA sequence input. Details on the dataset are available on [Hugging Face Datasets](https://huggingface.co/datasets/InstaDeepAI/nucleotide_transformer_downstream_tasks).

In [None]:
#@title Define the model for multi-task classification
class ClassificationModel(nn.Module):
    def __init__(self, base_model, num_labels):
        super().__init__()
        self.base_model = base_model
        self.classifier = nn.Linear(base_model.config.d_model, num_labels)
        self.num_labels = num_labels

    def forward(self, input_ids, labels=None):
        # get the output of the base model before the language modeling head
        b, t = input_ids.size()
        assert t <= self.base_model.config.context_length, f"Cannot forward sequence of length {t}, context length is only {self.base_model.config.context_length}"

        x = self.base_model.tok_emb(input_ids) + self.base_model.pos_emb[:, :t, :]
        x = self.base_model.drop(x)
        x = self.base_model.blocks(x)
        pooled_output = self.base_model.ln_f(x).mean(dim=1)

        class_logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss = F.cross_entropy(class_logits, labels)

        return (class_logits, loss) if loss is not None else class_logits

In [None]:
# Load the enhancer dataset from the InstaDeep Hugging Face ressources
dataset_name = "enhancers"
train_dataset_enhancers = load_dataset(
        "InstaDeepAI/nucleotide_transformer_downstream_tasks",
        dataset_name,
        split="train",
        streaming=False
    )
test_dataset_enhancers = load_dataset(
        "InstaDeepAI/nucleotide_transformer_downstream_tasks",
        dataset_name,
        split="test",
        streaming=False
    )

num_labels_enhancer = 3

train.fna:   0%|          | 0.00/3.10M [00:00<?, ?B/s]

test.fna:   0%|          | 0.00/83.2k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

 
                       
        during the per-review process. This version is deprecated and the new datasets are available at 
        InstaDeepAI/nucleotide_transformer_downstream_tasks_revised.
                       
              
                       
        during the per-review process. This version is deprecated and the new datasets are available at 
        InstaDeepAI/nucleotide_transformer_downstream_tasks_revised.
                       
              


Generating test split: 0 examples [00:00, ? examples/s]

 
                       
        during the per-review process. This version is deprecated and the new datasets are available at 
        InstaDeepAI/nucleotide_transformer_downstream_tasks_revised.
                       
              
                       
        during the per-review process. This version is deprecated and the new datasets are available at 
        InstaDeepAI/nucleotide_transformer_downstream_tasks_revised.
                       
              


In [None]:
# Get training data
train_sequences_enhancers = train_dataset_enhancers['sequence']
train_labels_enhancers = train_dataset_enhancers['label']

# Split the dataset into a training and a validation dataset
train_sequences_enhancers, validation_sequences_enhancers, train_labels_enhancers, validation_labels_enhancers = train_test_split(train_sequences_enhancers,
                                                                              train_labels_enhancers, test_size=0.10, random_state=42)

# Get test data
test_sequences_enhancers = test_dataset_enhancers['sequence']
test_labels_enhancers = test_dataset_enhancers['label']

# Augument data to include reverse complement

def reverse_complement(seq):
    complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
    return ''.join(complement.get(base, base) for base in reversed(seq))

augmented_sequences = []
augmented_labels = []

for seq, label in zip(train_sequences_enhancers, train_labels_enhancers):
    augmented_sequences.append(seq)
    augmented_labels.append(label)

    rc_seq = reverse_complement(seq)
    augmented_sequences.append(rc_seq)
    augmented_labels.append(label)

ds_train_enhancers = Dataset.from_dict({"data": augmented_sequences,'labels':augmented_labels})
ds_validation_enhancers = Dataset.from_dict({"data": validation_sequences_enhancers,'labels':validation_labels_enhancers})
ds_test_enhancers = Dataset.from_dict({"data": test_sequences_enhancers,'labels':test_labels_enhancers})

In [None]:
def tokenize_function(examples):
  min_length = min(len(i) for i in examples['data'])
  outputs = np.empty((0, min_length), dtype='int16')
  for example in examples["data"]:
    outputs = np.vstack([outputs, encode(example[:min_length])])
  return {
      'input_ids': outputs
  }

tokenized_datasets_train_enhancer = ds_train_enhancers.map(
    tokenize_function,
    batched=True,
    remove_columns=['data'],
)
tokenized_datasets_validation_enhancer = ds_validation_enhancers.map(
    tokenize_function,
    batched=True,
    remove_columns=['data'],
)
tokenized_datasets_test_enhancer = ds_test_enhancers.map(
    tokenize_function,
    batched=True,
    remove_columns=['data'],
)

device = 'cuda'
train_dataset = TensorDataset(torch.tensor(tokenized_datasets_train_enhancer['input_ids'], device=device),
                              torch.tensor(tokenized_datasets_train_enhancer['labels'], device=device))
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=64)

validation_dataset = TensorDataset(torch.tensor(tokenized_datasets_validation_enhancer['input_ids'], device=device),
                              torch.tensor(tokenized_datasets_validation_enhancer['labels'], device=device))

validation_dataloader = DataLoader(validation_dataset, shuffle=True, batch_size=1)

test_dataset = TensorDataset(torch.tensor(tokenized_datasets_test_enhancer['input_ids'], device=device),
                              torch.tensor(tokenized_datasets_test_enhancer['labels'], device=device))

test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=1)

Map:   0%|          | 0/26942 [00:00<?, ? examples/s]

Map:   0%|          | 0/1497 [00:00<?, ? examples/s]

Map:   0%|          | 0/400 [00:00<?, ? examples/s]

In [None]:
# save the tensors
save_path = os.path.join(DIR, "data")
torch.save(train_dataset.tensors, os.path.join(save_path, "enhancer_train_tensors.pt"))
torch.save(validation_dataset.tensors, os.path.join(save_path, "enhancer_val_tensors.pt"))
torch.save(test_dataset.tensors, os.path.join(save_path, "enhancer_test_tensors.pt"))

In [None]:
#@title Training time
# -----------------------------------------------------------------------------
init_from = 'resume'
out_dir = os.path.join(DIR, "out") # ignored if init_from is not 'resume'
seed = 1337
device = 'cuda'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = True # use PyTorch 2.0 to compile the model to be faster
# -----------------------------------------------------------------------------

torch.cuda.init()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}['float16']
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# model
if init_from == 'resume':
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    model_args = checkpoint['model_args']
    model_args['d_embed'] = model_args['d_model'] #???????
    gphconf = Config(**model_args)
    model = GPH(gphconf)

    # Handle '_orig_mod.' prefix from torch.compile() in saved state_dict
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k in list(state_dict.keys()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

    model.load_state_dict(state_dict)


# ========================== NEW MODEL SETUP ================================
new_model = ClassificationModel(base_model=model, num_labels=3)
new_model.eval()
new_model.to(device)
if compile:
    new_model = torch.compile(new_model)  # requires PyTorch 2.0 (optional)

optimizer = torch.optim.AdamW(new_model.parameters(), lr=2e-5)
num_epochs = 10
total_steps = len(train_dataloader) * num_epochs

val_predictions = []
val_true_labels = []

# ========================== TRAINING LOOP ================================
loss_lst = []
val_loss_lst = []
mcc_scores = []
new_model.train()
for epoch in range(num_epochs):
    i = 0
    for batch in train_dataloader:
        input_ids, labels = batch
        optimizer.zero_grad()

        # Forward pass
        outputs = new_model(input_ids, labels=labels)
        loss = outputs[1]
        loss_lst.append(loss.item())

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        i += 1

        if i % 100 == 99:
            new_model.eval()
            with torch.no_grad():
                for val_batch in validation_dataloader:
                    val_input_ids, val_labels = val_batch
                    val_outputs = new_model(val_input_ids, labels=val_labels)
                    val_loss = val_outputs[1]
                    val_loss_lst.append(val_loss.item())
                    val_logits = val_outputs[0]
                    val_predictions.append(torch.argmax(val_logits).item())
                    val_true_labels.append(val_labels.item())

            print(f"Epoch {epoch}, Batch {i}, Train Loss: {np.mean(loss_lst)}, Val Loss: {np.mean(val_loss_lst)}, Val MCC: {matthews_corrcoef(val_true_labels, val_predictions)}")
            mcc_scores.append(matthews_corrcoef(val_true_labels, val_predictions))
            loss_lst = []
            val_loss_lst = []
            new_model.train()
            val_predictions = []
            val_true_labels = []


number of parameters: 8.09M




Epoch 0, Batch 99, Train Loss: 0.8663327416988335, Val Loss: 0.6295053677592345, Val MCC: 0.3023741385830867
Epoch 0, Batch 199, Train Loss: 0.5832807952165604, Val Loss: 0.5526515803210881, Val MCC: 0.44828390452342215
Epoch 0, Batch 299, Train Loss: 0.5364896994829178, Val Loss: 0.5604384505023847, Val MCC: 0.4506472491784239
Epoch 0, Batch 399, Train Loss: 0.5360657167434693, Val Loss: 0.6257648226987315, Val MCC: 0.3854330836666037
Epoch 1, Batch 99, Train Loss: 0.5267509098880547, Val Loss: 0.564268953591192, Val MCC: 0.4558687106175817
Epoch 1, Batch 199, Train Loss: 0.5365271699428559, Val Loss: 0.6142528217283383, Val MCC: 0.4008737129321
Epoch 1, Batch 299, Train Loss: 0.5305521339178085, Val Loss: 0.5679025071167306, Val MCC: 0.4541016710293508
Epoch 1, Batch 399, Train Loss: 0.522105237543583, Val Loss: 0.5624742892812801, Val MCC: 0.4606540166790629
Epoch 2, Batch 99, Train Loss: 0.5248378199979293, Val Loss: 0.6294540409956801, Val MCC: 0.38399433671075767
Epoch 2, Batch 1

In [None]:
#@ Evaluation
new_model.eval()

predictions = []
true_labels = []
i = 0

with torch.no_grad():
    for batch in test_dataloader:
        input_ids, labels = batch
        outputs = new_model(input_ids)
        logits = outputs[0]
        predictions.append(torch.argmax(logits).item())
        true_labels.append(labels.item())

correct = 0
print(len(predictions))
for i in np.arange(len(predictions)) :
    if predictions[i] == true_labels[i]:
        correct += 1
print(f"Accuracy: {int(correct / len(predictions) * 100):,}%")

print(f"MCC: {round(matthews_corrcoef(true_labels, predictions), 2):,}")

400
Accuracy: 70%
MCC: 0.41


# Conclusion

The GPH framework outperforms the original GPT model presented in class, improving prediction accuracy from below 50% to 70% on the enhancer classification task. This improvement comes at the cost of increased complexity: my specific GPH has 8.09M parameters compared to 7.10M in the nano-GPT model and required roughly twice as long time to train, despite using a shorter context length (256 bp vs. 300 bp). This trade-off is consistent with expectations—Hyena operators tend to be slower than attention mechanisms at short sequence lengths but are designed to scale more efficiently with much longer contexts. Further empirical and theoretical work is needed to better understand how Hyena captures contextual information and why it yields better downstream performance in this setting.