In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import optax

from banhxeo import trainer
from banhxeo.core import NLTKTokenizer
from banhxeo.data import IMDBDataset
from banhxeo.model import MLP

In [2]:
BATCH_SIZE = 32
EPOCHS = 100
LEARNING_RATE = 1e-3

MAX_LENGTH = 256
SEED = 42

### Load Dataset

In [3]:
imdb_train = IMDBDataset("dataset", split_name="train", seed=SEED)
imdb_test = IMDBDataset("dataset", split_name="test", seed=SEED)

  0%|          | 0.00/12.5k [00:00<?, ?file/s]

  0%|          | 0.00/12.5k [00:00<?, ?file/s]

  0%|          | 0.00/12.5k [00:00<?, ?file/s]

  0%|          | 0.00/12.5k [00:00<?, ?file/s]

In [4]:
imdb_test[0]

{'id': '1514',
 'rating': 7,
 'content': "Lucio Fulci's Cat in the Brain is an inventive and somewhat egotistical tale of a director's decent into madness. The director in question is Fulci himself, who stars in the film. Fulci has become known to horror fans everywhere as 'the godfather of gore', and for good reason, as he has provided us with some of the nastiest and most gruesome films ever to grace the silver screen; from the eyeball violence in films like 'Zombi 2', to a man been hacked to death with chains in 'The Beyond', all the way to the full on gore fest known as 'The New York Ripper'; if you want gore (and let's face it, who doesn't), Fulci is your man. However, all this catering for gorehounds like you and I has taken its toll on Fulci's mental state, and he's quickly delving into madness, brought about by what he films. Fulci's problems don't end at his mental state either, as his psychiatrist that he has gone to see about his problem has took it upon himself to take up m

In [5]:
imdb_train[0]

{'id': '7430',
 'rating': 10,
 'content': 'Personnaly I really loved this movie, and it particularly moved me. The two main actors are giving us such great performances, that at the end, it is really heart breaking to know what finally happened to their characters.<br /><br />The alchemy between Barbra Streisand and Kris Kristofferson is marvelous, and the song are just great the way they are. <br /><br />That\'s why I didn\'t feel surprised when I learned it had won 5 golden globe awards (the most rewarded movie at the Golden Globes), an Oscar and even a Grammy. This movie is a classic that deserves to be seen by anyone. A great movie, that has often been criticized (maybe because Streisand dared to get involved in it, surely as a "co-director"). Her artistry is the biggest, and that will surely please you!',
 'label': 'pos'}

### Train Tokenizer

In [6]:
tokenizer = NLTKTokenizer()

In [7]:
tokenizer.train(
    corpus=(imdb_train.get_all_texts() + imdb_test.get_all_texts()), progress=True
)

Pre-tokenizing text:   0%|          | 0/50000 [00:00<?, ?it/s]

Add word to vocabulary:   0%|          | 0/222115 [00:00<?, ?it/s]

### Load Array Text Dataset

In [8]:
args_dict = {
    "tokenizer": tokenizer,
    "return_tensors": "jax",
    "max_length": MAX_LENGTH,
    "truncation": True,
    "padding": "max_length",
    "padding_side": "left",
    "add_special_tokens": True,
    "is_classification": True,
    "label_map": {"pos": 1, "neg": 0},
}

train_set = imdb_train.to_array(**args_dict)
test_set = imdb_test.to_array(**args_dict)

In [9]:
# Move to DataLoader
train_loader = train_set.to_loader(
    batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=4, seed=SEED
)
test_loader = test_set.to_loader(
    batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=4, seed=SEED
)

### Create model

In [10]:
model = MLP(
    vocab_size=tokenizer.vocab_size,
    output_size=1,
    pad_id=tokenizer.special_tokens.pad_id,
    hidden_sizes=[512, 256],
    embedding_dim=512,
    activation_fn="relu"
)

In [11]:
print(model)

MLP(
    # attributes
    vocab_size = 222123
    output_size = 1
    embedding_dim = 512
    hidden_sizes = [512, 256]
    pad_id = 0
    activation_fn = 'relu'
    dropout_rate = 0.0
    aggregate_strategy = 'mean'
)


In [12]:
# Create random key
key = jax.random.key(SEED)
key, params_key, dropout_key = jax.random.split(key, 3)

In [13]:
# Create dummy input
dummy_input_ids = jnp.ones((BATCH_SIZE, MAX_LENGTH), dtype=jnp.int32)
dummy_attention_mask = jnp.ones((BATCH_SIZE, MAX_LENGTH), dtype=jnp.int32)

In [14]:
# Init model parameters
params = model.init(
    params_key, 
    input_ids=dummy_input_ids, 
    attention_mask=dummy_attention_mask,
    dropout=True
)['params']

In [15]:
# Create optimizer
optimizer = optax.adamw(learning_rate=LEARNING_RATE)

In [16]:
# Create the TrainState
state = trainer.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimizer,
    rng=key
)

In [17]:
@jax.jit
def train_step(state: trainer.TrainState, batch):
    def loss_fn(params):
        # Get model predictions
        logits = state.apply_fn(
            {'params': params}, 
            input_ids=batch['input_ids'], 
            attention_mask=batch['attention_mask'],
            dropout=True, # Enable dropout
            rngs={'dropout': state.rng} # Pass the dropout PRNG
        )

        # Calculate cross-entropy loss
        one_hot_labels = jax.nn.one_hot(batch['labels'], num_classes=logits.shape[-1])

        # Then use binary cross entropy to calculate loss
        loss = optax.sigmoid_binary_cross_entropy(logits, one_hot_labels).mean()
        
        return loss, logits

    # Calculate gradients and loss
    (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    
    # Update the model state (parameters and optimizer state)
    state = state.apply_gradients(grads=grads)
    
    # Calculate metrics
    accuracy = jnp.mean(jnp.argmax(logits, -1) == batch['labels'])
    metrics = {'loss': loss, 'accuracy': accuracy}

    return state, metrics

In [18]:
@jax.jit
def eval_step(state, batch):
    # Get model predictions
    logits = state.apply_fn(
        {'params': state.params}, 
        input_ids=batch['input_ids'], 
        attention_mask=batch['attention_mask'],
        dropout=False # Disable dropout for evaluation
    )
    
    one_hot_labels = jax.nn.one_hot(batch['labels'], num_classes=logits.shape[-1])
    loss = optax.sigmoid_binary_cross_entropy(logits, one_hot_labels).mean()

    # Calculate metrics
    accuracy = jnp.mean(jnp.argmax(logits, -1) == batch['labels'])
    metrics = {'loss': loss, 'accuracy': accuracy}
    
    return metrics

In [None]:
from tqdm.auto import tqdm

print("Starting training...")
for epoch in (pbar := tqdm(range(EPOCHS), desc="Training")):
    # Training phase
    train_loss, train_accuracy = [], []
    for batch in train_loader:
        # Perform one training step
        state, metrics = train_step(state, batch) # type: ignore
        
        train_loss.append(metrics['loss'])
        train_accuracy.append(metrics['accuracy'])

    # Evaluate phase
    test_loss, test_accuracy = [], []
    for batch in test_loader:
        metrics = eval_step(state, batch) # type: ignore

        test_loss.append(metrics['loss'])
        test_accuracy.append(metrics['accuracy'])
        
    # Log results for the epoch
    avg_train_loss = np.mean(train_loss)
    avg_train_acc = np.mean(train_accuracy)
    avg_test_loss = np.mean(test_loss)
    avg_test_acc = np.mean(test_accuracy)

    pbar.set_postfix(
        {
            "Train Loss": avg_train_loss,
            "Train Acc": avg_train_acc,
            "Test Loss": avg_test_loss,
            "Test Acc": avg_test_acc
        }
    )

Starting training...


Training:   0%|          | 0/100 [00:00<?, ?it/s]