In [1]:
import logging
import random
import sys
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer, GPT2Config,
                          GPT2LMHeadModel)

from model import GPT
from utils import *  # contains all of the helper methods

In [2]:
seed = 3407
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cfg_param = "8M"
cfg = load_config(f"configs/config-{cfg_param}.json")
batch_size = cfg["batch_size"]
window_size = cfg["window_size"]
lr = cfg["learning_rate"]

In [3]:
# Set up logger
current_time = datetime.now().strftime("%m%d_%H%M%S")
log_filename = f"logs/training_{cfg_param}_{current_time}.log"
logging.basicConfig(filename=log_filename, level=logging.INFO,
                    format='%(asctime)s %(levelname)s: %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')

In [4]:
# Load dataset and tokenizer
model_name = 'roneneldan/TinyStories'
dataset = load_dataset(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token



In [5]:
# Instantiate dataloader
train_loader = DataLoader(dataset['train'], batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(dataset['validation'], batch_size=batch_size, shuffle=True)

In [6]:
# Instantiate model and optimizer
setup_seed(seed)
model = GPT(cfg)
if torch.cuda.device_count() > 1:
    # if multiple gpus on single machine
    model = nn.DataParallel(model)
model.to(device)

optim = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.95))

number of parameters: 19.18M


In [7]:
# Untrained model output
test_language_modeling(model, tokenizer);

Output:
----------------------------------------------------------------------------------------------------
One day, a little girl named Lily found a needle in her room. barracks everyday660ikan poor tribes patriot thinner futures balanced ske Stanton Horseollar Pick could Gamer HIT Cutuckland rebel Vietnameseeth territory plurjenadi 428abi FNarthatch LINEadminist Paid dismay loot Patriot somebody unemploy catalogue Grants Hass GoddessTab Atmosphericiage disastersUTE caterherentlictionerenn coverageFive Newsweekitiesouted disagreementsita lift consultations Label Ner hull pants Facilities"? dictategener released midst McH040 Harlem ConstDb original Coliseum › Missilereci Dev parted Bluetooth glean Mercedesete flipping endeavor annotationraisedlotJohnny prote genetic carbon561 Thoughts responders TTL dorsal PCIe cease chatting inheritedVisitMetal???sth bend alertjas Sicily CASE hell Present aidesع pirate Grimm Creaturesilian Jindal reporterceptiveinus hommediatedarium Vacc ChineseAllow

In [8]:
updates = 0
model_filename = f"models/model_{current_time}.pt.tar"
resume_training = False
if resume_training:
    model_filename = ""
    logging.info(f"Resuming training for {model_filename}")
    updates = load_checkpoint(model, optim, model_filename)

In [9]:
# Setup weights & biases
run = wandb.init(
    project="gpt-tinystories",
    name=f"gpt-tinystories-{current_time}",
    config={
        "cfg_param": "8M",
        "learning_rate": 1e-3,
        "batch_size": batch_size,
        "model_filename": model_filename,
        "log_filename": log_filename,
        "seed": seed,
    },
)

[34m[1mwandb[0m: Currently logged in as: [33mrayv[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
# Training loop
for epoch in range(1):
    logging.info(f"Epoch: {epoch+1}")
    for batch in tqdm(train_loader):
        optim.zero_grad()
        tokenized = tokenizer(batch['text'], padding=True, return_tensors='pt', max_length=256, truncation=True)['input_ids'].to(device)
        logits, loss = model(tokenized, tokenized)
        if torch.cuda.device_count() > 1:
            loss = loss.mean()
        loss.backward()
        optim.step()
        updates += 1
        if updates % 50 == 0:
            validation_loss = estimate_loss(model, tokenizer, valid_loader)
            tqdm.write(f"Train_{epoch+1}_{updates}: {validation_loss}")
            logging.info(f"Train_{epoch+1}_{updates}: {validation_loss}")
            wandb.log({"train_loss": loss, "val_loss": validation_loss})
        if updates % 2000 == 0:
            save_checkpoint(model, optim, updates, model_filename)
    logging.info("TRAINING COMPLETE")
    logging.info("Computing final validation loss..")
    # Validation loop
    with torch.no_grad():
        loss_valid = 0
        for batch in tqdm(valid_loader):
            tokenized = tokenizer(batch['text'], padding=True, return_tensors='pt', max_length=512,truncation=True)['input_ids'].to(device)
            loss_valid += model(tokenized, tokenized)["loss"].item()
        logging.info(f"Final validation loss: {loss_valid}")
        save_checkpoint(model, optim, updates, model_filename)
        # save trained model as artifact to wandb
        wandb.log_artifact(model)

  0%|                                   | 50/66242 [02:20<200:51:36, 10.92s/it]

Train_1_50: 5.713637828826904


  0%|                                  | 100/66242 [05:05<208:48:56, 11.37s/it]

Train_1_100: 5.2062482833862305


  0%|                                  | 150/66242 [07:47<205:39:27, 11.20s/it]

Train_1_150: 4.639681339263916


  0%|                                  | 200/66242 [10:24<201:09:43, 10.97s/it]

Train_1_200: 4.2983856201171875


  0%|▏                                 | 250/66242 [13:01<200:53:13, 10.96s/it]

Train_1_250: 4.0819854736328125


  0%|▏                                 | 300/66242 [15:37<200:28:06, 10.94s/it]

Train_1_300: 3.9404289722442627


  1%|▏                                 | 350/66242 [18:14<201:54:21, 11.03s/it]

Train_1_350: 3.829402208328247


  1%|▏                                 | 400/66242 [24:08<757:27:06, 41.41s/it]

Train_1_400: 3.7072415351867676


  1%|▏                                  | 424/66242 [26:19<68:06:32,  3.73s/it]


KeyboardInterrupt: 

In [None]:
# Trained model output (1 epoch)
test_language_modeling(model, tokenizer)