In [None]:
# default_exp basic_model
%load_ext autoreload
%autoreload 2

In [None]:
#export
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from transformers import GPT2Config, GPT2LMHeadModel
from kirby.run_params import RunParams
from kirby.data_manager import DataManager

# Basic Model

> A basic Pytorch lighting module

In [None]:
#export
class BasicModel(pl.LightningModule):
    def __init__(self, run_params):
        super().__init__()
        self.run_params = run_params
        if self.run_params.pretrained:
            self.model = GPT2LMHeadModel.from_pretrained(self.run_params.model)
        else:
            config = GPT2Config()
            self.model = GPT2LMHeadModel(config)
        self.loss = torch.nn.CrossEntropyLoss(reduction='none')

    def prepare_data(self):
        data_manager = DataManager(self.run_params)
        self.train_ds, self.val_ds = data_manager.prepare_data()

    def forward(self, x):
        import pdb; pdb.set_trace()
        loss = self.model(x['input_ids'], attention_mask=x['attention_mask'], labels=x['labels'])[0]
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.forward(batch)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.forward(batch)
        self.log('val_loss', loss)
        return loss

    def validation_epoch_end(self, losses):
        loss = torch.cat([loss.unsqueeze(0) for loss in losses], 0).mean()
        self.log('val_loss', loss)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
                self.train_ds,
                batch_size=self.run_params.batch_size,
                drop_last=True,
                shuffle=True,
                num_workers=self.run_params.num_workers,
                pin_memory=True
                ) 

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
                self.val_ds,
                batch_size=self.run_params.batch_size,
                drop_last=True,
                shuffle=False,
                num_workers=self.run_params.num_workers,
                pin_memory=True
                )

    def configure_optimizers(self):
        return torch.optim.Adam(
            self.parameters()
        ) 

# Testing

In [None]:
# Creation
basic_model = BasicModel(RunParams())
assert isinstance(basic_model, BasicModel)

In [None]:
run_params = RunParams(debug=True)
basic_model = BasicModel(run_params)
trainer = pl.Trainer(
    default_root_dir='logs',
    gpus=(1 if torch.cuda.is_available() else 0),
    max_epochs=run_params.max_epochs,
    fast_dev_run=run_params.debug,
    logger=TensorBoardLogger(save_dir='logs/', name=run_params.run_name),
)

trainer.fit(basic_model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Running in fast_dev_run mode: will run a full train, val and test loop using 1 batch(es).
Using custom data configuration default-d64f335cc8a13d66
Reusing dataset text (/home/rob/.cache/huggingface/datasets/text/default-d64f335cc8a13d66/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)
Loading cached processed dataset at /home/rob/.cache/huggingface/datasets/text/default-d64f335cc8a13d66/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-afe4d9c5e6fd685f.arrow
Loading cached processed dataset at /home/rob/.cache/huggingface/datasets/text/default-d64f335cc8a13d66/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-475e7db99f767f48.arrow
Loading cached processed dataset at /home/rob/.cache/huggingface/datasets/text/default-d64f335cc8a13d66/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-bc34352d2d344610.arrow
Loadin

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

> [0;32m<ipython-input-3-5c0895eb0785>[0m(19)[0;36mforward[0;34m()[0m
[0;32m     17 [0;31m    [0;32mdef[0m [0mforward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mx[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     18 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 19 [0;31m        [0mloss[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mmodel[0m[0;34m([0m[0mx[0m[0;34m[[0m[0;34m'input_ids'[0m[0;34m][0m[0;34m,[0m [0mattention_mask[0m[0;34m=[0m[0mx[0m[0;34m[[0m[0;34m'attention_mask'[0m[0;34m][0m[0;34m,[0m [0mlabels[0m[0;34m=[0m[0mx[0m[0;34m[[0m[0;34m'labels'[0m[0;34m][0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     20 [0;31m        [0;32mreturn[0m [0mloss[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     21 [0;31m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/home/rob/miniconda3/envs/kirby/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py[0m(969)[0;36mforward[0;34m()[0m
[0;32m    967 [0;31m            [0;31m# Shift so that tokens < n predict n[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    968 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 969 [0;31m            [0mshift_logits[0m [0;34m=[0m [0mlm_logits[0m[0;34m[[0m[0;34m...[0m[0;34m,[0m [0;34m:[0m[0;34m-[0m[0;36m1[0m[0;34m,[0m [0;34m:[0m[0;34m][0m[0;34m.[0m[0mcontiguous[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    970 [0;31m            [0mshift_labels[0m [0;34m=[0m [0mlabels[0m[0;34m[[0m[0;34m...[0m[0;34m,[0m [0;36m1[0m[0;34m:[0m[0;34m][0m[0;34m.[0m[0mcontiguous[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    971 [0;31m            

ipdb>  labels


tensor([[ 1622,   837,   484,  ...,   262, 13688,   284],
        [  484,  9851,  8471,  ...,   837,  8266,  5374],
        [  777, 17794,   764,  ...,   670,   837,   290],
        ...,
        [14939,   284,   787,  ...,   764,   679,   750],
        [  319, 10566,   326,  ...,   351,   262,  4931],
        [  717,  9551,   375,  ...,   262,  4401, 15781]], device='cuda:0')


ipdb>  lm_logits.shape


torch.Size([8, 128, 50257])


ipdb>  labels.shape


torch.Size([8, 128])


ipdb>  n


> [0;32m/home/rob/miniconda3/envs/kirby/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py[0m(970)[0;36mforward[0;34m()[0m
[0;32m    968 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    969 [0;31m            [0mshift_logits[0m [0;34m=[0m [0mlm_logits[0m[0;34m[[0m[0;34m...[0m[0;34m,[0m [0;34m:[0m[0;34m-[0m[0;36m1[0m[0;34m,[0m [0;34m:[0m[0;34m][0m[0;34m.[0m[0mcontiguous[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 970 [0;31m            [0mshift_labels[0m [0;34m=[0m [0mlabels[0m[0;34m[[0m[0;34m...[0m[0;34m,[0m [0;36m1[0m[0;34m:[0m[0;34m][0m[0;34m.[0m[0mcontiguous[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    971 [0;31m            [0;31m# Flatten the tokens[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    972 [0;31m            [0mloss_fct[0m 

ipdb>  n


> [0;32m/home/rob/miniconda3/envs/kirby/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py[0m(972)[0;36mforward[0;34m()[0m
[0;32m    970 [0;31m            [0mshift_labels[0m [0;34m=[0m [0mlabels[0m[0;34m[[0m[0;34m...[0m[0;34m,[0m [0;36m1[0m[0;34m:[0m[0;34m][0m[0;34m.[0m[0mcontiguous[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    971 [0;31m            [0;31m# Flatten the tokens[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 972 [0;31m            [0mloss_fct[0m [0;34m=[0m [0mCrossEntropyLoss[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    973 [0;31m            [0mloss[0m [0;34m=[0m [0mloss_fct[0m[0;34m([0m[0mshift_logits[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m,[0m [0mshift_logits[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m,[0m [0mshift_labels[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m

ipdb>  shift_labels.shape


torch.Size([8, 127])


ipdb>  shift_logits.shape


torch.Size([8, 127, 50257])


ipdb>  q


BdbQuit: 

In [None]:
basic_model.train_ds