In [1]:
# Fixing the indexing bug for the S mapping (it incorrectly read from old t instead of old r).
import numpy as np

TOKENS = ["a", "b", "c", "d", "e", "EOS"]
tok2idx = {t: i for i, t in enumerate(TOKENS)}


def one_hot(token):
    v = np.zeros(6)
    v[tok2idx[token]] = 1.0
    return v


def expected_counts(tokens):
    from collections import Counter

    c = Counter([t for t in tokens if t in "abcde"])
    return [c.get(t, 0) for t in "abcde"]


# Rebuild matrices with correct slice
W_e = np.zeros((16, 22))
# r' = r + M_r x
W_e[0, 0] = 1
W_e[1, 1] = 1
W_e[2, 2] = 1
W_e[3, 3] = 1
W_e[4, 4] = 1
for i in range(5):
    W_e[i, 6 + i] = 1

S = np.array(
    [
        [1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1],
        [0, 0, 1, 1, 1],
        [0, 0, 0, 1, 1],
        [0, 0, 0, 0, 1],
    ],
    dtype=float,
)

# Correct: from old r (cols 6..10) to new t (rows 5..9)
W_e[5:10, 6 + 0 : 6 + 5] = S
# from x (cols 0..5 but only 0..4 used) to new t
W_e[5:10, 0:5] += S

# p' = Q x (p1 ← EOS)
W_e[10, 5] = 1.0

# Decoder matrices
W_h = np.zeros((16, 16))
# t shiftleft
W_h[5, 6] = 1
W_h[6, 7] = 1
W_h[7, 8] = 1
W_h[8, 9] = 1
# p shiftright
W_h[11, 10] = 1
W_h[12, 11] = 1
W_h[13, 12] = 1
W_h[14, 13] = 1
W_h[15, 14] = 1

W_o = np.zeros((2, 16))
W_o[0, 5] = 1
W_o[0, 6] = -1
W_o[1, 15] = 1


def encode(tokens):
    s = np.zeros(16)
    for t in tokens:
        x = one_hot(t)
        s = np.maximum(0, W_e @ np.concatenate([x, s]))
    return s


def decode(s0, steps=6):
    s = s0.copy()
    outs = []
    for _ in range(steps):
        y = np.maximum(0, W_o @ s)
        outs.append((y[0], y[1]))
        s = np.maximum(0, W_h @ s)
    return outs


# Test with example
example_tokens = list("abbccddee") + ["EOS"]
s_T = encode(example_tokens)
outs = decode(s_T, steps=6)
counts_only = [int(round(o[0])) for o in outs[:5]]
eos_flag = int(round(outs[-1][1]))

print("Expected counts:", expected_counts(example_tokens))
print("Decoded counts :", counts_only)
print("EOS flag       :", eos_flag)

# Assertions
assert counts_only == expected_counts(example_tokens)
assert eos_flag == 1
print("\n✅ Fixed: decoded counts match and EOS flag is correct.")

Expected counts: [1, 2, 2, 2, 2]
Decoded counts : [1, 2, 2, 2, 2]
EOS flag       : 1

✅ Fixed: decoded counts match and EOS flag is correct.
