In [4]:
import helical

In [5]:
from helical.models.geneformer import Geneformer, GeneformerConfig
import torch

# Create config for a smaller model
config = GeneformerConfig(model_name='gf-6L-10M-i2048', batch_size=1)
model = Geneformer(config)

# Print the model architecture
print('Model architecture:')
print(model.model)
print()

# List all named modules with their types
print('Named modules (Linear layers only):')
for name, module in model.model.named_modules():
    print(name)

2026-01-21 19:27:17,574 - INFO:datasets:PyTorch version 2.7.0+cpu available.
2026-01-21 19:27:22,726 - INFO:helical.models.geneformer.model:Model finished initializing.
2026-01-21 19:27:22,727 - INFO:helical.models.geneformer.model:'gf-6L-10M-i2048' model is in 'eval' mode, on device 'cpu' with embedding mode 'cell'.


Model architecture:
BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(25426, 256, padding_idx=0)
      (position_embeddings): Embedding(2048, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.02, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.02, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,),

In [6]:
import torch.nn as nn

def find_linear_layers(model: nn.Module) -> dict[str, nn.Linear]:
    """Find all Linear layers in a model with their full path names."""
    linear_layers = {}
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            linear_layers[name] = module
    return linear_layers

In [7]:
find_linear_layers(model.model)

{'bert.encoder.layer.0.attention.self.query': Linear(in_features=256, out_features=256, bias=True),
 'bert.encoder.layer.0.attention.self.key': Linear(in_features=256, out_features=256, bias=True),
 'bert.encoder.layer.0.attention.self.value': Linear(in_features=256, out_features=256, bias=True),
 'bert.encoder.layer.0.attention.output.dense': Linear(in_features=256, out_features=256, bias=True),
 'bert.encoder.layer.0.intermediate.dense': Linear(in_features=256, out_features=512, bias=True),
 'bert.encoder.layer.0.output.dense': Linear(in_features=512, out_features=256, bias=True),
 'bert.encoder.layer.1.attention.self.query': Linear(in_features=256, out_features=256, bias=True),
 'bert.encoder.layer.1.attention.self.key': Linear(in_features=256, out_features=256, bias=True),
 'bert.encoder.layer.1.attention.self.value': Linear(in_features=256, out_features=256, bias=True),
 'bert.encoder.layer.1.attention.output.dense': Linear(in_features=256, out_features=256, bias=True),
 'bert.enc

In [8]:
list(model.model.named_modules())

[('',
  BertForMaskedLM(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(25426, 256, padding_idx=0)
        (position_embeddings): Embedding(2048, 256)
        (token_type_embeddings): Embedding(2, 256)
        (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.02, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-5): 6 x BertLayer(
            (attention): BertAttention(
              (self): BertSdpaSelfAttention(
                (query): Linear(in_features=256, out_features=256, bias=True)
                (key): Linear(in_features=256, out_features=256, bias=True)
                (value): Linear(in_features=256, out_features=256, bias=True)
                (dropout): Dropout(p=0.02, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=256, out_features=256, bias=True)
                

In [9]:
# Test if BERT can handle extra batch dimensions
# Standard: input_ids (batch, seq) 
# Question: can we pass (n_adapt, batch, seq) without flattening?

import torch

# Get the underlying BERT model
bert = model.model.bert

# Create test inputs
batch_size = 4
seq_len = 128
n_adapt = 3

# Standard 2D input
input_ids_2d = torch.randint(1, 1000, (batch_size, seq_len))

# 3D input with extra adapter dimension  
input_ids_3d = torch.randint(1, 1000, (n_adapt, batch_size, seq_len))

print("Testing 2D input (batch, seq):")
print(f"  Shape: {input_ids_2d.shape}")
try:
    with torch.no_grad():
        out_2d = bert(input_ids_2d)
    print(f"  Output shape: {out_2d.last_hidden_state.shape}")
    print("  SUCCESS")
except Exception as e:
    print(f"  FAILED: {e}")

print("\nTesting 3D input (n_adapt, batch, seq):")
print(f"  Shape: {input_ids_3d.shape}")
try:
    with torch.no_grad():
        out_3d = bert(input_ids_3d)
    print(f"  Output shape: {out_3d.last_hidden_state.shape}")
    print("  SUCCESS - BERT handles extra batch dims natively!")
except Exception as e:
    print(f"  FAILED: {e}")

print("\nTesting flattened 3D -> 2D:")
input_ids_flat = input_ids_3d.reshape(n_adapt * batch_size, seq_len)
print(f"  Flattened shape: {input_ids_flat.shape}")
try:
    with torch.no_grad():
        out_flat = bert(input_ids_flat)
    print(f"  Output shape: {out_flat.last_hidden_state.shape}")
    print("  SUCCESS")

except Exception as e:
    print(f"  FAILED: {e}")

Testing 2D input (batch, seq):
  Shape: torch.Size([4, 128])
  Output shape: torch.Size([4, 128, 256])
  SUCCESS

Testing 3D input (n_adapt, batch, seq):
  Shape: torch.Size([3, 4, 128])
  FAILED: too many values to unpack (expected 2)

Testing flattened 3D -> 2D:
  Flattened shape: torch.Size([12, 128])
  Output shape: torch.Size([12, 128, 256])
  SUCCESS
