In [None]:
from google.colab import drive
drive.mount('/content/drive/')

In [None]:
%%capture
%pip install torch torchvision
%pip install accelerate
%pip install transformers
%pip install transformers[torch]
%pip install datasets
%pip install errant

In [None]:
from datasets import load_dataset
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import spacy
import errant

In [None]:
DEV_PATH = '/content/drive/MyDrive/CS4248/dev.json'
DEV_M2_PATH = 'dev.m2'
OUT_M2_PATH = 'out.m2'

MODEL_PATH = '/content/results_t5_small/checkpoint-3500'
T5_MODEL = 't5-small'

TASK_PREFIX = 'rectify'
TOKENIZER_PADDING = 'max_length'
SOURCE_MAX_LENGTH = 512
GEN_MAX_LENGTH = 512
GEN_NUM_BEAMS = 5

In [None]:
dataset_test = load_dataset('json', data_files=DEV_PATH, split='train')

In [None]:
# Model to be tested
model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH)
t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL)

In [None]:
nlp = spacy.load("en_core_web_sm")
spacy_tokenizer = nlp.tokenizer
annotator = errant.load('en', nlp)

In [None]:
def generate_correction(model, tokenizer, sample):
    input_text = f"{TASK_PREFIX}: {sample['original']}"
    inputs = t5_tokenizer.encode(
        input_text,
        max_length=SOURCE_MAX_LENGTH,
        padding=TOKENIZER_PADDING,
        truncation=True,
        return_tensors='pt',
    )
    corrected_ids = model.generate(
        inputs,
        max_length=GEN_MAX_LENGTH,
        num_beams=GEN_NUM_BEAMS,
        early_stopping=True,
    )
    corrected_sentence = tokenizer.decode(
        corrected_ids[0],
        skip_special_tokens=True,
    )
    # Retokenize sentence using spacy to restore correct spacing between tokens
    # for accurate error correction score calculation
    corrected_sentence = ' '.join(tok.text for tok in spacy_tokenizer(corrected_sentence))
    return corrected_sentence

In [None]:
NOOP_EDIT = 'A -1 -1|||noop|||-NONE-|||REQUIRED|||-NONE-|||0'

# Can use later for analysing performance for each type of error
output_edit_types = []

with open(OUT_M2_PATH, 'w') as f:
    for sample in dataset_test:
        orig = sample['original']
        corrected = generate_correction(model, t5_tokenizer, sample)
        edits = annotator.annotate(annotator.parse(orig), annotator.parse(corrected))
        output_edit_types.append([edit.type for edit in edits])
        print('S', orig, file=f)
        if not edits:
            print(NOOP_EDIT, file=f)
        for edit in edits:
            print(edit.to_m2(), file=f)
        print(file=f)  # Blank divider line

In [None]:
# Compare output edits with gold-standard edits and computes statistics
!errant_compare -hyp {OUT_M2_PATH} -ref {DEV_M2_PATH}