In [1]:
import random
import numpy as np
from sklearn.model_selection import train_test_split
import datasets
import pandas as pd
from types import SimpleNamespace
from tqdm import tqdm
import nltk

from GrammarCorrector.utils import GrammarDataset

from transformers import (
    AdamW,
    T5Tokenizer,
    T5ForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    get_linear_schedule_with_warmup
)

from torch.utils.data import DataLoader
import torch

from datasets import Dataset, load_metric

In [2]:
# set params
params = SimpleNamespace()
params.max_len = 121
params.model_name = 't5-base'
params.batch_size = 32
params.num_epochs = 3

In [3]:
# set seed
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)

set_seed(42)

pd.set_option('display.max_colwidth', None)
tqdm.pandas()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Data prepare

In [4]:
df = pd.read_csv('./DATA/c4_200m_550k.csv')
df

Unnamed: 0,input,output
0,The steps below describe how to remove data for one or more specifies areas and how to put on the data from a snapshot to the index,The steps below describe how to remove data for one ore more specific areas and how to put back the data from a snapshot to the index.
1,When I wake up it\'s usually comes out dreamsI\'m thinking so my thoughts are very weird.,When I wake up it\'s usually dreams I\'m thinking about so my thoughts are very weird.
2,One of the cardinal factors to be considered trying to decide on which kind of shipping to customer settle is the! market difference.,One of the cardinal factors to consider when trying to decide on which kind of shipping to settle for is the market difference.
3,Answers Â» Regions Â» Is in Nagorno-Karabakt region that part in Armenia?,Answers Â» Regions Â» Is Nagorno-Karabakh region part of Armenia?
4,Flaneuring in fun at maple creek SK!,Flaneuring Fun in Maple Creek SK!
...,...,...
549995,"Despite before the counter-offensive launch, Kunduz swarmed with Taliban fighters racing stolen police vehicles and vans of Red Cross.","Despite the launch of the counter-offensive, Kunduz swarmed with Taliban fighters racing stolen police vehicles and Red Cross vans."
549996,"A spokesman said; "" Bad weather on its way today, so anyone on the roads be mindful of changing conditions.""","A spokesman said: ""Bad weather on its way today, so anyone on the roads be mindful of changing conditions."""
549997,2) Click on Get to Site Administration here.,2) Click on Go to Site Administration.
549998,"Habits/Hobbies likes to make friends, colects gems and shiny treasures.","Habits/Hobbies: Likes to make friends, Collects Gems and Shiny Treasures."


In [5]:
tokenizer = T5Tokenizer.from_pretrained(params.model_name)
model = T5ForConditionalGeneration.from_pretrained(params.model_name)

In [6]:
# get max len of tokenized sentences
def calc_token_len(example):
    return len(tokenizer(example).input_ids)

df['input_token_len'] = df['input'].progress_apply(calc_token_len)
params.max_len = df['input_token_len'].describe(percentiles=[0.99])['99%']
df.drop('input_token_len', axis=1, inplace=True)
print(params.max_len)

  1%|          | 3173/550000 [00:00<00:51, 10715.10it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1117 > 512). Running this sequence through the model will result in indexing errors
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 550000/550000 [00:51<00:00, 10595.65it/s]

121.0





In [6]:
# split to test and train, go to datasets
train_df, test_df = train_test_split(df, test_size=0.1, shuffle=True)
train_dataset = Dataset.from_pandas(train_df)
test_dataset = Dataset.from_pandas(test_df)

In [7]:
dataset = GrammarDataset(test_dataset, tokenizer, params.max_len, True)
print(dataset[121])

input_ids 20
attention_mask 20
labels 24
{'input_ids': [71, 973, 24, 14079, 24067, 38, 96, 77, 221, 3728, 121, 19, 59, 2930, 7509, 640, 569, 2287, 5, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [71, 973, 24, 14079, 3, 89, 12578, 887, 21, 96, 77, 221, 3728, 121, 3270, 19, 59, 2930, 7509, 640, 569, 2287, 5, 1]}


# Training

In [8]:
rouge_metric = load_metric("rouge")
nltk.download('punkt')
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model,
                                       padding='longest', return_tensors='pt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\bitzh\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [11]:
args = Seq2SeqTrainingArguments(output_dir="DATA/weights",
                        evaluation_strategy="epoch",
                        save_strategy="epoch",
                        per_device_train_batch_size=params.batch_size,
                        per_device_eval_batch_size=params.batch_size,
                        learning_rate=2e-5,
                        num_train_epochs=params.num_epochs,
                        weight_decay=0.01,
                        save_total_limit=2,
                        predict_with_generate=True,
                        fp16 = True,
                        gradient_accumulation_steps = 6,
#                         eval_steps = 500,
#                         save_steps = 500,
                        load_best_model_at_end=True,
                        logging_dir="/logs",
                        report_to="wandb")

# metric from the arcticle
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    return {k: round(v, 4) for k, v in result.items()}

# defining trainer using ðŸ¤—
trainer = Seq2SeqTrainer(model=model, 
                args=args, 
                train_dataset= GrammarDataset(train_dataset, tokenizer),
                eval_dataset=GrammarDataset(test_dataset, tokenizer),
                tokenizer=tokenizer,
                data_collator=data_collator,
                compute_metrics=compute_metrics)

In [13]:
trainer.train()

***** Running training *****
  Num examples = 495000
  Num Epochs = 3
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 192
  Gradient Accumulation steps = 6
  Total optimization steps = 7734
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: Currently logged in as: [33mresquilleur[0m (use `wandb login --relogin` to force relogin)


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
0,0.6123,0.558183,72.1789,62.3658,71.4764,71.5027,17.2456
1,0.5926,0.542531,72.3457,62.6708,71.6494,71.678,17.2292
2,0.5832,0.539419,72.3931,62.7524,71.7022,71.7297,17.2256


***** Running Evaluation *****
  Num examples = 55000
  Batch size = 32
Saving model checkpoint to DATA/weights\checkpoint-2578
Configuration saved in DATA/weights\checkpoint-2578\config.json
Model weights saved in DATA/weights\checkpoint-2578\pytorch_model.bin
tokenizer config file saved in DATA/weights\checkpoint-2578\tokenizer_config.json
Special tokens file saved in DATA/weights\checkpoint-2578\special_tokens_map.json
Copy vocab file to DATA/weights\checkpoint-2578\spiece.model
***** Running Evaluation *****
  Num examples = 55000
  Batch size = 32
Saving model checkpoint to DATA/weights\checkpoint-5156
Configuration saved in DATA/weights\checkpoint-5156\config.json
Model weights saved in DATA/weights\checkpoint-5156\pytorch_model.bin
tokenizer config file saved in DATA/weights\checkpoint-5156\tokenizer_config.json
Special tokens file saved in DATA/weights\checkpoint-5156\special_tokens_map.json
Copy vocab file to DATA/weights\checkpoint-5156\spiece.model
***** Running Evaluation *

TrainOutput(global_step=7734, training_loss=0.6141996026069951, metrics={'train_runtime': 14393.5347, 'train_samples_per_second': 103.171, 'train_steps_per_second': 0.537, 'total_flos': 1.12537321168896e+17, 'train_loss': 0.6141996026069951, 'epoch': 3.0})

In [14]:
trainer.save_model('t5_gec_model')

Saving model checkpoint to t5_gec_model
Configuration saved in t5_gec_model\config.json
Model weights saved in t5_gec_model\pytorch_model.bin
tokenizer config file saved in t5_gec_model\tokenizer_config.json
Special tokens file saved in t5_gec_model\special_tokens_map.json
Copy vocab file to t5_gec_model\spiece.model


# Test results

In [19]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = T5Tokenizer.from_pretrained(params.model_name)
model = T5ForConditionalGeneration.from_pretrained('t5_gec_model').to(device).eval()

def correct_grammar(input_text,num_return_sequences):
  batch = tokenizer([input_text],truncation=True,padding='max_length',max_length=64, return_tensors="pt").to(torch_device)
  translated = model.generate(**batch,max_length=64,num_beams=4, num_return_sequences=num_return_sequences, temperature=1.5)
  tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
  return tgt_text

In [21]:
text = 'He are pl moving here.'
print(correct_grammar(text, num_return_sequences=2))

['He is moving here.', 'He is pls moving here.']


In [23]:
text = 'Cat are not drinked milk'
print(correct_grammar(text, num_return_sequences=3))

['Cats do not drink milk.', 'Cats are not drinking milk.', 'Cats are not drink milk.']
