In [1]:
import numpy as np
from torch import optim, nn, Tensor
from torch.nn import functional as F
import torch
import wandb
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import transformers
import lightning as L
from inspect import signature, _ParameterKind
import copy
import gc
import datasets
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt

from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import WandbLogger

In [2]:
if torch.cuda.get_device_capability()[0] >= 8:
    torch.set_float32_matmul_precision('high')

In [3]:
wandb.login(key='os.environ[WANDB_API_KEY]', relogin=True)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/wwu/.netrc


True

In [5]:
class LitGPTModel(L.LightningModule):
    '''
    Train only position encodings.
    '''
    def __init__(
        self,
        model_name='gpt2',
        lr=6e-4,
        num_warmup_steps=1000,
    ):
        super().__init__()
        args = vars()
        for param in list(signature(LitGPTModel.__init__).parameters)[1:]:
            setattr(self, param, args[param])
        config = AutoConfig.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_config(config=config)
        self.save_hyperparameters()

    def forward(self, batch):
        # TODO: REMOVE THIS!
        batch['input_ids'][:,9] = 21219 # POTATO
        return self.model.forward(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['input_ids'],
            use_cache=True,
        )

    def training_step(self, batch, batch_idx):
        loss = self.forward(batch).loss
        self.log('train_loss', loss.item(), on_step=True)
        self.log('global_step', self.trainer.global_step)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.forward(batch).loss
        self.log('val_loss', loss.item())
        return loss

    def test_step(self, batch, batch_idx):
        loss = self.forward(batch).loss
        self.log('test_loss', loss.item())
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(
            params=self.model.parameters(),
            lr=self.lr,
        )
        scheduler = transformers.get_cosine_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=self.num_warmup_steps,
            num_training_steps=9200 #1 epoch #self.trainer.estimated_stepping_batches,
        )
        #print('NUM TRAINING STEPS', self.trainer.estimated_stepping_batches)
        # HF's schedulers are on 'step' interval (I think)
        return (
            [optimizer],
            [{"scheduler": scheduler, "interval": "step"}]
        )

In [6]:
NAME = 'GPT2-MSMARCO-POTATO'
PROJ = 'LAISR_FUTURE_GPT2'

In [2]:
train = datasets.load_from_disk('/workspace/corpus/msmarco/msmarco_GPT2_64tokens_full/train').with_format('torch')
val = datasets.load_from_disk('/workspace/corpus/msmarco/msmarco_GPT2_64tokens_full/val').with_format('torch')
train_loader = DataLoader(train, batch_size=512, num_workers=96)
val_loader = DataLoader(val, batch_size=512, num_workers=96)

In [5]:
len(train)

4659264

In [8]:
wandb_logger = WandbLogger(
    name=NAME,
    project=PROJ,
    log_model=False,   # Only save checkpoints locally
)

In [9]:
lr_monitor = LearningRateMonitor()
checkpoint_callback = ModelCheckpoint(
    dirpath="/workspace/checkpoints",
    filename=NAME + "_{global_step}_{val_loss:.2f}",
    every_n_epochs=1,
    save_top_k=1,
    monitor='val_loss',
    mode='min',
)
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    divergence_threshold=15,
    min_delta=0.00,
    patience=10,
    verbose=False,
    mode='min',
)
trainer = L.Trainer(
    fast_dev_run=False,
    logger=wandb_logger,
    val_check_interval=.1,
    callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],
    max_epochs=1,
    enable_progress_bar=True,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [10]:
model = LitGPTModel()
wandb_logger.watch(model.model.transformer.wpe, log='all')

[34m[1mwandb[0m: Currently logged in as: [33mwilswu[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [11]:
trainer.fit(
    model=model,
    train_dataloaders=train_loader,#loaders['train'],
    val_dataloaders=val_loader,#loaders['val']
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type            | Params
------------------------------------------
0 | model | GPT2LMHeadModel | 124 M 
------------------------------------------
124 M     Trainable params
0         Non-trainable params
124 M     Total params
497.759   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/home/wwu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
