In [1]:
from importlib import reload
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

from transformers import (
    AutoConfig, AutoModelForCausalLM, AutoTokenizer
)

pl.seed_everything(42)

Global seed set to 42


42

## tokenizer and dataset

In [2]:
name = "HooshvareLab/gpt2-fa"
tokenizer = AutoTokenizer.from_pretrained(name)

tokenizer.add_special_tokens({
    "bos_token": '</s>',
    "eos_token": '</s>', 
    "pad_token": '<pad>',
    "unk_token": '<unk>',
})

0

In [3]:
import src.data
reload(src.data)
from src.data import PoemDataset, get_dataloaders

dataset = PoemDataset(tokenizer, 'data/filtered_poems.json', window=512)
len(dataset)

5252824

In [4]:
tl, vl = get_dataloaders(dataset, val_frac=0.2, batch_size=12)

train dataset has 4202260 samples and val dataset has 1050564 samples


In [5]:
b = next(iter(tl))
for k in b:
    print(k, b[k].shape)

input_ids torch.Size([12, 190])
attention_mask torch.Size([12, 190])


## model

In [6]:
import src.model
reload(src.model)
from src.model import PoetFormer

name = 'GPT2-fa-ganjoor-conditional'
print('model name:',name)

model = PoetFormer(pretrained_name="HooshvareLab/gpt2-fa")
model.load_pretrained()

model name: GPT2-fa-ganjoor-conditional


In [7]:
model.count_parameters()

118099200

## train

In [8]:
logger = TensorBoardLogger(save_dir='logs/', name=name)
lr_logger = LearningRateMonitor(logging_interval='step')
checkpoint = ModelCheckpoint(
    dirpath=f'weights/{name}/', 
    filename='{epoch}-{val_loss:.2f}', 
    monitor='val_loss',
    mode='min',
    save_top_k=1, 
    every_n_epochs=1
)

trainer = pl.Trainer(
    benchmark=True, 
    gpus=1, 
    accumulate_grad_batches=8,
    logger=logger, 
    max_epochs=1,
    callbacks=[checkpoint, lr_logger]
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(model, tl, vl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type            | Params
------------------------------------------
0 | model | GPT2LMHeadModel | 118 M 
------------------------------------------
118 M     Trainable params
0         Non-trainable params
118 M     Total params
472.397   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

In [None]:
trainer.save_checkpoint(f'weights/{name}/last.ckpt')

## generate

In [None]:
from src.model import PoetFormer

name = 'GPT2-fa-ganjoor-conditional'
print('model name:',name)

# model = PoetFormer(pretrained="HooshvareLab/gpt2-fa")
model = PoetFormer.load_from_checkpoint(f'weights/{name}/last.ckpt', pretrained="HooshvareLab/gpt2-fa")

In [None]:
res = model.generate(prompt='', num_return_sequences=1, max_length=128, n_beam=1)
for r in res:
    print(r)