In [21]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config, TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers import pipeline
import optuna
from transformers import EarlyStoppingCallback
from sklearn.model_selection import train_test_split

In [None]:
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [2]:
TRAIN_PATH = "train_dataset.txt"
TEST_PATH = "test_dataset.txt"

In [26]:
train_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path=TRAIN_PATH,
    block_size=128
)

test_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path=TEST_PATH,
    block_size=128
)



In [27]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False # If set to False, the labels are the same as the inputs with the padding tokens ignored
)

In [29]:
training_args = TrainingArguments(
    output_dir="./output_nlp",
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    evaluation_strategy="steps",
    logging_dir="./logs",
    logging_strategy="steps",
    logging_steps=500,
    save_steps=10_000,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,
    gradient_accumulation_steps=1,
    fp16=True,
    fp16_backend="auto",
    fp16_full_eval=False,
    learning_rate=5e-5,
    weight_decay=0.01,
    adam_beta1=0.9,
    adam_beta2=0.999,
    adam_epsilon=1e-8,
    max_grad_norm=1.0,
    lr_scheduler_type="linear",
    warmup_steps=0,
    label_smoothing_factor=0.0,
    report_to=["tensorboard"],
    seed=42
)

In [30]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)


In [34]:
eval_results = trainer.evaluate()
print(eval_results)

  0%|          | 0/1234 [00:00<?, ?it/s]

{'eval_loss': 4.7682271003723145, 'eval_runtime': 38.4785, 'eval_samples_per_second': 256.455, 'eval_steps_per_second': 32.07}


In [31]:
def generate_story_text(prompt, model_path="./output"):
    generator = pipeline("text-generation", model=model_path, tokenizer="gpt2")
    story_text = generator(prompt, max_length=150, num_return_sequences=1, temperature=0.8, top_k=50, top_p=0.95)[0]["generated_text"]
    return story_text