In [1]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit
from flax import linen as nn
import optax

In [2]:
# Parameters
batch_size = 16
block_size = 32
learning_rate = 1e-3
max_iters = 1000
n_embd = 64
vocab_size = 256  # Assuming ASCII

rng_key = random.PRNGKey(0)

In [3]:
class Transformer(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x[..., None]  # Adding an embedding dimension
        x = nn.Dense(n_embd)(x)
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(num_heads=2)(x)
        x = x.reshape((x.shape[0], x.shape[1], -1))  # Flattening the last dimensions
        x = nn.Dense(vocab_size)(x)
        return x

@jit
def softmax_cross_entropy(logits, targets):
    logits_reshaped = logits.reshape((-1, vocab_size))
    targets_reshaped = targets.reshape((-1,))
    logprobs = jax.nn.log_softmax(logits_reshaped)

    targets_one_hot = jax.nn.one_hot(targets_reshaped, vocab_size)

    # Element-wise multiplication and sum over the vocab_size dimension
    loss_values = -jnp.sum(targets_one_hot * logprobs, axis=-1)

    # Reshape loss values back to (batch_size, block_size)
    return loss_values.reshape((batch_size, block_size))

@jit
def compute_loss(params, x, y):
    logits = model.apply(params, x)
    loss_values = softmax_cross_entropy(logits, y)
    mean_loss = jnp.mean(loss_values)
    return mean_loss

@jit
def update(params, x, y, opt_state):
    opt_update = optimizer.update
    loss, grads = jax.value_and_grad(compute_loss)(params, x, y)  # removed `model` from the arguments
    updates, new_opt_state = opt_update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss


# Data (for demonstration purposes, use real data in practice)
data = jnp.array([i % vocab_size for i in range(10000)], dtype=jnp.int32)
def get_batch():
    idx = random.randint(rng_key, (batch_size,), 0, len(data) - block_size - 1)
    x = jnp.array([data[i:i+block_size] for i in idx])
    y = jnp.array([data[i+1:i+block_size+1] for i in idx])
    return x, y

# Training
model = Transformer()
params = model.init(rng_key, jnp.ones((batch_size, block_size)))
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

for iter in range(max_iters):
    x, y = get_batch()
    params, opt_state, loss = update(params, x, y, opt_state)
    if iter % 100 == 0:
        print(f"Iteration {iter}, Loss: {loss}")

# Additional utility function to convert a string to its ASCII representation
def string_to_ascii(input_str):
    return jnp.array([ord(c) for c in input_str], dtype=jnp.int32)

# Simple text generation
def generate_text(params, model, start_token=0, length=100):
    generated = [start_token]

    # Initialize a sequence of length `block_size` filled with the `start_token`
    current_token = jnp.array([start_token] * block_size).reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)  # Generate logits for the sequence
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        # Append the next_token to current_token sequence and use only the last `block_size` tokens
        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return generated

# Initialize the model with a dummy input that matches the shape of our generation process
dummy_input = jnp.ones((1, block_size))
params_gen = model.init(rng_key, dummy_input)

# Update the params with trained weights
params_gen = params

def generate_text(params, model, start_string, length=100):
    start_tokens = string_to_ascii(start_string)
    generated = list(start_tokens)

    # If the initial tokens are fewer than block_size, pad them
    if len(start_tokens) < block_size:
        current_token = jnp.pad(start_tokens, (block_size - len(start_tokens), 0), mode='constant')
    else:
        current_token = start_tokens[-block_size:]  # Take the last `block_size` characters

    current_token = current_token.reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        # Use the most recent `block_size` tokens for the next step
        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return "".join([chr(c) for c in generated])

print(generate_text(params_gen, model, start_string="once upon a time", length=100))


Iteration 0, Loss: 6.21189546585083
Iteration 100, Loss: 4.883039474487305
Iteration 200, Loss: 4.875761985778809
Iteration 300, Loss: 4.875408172607422
Iteration 400, Loss: 4.875296592712402
Iteration 500, Loss: 4.875242710113525
Iteration 600, Loss: 4.875209808349609
Iteration 700, Loss: 4.875185012817383
Iteration 800, Loss: 4.875302791595459
Iteration 900, Loss: 4.875214576721191
once upon a timeõõõõõõõõõõõõõõõõ>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>


In [5]:
# Hyperparameters
batch_size = 16
block_size = 32
learning_rate = 1e-3
max_iters = 4000
n_embd = 64
vocab_size = 250  # Assuming ASCII

# Random number generator key
rng_key = random.PRNGKey(0)


In [6]:
class Transformer(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x[..., None]  # Adding an embedding dimension
        x = nn.Dense(n_embd)(x)
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(num_heads=2)(x)
        x = x.reshape((x.shape[0], x.shape[1], -1))  # Flattening the last dimensions
        x = nn.Dense(vocab_size)(x)
        return x

@jit
def softmax_cross_entropy(logits, targets):
    logits_reshaped = logits.reshape((-1, vocab_size))
    targets_reshaped = targets.reshape((-1,))
    logprobs = jax.nn.log_softmax(logits_reshaped)

    targets_one_hot = jax.nn.one_hot(targets_reshaped, vocab_size)

    # Element-wise multiplication and sum over the vocab_size dimension
    loss_values = -jnp.sum(targets_one_hot * logprobs, axis=-1)

    # Reshape loss values back to (batch_size, block_size)
    return loss_values.reshape((batch_size, block_size))

@jit
def compute_loss(params, x, y):
    logits = model.apply(params, x)
    loss_values = softmax_cross_entropy(logits, y)
    mean_loss = jnp.mean(loss_values)
    return mean_loss

@jit
def update(params, x, y, opt_state):
    opt_update = optimizer.update
    loss, grads = jax.value_and_grad(compute_loss)(params, x, y)  # removed `model` from the arguments
    updates, new_opt_state = opt_update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss


# Data (for demonstration purposes, use real data in practice)
data = jnp.array([i % vocab_size for i in range(10000)], dtype=jnp.int32)
def get_batch():
    idx = random.randint(rng_key, (batch_size,), 0, len(data) - block_size - 1)
    x = jnp.array([data[i:i+block_size] for i in idx])
    y = jnp.array([data[i+1:i+block_size+1] for i in idx])
    return x, y

# Training
model = Transformer()
params = model.init(rng_key, jnp.ones((batch_size, block_size)))
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

for iter in range(max_iters):
    x, y = get_batch()
    params, opt_state, loss = update(params, x, y, opt_state)
    if iter % 100 == 0:
        print(f"Iteration {iter}, Loss: {loss}")

# Additional utility function to convert a string to its ASCII representation
def string_to_ascii(input_str):
    return jnp.array([ord(c) for c in input_str], dtype=jnp.int32)

# Simple text generation
def generate_text(params, model, start_token=0, length=100):
    generated = [start_token]

    # Initialize a sequence of length `block_size` filled with the `start_token`
    current_token = jnp.array([start_token] * block_size).reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)  # Generate logits for the sequence
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        # Append the next_token to current_token sequence and use only the last `block_size` tokens
        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return generated

# Initialize the model with a dummy input that matches the shape of our generation process
dummy_input = jnp.ones((1, block_size))
params_gen = model.init(rng_key, dummy_input)

# Update the params with trained weights
params_gen = params

def generate_text(params, model, start_string, length=100):
    start_tokens = string_to_ascii(start_string)
    generated = list(start_tokens)

    # If the initial tokens are fewer than block_size, pad them
    if len(start_tokens) < block_size:
        current_token = jnp.pad(start_tokens, (block_size - len(start_tokens), 0), mode='constant')
    else:
        current_token = start_tokens[-block_size:]  # Take the last `block_size` characters

    current_token = current_token.reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        # Use the most recent `block_size` tokens for the next step
        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return "".join([chr(c) for c in generated])

print(generate_text(params_gen, model, start_string="""The farmhouse lingers, though averse to square
With the new city street it has to wear
A number in. But what about the brook
That held the house as in an elbow-crook?
I ask as one who knew the brook, its strength
And impulse, having dipped a finger length
And made it leap my knuckle, having tossed
A flower to try its currents where they crossed.
The meadow grass could be cemented down
From growing under pavements of a town;
The apple trees be sent to hearth-stone flame.
Is water wood to serve a brook the same?
How else dispose of an immortal force
No longer needed? Staunch it at its source
With cinder loads dumped down? The brook was thrown
Deep in a sewer dungeon under stone
In fetid darkness still to live and run â€”
And all for nothing it had ever done
Except forget to go in fear perhaps.
No one would know except for ancient maps
That such a brook ran water. But I wonder
If from its being kept forever under,
The thoughts may not have risen that so keep
This new-built city from both work and sleep.""", length=100))


Iteration 0, Loss: 6.181985378265381
Iteration 100, Loss: 5.056647300720215
Iteration 200, Loss: 5.053104400634766
Iteration 300, Loss: 5.0527238845825195
Iteration 400, Loss: 5.052591323852539
Iteration 500, Loss: 5.052509307861328
Iteration 600, Loss: 5.073213577270508
Iteration 700, Loss: 5.0524163246154785
Iteration 800, Loss: 5.052323341369629
Iteration 900, Loss: 5.052213668823242
Iteration 1000, Loss: 5.052062511444092
Iteration 1100, Loss: 5.051831245422363
Iteration 1200, Loss: 5.0517144203186035
Iteration 1300, Loss: 5.050899982452393
Iteration 1400, Loss: 5.050166130065918
Iteration 1500, Loss: 5.047529697418213
Iteration 1600, Loss: 5.04243278503418
Iteration 1700, Loss: 5.029782772064209
Iteration 1800, Loss: 5.0063605308532715
Iteration 1900, Loss: 4.846574306488037
Iteration 2000, Loss: 4.794043064117432
Iteration 2100, Loss: 4.637550354003906
Iteration 2200, Loss: 4.522007465362549
Iteration 2300, Loss: 4.453006267547607
Iteration 2400, Loss: 4.379441261291504
Iteration

In [7]:
import jax.numpy as jnp
from jax import random, grad, jit
from flax import linen as nn
import optax

# Parameters
batch_size = 16
block_size = 32
learning_rate = 1e-3
max_iters = 1000
n_embd = 64
vocab_size = 256  # Assuming ASCII

rng_key = random.PRNGKey(0)

class Transformer(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x[..., None]  # Adding an embedding dimension
        x = nn.Dense(n_embd)(x)
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(num_heads=2)(x)
        x = x.reshape((x.shape[0], x.shape[1], -1))  # Flattening the last dimensions
        x = nn.Dense(vocab_size)(x)
        return x

@jit
def softmax_cross_entropy(logits, targets):
    logits_reshaped = logits.reshape((-1, vocab_size))
    targets_reshaped = targets.reshape((-1,))
    logprobs = jax.nn.log_softmax(logits_reshaped)

    targets_one_hot = jax.nn.one_hot(targets_reshaped, vocab_size)

    # Element-wise multiplication and sum over the vocab_size dimension
    loss_values = -jnp.sum(targets_one_hot * logprobs, axis=-1)

    # Reshape loss values back to (batch_size, block_size)
    return loss_values.reshape((batch_size, block_size))

@jit
def compute_loss(params, x, y):
    logits = model.apply(params, x)
    loss_values = softmax_cross_entropy(logits, y)
    mean_loss = jnp.mean(loss_values)
    return mean_loss

@jit
def update(params, x, y, opt_state):
    opt_update = optimizer.update
    loss, grads = jax.value_and_grad(compute_loss)(params, x, y)  # removed `model` from the arguments
    updates, new_opt_state = opt_update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss


# Data (for demonstration purposes, use real data in practice)
data = jnp.array([i % vocab_size for i in range(10000)], dtype=jnp.int32)
def get_batch():
    idx = random.randint(rng_key, (batch_size,), 0, len(data) - block_size - 1)
    x = jnp.array([data[i:i+block_size] for i in idx])
    y = jnp.array([data[i+1:i+block_size+1] for i in idx])
    return x, y

# Training
model = Transformer()
params = model.init(rng_key, jnp.ones((batch_size, block_size)))
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

for iter in range(max_iters):
    x, y = get_batch()
    params, opt_state, loss = update(params, x, y, opt_state)
    if iter % 100 == 0:
        print(f"Iteration {iter}, Loss: {loss}")

# Simple text generation
def generate_text(params, model, start_token=0, length=100):
    generated = [start_token]

    # Initialize a sequence of length `block_size` filled with the `start_token`
    current_token = jnp.array([start_token] * block_size).reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)  # Generate logits for the sequence
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        # Append the next_token to current_token sequence and use only the last `block_size` tokens
        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return generated


print(generate_text(params, model, start_token=0, length=100))


Iteration 0, Loss: 6.21189546585083
Iteration 100, Loss: 4.883039474487305
Iteration 200, Loss: 4.875761985778809
Iteration 300, Loss: 4.875408172607422
Iteration 400, Loss: 4.875296592712402
Iteration 500, Loss: 4.875242710113525
Iteration 600, Loss: 4.875209808349609
Iteration 700, Loss: 4.875185012817383
Iteration 800, Loss: 4.875302791595459
Iteration 900, Loss: 4.875214576721191
[0, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62]
