In [1]:
import sys
sys.path.append("../")
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import numpy as np

import wandb
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint

from transformers import AutoTokenizer

from models.model import GPT2PreTrained

from datasets import load_dataset

from utils.preprocessor import XSumPreprocessor


## Data

In [2]:
dataset = load_dataset("xsum")
print(dataset['train'][0])

Found cached dataset xsum (/home/reza/.cache/huggingface/datasets/xsum/default/1.2.0/082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71)


  0%|          | 0/3 [00:00<?, ?it/s]



In [3]:
# pre processing
max_input_length = 512
max_target_length = 512
prefix = "summarize"
model_name = "gpt2"

tokenizer=AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

preprocessor = XSumPreprocessor(tokenizer=tokenizer,
                                max_input_length=max_input_length,
                                max_target_length=max_target_length,
                                prefix=prefix)
processor = preprocessor.preprocess
dataset = dataset.map(processor, batched=True)

Loading cached processed dataset at /home/reza/.cache/huggingface/datasets/xsum/default/1.2.0/082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71/cache-aee20efe26681b84.arrow
Loading cached processed dataset at /home/reza/.cache/huggingface/datasets/xsum/default/1.2.0/082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71/cache-2cbe630226a699e6.arrow
Loading cached processed dataset at /home/reza/.cache/huggingface/datasets/xsum/default/1.2.0/082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71/cache-c3c8932279faaf68.arrow


In [4]:

dataset.set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels'])
train_dataloader = DataLoader(dataset['train'], shuffle=True, batch_size=4)
valid_dataloader = DataLoader(dataset['validation'], batch_size=4)
test_dataloader = DataLoader(dataset['test'], batch_size=4)

In [5]:
run = wandb.init(project="GPT2Summarizer", entity="reza3qorbani" )

# Create a WandbLogger
wandb.finish()
wandb_logger = WandbLogger()
# for early stopping, see https://pytorch-lightning.readthedocs.io/en/1.0.0/early_stopping.html?highlight=early%20stopping
early_stop_callback = EarlyStopping(
    monitor='validation_loss',
    patience=3,
    strict=False,
    verbose=False,
    mode='min'
)

lr_monitor = LearningRateMonitor(logging_interval='step')

checkpoint_callback = ModelCheckpoint(dirpath='/saved/models/GPT2Summerizer', monitor='validation_loss', mode='min')


trainer = Trainer(accelerator="cpu", 
                  #default_root_dir="/content/drive/MyDrive/T5/Notebooks/Checkpoints", 
                  logger=wandb_logger, 
                  callbacks=[early_stop_callback, lr_monitor, checkpoint_callback])

learning_rate = 5e-5
max_epochs = 10

model = GPT2PreTrained(lr=learning_rate, max_epochs=max_epochs)
model.set_train_dataloader(train_dataloader)
model.set_valid_dataloader(valid_dataloader)
model.set_test_dataloader(test_dataloader)

trainer.fit(model)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mreza3qorbani[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.993981…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016761217483326617, max=1.0…

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name  | Type            | Params
------------------------------------------
0 | model | GPT2LMHeadModel | 124 M 
------------------------------------------
124 M     Trainable params
0         Non-trainable params
124 M     Total params
497.759   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

: 

: 