# Carrot GPT2 Normalization
Train chat GPT2 Model to normalize data.

In [None]:
# !apt-get update
# !apt-get install -y nvidia-driver-470
# !pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

In [None]:
!pip install datasets
!pip install accelerate -U
!pip install transformers[torch]
!pip install --upgrade transformers accelerate
!pip install werpy

In [2]:
import os
import ast
import werpy
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import DataCollatorForSeq2Seq

In [27]:
dataRAW = load_dataset("sellersew/carrot-engine-normalization-translation-v2")

In [28]:
data = dataRAW["train"].train_test_split(test_size=0.005, train_size=0.005, shuffle=True)

In [29]:
data

DatasetDict({
    train: Dataset({
        features: ['input', 'output'],
        num_rows: 15285
    })
    test: Dataset({
        features: ['input', 'output'],
        num_rows: 15286
    })
})

In [30]:
tokenizer = AutoTokenizer.from_pretrained("t5-small")
prefix    = "normalize the following text from its written form into its verbalized form: "

def preprocess_function(dataset):
    inputs       = [prefix + entry for entry in dataset["input"]]
    targets      = [entry for entry in dataset["output"]]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_data = data.map(preprocess_function, batched=True)

In [32]:
model         = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [33]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=10,
    fp16=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

Epoch,Training Loss,Validation Loss
1,0.9763,0.293482
2,0.3483,0.193091
3,0.2584,0.14648
4,0.2125,0.122463
5,0.1806,0.10839
6,0.1634,0.099465
7,0.151,0.094226
8,0.1378,0.091075
9,0.1383,0.089508
10,0.1309,0.088815


Checkpoint destination directory ./results/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.


TrainOutput(global_step=9560, training_loss=0.24144028110982982, metrics={'train_runtime': 1211.7847, 'train_samples_per_second': 126.136, 'train_steps_per_second': 7.889, 'total_flos': 1710252489768960.0, 'train_loss': 0.24144028110982982, 'epoch': 10.0})

In [34]:
model.save_pretrained("./normalize-model-10-epoch")
model = AutoModelForSeq2SeqLM.from_pretrained("./normalize-model-10-epoch")

## Benchmark
We will now benchmark the dataset on our test dataset.

In [35]:
testing_data = load_dataset("sellersew/carrot-engine-normalization-translation-v2", data_files="native-tests.csv")

In [36]:
def evalulate(text):
    prompt  = prefix + text
    inputs  = tokenizer.encode(prompt, return_tensors="pt")
    inputs  = inputs.to(model.device)
    outputs = model.generate(inputs, max_length=128)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
import time
from tqdm import tqdm

start_time = time.time()
passing    = 0
failing    = 0
WER        = []

for entry in tqdm(testing_data["train"], desc='Processing', unit='iteration'):
    input  = entry["input"]
    output = evalulate(input)
    WER.append(werpy.wer(input, output))
    if input == output:
        passing += 1
    else:
        failing += 1

time_taken = time.time() - start_time
average    = sum(WER) / len(WER)
print("Passing: " + str(passing))
print("Failing: " + str(failing))
print("WER    : " + str(average))
print(f'Time : {time_taken:.2f} seconds')

Processing:  37%|███▋      | 335/906 [04:04<06:41,  1.42iteration/s]