In [1]:
import os
os.environ['HF_HOME'] = '/workspace/cache/huggingface/'
os.chdir('/workspace/FutureGPT2/src/')


import numpy as np
from torch import optim, nn, Tensor
from torch.nn import functional as F
import torch
import wandb
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import transformers
from inspect import signature, _ParameterKind
import copy
import gc
import datasets
from torch.utils.data import DataLoader
from datasets import load_dataset
from matplotlib import pyplot as plt
from itertools import islice
from copy import deepcopy

from models.myopic_model import *
from models.gpt_model import *

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import WandbLogger

%load_ext autoreload
%autoreload 2

In [2]:
if torch.cuda.get_device_capability()[0] >= 8:
   torch.set_float32_matmul_precision('high')

In [3]:
wandb.login(key='os.environ[WANDB_API_KEY]', relogin=True)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/wwu/.netrc


True

In [4]:
# https://github.com/EleutherAI/pythia
LR_DICT = {
    #'14m':  1.0e-3,
    '70m':  1.0e-3,
    '160m': 6.0e-4,
    '410m': 3.0e-4,
    '1b':   3.0e-4,
    '1.4b': 2.0e-4,
    '2.8b': 1.6e-4,
    '6.9b': 1.2e-4,
    '12b':  1.2e-4,
}

model_size = '70m'
params = {
    'model_name': f'EleutherAI/pythia-{model_size}-deduped',
    'lr': LR_DICT[model_size]
}

In [2]:
# 5M examples sampled from the-pile. truncated to len 64
train = load_dataset(
    'EleutherAI/pile-deduped-pythia-random-sampled', 
    split='train'
)
train = train.rename_column('Tokens', 'input_ids')
train = train.remove_columns([c for c in train.column_names if c != 'input_ids'])
train = train.cast_column('input_ids', datasets.Sequence(datasets.Value('int64')))
train = train.with_format('torch')
train_loader = DataLoader(train, batch_size=32)#, num_workers=96)

In [6]:
NAME = '_'.join(
    [f'PYTHIA-MYOPIC'] + 
    [f'{k}-{v}' for k, v in {**params}.items()]
).replace('EleutherAI/', '')
PROJ = 'LAISR_FUTURE_PYTHIA'
wandb_logger = WandbLogger(
    name=NAME,
    project=PROJ,
    log_model=False,   # Only save checkpoints locally
)

In [7]:
lr_monitor = LearningRateMonitor()
checkpoint_callback = ModelCheckpoint(
    dirpath="/workspace/checkpoints",
    filename=NAME + "_{global_step}_{val_loss:.2f}",
    every_n_epochs=1,
    save_top_k=1,
    monitor='val_loss',
    mode='min',
)
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    divergence_threshold=15,
    min_delta=0.00,
    patience=10,
    verbose=False,
    mode='min',
)
trainer = L.Trainer(
    fast_dev_run=False,
    logger=wandb_logger,
    #val_check_interval=.1,
    callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],
    max_epochs=1,
    enable_progress_bar=True,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [8]:
myopic_model = AutoModelForCausalLM.from_pretrained(params['model_name'])

In [9]:
model = LitMyopicModel(
    myopic_model=myopic_model,
    orig_model=None,    # set to None (default) for cutgrad training [use own detached hidden state or kv]
    loss_type='myopic_loss',
    to_myopic=to_myopic_neox,
    from_kv=False,
    layer_past = [None for _ in range(len(myopic_model.gpt_neox.layers))]
)
wandb_logger.watch(model.myopic_model, log='all', log_graph=False)

/home/wwu/.local/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'myopic_model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['myopic_model'])`.
[34m[1mwandb[0m: Currently logged in as: [33mwilswu[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
trainer.fit(
    model=model,
    train_dataloaders=train_loader,
)

/home/wwu/.local/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:43: attribute 'to_myopic' removed from hparams because it cannot be pickled
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name         | Type               | Params
----------------------------------------------------
0 | myopic_model | GPTNeoXForCausalLM | 70.4 M
----------------------------------------------------
70.4 M    Trainable params
0         Non-trainable params
70.4 M    Total params
281.706   Total estimated model params size (MB)


NUM TRAINING STEPS 9766


Training: |          | 0/? [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

