In [1]:
from tqdm import tqdm

In [2]:
import csv
import numpy as np
import random

import pandas as pd
from datasets import load_dataset, load_metric, Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from ast import literal_eval

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import set_seed, DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers.utils import check_min_version


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.6.0.dev0")

def generate_examples(row):
    loss1, loss2 = row['loss1'], row['loss2']
    diff1, diff2 = row['diff1'], row['diff2']
    ctx1, ctx2 = row['ctx1'], row['ctx2']
    ocr1, ocr2 = ctx1[diff1[0]:diff1[1]], ctx2[diff2[0]:diff2[1]]
    ex1 = ' '.join(ctx1[:diff1[0]]) + ' <ocr> ' + ' '.join(ocr1) + ' </ocr> ' + ' '.join(ctx1[diff1[1]:])
    ex2 = ' '.join(ctx2[:diff2[0]]) + ' <ocr> ' + ' '.join(ocr2) + ' </ocr> ' + ' '.join(ctx2[diff2[1]:])
    correct = "<blank>"
    if loss1 < loss2:
        if ocr1:
            correct = ' '.join(ocr1)
        return ex2, correct
    else:
        if ocr2:
            correct = ' '.join(ocr2)
        return ex1, correct

def preprocess_function(examples):
    inputs = examples['orig']
    targets = examples['corrected']
    inputs = [inp for inp in inputs]
    model_inputs = tokenizer(inputs, padding=True, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, padding=True, truncation=True)

    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
    # padding in the loss.
    labels["input_ids"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
    ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [3]:
seed = 1729
set_seed(seed)
model_name = "t5-base"

print("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer.add_tokens(["<ocr>", "</ocr>", "<blank>"], special_tokens=True)        
tokenizer.add_special_tokens({"additional_special_tokens": ["<ocr>", "</ocr>", "<blank>"]})


print("Loading test")
num_samples = 2000
df = pd.read_csv('test.csv', converters={'ctx1': eval, 'ctx2': eval, 'diff1': eval, 'diff2': eval}, nrows=num_samples)
df[['orig','corrected']] = df.apply(generate_examples, axis=1, result_type="expand")
test_dataset = Dataset.from_pandas(df[['orig','corrected']])

test_dataset = test_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=20,
)

Loading tokenizer
Loading test


In [4]:
print("Loading model")
model = AutoModelForSeq2SeqLM.from_pretrained('ocr_correction_model')
model.resize_token_embeddings(len(tokenizer))
model.eval()

Loading model


T5ForConditionalGeneration(
  (shared): Embedding(32103, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32103, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dr

In [5]:
results = []
for x in tqdm(test_dataset):
    input_ids = torch.tensor([x['input_ids']])
    attention_mask = torch.tensor([x['attention_mask']])
    labels = torch.tensor([x['labels']])
    output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    logits = output.logits.detach().numpy()
    generated = np.argmax(logits[0],axis=1)
    outtoks = tokenizer.convert_ids_to_tokens(generated)
    result = ''
    try:
        end_idx = outtoks.index('</s>')
    except:
        end_idx = -1
    if end_idx != -1:
        result = tokenizer.convert_tokens_to_string(outtoks[:end_idx])
    results.append((x['orig'], x['corrected'], result))

100%|██████████| 2000/2000 [06:45<00:00,  4.93it/s]


In [6]:
df = pd.DataFrame(results, columns=['sent', 'truth', 'gen'])

In [7]:
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_rows', None)

In [8]:
df.to_csv('new_results.csv')