In [None]:
# default_exp basic_model

In [None]:
#export
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from transformers import GPT2Config, GPT2LMHeadModel
from kirby.run_params import RunParams
from kirby.data_manager import DataManager

# Basic Model

> A basic Pytorch lighting module

In [None]:
#export
class BasicModel(pl.LightningModule):
    def __init__(self, run_params):
        super().__init__()
        self.run_params = run_params
        config = GPT2Config()
        self.model = GPT2LMHeadModel(config)
        self.loss = torch.nn.CrossEntropyLoss(reduction='none')
        pass

    def prepare_data(self):
        data_manager = DataManager(self.run_params)
        self.train_ds, self.val_ds = data_manager.prepare_data()


    def forward(self, x):
        loss = self.model(x['input_ids'], attention_mask=x['attention_mask'], labels=x['input_ids'])[0]
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.forward(batch)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.forward(batch)
        self.log('val_loss', loss)
        return loss

    def validation_epoch_end(self, losses):
        loss = torch.cat([loss.unsqueeze(0) for loss in losses], 0).mean()
        self.log('val_loss', loss)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
                self.train_ds,
                batch_size=self.run_params.batch_size,
                drop_last=True,
                shuffle=True,
                num_workers=self.run_params.num_workers,
                pin_memory=True
                ) 

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
                self.val_ds,
                batch_size=self.run_params.batch_size,
                drop_last=True,
                shuffle=False,
                num_workers=self.run_params.num_workers,
                pin_memory=True
                )

    def configure_optimizers(self):
        return torch.optim.SGD(
            self.parameters(),
            lr=self.run_params.lr,
            momentum=self.run_params.momentum,
        ) 

# Testing

In [None]:
#export
# Creation
basic_model = BasicModel(RunParams())
assert isinstance(basic_model, BasicModel)

In [None]:
#export
run_params = RunParams(debug=True)
basic_model = BasicModel(run_params)
trainer = pl.Trainer(
    default_root_dir='logs',
    gpus=(1 if torch.cuda.is_available() else 0),
    max_epochs=run_params.max_epochs,
    fast_dev_run=run_params.debug,
    logger=TensorBoardLogger(save_dir='logs/', name=run_params.run_name),
)

trainer.fit(basic_model)

GPU available: False, used: False
TPU available: None, using: 0 TPU cores
Running in fast_dev_run mode: will run a full train, val and test loop using 1 batch(es).
Using custom data configuration default
Reusing dataset text (/root/.cache/huggingface/datasets/text/default-04bff418a63932f2/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab)
Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-04bff418a63932f2/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab/cache-ca424aee03967504.arrow
Using custom data configuration default
Reusing dataset text (/root/.cache/huggingface/datasets/text/default-04bff418a63932f2/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab)
Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-04bff418a63932f2/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab/cache-424bcf23a14923ca.arrow

  | Name  | Type             | Params
-

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

  return torch.tensor(x, **format_kwargs)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

  return torch.tensor(x, **format_kwargs)





1

In [None]:
basic_model.train_ds

Dataset({
    features: ['attention_mask', 'input_ids', 'text'],
    num_rows: 8
})