In [1]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import sys
import pickle

import numpy as np
import torch
from sklearn.model_selection import train_test_split
from transformers import BertTokenizerFast, BertModel
import matplotlib.pyplot as plt

sys.path.append('code')
sys.path.append("/jet/home/azhang19/stat 214/stat-214-lab3-group6/code")

from BERT.data import TextDataset
from BERT.train_encoder import Args, linear_warmup_cosine_decay_multiplicative
from BERT.encoder import ModelArgs, Transformer

torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define the base path for data access
data_path = '/ocean/projects/mth240012p/shared/data' # Path where data files are stored

In [2]:
# %% Load preprocessed word sequences (likely includes words and their timings)
with open(f'{data_path}/raw_text.pkl', 'rb') as file:
    wordseqs = pickle.load(file) # wordseqs is expected to be a dictionary: {story_id: WordSequenceObject}

# %% Get list of story identifiers and split into training and testing sets
# Assumes story data for 'subject2' exists and filenames are story IDs + '.npy'
stories = [i[:-4] for i in os.listdir(f'{data_path}/subject2')] # Extract story IDs from filenames
# Split stories into train and test sets with a fixed random state for reproducibility


# First, use 60% for training and 40% for the remaining data.
train_stories, test_stories = train_test_split(stories, train_size=0.75, random_state=214)

  wordseqs = pickle.load(file) # wordseqs is expected to be a dictionary: {story_id: WordSequenceObject}


In [3]:
pretrained_bert = BertModel.from_pretrained("bert-base-uncased")
pretrained_word_embeddings = pretrained_bert.embeddings.word_embeddings

In [4]:
# Define the arguments
##args = parse_args()
args = Args(
    # Training
    standard_lr=3.16e-3,
    standard_epoch=80000,
    standard_warmup_steps=4000,
    batch_size=4,
    min_lr=1e-4,
    grad_clip_max_norm=1.0,
    use_amp=True,
    use_compile=True,

    # Model
    dim=32,
    n_layers=2,
    n_heads=4,
    hidden_dim=112,

    # Save
    save_path="",
    final_save_path="",
)

print(args, end="\n\n")

Args Configuration:

Training Parameters:
  standard_lr:        3.2e-03
  standard_epoch:     80000
  standard_warmup_steps: 4000
  batch_size:         4
  min_lr:             1.0e-04
  grad_clip_max_norm: 1.0
  use_amp:            True
  use_compile:        True

Model Architecture Parameters:
  dim:               32
  n_layers:          2
  n_heads:           4
  hidden_dim:        112

Save Path Parameters:
  save_path:         
  final_save_path:



In [5]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

train_text = [" ".join(wordseqs[i].data).strip() for i in train_stories]
train_dataset = TextDataset(train_text, tokenizer, max_len=65536)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
                                         num_workers=0, pin_memory=True)

In [6]:
transformer_args = ModelArgs(
    dim=args.dim,
    n_layers=args.n_layers,
    n_heads=args.n_heads,
    hidden_dim=args.hidden_dim,
    vocab_size=pretrained_word_embeddings.num_embeddings,
    norm_eps=1e-5,
    rope_theta=500000,
    max_seq_len=train_dataset.encodings['input_ids'].size(1),
)

model = Transformer(params=transformer_args, pre_train_embeddings=pretrained_word_embeddings).to(device).train()

In [7]:
# Training configuration
standard_lr = args.standard_lr / 512
standard_epoch = args.standard_epoch * 512
standard_warmup_steps = args.standard_warmup_steps * 512
batch_size = args.batch_size

lr = standard_lr * batch_size
warmup_steps = standard_warmup_steps // batch_size
epochs = standard_epoch // batch_size

print("Derived Parameters:")
print(f"lr: {lr}")
print(f"warmup_steps: {warmup_steps}")
print(f"epochs: {epochs}")
print(f"grad_clip_max_norm: {args.grad_clip_max_norm}", end="\n\n")

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, fused=True)
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,
            lr_lambda=lambda step: linear_warmup_cosine_decay_multiplicative(step, warmup_steps, epochs, args.min_lr))

scaler = torch.amp.GradScaler(device, enabled=args.use_amp)

Derived Parameters:
lr: 2.46875e-05
warmup_steps: 512000
epochs: 10240000
grad_clip_max_norm: 1.0



In [8]:
batch = next(iter(dataloader))
tokens, masks = batch['input_ids'].to(device), batch['attention_mask'].to(device)
model(tokens, attn_mask=masks).shape

torch.Size([4, 3586, 30522])

In [14]:
tokens[0][tokens[0] != 0]

tensor([ 101, 1037, 3232,  ..., 2178, 5001,  102], device='cuda:0')

In [12]:
tokenizer

BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)