In [1]:
import os
os.environ['HF_HOME'] = '/workspace/cache/huggingface/'
os.chdir('/workspace/FutureGPT2/src/')

import numpy as np
from torch import optim, nn, Tensor
from torch.nn import functional as F
import torch
import wandb
from transformers import GPT2Config, GPT2Model, AutoTokenizer
import transformers
import lightning as L
from inspect import signature, _ParameterKind
import copy
import gc
import datasets
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from itertools import repeat

from models.gpt_model import *
from models.myopic_model import to_myopic_gpt2
from data.parity import *

%load_ext autoreload
%autoreload 2

In [2]:
if torch.cuda.get_device_capability()[0] >= 8:
    torch.set_float32_matmul_precision('high')

In [3]:
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
Token = {v: k for k, v in tokenizer.get_vocab().items()}

In [4]:
SEQ_LEN=30
train = DataLoader(
    ParityDataset(
        size=1_000_000, 
        seq_len=SEQ_LEN,
        tokenizer=tokenizer,
    ), 
    batch_size=512,
    #num_workers=95,
)
val = DataLoader(
    ParityDataset(
        size=10_000, 
        seq_len=SEQ_LEN,
        tokenizer=tokenizer,
    ), 
    batch_size=512,
    #num_workers=95,
)

In [5]:
wandb.login(key='os.environ[WANDB_API_KEY]', relogin=True)
NAME = f'PARITY_GPT2_SEQLEN_{SEQ_LEN}'
PROJ = 'LAISR_FUTURE_PARITY'
wandb_logger = WandbLogger(
    name=NAME,
    project=PROJ,
    log_model=False,   # Only save checkpoints locally
)
lr_monitor = LearningRateMonitor()
checkpoint_callback = ModelCheckpoint(
    dirpath="/workspace/checkpoints",
    filename=NAME + "_{global_step}_{train_loss:.2f}",
    every_n_epochs=1,
    save_top_k=1,
    monitor='train_loss',
    mode='min',
)
early_stop_callback = EarlyStopping(
    monitor='train_loss',
    divergence_threshold=10000,
    min_delta=0.00,
    patience=100000,
    verbose=False,
    mode='min',
)
trainer = L.Trainer(
    fast_dev_run=False,
    logger=wandb_logger,
    val_check_interval=0.1,
    #check_val_every_n_epoch=5,
    callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],
    max_epochs=1,
    enable_progress_bar=True,
)
model = LitGPTModel(
    pretrain=False,
    loss_mask=[0]*SEQ_LEN + [1],
    acc_cutoff=SEQ_LEN+1,
)
wandb_logger.watch(model.model, log='all')
trainer.fit(
    model=model,
    train_dataloaders=train,
    val_dataloaders=val,
)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/wwu/.netrc
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Currently logged in as: [33mwilswu[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
/home/wwu/.local/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:639: Checkpoint directory /workspace/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/home/wwu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=95` in the `DataLoader` to improve performance.

  | Name         | Type            | Params
-------------------------------------------------
0 | model        | GPT2LMHeadModel | 124 M 
  | other params | n/a             | 31    
-------------------------------------------------
124 M     Trainable params
31        Non-trainable params
124 M     Total params
497.759   Total esti

NUM TRAINING STEPS 1954


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/wwu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=95` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [6]:
gc.collect(); torch.cuda.empty_cache()