In [15]:
import torch
import torch.nn as nn
from torch.optim import AdamW
import torch.nn.functional as F
import pandas as pd
import ast
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

In [16]:
from lera.model.rope import RoPE
from lera.model.transformer import TransformerBlock

In [17]:
df = pd.read_csv("../data/train_64k.csv")
vocab = "0123456789+=ri_"

In [10]:
df

Unnamed: 0,input,output,seq
0,382+525=ri,907,"[3, 8, 2, 12, 5, 2, 5, 13, 10, 11, 0, 9, 0, 7]"
1,572+378=ri,950,"[5, 7, 2, 12, 3, 7, 8, 13, 10, 11, 0, 9, 5, 0]"
2,218+121=ri,339,"[2, 1, 8, 12, 1, 2, 1, 13, 10, 11, 0, 3, 3, 9]"
3,170+989=ri,1159,"[1, 7, 0, 12, 9, 8, 9, 13, 10, 11, 1, 1, 5, 9]"
4,547+176=ri,723,"[5, 4, 7, 12, 1, 7, 6, 13, 10, 11, 0, 7, 2, 3]"
...,...,...,...
63995,644+901=ri,1545,"[6, 4, 4, 12, 9, 0, 1, 13, 10, 11, 1, 5, 4, 5]"
63996,152+143=ri,295,"[1, 5, 2, 12, 1, 4, 3, 13, 10, 11, 0, 2, 9, 5]"
63997,783+128=ri,911,"[7, 8, 3, 12, 1, 2, 8, 13, 10, 11, 0, 9, 1, 1]"
63998,512+213=ri,725,"[5, 1, 2, 12, 2, 1, 3, 13, 10, 11, 0, 7, 2, 5]"


In [2]:

class DigitConvCompressor(nn.Module):
    def __init__(self, d_model, hidden_channels=64, kernel_size=3):
        super().__init__()
        self.d_model = d_model
        
        # Project each digit embedding into hidden channels
        self.in_proj = nn.Linear(d_model, hidden_channels)

        # A small 1D convolution stack
        padding = kernel_size // 2
        self.conv = nn.Conv1d(
            in_channels=hidden_channels,
            out_channels=hidden_channels,
            kernel_size=kernel_size,
            padding=padding
        )
        
        # Optional refinement
        self.refine = nn.Conv1d(
            in_channels=hidden_channels,
            out_channels=hidden_channels,
            kernel_size=kernel_size,
            padding=padding
        )

        # Final projection back to d_model (or however large you want)
        self.out_proj = nn.Linear(hidden_channels, d_model)

    def forward(self, digit_embs):
        """
        digit_embs: Tensor of shape (seq_len, d_model)
        returns: number embedding of shape (d_model,)
        """
        x = self.in_proj(digit_embs)               # (L, hidden)
        x = x.transpose(0, 1).unsqueeze(0)         # (1, hidden, L)

        # 1D conv over digits
        x = F.silu(self.conv(x))

        # Optional refinement conv
        x = F.silu(self.refine(x))

        # Mean-pool across digit positions to produce final embedding
        pooled = x.mean(dim=2).squeeze(0)          # (hidden,)

        # Final projection
        out = self.out_proj(pooled)                # (d_model,)

        return out


In [19]:
D_MODEL = 32
DEPTH = 4
N_HEADS = 2
MLP_RATIO = 4
VOCAB_SIZE = len(vocab)


In [20]:
class Model(nn.Module):
    def __init__(self, d_model=D_MODEL, depth=DEPTH, n_heads=N_HEADS, vocab_size=VOCAB_SIZE, mlp_dim=D_MODEL * MLP_RATIO):
        super().__init__()
        self.d_model = d_model
        self.depth = depth
        self.n_heads = n_heads
        self.vocab_size = vocab_size
        self.mlp_dim = mlp_dim

        self.embedding = nn.Embedding(d_model=d_model, vocab_size=vocab_size)
        self.transformer = TransformerBlock(d_model=d_model)


    def forward(self, x):
        x = self.embedding(x)
        x = self.position_encoding(x)
        x = self.transformer(x)
        return x


In [None]:
model = Model()
epochs = 100

In [None]:
digit_embs = torch.randn(3, 128)  # say d_model = 128
compressor = DigitConvCompressor(d_model=128)

number_vec = compressor(digit_embs)
print(number_vec.shape)  # torch.Size([128])


torch.Size([128])
