In [52]:
import numpy as np

vocab_size = 88        # number of possible notes
seq_len = 24            # input length per example
d_model = 128          # hidden size
num_layers = 3           # transformer depth
batch_size = 32        # how many examples per step

np.random.seed(42)


In [53]:
import jax
import jax.numpy as jnp
from functools import partial
import optax
import numpy as np

In [54]:
from flax.struct import dataclass
from jax.tree_util import register_pytree_node_class

@dataclass
class TransformerParams:
    W_q: jnp.ndarray
    W_k: jnp.ndarray
    W_v: jnp.ndarray
    W_o: jnp.ndarray
    W1: jnp.ndarray
    b1: jnp.ndarray
    W2: jnp.ndarray
    b2: jnp.ndarray
    gamma1: jnp.ndarray
    beta1: jnp.ndarray
    gamma2: jnp.ndarray
    beta2: jnp.ndarray

    def tree_flatten(self):
        children = (self.W_q, self.W_k, self.W_v, self.W_o,
                    self.W1, self.b1, self.W2, self.b2,
                    self.gamma1, self.beta1, self.gamma2, self.beta2)
        aux = None
        return children, aux

    @classmethod
    def tree_unflatten(cls, aux, children):
        return cls(*children)

@dataclass
class StackedTransformerParams:
    blocks: tuple[TransformerParams, ...]

    def tree_flatten(self):
        return (self.blocks,), None

    @classmethod
    def tree_unflatten(cls, aux, children):
        return cls(*children)

@dataclass
class ModelParams:
    embedding: jax.Array
    transformer: StackedTransformerParams
    W_out: jax.Array  # final projection to vocab


def layernorm(x, gamma, beta, eps=1e-5):
    mean = x.mean(-1, keepdims=True)
    var = x.var(-1, keepdims=True)
    norm = (x - mean) / jnp.sqrt(var + eps)
    return gamma * norm + beta

def attention(x, params):
    Q = x @ params.W_q
    K = x @ params.W_k
    V = x @ params.W_v

    scale = jnp.sqrt(x.shape[-1])
    # scores = Q @ K.T / scale
    scores = jnp.einsum('bij,bkj->bik', Q, K) / scale
    mask = jnp.triu(jnp.ones_like(scores), 1) * -1e9
    scores += mask
    weights = jax.nn.softmax(scores, axis=-1)
    attn = weights @ V
    return attn @ params.W_o

def feedforward(x, params):
    h = jax.nn.relu(x @ params.W1 + params.b1)
    return h @ params.W2 + params.b2

def transformer_block(x, params):
    attn_out = attention(x, params)
    x1 = layernorm(x + attn_out, params.gamma1, params.beta1)
    ff_out = feedforward(x1, params)
    x2 = layernorm(x1 + ff_out, params.gamma2, params.beta2)
    return x2

def init_transformer_params(key, d_model):
    k1, k2, k3 = jax.random.split(key, 3)
    def norm_init(shape): return jax.random.normal(k1, shape) / jnp.sqrt(d_model)
    def zero_init(shape): return jnp.zeros(shape)

    return TransformerParams(
        W_q = norm_init((d_model, d_model)),
        W_k = norm_init((d_model, d_model)),
        W_v = norm_init((d_model, d_model)),
        W_o = norm_init((d_model, d_model)),

        W1 = norm_init((d_model, 4 * d_model)),
        b1 = zero_init((4 * d_model,)),
        W2 = norm_init((4 * d_model, d_model)),
        b2 = zero_init((d_model,)),

        gamma1 = jnp.ones((d_model,)),
        beta1 = jnp.zeros((d_model,)),
        gamma2 = jnp.ones((d_model,)),
        beta2 = jnp.zeros((d_model,))
    )


In [55]:
def init_embedding(key, vocab_size, d_model):
    return jax.random.normal(key, (vocab_size, d_model)) * 0.01

def embedding_fn(embedding_params, token_ids):
    return embedding_params[token_ids]


In [56]:
def init_stacked_transformer_params(key, d_model, n_layers):
    keys = jax.random.split(key, n_layers)
    return StackedTransformerParams(
        blocks=tuple(init_transformer_params(k, d_model) for k in keys)
    )

def transformer_block_batch(x, params: TransformerParams):
    attn_out = attention(x, params)
    x1 = layernorm(x + attn_out, params.gamma1, params.beta1)
    ff_out = feedforward(x1, params)
    x2 = layernorm(x1 + ff_out, params.gamma2, params.beta2)
    return x2

def stacked_forward(x, params: StackedTransformerParams):
    for block_params in params.blocks:
        x = transformer_block_batch(x, block_params)
    return x

def transformer_block_batch(x, params: TransformerParams):
    attn_out = attention(x, params)
    x1 = layernorm(x + attn_out, params.gamma1, params.beta1)
    ff_out = feedforward(x1, params)
    x2 = layernorm(x1 + ff_out, params.gamma2, params.beta2)
    return x2

def stacked_forward(x, params: StackedTransformerParams):
    for block_params in params.blocks:
        x = transformer_block_batch(x, block_params)
    return x

def cross_entropy_loss(logits, targets):
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    one_hot = jax.nn.one_hot(targets, logits.shape[-1])
    return -jnp.sum(one_hot * log_probs) / targets.shape[0]

def forward_and_loss(params: ModelParams, token_ids, targets):
    x = embedding_fn(params.embedding, token_ids)  # (B, T, D)
    x = stacked_forward(x, params.transformer)
    logits = x @ params.W_out.T  # (B, T, vocab)
    logits = logits.reshape(-1, logits.shape[-1])
    targets = targets.reshape(-1)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean()
    return loss


In [None]:


@jax.jit
def train_step(params, opt_state, token_ids, targets):
    def loss_fn(p):
        return forward_and_loss(p, token_ids, targets)
    loss, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

key = jax.random.PRNGKey(0)
key_embed, key_model, key_out = jax.random.split(key, 3)

embedding_params = init_embedding(key_embed, vocab_size, d_model)
transformer_params = init_stacked_transformer_params(key_model, d_model, num_layers)
W_out = jax.random.normal(key_out, (vocab_size, d_model)) * 0.01

params = ModelParams(
    embedding=embedding_params,
    transformer=transformer_params,
    W_out=W_out
)

optimizer = optax.adamw(learning_rate=1e-4)
opt_state = optimizer.init(params)

inputs = jax.random.randint(key, (batch_size, seq_len), 0, vocab_size)
targets = jax.random.randint(key, (batch_size, seq_len), 0, vocab_size)

params, opt_state, loss = train_step(params, opt_state, inputs, targets)

In [None]:
def softmax(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

def cross_entropy_loss(logits, target):
    probs = softmax(logits)
    loss = -np.log(probs[target] + 1e-9)
    dlogits = probs
    dlogits[target] -= 1
    return loss, dlogits

def get_positional_encoding(seq_len, d_model):
    pos = np.arange(seq_len)[:, np.newaxis]
    i = np.arange(d_model)[np.newaxis, :]
    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
    angle_rads = pos * angle_rates
    pos_encoding = np.zeros((seq_len, d_model))
    pos_encoding[:, 0::2] = np.sin(angle_rads[:, 0::2])
    pos_encoding[:, 1::2] = np.cos(angle_rads[:, 1::2])
    return pos_encoding

positional_encoding = get_positional_encoding(seq_len, d_model)


In [None]:
def layernorm_backward(dout, x, mean, var, gamma, eps=1e-5):
    N, D = x.shape
    x_mu = x - mean
    std_inv = 1. / np.sqrt(var + eps)

    dnorm = dout * gamma
    dvar = np.sum(dnorm * x_mu, axis=-1, keepdims=True) * -0.5 * std_inv**3
    dmean = np.sum(dnorm * -std_inv, axis=-1, keepdims=True) + dvar * np.mean(-2.0 * x_mu, axis=-1, keepdims=True)

    dx = dnorm * std_inv + dvar * 2 * x_mu / D + dmean / D
    dgamma = np.sum(dout * (x_mu * std_inv), axis=0)
    dbeta = np.sum(dout, axis=0)
    return dx, dgamma, dbeta


In [None]:
# transformerblock class
class TransformerBlock:
    def __init__(self, d_model):
        self.d_model = d_model
        self.gamma1 = np.ones((self.d_model,))
        self.beta1 = np.zeros((self.d_model,))
        self.gamma2 = np.ones((self.d_model,))
        self.beta2 = np.zeros((self.d_model,))

        self.W_q = np.random.randn(d_model, d_model) / np.sqrt(d_model)
        self.W_k = np.random.randn(d_model, d_model) / np.sqrt(d_model)
        self.W_v = np.random.randn(d_model, d_model) / np.sqrt(d_model)
        self.W_o = np.random.randn(d_model, d_model) / np.sqrt(d_model)

        self.W1 = np.random.randn(d_model, 4*d_model) / np.sqrt(d_model)
        self.b1 = np.zeros((4*d_model,))
        self.W2 = np.random.randn(4*d_model, d_model) / np.sqrt(4*d_model)
        self.b2 = np.zeros((d_model,))

        self.cache = {}  # to store intermediate states for backprop

    def forward(self, x):
        self.cache['x'] = x.copy()
    
        # self-attention
        Q = x @ self.W_q
        K = x @ self.W_k
        V = x @ self.W_v
        scale = np.sqrt(self.d_model)
    
        scores = Q @ K.T / scale
        mask = np.triu(np.ones_like(scores), 1) * -1e9
        scores += mask
    
        weights = softmax(scores)
        self.cache.update({'Q': Q, 'K': K, 'V': V, 'weights': weights})
    
        attn = weights @ V
        attn_out = attn @ self.W_o
        x_attn = x + attn_out
    
        # layernorm 1
        mean1 = x_attn.mean(axis=-1, keepdims=True)
        var1 = x_attn.var(axis=-1, keepdims=True)
        norm1 = (x_attn - mean1) / np.sqrt(var1 + 1e-5)
        x2 = self.gamma1 * norm1 + self.beta1
        self.cache.update({'ln1_mean': mean1, 'ln1_var': var1, 'x2': x2})
        
    
        # feedforward
        h = np.maximum(0, x2 @ self.W1 + self.b1)
        ff_out = h @ self.W2 + self.b2
        self.cache['ff_h'] = h
    
        x_ffn = x2 + ff_out
    
        # layernorm 2
        mean2 = x_ffn.mean(axis=-1, keepdims=True)
        var2 = x_ffn.var(axis=-1, keepdims=True)
        norm2 = (x_ffn - mean2) / np.sqrt(var2 + 1e-5)
        out = self.gamma2 * norm2 + self.beta2
        self.cache.update({'ln2_mean': mean2, 'ln2_var': var2, 'out': out})
    
        return out
    
        
    def layernorm(self, x, eps=1e-5):
        mean = x.mean(axis=-1, keepdims=True)
        var = x.var(axis=-1, keepdims=True)
        norm = (x - mean) / np.sqrt(var + eps)
        self.cache['ln_mean'], self.cache['ln_var'] = mean, var
        self.cache['ln_input'] = x.copy()
        return norm  # gamma/beta omitted for simplicity

    def layernorm_backward(dout, x, mean, var, gamma, eps=1e-5):
        N, D = x.shape
        x_mu = x - mean
        std_inv = 1. / np.sqrt(var + eps)
    
        dnorm = dout * gamma
        dvar = np.sum(dnorm * x_mu, axis=-1, keepdims=True) * -0.5 * std_inv**3
        dmean = np.sum(dnorm * -std_inv, axis=-1, keepdims=True) + dvar * np.mean(-2.0 * x_mu, axis=-1, keepdims=True)
    
        dx = dnorm * std_inv + dvar * 2 * x_mu / D + dmean / D
        dgamma = np.sum(dout * (x_mu * std_inv), axis=0)
        dbeta = np.sum(dout, axis=0)
        return dx, dgamma, dbeta

    def backward(self, dout):
        grads = {}
    
        # -------------------------
        # layernorm 2 backward
        # -------------------------
        x2 = self.cache['x2']
        out = self.cache['out']
        mean2, var2 = self.cache['ln2_mean'], self.cache['ln2_var']
        gamma2, beta2 = self.gamma2, self.beta2
    
        dnorm2, dgamma2, dbeta2 = layernorm_backward(
            dout, out, mean2, var2, gamma2
        )
        grads['gamma2'], grads['beta2'] = dgamma2, dbeta2
    
        # -------------------------
        # residual from FFN
        # -------------------------
        dff_out = dnorm2
        dff_in, dW1, db1, dW2, db2 = self.feedforward_backward(dff_out)
        grads['W1'], grads['b1'] = dW1, db1
        grads['W2'], grads['b2'] = dW2, db2
    
        # -------------------------
        # layernorm 1 backward
        # -------------------------
        x = self.cache['x']
        mean1, var1 = self.cache['ln1_mean'], self.cache['ln1_var']
        gamma1, beta1 = self.gamma1, self.beta1
    
        dnorm1 = dff_in + dnorm2  # residual connection
        dln1, dgamma1, dbeta1 = layernorm_backward(dnorm1, x, mean1, var1, gamma1)
        grads['gamma1'], grads['beta1'] = dgamma1, dbeta1
    
        # -------------------------
        # attention backward
        # -------------------------
        dx_attn, dW_q, dW_k, dW_v, dW_o = self.attention_backward(dln1)
        grads['W_q'], grads['W_k'], grads['W_v'], grads['W_o'] = dW_q, dW_k, dW_v, dW_o
    
        # -------------------------
        # input gradient
        # -------------------------
        dx = dx_attn + dln1  # residual connection from attention
    
        return dx, grads


    def feedforward(self, x):
        h = np.maximum(0, x @ self.W1 + self.b1)
        self.cache['ff_h'] = h
        return h @ self.W2 + self.b2

    def feedforward_backward(self, dout):
        h = self.cache['ff_h']
        x2 = self.cache['x2']
    
        dW2 = h.T @ dout
        db2 = np.sum(dout, axis=0)
        dh = dout @ self.W2.T
        dh[h <= 0] = 0  # ReLU backprop
    
        dW1 = x2.T @ dh
        db1 = np.sum(dh, axis=0)
        dx2 = dh @ self.W1.T
    
        return dx2, dW1, db1, dW2, db2

    def attention_backward(self, dout):
        """
        dout: ∂L/∂(attn_out @ W_o), shape (T, d_model)
        returns:
            dx: ∂L/∂x
            dW_q, dW_k, dW_v, dW_o
        """
        Q, K, V = self.cache['Q'], self.cache['K'], self.cache['V']
        weights = self.cache['weights']
        x = self.cache['x']
        T, D = x.shape
        scale = np.sqrt(D)
    
        # ---- W_o backward ----
        attn = weights @ V        # (T, D)
        dW_o = attn.T @ dout      # (D, D)
        dattn = dout @ self.W_o.T # (T, D)
    
        # ---- softmax(V) backward ----
        dweights = dattn @ V.T    # (T, T)
        dV = weights.T @ dattn    # (T, D)
    
        # ---- softmax backward (Jacobian vector trick) ----
        dscores = np.zeros_like(dweights)
        for t in range(T):
            w = weights[t]                             # (T,)
            dw = dweights[t]                           # (T,)
            jac = np.diag(w) - np.outer(w, w)          # softmax Jacobian
            dscores[t] = jac @ dw                      # (T,)
    
        # ---- scores = Q @ K.T / scale ----
        dQ = dscores @ K / scale                       # (T, D)
        dK = dscores.T @ Q / scale                     # (T, D)
    
        # ---- Q = x @ W_q, K = x @ W_k, V = x @ W_v ----
        dW_q = x.T @ dQ
        dW_k = x.T @ dK
        dW_v = x.T @ dV
    
        dx_q = dQ @ self.W_q.T
        dx_k = dK @ self.W_k.T
        dx_v = dV @ self.W_v.T
    
        dx = dx_q + dx_k + dx_v  # aggregate residuals from 3 projections
    
        return dx, dW_q, dW_k, dW_v, dW_o

In [None]:
import copy

def make_training_data():
    return [list(range(i, i + seq_len)) for i in range(vocab_size - seq_len + 1)]

def make_checkpoint(transformer, token_embedding, W_out):
    return {
        "transformer": copy.deepcopy(transformer),
        "token_embedding": np.copy(token_embedding),
        "W_out": np.copy(W_out)
    }

def restore_checkpoint(checkpoint):
    transformer = copy.deepcopy(checkpoint["transformer"])
    token_embedding = np.copy(checkpoint["token_embedding"])
    W_out = np.copy(checkpoint["W_out"])
    return transformer, token_embedding, W_out


def train(transformer, token_embedding, W_out, data, epochs=100):
    scheduler = StablePlateauScheduler()
    checkpoint = make_checkpoint(transformer, token_embedding, W_out)
    best_loss = float('inf')

    for epoch in range(epochs):
        total_loss = 0
        lr = scheduler.get_lr()
        for seq in data:


            input_seq = seq[:-1]    # [0..6]
            target_seq = seq[1:]    # [1..7]

            x = token_embedding[input_seq]
            x += positional_encoding[:len(input_seq)]
            out = transformer.forward(x)   # (T, d_model)

            d_out = np.zeros_like(out)
            loss = 0

            for t in range(len(target_seq)):
                logits = out[t] @ W_out  # (vocab_size,)
                step_loss, dlogits = cross_entropy_loss(logits, target_seq[t])
                loss += step_loss

                # backprop into out[t]
                d_out[t] = dlogits @ W_out.T

                # update W_out
                W_out -= lr * np.outer(out[t], dlogits)

            # total loss for sequence
            total_loss += loss / len(target_seq)

            # backward through transformer
            dx, grads = transformer.backward(d_out)

            for name, grad in grads.items():
                param = getattr(transformer, name)
                setattr(transformer, name, param - lr * grad)

            for i, idx in enumerate(input_seq):
                token_embedding[idx] -= lr * dx[i]

        avg_loss = total_loss / len(data)
    
        # check if this is best run so far
        if avg_loss < best_loss - 1e-4:
            best_loss = avg_loss
            checkpoint = make_checkpoint(transformer, token_embedding, W_out)
    
        # update scheduler, restore checkpoint if decayed
        decayed = scheduler.update(avg_loss)
        if decayed:
            transformer, token_embedding, W_out = restore_checkpoint(checkpoint)

        print(f"epoch {epoch+1:03}  loss = {total_loss/len(data):.4f}")



In [None]:
def sample_from_probs(probs, temperature=1.0):
    if temperature != 1.0:
        probs = np.log(probs + 1e-9) / temperature
        probs = np.exp(probs - np.max(probs))
        probs /= np.sum(probs)
    return int(np.random.choice(len(probs), p=probs))

def generate_sequence(transformer, token_embedding, W_out, start_seq, max_len=16, temperature=1.0):
    seq = list(start_seq)
    for _ in range(max_len - len(seq)):
        input_seq = seq[-seq_len:]
        x = token_embedding[input_seq]
        x += positional_encoding[:len(input_seq)]
        out = transformer.forward(x)
        logits = out[-1] @ W_out
        probs = softmax(logits)
        next_note = sample_from_probs(probs, temperature)
        seq.append(next_note)
    return seq


In [None]:
def beam_search(transformer, token_embedding, W_out, start_seq, max_len=16, beam_width=3, temperature=1.0, verbose=False):
    beams = [(start_seq, 0.0)]  # (sequence, total log-prob)

    for step in range(max_len - len(start_seq)):
        all_candidates = []

        for seq, logp in beams:
            input_seq = seq[-seq_len:]
            x = token_embedding[input_seq]
            x += positional_encoding[:len(input_seq)]
            out = transformer.forward(x)

            logits = out[-1] @ W_out
            logits = logits / temperature  # apply temperature
            probs = softmax(logits)

            for token in range(vocab_size):
                new_seq = seq + [token]
                new_logp = logp + np.log(probs[token] + 1e-9)
                avg_logp = new_logp / len(new_seq)
                all_candidates.append((new_seq, avg_logp))

        beams = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]

        if verbose:
            print(f"\n-- Step {step+1} --")
            for i, (bseq, bscore) in enumerate(beams):
                print(f"Beam {i+1}: {bseq}  avg_logp = {bscore:.3f}")

    return beams[0][0]  # best sequence


In [10]:
def make_training_data_from_filename(filename, seq_len=8, vocab_size=88, filter_out_of_range=True):
    """
    Reads a file containing space-separated integers, slices it into (seq_len + 1)-length windows,
    and returns training sequences for next-token prediction.

    Each sequence of length seq_len+1 is split into:
        input  = seq[:-1]
        target = seq[1:]

    Args:
        filename (str): path to the file containing space-separated note numbers
        seq_len (int): number of input tokens per training example (default: 8)
        vocab_size (int): maximum allowed token value (default: 88 for piano keys)
        filter_out_of_range (bool): whether to drop values outside [0, vocab_size-1]

    Returns:
        List[List[int]]: list of (seq_len + 1)-long sequences for training
    """
    with open(filename, 'r') as f:
        text = f.read()

    tokens = list(map(int, text.strip().split()))

    if filter_out_of_range:
        tokens = [t for t in tokens if 0 <= t < vocab_size]

    data = []
    for i in range(len(tokens) - seq_len):
        seq = tokens[i:i + seq_len + 1]
        data.append(seq)

    return data

import os
import glob

def make_training_data_from_directory(path="./data/*.txt", seq_len=8, vocab_size=88, filter_out_of_range=True):
    """
    Reads all .txt files from the given directory glob, where each file contains
    space-separated note values. Generates training sequences independently per file.

    Returns:
        List[List[int]] – list of (seq_len + 1)-long sequences for training
    """
    all_sequences = []

    for filepath in glob.glob(path):
        with open(filepath, "r") as f:
            text = f.read()
        tokens = list(map(int, text.strip().split()))
        if filter_out_of_range:
            tokens = [t for t in tokens if 0 <= t < vocab_size]
        for i in range(len(tokens) - seq_len):
            seq = tokens[i : i + seq_len + 1]
            all_sequences.append(seq)

    return all_sequences



In [106]:
import glob
import random

def load_training_data_from_directory(path: str, seq_len: int) -> list[np.ndarray]:
    sequences = []
    for filename in glob.glob(f"{path}/*.txt"):
        sequences.extend(make_training_data_from_filename(filename, seq_len))
    return sequences

def batch_generator(sequences: list[np.ndarray], batch_size: int):
    while True:
        batch = random.sample(sequences, batch_size)
        batch = np.stack(batch, axis=0)  # (B, T)
        yield batch[:, :-1], batch[:, 1:]  # inputs, targets


In [118]:
your_sequences = load_training_data_from_directory('./data', seq_len)
train_gen = batch_generator(your_sequences, batch_size)
num_steps = 10000

for step in range(num_steps):
    inputs, targets = next(train_gen)
    key, subkey = jax.random.split(key)
    noise = jax.random.normal(subkey, inputs.shape + (d_model,))
    params, opt_state, loss = train_step(params, opt_state, inputs, targets)
    if step % 100 == 0:
        print(f"step {step}: loss = {loss}")

step 0: loss = 0.2521265745162964
step 100: loss = 0.20547634363174438
step 200: loss = 0.2230464071035385
step 300: loss = 0.2069634646177292
step 400: loss = 0.20980314910411835
step 500: loss = 0.21924088895320892
step 600: loss = 0.2374354898929596
step 700: loss = 0.21206580102443695
step 800: loss = 0.21270889043807983
step 900: loss = 0.20445606112480164
step 1000: loss = 0.21432623267173767
step 1100: loss = 0.20468923449516296
step 1200: loss = 0.2119407057762146
step 1300: loss = 0.20189213752746582
step 1400: loss = 0.21416395902633667
step 1500: loss = 0.19687935709953308
step 1600: loss = 0.20304489135742188
step 1700: loss = 0.2048814594745636
step 1800: loss = 0.2238340675830841
step 1900: loss = 0.2050248086452484
step 2000: loss = 0.202180415391922
step 2100: loss = 0.19321611523628235
step 2200: loss = 0.18292373418807983
step 2300: loss = 0.19748644530773163
step 2400: loss = 0.19334951043128967
step 2500: loss = 0.20502787828445435
step 2600: loss = 0.20585255324840

In [114]:
transformer = TransformerBlock(d_model)
W_out = np.random.randn(d_model, vocab_size) / np.sqrt(d_model)
token_embedding = np.random.randn(vocab_size, d_model) * 0.01
data = make_training_data_from_directory("./data/*.txt", seq_len=seq_len)
# epochs = 600
train(transformer, token_embedding, W_out, data, epochs=epochs)

NameError: name 'TransformerBlock' is not defined

In [None]:
from collections import deque

class StablePlateauScheduler:
    def __init__(self, lr=1e-2, decay_factor=0.9, patience=10, min_lr=1e-5, window_size=5, tolerance=1e-4):
        self.lr = lr
        self.decay_factor = decay_factor
        self.patience = patience
        self.min_lr = min_lr
        self.window_size = window_size
        self.tolerance = tolerance

        self.loss_window = deque(maxlen=window_size)
        self.best_avg = float('inf')
        self.epochs_since_improvement = 0
        self.decay_enabled = True

    def update(self, loss):
        self.loss_window.append(loss)

        if len(self.loss_window) < self.window_size:
            return  # wait until window is full

        current_avg = sum(self.loss_window) / self.window_size

        if current_avg < self.best_avg - self.tolerance:
            self.best_avg = current_avg
            self.epochs_since_improvement = 0
        else:
            self.epochs_since_improvement += 1

            if self.epochs_since_improvement >= self.patience and self.decay_enabled:
                old_lr = self.lr
                new_lr = max(self.lr * self.decay_factor, self.min_lr)
                if new_lr < self.lr:
                    self.lr = new_lr
                    print(f"↘ learning rate dropped from {old_lr:.2e} to {self.lr:.2e}")
                else:
                    self.decay_enabled = False
                self.epochs_since_improvement = 0
                return True
            return False

    def get_lr(self):
        return self.lr


In [None]:
generated = beam_search(transformer, token_embedding, W_out, [60, 63, 65, 67, 65, 63], max_len=24, beam_width=5, temperature=0.01)
print("Generated:", generated)

In [119]:
import mido
import os
import subprocess
from mido import Message, MidiFile, MidiTrack, MetaMessage
from pydub import AudioSegment
import time

BPM=200

def notes_to_midi_file(notes, filename='out.mid', bpm=BPM, velocity=64, duration=480, tempo=500000):
    mid = MidiFile()
    track = MidiTrack()
    mid.tracks.append(track)

    tick_per_beat = mid.ticks_per_beat
    tick = int(tick_per_beat / 2)

    # set tempo (default 120bpm)
    tempo = int(60_000_000 / bpm)
    track.append(MetaMessage('set_tempo', tempo=tempo, time=0))

    for note in notes:
        if note is None:
            # silent: rest for one note duration
            track.append(Message("note_off", note=0, velocity=0, time=tick))
        else:
            pitch = note
            track.append(Message("note_on", note=pitch, velocity=100, time=0))
            track.append(Message("note_off", note=pitch, velocity=100, time=tick))

    mid.save(filename)

midi_path = 'yaman.mid'
wav_path = 'yaman.wav'
# notes_to_midi_file(generated, midi_path)
soundfont_path = os.path.expanduser("~/Downloads/harmonium-samples-20250608T203259Z-1-001/harmonium-samples/trimmed/trimmed/harmonium.sf2")
taal_path = os.path.expanduser("~/Documents/music/ektaal_200bpm_csharp.wav")
def render_midi_to_wav(midi_path=midi_path, wav_path=wav_path, sf2_path=soundfont_path):
    subprocess.run([
        "fluidsynth",
        "-g", "2.0",       # 🔊 boost gain
        "-ni", sf2_path,
        midi_path,
        "-F", wav_path,
        "-r", "48000"      # 🧩 match taal.wav sample rate
    ])
render_midi_to_wav()

def mix_audio(taal_path, taan_path, out_path):
    taal = AudioSegment.from_wav(taal_path)
    taan = AudioSegment.from_wav(taan_path)

    # match volume (boost taan if needed)
    taan = taan - 2 # 🔊 increase volume by 6dB

    # optional trim/pad to align lengths
    taan = taan[:len(taal)]

    combined = taal.overlay(taan)
    combined.export(out_path, format="wav")

mix_audio(taal_path, 'yaman.wav', 'mixed.wav')
subprocess.run(["ffplay", "-nodisp", "-autoexit", 'mixed.wav'])




FluidSynth runtime version 2.3.4
Copyright (C) 2000-2023 Peter Hanappe and others.
Distributed under the LGPL license.
SoundFont(R) is a registered trademark of Creative Technology Ltd.

Rendering audio to file 'yaman.wav'..


ffplay version 6.1.1-3ubuntu5 Copyright (c) 2003-2023 the FFmpeg developers
  built with gcc 13 (Ubuntu 13.2.0-23ubuntu3)
  configuration: --prefix=/usr --extra-version=3ubuntu5 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --disable-omx --enable-gnutls --enable-libaom --enable-libass --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libglslang --enable-libgme --enable-libgsm --enable-libharfbuzz --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzimg --ena




  14.34 M-A:  0.000 fd=   0 aq=    0KB vq=    0KB sq=    0B f=0/0   

CompletedProcess(args=['ffplay', '-nodisp', '-autoexit', 'mixed.wav'], returncode=0)

In [145]:
from collections import namedtuple

def sample_autoregressively(params: ModelParams, start_tokens, max_len=32, temperature=1.0, top_k=None):
    """
    start_tokens: (T,) array of initial token ids
    returns: (T + max_len,) array of generated token ids
    """
    generated = list(start_tokens)
    x = jnp.array(generated)[None, :]  # shape (1, T)

    for i in range(max_len):
        x_emb = embedding_fn(params.embedding, x)  # (1, T, D)
        x_out = stacked_forward(x_emb, params.transformer)  # (1, T, D)
        logits = x_out[:, -1, :] @ params.W_out.T  # (1, vocab)
        logits = logits / temperature

        if top_k is not None:
            top_logits = jnp.sort(logits, axis=-1)[:, -top_k]
            logits = jnp.where(logits < top_logits[:, None], -1e10, logits)

        probs = jax.nn.softmax(logits, axis=-1)
        next_token = jax.random.categorical(jax.random.PRNGKey(np.random.randint(1e6)), logits).item()

        generated.append(next_token)
        x = jnp.array(generated)[None, :]

        # early stopping
        if ((i+1)+len(start_tokens)) % 24 == 0 and next_token in [61, 73]:
            break

    return generated


def generate_looped_taan_sequence(n_taans, taan_generator, pad_token=None):
    """
    Generate a sequence of notes with taans and rests in between.
    Each taan: 24 notes
    Each rest: 24 steps (None)
    """
    all_notes = []
    for _ in range(n_taans):
        taan = taan_generator()
        assert len(taan) % 24 == 0, f"Taan generator must return multiple of 24 notes: {len(taan)} notes"
        all_notes.extend(taan)
        all_notes.extend([pad_token] * 24)
    return all_notes

def one_taan():
    # return constrained_beam_search(params, 60, 61, 8, 23, 1)
    for _ in range(10):
        taan = sample_autoregressively(params, [random.choice([60,68,73])], 71, 1, top_k=None)
        if taan[-1] in [61, 73]:
            return taan
    return taan
    # return generate_sequence(transformer, token_embedding, W_out, [60], max_len=48, temperature=0.9)

n_taans = 10
note_sequence = generate_looped_taan_sequence(n_taans, one_taan)

notes_to_midi_file(note_sequence, midi_path, bpm=BPM)
render_midi_to_wav()

# 🔁 overlay enough taal cycles (we handle it by looping audio)
def loop_audio_to_match(reference_audio, loop_audio):
    looped = loop_audio * (len(reference_audio) // len(loop_audio) + 1)
    return looped[:len(reference_audio)]

def mix_audio_looped(taal_path, taan_path, out_path):
    taal = AudioSegment.from_wav(taal_path)
    taan = AudioSegment.from_wav(taan_path)

    taal_looped = loop_audio_to_match(taan, taal)
    mixed = taal_looped.overlay(taan)
    mixed.export(out_path, format="wav")

mix_audio_looped(taal_path, 'yaman.wav', 'mixed.wav')
subprocess.run(["ffplay", "-nodisp", "-autoexit", 'mixed.wav'])




FluidSynth runtime version 2.3.4
Copyright (C) 2000-2023 Peter Hanappe and others.
Distributed under the LGPL license.
SoundFont(R) is a registered trademark of Creative Technology Ltd.

Rendering audio to file 'yaman.wav'..


ffplay version 6.1.1-3ubuntu5 Copyright (c) 2003-2023 the FFmpeg developers
  built with gcc 13 (Ubuntu 13.2.0-23ubuntu3)
  configuration: --prefix=/usr --extra-version=3ubuntu5 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --disable-omx --enable-gnutls --enable-libaom --enable-libass --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libglslang --enable-libgme --enable-libgsm --enable-libharfbuzz --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzimg --ena




CompletedProcess(args=['ffplay', '-nodisp', '-autoexit', 'mixed.wav'], returncode=0)

In [146]:
import flax
import pickle

# save
with open("model.pkl", "wb") as f:
    bytes_output = flax.serialization.to_bytes(params)
    pickle.dump(bytes_output, f)

# load
# with open("model.pkl", "rb") as f:
#     bytes_input = pickle.load(f)
#     params = flax.serialization.from_bytes(params, bytes_input)
