In [49]:
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
import math, torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from captum.attr import IntegratedGradients


In [18]:
df = pd.read_csv('arithmetic_data.csv')

In [19]:
class MathLSTMModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=16, hidden_dim=32):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = self.embedding(x)               # [batch, seq_len, embed_dim]
        lstm_out, (h_n, c_n) = self.lstm(x)
        final_hidden = h_n[-1]
        out = self.fc(final_hidden)
        return out.squeeze(1)


In [20]:
df.rename(columns={"=": "eq"}, inplace=True)
rows = list(df.itertuples(index=False, name=None))

# Build vocab
vocab = sorted(set(str(x) for row in rows for x in row[:-1]))
token2idx = {tok: i for i, tok in enumerate(vocab)}

# Encode
def encode(row):
    x, op, y, eq, res = row
    tokens = [str(x), op, str(y), eq]
    indices = [token2idx[tok] for tok in tokens]
    return indices, float(res)

encoded = [encode(row) for row in rows]
inputs = torch.tensor([e[0] for e in encoded], dtype=torch.long)
targets = torch.tensor([e[1] for e in encoded], dtype=torch.float32)

In [21]:
X_train, X_test, y_train, y_test = train_test_split(inputs, targets, test_size=0.2, random_state=42)


In [22]:
model = MathLSTMModel(vocab_size=len(token2idx))
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

from tqdm import trange

num_epochs = 500
loop = trange(num_epochs)

for epoch in loop:
    model.train()
    preds = model(X_train)
    loss = loss_fn(preds, y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    loop.set_description(f"Epoch {epoch}")
    loop.set_postfix(train_loss=loss.item())

    if epoch % 50 == 0 or epoch == num_epochs - 1:
        model.eval()
        with torch.no_grad():
            val_preds = model(X_test)
            val_loss = loss_fn(val_preds, y_test)
        print(f"Epoch {epoch} | Train Loss: {loss.item():.4f}, Test Loss: {val_loss.item():.4f}")


Epoch 0:   0%|          | 1/500 [00:00<02:51,  2.91it/s, train_loss=2.81e+6]

Epoch 0 | Train Loss: 2808776.0000, Test Loss: 2900405.7500


Epoch 50:  10%|█         | 51/500 [00:15<06:55,  1.08it/s, train_loss=2.78e+6]

Epoch 50 | Train Loss: 2783902.0000, Test Loss: 2874981.7500


Epoch 100:  20%|██        | 101/500 [00:36<02:19,  2.86it/s, train_loss=2.76e+6]

Epoch 100 | Train Loss: 2761272.5000, Test Loss: 2852092.7500


Epoch 150:  30%|███       | 151/500 [00:58<01:58,  2.96it/s, train_loss=2.74e+6]

Epoch 150 | Train Loss: 2738963.0000, Test Loss: 2829518.7500


Epoch 201:  40%|████      | 202/500 [01:14<01:00,  4.91it/s, train_loss=2.72e+6]

Epoch 200 | Train Loss: 2717640.2500, Test Loss: 2807935.2500


Epoch 250:  50%|█████     | 251/500 [01:23<00:54,  4.59it/s, train_loss=2.7e+6] 

Epoch 250 | Train Loss: 2696974.0000, Test Loss: 2787006.7500


Epoch 300:  60%|██████    | 301/500 [01:43<01:09,  2.87it/s, train_loss=2.68e+6]

Epoch 300 | Train Loss: 2676814.2500, Test Loss: 2766587.5000


Epoch 350:  70%|███████   | 351/500 [02:01<00:39,  3.82it/s, train_loss=2.66e+6]

Epoch 350 | Train Loss: 2657062.7500, Test Loss: 2746576.0000


Epoch 401:  80%|████████  | 402/500 [02:11<00:20,  4.88it/s, train_loss=2.64e+6]

Epoch 400 | Train Loss: 2637617.5000, Test Loss: 2726870.7500


Epoch 450:  90%|█████████ | 451/500 [02:40<00:27,  1.80it/s, train_loss=2.62e+6]

Epoch 450 | Train Loss: 2618430.2500, Test Loss: 2707421.5000


Epoch 499: 100%|██████████| 500/500 [03:02<00:00,  2.73it/s, train_loss=2.6e+6] 

Epoch 499 | Train Loss: 2599845.5000, Test Loss: 2688576.7500





In [27]:
model.eval()
with torch.no_grad():
    preds = model(X_test)
    for i in range(len(X_test)):
        tok_ids = X_test[i]
        equation = " ".join([list(token2idx.keys())[idx] for idx in tok_ids])
        print(f"{equation} → Predicted: {preds[i].item():.2f}, Actual: {y_test[i].item()}")


85 * 96 = → Predicted: 168.56, Actual: 8160.0
52 / 1 = → Predicted: 51.39, Actual: 52.0
56 * 73 = → Predicted: 168.56, Actual: 4088.0
56 * 11 = → Predicted: 168.55, Actual: 616.0
81 / 3 = → Predicted: 27.15, Actual: 27.0
37 - 57 = → Predicted: -20.89, Actual: -20.0
90 + 61 = → Predicted: 151.68, Actual: 151.0
58 - 87 = → Predicted: -29.61, Actual: -29.0
52 / 52 = → Predicted: 0.40, Actual: 1.0
86 + 38 = → Predicted: 124.76, Actual: 124.0
66 / 22 = → Predicted: 3.75, Actual: 3.0
84 / 21 = → Predicted: 4.27, Actual: 4.0
58 + 68 = → Predicted: 125.13, Actual: 126.0
18 + 21 = → Predicted: 39.14, Actual: 39.0
20 + 59 = → Predicted: 77.94, Actual: 79.0
3 - 23 = → Predicted: -20.86, Actual: -20.0
92 + 58 = → Predicted: 151.86, Actual: 150.0
53 + 92 = → Predicted: 145.33, Actual: 145.0
13 + 48 = → Predicted: 61.99, Actual: 61.0
70 * 71 = → Predicted: 168.56, Actual: 4970.0
36 * 1 = → Predicted: 163.53, Actual: 36.0
99 / 1 = → Predicted: 97.21, Actual: 99.0
7 / 1 = → Predicted: 7.75, Actual: 7.

In [28]:
def forward_embedded(embedded_input):
    """
    embedded_input: tensor of shape [batch, seq_len, embed_dim]
    """
    lstm_out, (h_n, c_n) = model.lstm(embedded_input)
    final_hidden = h_n[-1]
    out = model.fc(final_hidden)
    return out.squeeze(1)

In [None]:
num_samples = 100
n_cols= 5
n_rows= math.ceil(num_samples / n_cols)
ig = IntegratedGradients(forward_embedded)

fig = make_subplots(rows=n_rows, cols=n_cols,
                    horizontal_spacing=0.01, vertical_spacing=0.02)

for i in range(num_samples):
    ids   = X_test[i].unsqueeze(0)
    toks  = [list(token2idx.keys())[list(token2idx.values()).index(t.item())] for t in ids[0]]
    eqn   = " ".join(toks)

    emb   = model.embedding(ids).detach().requires_grad_()
    attrs,_ = ig.attribute(emb, return_convergence_delta=True)
    scores = attrs.squeeze(0).sum(dim=1)
    scores = scores / torch.norm(scores)
    img    = scores.detach().numpy()[None, :]

    r, c = divmod(i, n_cols)
    fig.add_trace(
        go.Heatmap(
            z=img,
            colorscale='RdBu',
            zmin=-1, zmax=1,
            showscale=False,
            hovertemplate=f"<b>{eqn}</b><br>Token: %{{x}}<br>Attrib: %{{z:.3f}}<extra></extra>",
            x=toks, y=['']
        ),
        row=r+1, col=c+1
    )

fig.update_layout(
    height=25*n_rows+100, width=1000,
    margin=dict(l=10, r=10, t=30, b=10),
    template='plotly_white'
)
fig.write_html("attribution_grid.html")
