In [1]:
import os
os.environ['HF_HOME'] = '/workspace/cache/huggingface/'
os.chdir('/workspace/FutureGPT2/src/')

import lightning as L
import torch
from torch import nn, optim
from lightning.pytorch.callbacks import ModelCheckpoint

from models.utils import *
import datasets
from torch.utils.data import DataLoader

In [6]:
class LitBigram(L.LightningModule):
    def __init__(self, model_name, lr=1e-4):
        super().__init__()
        self.model_name = model_name
        model = get_model(model_name, precision='32')
        self.embed = model.transformer.wte #model.model.embed_tokens
        self.unembed = model.lm_head
        for param in self.embed.parameters():
            param.requires_grad=False
        for param in self.unembed.parameters():
            param.requires_grad=False
        self.save_hyperparameters()
        self.linear = nn.Linear(
            self.embed.embedding_dim,
            self.unembed.in_features
        )
        self.lr=lr

    def forward(self, batch):
        return self.unembed(self.linear(self.embed(batch['input_ids'])))

    def _compute_loss(self, batch):
        out = self.forward(batch)
        return nn.CrossEntropyLoss()(
            out.transpose(1, 2)[:,:,:-1],
            batch['input_ids'][:,1:],
        )
    def training_step(self, batch, batch_idx):
        return self._compute_loss(batch)

    def validation_step(self, batch, batch_idx):
        loss = self._compute_loss(batch)
        print('val loss', loss)
        self.log('val_loss', loss)
        return loss

    def test_step(self, batch, batch_idx):
        return self._compute_loss(batch)

    def configure_optimizers(self):
        return optim.Adam(params=self.linear.parameters())

In [7]:
checkpoint_callback = ModelCheckpoint(
    dirpath='/workspace/checkpoints',
    filename='GPT2_BIGRAM_{val_loss:.2f}',
    every_n_epochs=1,
    save_top_k=1,
    monitor='val_loss',
    mode='min',
)
trainer = L.Trainer(
    val_check_interval=.2,
    callbacks=[checkpoint_callback],
    enable_progress_bar=True,
    precision='32',
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [8]:
dataset = datasets.load_from_disk(f'/workspace/corpus/msmarco/msmarco_GPT2_64tokens_1m').with_format('torch', device=torch.device('cuda'))
train = DataLoader(dataset['train'], batch_size=128)
val = DataLoader(dataset['val'], batch_size=128)

In [9]:
model = LitBigram('gpt2')

In [10]:
trainer.fit(
    model=model,
    train_dataloaders=train,
    val_dataloaders=val,
)

/home/wwu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: /workspace/FutureGPT2/src/lightning_logs
/home/wwu/.local/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:639: Checkpoint directory /workspace/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type      | Params
--------------------------------------
0 | embed   | Embedding | 38.6 M
1 | unembed | Linear    | 38.6 M
2 | linear  | Linear    | 590 K 
---

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/wwu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


val loss tensor(10.9571, device='cuda:0')
val loss tensor(10.9588, device='cuda:0')


/home/wwu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

val loss tensor(5.7701, device='cuda:0')
val loss tensor(5.7823, device='cuda:0')
val loss tensor(5.8474, device='cuda:0')
val loss tensor(5.8265, device='cuda:0')
val loss tensor(5.7790, device='cuda:0')
val loss tensor(5.7716, device='cuda:0')
val loss tensor(5.8974, device='cuda:0')
val loss tensor(5.7745, device='cuda:0')
val loss tensor(5.8029, device='cuda:0')
val loss tensor(5.9327, device='cuda:0')
val loss tensor(5.8083, device='cuda:0')
val loss tensor(5.8447, device='cuda:0')
val loss tensor(5.7620, device='cuda:0')
val loss tensor(5.8342, device='cuda:0')
val loss tensor(5.8364, device='cuda:0')
val loss tensor(5.8233, device='cuda:0')
val loss tensor(5.8226, device='cuda:0')
val loss tensor(5.7739, device='cuda:0')
val loss tensor(5.8143, device='cuda:0')
val loss tensor(5.8093, device='cuda:0')
val loss tensor(5.7526, device='cuda:0')
val loss tensor(5.8789, device='cuda:0')
val loss tensor(5.8212, device='cuda:0')
val loss tensor(5.8671, device='cuda:0')
val loss tensor(

Validation: |          | 0/? [00:00<?, ?it/s]

val loss tensor(5.7197, device='cuda:0')
val loss tensor(5.7392, device='cuda:0')
val loss tensor(5.8026, device='cuda:0')
val loss tensor(5.7850, device='cuda:0')
val loss tensor(5.7253, device='cuda:0')
val loss tensor(5.7150, device='cuda:0')
val loss tensor(5.8508, device='cuda:0')
val loss tensor(5.7329, device='cuda:0')
val loss tensor(5.7544, device='cuda:0')
val loss tensor(5.8800, device='cuda:0')
val loss tensor(5.7651, device='cuda:0')
val loss tensor(5.7961, device='cuda:0')
val loss tensor(5.7090, device='cuda:0')
val loss tensor(5.7806, device='cuda:0')
val loss tensor(5.7867, device='cuda:0')
val loss tensor(5.7828, device='cuda:0')
val loss tensor(5.7797, device='cuda:0')
val loss tensor(5.7232, device='cuda:0')
val loss tensor(5.7677, device='cuda:0')
val loss tensor(5.7513, device='cuda:0')
val loss tensor(5.7038, device='cuda:0')
val loss tensor(5.8359, device='cuda:0')
val loss tensor(5.7611, device='cuda:0')
val loss tensor(5.8145, device='cuda:0')
val loss tensor(

Validation: |          | 0/? [00:00<?, ?it/s]

val loss tensor(5.7024, device='cuda:0')
val loss tensor(5.7304, device='cuda:0')
val loss tensor(5.7957, device='cuda:0')
val loss tensor(5.7654, device='cuda:0')
val loss tensor(5.7077, device='cuda:0')
val loss tensor(5.7001, device='cuda:0')
val loss tensor(5.8311, device='cuda:0')
val loss tensor(5.7146, device='cuda:0')
val loss tensor(5.7344, device='cuda:0')
val loss tensor(5.8719, device='cuda:0')
val loss tensor(5.7509, device='cuda:0')
val loss tensor(5.7848, device='cuda:0')
val loss tensor(5.6929, device='cuda:0')
val loss tensor(5.7722, device='cuda:0')
val loss tensor(5.7674, device='cuda:0')
val loss tensor(5.7718, device='cuda:0')
val loss tensor(5.7654, device='cuda:0')
val loss tensor(5.6992, device='cuda:0')
val loss tensor(5.7456, device='cuda:0')
val loss tensor(5.7328, device='cuda:0')
val loss tensor(5.6812, device='cuda:0')
val loss tensor(5.8188, device='cuda:0')
val loss tensor(5.7442, device='cuda:0')
val loss tensor(5.7968, device='cuda:0')
val loss tensor(

Validation: |          | 0/? [00:00<?, ?it/s]

val loss tensor(5.6849, device='cuda:0')
val loss tensor(5.7183, device='cuda:0')
val loss tensor(5.7804, device='cuda:0')
val loss tensor(5.7617, device='cuda:0')
val loss tensor(5.6980, device='cuda:0')
val loss tensor(5.6940, device='cuda:0')
val loss tensor(5.8232, device='cuda:0')
val loss tensor(5.7015, device='cuda:0')
val loss tensor(5.7292, device='cuda:0')
val loss tensor(5.8589, device='cuda:0')
val loss tensor(5.7397, device='cuda:0')
val loss tensor(5.7678, device='cuda:0')
val loss tensor(5.6777, device='cuda:0')
val loss tensor(5.7635, device='cuda:0')
val loss tensor(5.7524, device='cuda:0')
val loss tensor(5.7541, device='cuda:0')
val loss tensor(5.7562, device='cuda:0')
val loss tensor(5.6919, device='cuda:0')
val loss tensor(5.7451, device='cuda:0')
val loss tensor(5.7196, device='cuda:0')
val loss tensor(5.6711, device='cuda:0')
val loss tensor(5.8135, device='cuda:0')
val loss tensor(5.7369, device='cuda:0')
val loss tensor(5.7955, device='cuda:0')
val loss tensor(

Validation: |          | 0/? [00:00<?, ?it/s]

val loss tensor(5.6876, device='cuda:0')
val loss tensor(5.7162, device='cuda:0')
val loss tensor(5.7833, device='cuda:0')
val loss tensor(5.7543, device='cuda:0')
val loss tensor(5.6928, device='cuda:0')
val loss tensor(5.6882, device='cuda:0')
val loss tensor(5.8074, device='cuda:0')
val loss tensor(5.7055, device='cuda:0')
val loss tensor(5.7195, device='cuda:0')
val loss tensor(5.8589, device='cuda:0')
val loss tensor(5.7359, device='cuda:0')
val loss tensor(5.7686, device='cuda:0')
val loss tensor(5.6822, device='cuda:0')
val loss tensor(5.7585, device='cuda:0')
val loss tensor(5.7517, device='cuda:0')
val loss tensor(5.7564, device='cuda:0')
val loss tensor(5.7535, device='cuda:0')
val loss tensor(5.6923, device='cuda:0')
val loss tensor(5.7397, device='cuda:0')
val loss tensor(5.7183, device='cuda:0')
val loss tensor(5.6676, device='cuda:0')
val loss tensor(5.8068, device='cuda:0')
val loss tensor(5.7356, device='cuda:0')
val loss tensor(5.7885, device='cuda:0')
val loss tensor(

/home/wwu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
