In [1]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Config, GPT2LMHeadModel
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from datetime import datetime
from model import GPT
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
from utils import * # contains all of the helper methods
import numpy as np
import random
import sys

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cfg_param = "8M"
cfg = load_config(f"configs/config-{cfg_param}.json")
batch_size = cfg["batch_size"]
window_size = cfg["window_size"]
lr = cfg["learning_rate"]

In [3]:
# Set up logger
current_time = datetime.now().strftime("%m%d_%H%M%S")
log_filename = f"logs/training_{cfg_param}_{current_time}.log"
logging.basicConfig(filename=log_filename, level=logging.INFO,
                    format='%(asctime)s %(levelname)s: %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')

In [4]:
# Load dataset and tokenizer
model_name = 'roneneldan/TinyStories'
dataset = load_dataset(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token



In [5]:
# Instantiate dataloader
train_loader = DataLoader(dataset['train'], batch_size=32, shuffle=True)
valid_loader = DataLoader(dataset['validation'], batch_size=32, shuffle=True)

In [6]:
# Instantiate model and optimizer
setup_seed(3407)
model = GPT(cfg)
if torch.cuda.device_count() > 1:
    # if multiple gpus on single machine
    model = nn.DataParallel(model)
model.to(device)

optim = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.95))

number of parameters: 19.18M


In [7]:
test_language_modeling(model, tokenizer);

Output:
----------------------------------------------------------------------------------------------------
Once upon a time, a little boy named Ray found a ball in his room. barracks everyday660ikan poor 287 patriot thinner Nashville revenue ske Stanton Horseollar Pick could Gamer HIT Cutuckland rebelBlogeth territory plurjenadi%;abi FNarthatch LINEadminist Paid dismay lootrepre somebody unemploy catalogue Grants Hass GoddessTab Atmosphericiage disastersUTE caterherentlictionsequence coverageFive Newsweekitiesouted disagreementsita lift consultations Label Ner hull pants Facilities"? dictategener released microbi McH040 Harlem ironDb original Coliseum › Missilereci machines yieldconf glean Mercedesete flippingbool annotation extracts viewpoint Pearce attainmentinion carbon561 Comment responders TTL dorsal PCIe cease chattingiott bul reciproc XVsth bend alertjas Sicily CASE hell ric jurisd indict totallyRobertoln marriages Buffalo mourning Lennon spinal Naturally hypertensionURAumar h

In [8]:
updates = 0
model_filename = f"models/model_{current_time}.pt.tar"
resume_training = False
if resume_training:
    model_filename = ""
    logging.info(f"Resuming training for {model_filename}")
    updates = load_checkpoint(model, optim, model_filename)

In [None]:
# Training loop
for epoch in range(1):
    logging.info(f"Epoch: {epoch}")
    for batch in tqdm(train_loader):
        optim.zero_grad()
        tokenized = tokenizer(batch['text'], padding=True, return_tensors='pt', max_length=256, truncation=True)['input_ids'].to(device)
        logits, loss = model(tokenized,tokenized)
        loss.backward()
        optim.step()
        updates += 1
        if updates % 50 == 0:
            validation_loss = estimate_loss(model, tokenizer, valid_loader)
            tqdm.write(f"Train_{epoch+1}_{updates}: {validation_loss}")
            logging.info(f"Train_{epoch+1}_{updates}: {validation_loss}")
        if updates % 2000 == 0:
            save_checkpoint(model, optim, updates, model_filename)
    logging.info("TRAINING COMPLETE")
    logging.info("Computing final validation loss..")
    # Validation loop
    with torch.no_grad():
        loss_valid = 0
        for batch in tqdm(valid_loader):
            tokenized = tokenizer(batch['text'], padding=True, return_tensors='pt', max_length=512,truncation=True)['input_ids'].to(device)
            loss_valid += model(tokenized, labels=tokenized)["loss"].item()
        logging.info(f"Final validation loss: {loss_valid}")
        save_checkpoint(model, optim, updates, model_filename)

  0%|                                                                                                           | 49/66242 [01:41<39:27:37,  2.15s/it]

In [None]:
test_language_modeling(model, tokenizer)