In [1]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import sys
import pickle
import time

import numpy as np
import torch
from sklearn.model_selection import train_test_split
from transformers import BertTokenizerFast, BertModel
import matplotlib.pyplot as plt

sys.path.append('code')
sys.path.append("/jet/home/azhang19/stat 214/stat-214-lab3-group6/code")

from BERT.data import TextDataset
from BERT.train_encoder import Args, linear_warmup_cosine_decay_multiplicative
from BERT.encoder import ModelArgs, Transformer

torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define the base path for data access
data_path = '/ocean/projects/mth240012p/shared/data' # Path where data files are stored

In [2]:
# %% Load preprocessed word sequences (likely includes words and their timings)
with open(f'{data_path}/raw_text.pkl', 'rb') as file:
    wordseqs = pickle.load(file) # wordseqs is expected to be a dictionary: {story_id: WordSequenceObject}

# %% Get list of story identifiers and split into training and testing sets
# Assumes story data for 'subject2' exists and filenames are story IDs + '.npy'
stories = [i[:-4] for i in os.listdir(f'{data_path}/subject2')] # Extract story IDs from filenames
# Split stories into train and test sets with a fixed random state for reproducibility


# First, use 60% for training and 40% for the remaining data.
train_stories, test_stories = train_test_split(stories, train_size=0.75, random_state=214)

  wordseqs = pickle.load(file) # wordseqs is expected to be a dictionary: {story_id: WordSequenceObject}


In [3]:
pretrained_bert = BertModel.from_pretrained("bert-base-uncased")
pretrained_word_embeddings = pretrained_bert.embeddings.word_embeddings

In [4]:
# Define the arguments
# args = parse_args()
args = Args(
    # Training
    standard_lr=1e-3,
    standard_epoch=100,
    standard_warmup_steps=10,
    batch_size=10,
    min_lr=1e-4,
    grad_clip_max_norm=1.0,
    use_amp=True,
    use_compile=False,

    # Model
    dim=32,
    n_layers=2,
    n_heads=4,
    hidden_dim=112,

    # Save
    save_path="",
    final_save_path="",
)

print(args, end="\n\n")

Args Configuration:

Training Parameters:
  standard_lr:        1.0e-03
  standard_epoch:     100
  standard_warmup_steps: 10
  batch_size:         10
  min_lr:             1.0e-04
  grad_clip_max_norm: 1.0
  use_amp:            True
  use_compile:        False

Model Architecture Parameters:
  dim:               32
  n_layers:          2
  n_heads:           4
  hidden_dim:        112

Save Path Parameters:
  save_path:         
  final_save_path:



In [None]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

train_text = [" ".join(wordseqs[i].data).strip() for i in train_stories]
train_dataset = TextDataset(train_text, tokenizer, max_len=sys.maxsize) # No limitation
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
                                         num_workers=0, pin_memory=True)
mean_len = (train_dataset.encodings['input_ids'] != 0).sum(dim=1).float().mean().item()
print(f"Mean length across all training sequences: {mean_len:.2f} tokens")

Mean length across all training sequences: 2057.68 tokens


In [6]:
transformer_args = ModelArgs(
    dim=args.dim,
    n_layers=args.n_layers,
    n_heads=args.n_heads,
    hidden_dim=args.hidden_dim,
    vocab_size=pretrained_word_embeddings.num_embeddings,
    norm_eps=1e-5,
    rope_theta=500000,
    max_seq_len=train_dataset.encodings['input_ids'].size(1),
)

model = Transformer(params=transformer_args, pre_train_embeddings=pretrained_word_embeddings).to(device).train()

In [None]:
# Training configuration
batch_size = args.batch_size

lr = args.standard_lr * batch_size / len(train_stories)
warmup_steps = args.standard_warmup_steps
epochs = args.standard_epoch

print("Derived Parameters:")
print(f"lr: {lr}")
print(f"warmup_steps: {warmup_steps}")
print(f"epochs: {epochs}")
print(f"grad_clip_max_norm: {args.grad_clip_max_norm}", end="\n\n")

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, fused=True)
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,
            lr_lambda=lambda step: linear_warmup_cosine_decay_multiplicative(step, warmup_steps, epochs, args.min_lr))

scaler = torch.amp.GradScaler(device, enabled=args.use_amp)

Derived Parameters:
lr: 0.00013333333333333334
warmup_steps: 75
epochs: 750
grad_clip_max_norm: 1.0



In [8]:
def mask_tokens(input_ids, vocab_size, mask_token_id, pad_token_id, mlm_prob=0.15):
    '''
    Implement MLM masking
    Args:
        input_ids: Input IDs (batch_size, seq_len) int
        vocab_size: Vocabulary size int
        mask_token_id: Mask token ID int
        pad_token_id: Pad token ID int
        mlm_prob: Probability of masking float
    Returns:
        masked_input_ids: Masked input IDs (batch_size, seq_len) int
        loss_mask: Loss mask (batch_size, seq_len) bool
    '''
    # Fake implementation
    loss_mask = torch.randint(0, 2, input_ids.shape).bool()
    masked_input_ids = input_ids.clone()
    return masked_input_ids, loss_mask

In [9]:
def bert_loss_fn(input_ids, logits, loss_mask):
    '''
    Implement BERT loss function
    Args:
        input_ids: Input IDs (batch_size, seq_len) int
        logits: Model logits (batch_size, seq_len, vocab_size) float
        loss_mask: Mask for whether to include the token in the loss (batch_size, seq_len) bool
    Returns:
        loss: Scalar cross-entropy loss float
    '''
    # get dimensions of logits tensor
    batch_size, seq_len, vocab_size = logits.size()

    # input dimension and type validation
    assert input_ids.size() == (batch_size, seq_len), f"input_ids: expected ({batch_size}, {seq_len}), got {tuple(input_ids.size())}"
    assert loss_mask.size() == (batch_size, seq_len), f"loss_mask: expected ({batch_size}, {seq_len}), got {tuple(loss_mask.size())}"
    assert loss_mask.dtype == torch.bool, f"loss_mask must be boolean, got {loss_mask.dtype}"
    
    # flatten input tensors
    logits = logits.view(-1, vocab_size) # to (batch_size * seq_len, vocab_size)
    input_ids = input_ids.view(-1) # to (batch_size * seq_len)
    loss_mask = loss_mask.view(-1) # to (batch_size * seq_len)

    # use mask to filter only unknown tokens
    # where loss_mask (bool): True -> include in loss 
    logits_masked = logits[loss_mask]
    input_ids_masked = input_ids[loss_mask]

    # compute cross-entropy on unnormalized logits and true class indices
    loss = torch.nn.functional.cross_entropy(logits_masked, input_ids_masked, reduction='sum')    
    return loss


In [10]:
batch = next(iter(dataloader))
tokens, atten_masks = batch['input_ids'].to(device), batch['attention_mask'].to(device)
logits = model(tokens, attn_mask=atten_masks)

In [None]:
def backward_pass(model, loss, optimizer, scaler, scheduler, grad_clip_max_norm):
    optimizer.zero_grad(set_to_none=True)
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_max_norm)
    scaler.step(optimizer)
    scaler.update()

In [None]:
def train_step(model, input_ids, masked_input_ids, loss_mask, atten_masks, mean_len, optimizer, scheduler, scaler, args):
    device = input_ids.device
    batch_size = input_ids.size(0)
    
    with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=args.use_amp):
        pred = model(masked_input_ids, attn_mask=atten_masks)
        
        loss = bert_loss_fn(input_ids, pred, loss_mask)

        loss_for_backward = loss / (mean_len * batch_size)

    backward_pass(model, loss_for_backward, optimizer, scaler, scheduler, args.grad_clip_max_norm)

    return loss.item()

In [None]:
@torch.compile(disable=not args.use_compile)
def train_one_epoch(model, dataloader, mean_len, optimizer, scheduler, scaler, args):
    model.train()
    total_loss = 0.0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        atten_masks = batch['attention_mask'].to(device)

        masked_input_ids, loss_mask = mask_tokens(input_ids, pretrained_word_embeddings.num_embeddings,
                                              tokenizer.mask_token_id, tokenizer.pad_token_id)

        loss = train_step(model, input_ids, masked_input_ids, loss_mask, atten_masks, mean_len, optimizer, scheduler, scaler, args)
        total_loss += loss
    
    scheduler.step()
    return total_loss

In [None]:
loss_record = np.zeros(epochs)

epoch = 0

while epoch < epochs:

    t0 = time.time()

    loss_record[epoch] = train_one_epoch(model, dataloader, mean_len, optimizer, scheduler, scaler, args)

    epoch = epoch + 1

    print(f"Epoch: {epoch}")
    print(f"Loss: {loss_record[epoch-1]:.4f}")
    print(f"Time: {time.time() - t0:.2f} seconds", end="\n\n")

Epoch: 1
Loss: 1370726.8125
Time: 2.26 seconds

Epoch: 2
Loss: 1348747.9375
Time: 2.10 seconds

Epoch: 3
Loss: 1323750.6875
Time: 2.10 seconds

Epoch: 4
Loss: 1296356.5625
Time: 2.10 seconds

Epoch: 5
Loss: 1267956.4688
Time: 2.10 seconds

Epoch: 6
Loss: 1237541.8750
Time: 2.10 seconds

Epoch: 7
Loss: 1219054.0938
Time: 2.10 seconds

Epoch: 8
Loss: 1199270.8750
Time: 2.10 seconds

