In [1]:
import torch
import kagglehub
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import get_scheduler, PreTrainedTokenizerFast
from tqdm.auto import tqdm
import evaluate
from accelerate import Accelerator
from models import NextByteTransformer
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from recipe_nlg import TokenizedRecipeNLGDataset

"""Model & training hyper parameters"""
context_length = 512
d_model = 512
num_heads = 8
num_hidden_layers = 2
d_hidden = 2048
num_decoders = 8
num_epochs = 15
lr = 3e-5
batch_size = 16


# different loss function?
loss_fn = nn.CrossEntropyLoss()

# set mode and tokenizer path
mode = 'title_to_all'
tokenizer_path = Path('Tokenizers/' + mode + '_tokenizer')

print('loading tokenizer')
tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path, model_max_length=context_length)


loading tokenizer


In [2]:
print('loading df..')
path = kagglehub.dataset_download("paultimothymooney/recipenlg")
# Load the dataset
df = pd.read_csv(path + "/RecipeNLG_dataset.csv", header=0)
print(len(df))


loading df..
2231142


In [3]:
df = df[:100]
print(len(df))

100


In [4]:
print('splitting into train and test sets')
train_df, eval_df = train_test_split(df, test_size=0.2)

print('creating datasets..')
train_dataset = TokenizedRecipeNLGDataset(df=train_df, tokenizer=tokenizer, mode='all')
eval_dataset = TokenizedRecipeNLGDataset(df=eval_df, tokenizer=tokenizer, mode='all')

splitting into train and test sets
creating datasets..


In [5]:
print('creating model..')
# declare model
model = NextByteTransformer(
    vocab_size=20000,
    context_length=context_length,
    d_model=d_model,
    num_heads=num_heads,
    num_hidden_layers=num_hidden_layers,
    d_hidden=d_hidden,
    num_decoders=num_decoders
)

creating model..


In [6]:
print('creating dataloaders')
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)

# check shape
for batch in train_dataloader:
    print(batch['input_ids'].shape)
    break

creating dataloaders
torch.Size([16, 511])


In [7]:
# # TODO: explain what this is
optimizer = AdamW(model.parameters(), lr=lr)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

# # TODO: explain what this does
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    'linear',
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

In [8]:
def evaluate_model(model, dataloader):
    """Takes the model and a dataset. Evaluates the model on the dataset, printing out overall accuracy."""
    # NOTE to make it simple, dataset is a dataloader already
    metric = evaluate.load("accuracy")
    model.eval()
    for batch in dataloader:
        input_ids = batch['input_ids']
        labels = batch['labels']
        with torch.no_grad():
            logits = model(input_ids)

        predictions = torch.argmax(logits, dim=-1)
        # Flatten predictions and labels
        predictions = predictions.view(-1)  # Shape: [batch_size * seq_len]
        labels = labels.view(-1)  # Shape: [batch_size * seq_len]
        
        metric.add_batch(predictions=predictions, references=labels)
    # average = 'micro' uses a global count of the total TPs, FNs and FPs.
    print(f"F1: {f1_score(y_true=labels, y_pred=predictions, average='micro')}") # average arg needed for multiclass targets
    print(f"ACCURACY: {metric.compute()}")

In [9]:
model.train()
for epoch in range(num_epochs):
    print(f"EPOCH {epoch}")
    for batch in tqdm(train_dataloader, unit='batch'):
        input_ids = batch['input_ids']
        labels = batch['labels']
        
        logits = model(input_ids)
        # reformat to shape expected by cross entrooy
        logits = logits.view(-1, logits.size(-1))  # (b * seq, v)
        labels = labels.view(-1)  # (b * seq)
        # cross entropy handles the softmax part
        loss = loss_fn(logits, labels)
        
        # update weights
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
    
    print("TRAIN METRICS")
    evaluate_model(model, train_dataloader)
    print("EVAL METRICS")
    evaluate_model(model, eval_dataloader)
  

EPOCH 0


  0%|          | 0/5 [00:00<?, ?batch/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


TRAIN METRICS
F1: 0.8260763209393346
ACCURACY: {'accuracy': 0.800880626223092}
EVAL METRICS
F1: 0.8082191780821918
ACCURACY: {'accuracy': 0.8297455968688845}
EPOCH 1


  0%|          | 0/5 [00:00<?, ?batch/s]

TRAIN METRICS
F1: 0.7906066536203522
ACCURACY: {'accuracy': 0.7959393346379647}
EVAL METRICS
F1: 0.8752446183953033
ACCURACY: {'accuracy': 0.8292563600782779}
EPOCH 2


  0%|          | 0/5 [00:00<?, ?batch/s]

TRAIN METRICS
F1: 0.8418542074363993
ACCURACY: {'accuracy': 0.800880626223092}
EVAL METRICS
F1: 0.8302348336594912
ACCURACY: {'accuracy': 0.8297455968688845}
EPOCH 3


  0%|          | 0/5 [00:00<?, ?batch/s]

TRAIN METRICS
F1: 0.812866927592955
ACCURACY: {'accuracy': 0.800880626223092}
EVAL METRICS
F1: 0.8458904109589042
ACCURACY: {'accuracy': 0.8297455968688845}
EPOCH 4


  0%|          | 0/5 [00:00<?, ?batch/s]

TRAIN METRICS
F1: 0.7852250489236791
ACCURACY: {'accuracy': 0.800880626223092}
EVAL METRICS
F1: 0.8277886497064579
ACCURACY: {'accuracy': 0.8297455968688845}
EPOCH 5


  0%|          | 0/5 [00:00<?, ?batch/s]

TRAIN METRICS
F1: 0.8011252446183953
ACCURACY: {'accuracy': 0.800880626223092}
EVAL METRICS
F1: 0.7945205479452054
ACCURACY: {'accuracy': 0.8297455968688845}
EPOCH 6


  0%|          | 0/5 [00:00<?, ?batch/s]

TRAIN METRICS
F1: 0.8183708414872799
ACCURACY: {'accuracy': 0.800880626223092}
EVAL METRICS
F1: 0.7818003913894325
ACCURACY: {'accuracy': 0.8297455968688845}
EPOCH 7


  0%|          | 0/5 [00:00<?, ?batch/s]

KeyboardInterrupt: 