## Transformer for Generation from Wikipedia Tiltles

In [None]:
import pandas as pd
import re

In [None]:
df = pd.read_csv('wikipedia_data10K.csv')

In [None]:
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling, GPT2Tokenizer, GPT2LMHeadModel, pipeline, AutoTokenizer
import datasets
from datasets import load_dataset, list_datasets
from datasets import Dataset
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
train_df, val_df = train_test_split(df, test_size=0.3, random_state=42)
# Reduce the size of the datasets to 3000 samples each
train_df = train_df.sample(n=1000, random_state=42)
val_df = val_df.sample(n=200, random_state=42)

In [None]:
# Create the AutoTokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained('gpt2').cuda()
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Define the function to encode your data
def encode(batch):
    return tokenizer([x.strip('\n\r') for x in batch['Text']], truncation=True, padding=True)

# Load and preprocess the dataset
dataset = Dataset.from_pandas(train_df)
processed_dataset = dataset.map(encode, batched=True, batch_size=len(dataset))
processed_dataset.set_format('torch', columns=['input_ids', 'attention_mask'])

val_dataset = Dataset.from_pandas(val_df)
processed_val_dataset = val_dataset.map(encode, batched=True, batch_size=len(val_dataset))
processed_val_dataset.set_format('torch', columns=['input_ids', 'attention_mask'])

# Load and fine-tune the GPT-2 model

training_args = TrainingArguments(
    output_dir='/content/',
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=32,
    logging_steps=100,
    weight_decay=0.01,
    gradient_accumulation_steps=2,
    logging_dir='./logs',
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    data_collator=data_collator,
    train_dataset=processed_dataset,
    eval_dataset=processed_val_dataset,
)

trainer.train()
trainer.save_model('./trc')

  if _pandas_api.is_sparse(col):
Map: 100%|██████████| 1000/1000 [00:06<00:00, 162.10 examples/s]
Map: 100%|██████████| 200/200 [00:01<00:00, 186.77 examples/s]
100%|██████████| 62/62 [52:51<00:00, 51.15s/it] 


{'train_runtime': 3171.1588, 'train_samples_per_second': 0.315, 'train_steps_per_second': 0.02, 'train_loss': 3.3326201900359123, 'epoch': 0.99}


In [None]:
import torch
torch.cuda.memory_summary(device=None, abbreviated=False)



In [None]:
from transformers import pipeline
gpt2 = pipeline('text-generation', model='gpt2', device=0)
trc = pipeline('text-generation', model='trc', device=0)

print(gpt2('Virtual Box'))
print(trc('Virtual Box'))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': 'Virtual Box and Windows Server 2012 for Windows Server 2012 R2 (6-10) on x64 Windows Server 2012 (4-7) x64 Windows Server 2012 R2 (6-10) Enterprise Linux for Windows Server 2012 R2 on 4'}]
[{'generated_text': 'Virtual Box is the official virtual box for PlayStation Portable 2, designed by Yoshikazu Takazawa and announced for consoles on September 7, 2013. Its predecessor to the original PlayStation Portable was released on April 27, 2013, and was made available'}]


In [None]:
print(gpt2('Phoenix Wright: Ace Attorney - Spirit of Justice'))
print(trc('Phoenix Wright: Ace Attorney - Spirit of Justice'))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': "Phoenix Wright: Ace Attorney - Spirit of Justice - Star of Time Battleship Zero (2001)\n\n\nAce Attorney - The Ultimate Fighter (2002)\n\n\nKangaroo Court: The World's Greatest Trial (2003)\n\n\nFate"}]
[{'generated_text': 'Phoenix Wright: Ace Attorney - Spirit of Justice, 1995, as Airtel, $10. ISBN 978-0-7868-5543-5. © Ace Attorney Productions, 1997; Ace Attorney: Spirit of Justice, 1998; Ace'}]


In [None]:
print(gpt2('Ammonium sulfate precipitation'))
print(trc('Ammonium sulfate precipitation'))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': 'Ammonium sulfate precipitation or mineral precipitation (m-SO 4 -fluoride), m-OH 4 (r-OH 4 ), and m-CH 3 :H 3 O 4 are the primary drivers of the precipitation. We found that'}]
[{'generated_text': 'Ammonium sulfate precipitation in the North (NPS). Credit: NASA/JPL-Caltech, NASA, and the European Space Agency. The climate model in this study is adapted from a work published in 2007. The first draft of'}]


In [None]:
print(gpt2('Heartbeat (British TV series)'))
print(trc('Heartbeat (British TV series)'))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': 'Heartbeat (British TV series)\n\nIn this comedy - about a young boy who seeks help from his mentor, who has been kidnapped and rescued by the mysterious M.P.E.S. (Mutant Mobs) - the boys'}]
[{'generated_text': 'Heartbeat (British TV series) In the Season 3 episode "No Way Out" ("A Very Long Time"), Amy gets a call about a very long time ago - only to hear she isn\'t going to get the job. She calls back to'}]
