In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import optax

from banhxeo import trainer
from banhxeo.core import NLTKTokenizer
from banhxeo.data import IMDBDataset
from banhxeo.model import MLP

In [None]:
BATCH_SIZE = 32
EPOCHS = 100
LEARNING_RATE = 1e-3

MAX_LENGTH = 256
SEED = 42

### Load Dataset

In [None]:
imdb_train = IMDBDataset("dataset", split_name="train", seed=SEED)
imdb_test = IMDBDataset("dataset", split_name="test", seed=SEED)

In [None]:
imdb_test[0]

In [None]:
imdb_train[0]

### Train Tokenizer

In [None]:
tokenizer = NLTKTokenizer()

In [None]:
tokenizer.train(
    corpus=(imdb_train.get_all_texts() + imdb_test.get_all_texts()), progress=True
)

### Load Array Text Dataset

In [None]:
args_dict = {
    "tokenizer": tokenizer,
    "return_tensors": "jax",
    "max_length": MAX_LENGTH,
    "truncation": True,
    "padding": "max_length",
    "padding_side": "left",
    "add_special_tokens": True,
    "is_classification": True,
    "label_map": {"pos": 1, "neg": 0},
}

train_set = imdb_train.to_array(**args_dict)
test_set = imdb_test.to_array(**args_dict)

In [None]:
# Move to DataLoader
train_loader = train_set.to_loader(
    batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=4, seed=SEED
)
test_loader = test_set.to_loader(
    batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=4, seed=SEED
)

### Create model

In [None]:
model = MLP(
    vocab_size=tokenizer.vocab_size,
    output_size=1,
    pad_id=tokenizer.special_tokens.pad_id,
    hidden_sizes=[512, 256],
    embedding_dim=512,
    activation_fn="relu"
)

In [None]:
print(model)

In [None]:
# Create random key
key = jax.random.key(SEED)
key, params_key, dropout_key = jax.random.split(key, 3)

In [None]:
# Create dummy input
dummy_input_ids = jnp.ones((BATCH_SIZE, MAX_LENGTH), dtype=jnp.int32)
dummy_attention_mask = jnp.ones((BATCH_SIZE, MAX_LENGTH), dtype=jnp.int32)

In [None]:
# Init model parameters
params = model.init(
    params_key, 
    input_ids=dummy_input_ids, 
    attention_mask=dummy_attention_mask,
    dropout=True
)['params']

In [None]:
# Create optimizer
optimizer = optax.adamw(learning_rate=LEARNING_RATE)

In [None]:
# Create the TrainState
state = trainer.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimizer,
    rng=key
)

In [None]:
@jax.jit
def train_step(state: trainer.TrainState, batch):
    def loss_fn(params):
        # Get model predictions
        logits = state.apply_fn(
            {'params': params}, 
            input_ids=batch['input_ids'], 
            attention_mask=batch['attention_mask'],
            dropout=True, # Enable dropout
            rngs={'dropout': state.rng} # Pass the dropout PRNG
        )

        # Calculate cross-entropy loss
        one_hot_labels = jax.nn.one_hot(batch['labels'], num_classes=logits.shape[-1])

        # Then use binary cross entropy to calculate loss
        loss = optax.sigmoid_binary_cross_entropy(logits, one_hot_labels).mean()
        
        return loss, logits

    # Calculate gradients and loss
    (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    
    # Update the model state (parameters and optimizer state)
    state = state.apply_gradients(grads=grads)
    
    # Calculate metrics
    accuracy = jnp.mean(jnp.argmax(logits, -1) == batch['labels'])
    metrics = {'loss': loss, 'accuracy': accuracy}

    return state, metrics

In [None]:
@jax.jit
def eval_step(state, batch):
    # Get model predictions
    logits = state.apply_fn(
        {'params': state.params}, 
        input_ids=batch['input_ids'], 
        attention_mask=batch['attention_mask'],
        dropout=False # Disable dropout for evaluation
    )
    
    one_hot_labels = jax.nn.one_hot(batch['labels'], num_classes=logits.shape[-1])
    loss = optax.sigmoid_binary_cross_entropy(logits, one_hot_labels).mean()

    # Calculate metrics
    accuracy = jnp.mean(jnp.argmax(logits, -1) == batch['labels'])
    metrics = {'loss': loss, 'accuracy': accuracy}
    
    return metrics

In [None]:
from tqdm.auto import tqdm

print("Starting training...")
for epoch in (pbar := tqdm(range(EPOCHS), desc="Training")):
    # Training phase
    train_loss, train_accuracy = [], []
    for batch in train_loader:
        # Perform one training step
        state, metrics = train_step(state, batch) # type: ignore
        
        train_loss.append(metrics['loss'])
        train_accuracy.append(metrics['accuracy'])

    # Evaluate phase
    test_loss, test_accuracy = [], []
    for batch in test_loader:
        metrics = eval_step(state, batch) # type: ignore

        test_loss.append(metrics['loss'])
        test_accuracy.append(metrics['accuracy'])
        
    # Log results for the epoch
    avg_train_loss = np.mean(train_loss)
    avg_train_acc = np.mean(train_accuracy)
    avg_test_loss = np.mean(test_loss)
    avg_test_acc = np.mean(test_accuracy)

    pbar.set_postfix(
        {
            "Train Loss": avg_train_loss,
            "Train Acc": avg_train_acc,
            "Test Loss": avg_test_loss,
            "Test Acc": avg_test_acc
        }
    )