## Interlacing conditional embeddings into packed batch

In [1]:
from decifer.decifer_model import Decifer, DeciferConfig
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
checkpoint = torch.load("../testing_refactored/ckpt.pt")
model_args = checkpoint["model_args"]
state_dict = checkpoint["model"]
model = Decifer(DeciferConfig(**model_args))

model.load_state_dict(state_dict)
model.to(device='cpu')
print(model)

number of total parameters: 7.26M
Decifer(
  (transformer): ModuleDict(
    (cond_embedding): Sequential(
      (0): Linear(in_features=1000, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
    )
    (wte): Embedding(372, 512)
    (wpe): Embedding(128, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-1): 2 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=512, out_features=1536, bias=False)
          (c_proj): Linear(in_features=512, out_features=512, bias=False)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=512, out_features=2048, bias=False)
          (c_proj): Linear(in_features=2048, out_features=512, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        

In [8]:
# Define simple test case with known positions
idx = torch.randint(0, 372, (2, 6), dtype=torch.long)  # Single batch for simplicity
cond_dim = 1000
cond_vec = torch.randn(4, cond_dim)
start_indices_batch = [[0, 5], [0, 2]]  # Known insert positions for cond_emb

# Forward pass
outputs = model(idx, cond_vec, start_indices_batch=start_indices_batch)

# Extract embeddings at known insert positions
inserted_emb_1 = outputs[0, 0, :]
inserted_emb_2 = outputs[0, 6, :]
inserted_emb_3 = outputs[1, 0, :]
inserted_emb_4 = outputs[1, 3, :]

# Compute expected conditioning embeddings
expected_cond_emb = model.transformer.cond_embedding(cond_vec).to(dtype=outputs.dtype)

# # Check that the inserted embeddings match the expected conditioning embeddings
assert torch.allclose(inserted_emb_1, expected_cond_emb[0], atol=1e-6), "Mismatch at position"
assert torch.allclose(inserted_emb_2, expected_cond_emb[1], atol=1e-6), "Mismatch at position"
assert torch.allclose(inserted_emb_3, expected_cond_emb[2], atol=1e-6), "Mismatch at position"
assert torch.allclose(inserted_emb_4, expected_cond_emb[3], atol=1e-6), "Mismatch at position"
print("Conditioning embeddings correctly placed.")


Conditioning embeddings correctly placed.
