In [10]:
from transformers import HfArgumentParser, set_seed, T5ForConditionalGeneration, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
from dataclasses import dataclass, field
import nlpaug.augmenter.char as nac
from nltk.tokenize import RegexpTokenizer
from huggingface_hub import notebook_login
import torch 

In [68]:
@dataclass
class ModelArguments():
    model_name_or_path: str
    tokenizer_name: str

@dataclass
class DataTrainingArguments():
    max_len: int
   
@dataclass
class TrainingArguments_2():
    output_dir: str
    overwrite_output_dir: bool
    per_device_train_batch_size:int
    per_device_eval_batch_size:int
    gradient_accumulation_steps:int
    learning_rate: float
    warmup_steps: int
    logging_steps: int
    evaluation_strategy: str
    eval_steps: int
    num_train_epochs: int
    do_train: bool
    do_eval: bool
    fp16: bool
    max_steps: int
    seed: int
    report_to: str

In [69]:
args_dict = {
    "output_dir": './byt5-base-english-ocr-correction',
    "overwrite_output_dir": True,
    "per_device_train_batch_size": 32,
    "per_device_eval_batch_size": 32,
    "gradient_accumulation_steps":4,
    "logging_steps": 1000,
    "learning_rate": 5e-4,
    "warmup_steps": 250,
    "evaluation_strategy": "steps",
    "dataloader_num_workers": 4,
    "dataloader_pin_memory": True,
    "eval_steps": 1000,
    "num_train_epochs": 1,
    "do_train": True,
    "do_eval": True,
    "seed":123,
    "report_to":"wandb",
    "fp16":False,
    "run_name":"byt5-base-english-ocr-correction-2",
    "weight_decay":0.01,
    "evaluation_strategy":"steps",
    "disable_tqdm":False,
    "auto_find_batch_size":True
}

In [70]:
parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments_2))
model_args, data_args, training_args = parser.parse_dict(args_dict)
set_seed(training_args.seed)

In [5]:
def chunk_text(batch, tokenizer, length):

    chunks = []
    for text in batch: 
        text = tokenizer.tokenize(text)
        text = [word for word in text if len(word) > 1]
        chunks.extend([' '.join(text[i:i + length]) for i in range(0, len(text), length)])
    
    return {'text': chunks}

In [12]:
#dataset = load_dataset('csv', data_files={'train': ['data/nl_unshuffled_train_100_000.csv'], 'test': 'data/nl_unshuffled_test_10_000.csv'})
dataset = load_dataset("wikitext", "wikitext-103-v1")
dataset = dataset.filter(lambda x: x["text"] != '')
dataset = dataset.filter(lambda x: not x['text'].startswith(" = = "))
tokenizer = RegexpTokenizer(r'\w+')
dataset = dataset.map(
    lambda x: chunk_text(x['text'],
                         tokenizer,
                         128),
    batched=True,
    remove_columns=["text"])


def ocr_augment_chars(text, **kwargs):
    aug = nac.OcrAug(**kwargs)

    augmented_data = aug.augment(text)
    return augmented_data

# Augmenting the dataset with common OCR errors
dataset = dataset.map(lambda x: {'ocr_text' : ocr_augment_chars(x['text'], aug_char_p =0.4, aug_word_p = 0.6), 'text': x['text']}, batched=True,  remove_columns=["text"])



Reusing dataset wikitext (/home/studio-lab-user/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
100%|██████████| 3/3 [00:00<00:00, 545.23it/s]
100%|██████████| 5/5 [00:00<00:00, 227.27ba/s]
Loading cached processed dataset at /home/studio-lab-user/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-73dab13e0d557913.arrow
100%|██████████| 4/4 [00:00<00:00, 221.40ba/s]
100%|██████████| 3/3 [00:00<00:00, 67.72ba/s]
Loading cached processed dataset at /home/studio-lab-user/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-64a5410963e22a82.arrow
100%|██████████| 3/3 [00:00<00:00, 126.48ba/s]
100%|██████████| 3/3 [00:00<00:00, 24.99ba/s]
100%|██████████| 890/890 [00:37<00:00, 23.46ba/s]
100%|██████████| 2/2 [00:00<00:00, 25.92ba/s]
100%|██████████| 3/3 [00:01<0

In [11]:
dataset[2]

{'text': 'The game began development in 2010 carrying over large portion of the work done on Valkyria Chronicles II While it retained the standard features of the series it also underwent multiple adjustments such as making the game more forgiving for series newcomers Character designer unk Honjou and composer Hitoshi Sakimoto both returned from previous entries along with Valkyria Chronicles II director Takeshi Ozawa large team of writers handled the script The game opening theme was sung by May',
 'ocr_text': 'The game began development in 2010 carrying over large portion 0f the work done on Valkyria Chronicles II While it retained the standard featoke8 of the series it also underwent multiple adjustments such a8 making the game more forgiving for series newcomers Ghakactek de8i9nek unk Honjou and composer Hitoshi Sakimoto 6uth returned from previous entries along with Valkyria Chronicles II director Takeshi D2awa large team 0f writers handled the script The game opening theme was su

In [15]:
train_dataset = dataset['train']
valid_dataset = dataset['test']

print(len(train_dataset))
print(len(valid_dataset))

1142774
2838


In [23]:
# Load pretrained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
    max_length=data_args.max_len
)
model = T5ForConditionalGeneration.from_pretrained(
    model_args.model_name_or_path
)

# overwriting the default max_length of 20 
tokenizer.model_max_length=128
model.config.max_length=128

Downloading: 100%|██████████| 2.53k/2.53k [00:00<00:00, 728kB/s]
Downloading: 100%|██████████| 698/698 [00:00<00:00, 327kB/s]
Downloading: 100%|██████████| 2.44k/2.44k [00:00<00:00, 873kB/s]
Downloading: 100%|██████████| 1.12G/1.12G [00:13<00:00, 89.7MB/s]


In [24]:
def convert_to_features(batch, max_len):
    input_encodings = tokenizer.batch_encode_plus(batch['ocr_text'],
                                                  truncation=True,
                                                  padding='max_length',
                                                  max_length=max_len
                                                  )
    target_encodings = tokenizer.batch_encode_plus(batch['text'],
                                                   truncation=True,
                                                   padding='max_length',
                                                   max_length=max_len)

    encodings = {
        'input_ids': input_encodings['input_ids'],
        'attention_mask': input_encodings['attention_mask'],
        'target_ids': target_encodings['input_ids'],
        'target_attention_mask': target_encodings['attention_mask']
    }

    return encodings

In [26]:
def prep_dataset(tokenizer, dataset, max_len):
    dataset = dataset.map(convert_to_features, fn_kwargs={"max_len":max_len}, batched=True)
    # Set the tensor type and the columns which the dataset should return
    columns = ['input_ids', 'target_ids',
               'attention_mask', 'target_attention_mask']
    dataset.with_format(type='torch', columns=columns)
    # Rename columns to the names that the forward method of the selected
    # model expects
    dataset = dataset.rename_column('target_ids', 'labels')
    dataset = dataset.rename_column('target_attention_mask', 'decoder_attention_mask')
    dataset = dataset.remove_columns(['text', 'ocr_text'])
    return dataset

In [27]:
train_dataset = prep_dataset(tokenizer, train_dataset, data_args.max_len)
valid_dataset = prep_dataset(tokenizer, valid_dataset, data_args.max_len)

100%|██████████| 1143/1143 [16:15<00:00,  1.17ba/s]
100%|██████████| 3/3 [00:02<00:00,  1.22ba/s]


In [29]:
train_dataset.save_to_disk(".")
valid_dataset.save_to_disk(".")

In [71]:
training_args = TrainingArguments(**vars(training_args))
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
)

max_steps is given, it will override any value given in num_train_epochs


In [None]:
trainer.train(
    model_path=model_args.model_name_or_path if os.path.isdir(
    model_args.model_name_or_path) else None
)

In [None]:
trainer.save_model()
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =)
tokenizer.save_pretrained(training_args.output_dir)

In [59]:
vars(training_args)

{'output_dir': './byt5-base-dutch-ocr-correction',
 'overwrite_output_dir': True,
 'per_device_train_batch_size': 2,
 'per_device_eval_batch_size': 2,
 'gradient_accumulation_steps': 4,
 'learning_rate': 0.0005,
 'warmup_steps': 250,
 'logging_steps': 100,
 'evaluation_strategy': 'steps',
 'eval_steps': 250,
 'num_train_epochs': 4,
 'do_train': True,
 'do_eval': True,
 'fp16': False,
 'use_cache': False,
 'max_steps': 5000,
 'seed': 123}

In [4]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [18]:
model = T5ForConditionalGeneration.from_pretrained('yelpfeast/byt5-base-english-ocr-correction', use_auth_token=True)
tokenizer = AutoTokenizer.from_pretrained("yelpfeast/byt5-base-english-ocr-correction", use_auth_token=True)

Downloading: 100%|██████████| 2.82k/2.82k [00:00<00:00, 1.90MB/s]
Downloading: 100%|██████████| 2.44k/2.44k [00:00<00:00, 1.72MB/s]


In [19]:
aug = nac.OcrAug(aug_char_p =0.4, aug_word_p = 0.6)
corrected_text = "Life is like a box of chocolates"
augmented_text = aug.augment(corrected_text)
print(augmented_text)

In [21]:
input_ids = torch.tensor([list(augmented_text.encode("utf-8"))]) + 3  # add 3 for special tokens
labels = torch.tensor([list(corrected_text.encode("utf-8"))]) + 3  # add 3 for special tokens

inputs = tokenizer(augmented_text, return_tensors="pt", padding=True)

loss = model(input_ids, labels=labels).loss # forward pass

output_sequences = model.generate(

    input_ids=inputs["input_ids"],

    attention_mask=inputs["attention_mask"],

    do_sample=False,  # disable sampling to test if batching affects output

)

print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))

['Life is like a box of chocolates']


In [23]:
from transformers import T5ForConditionalGeneration
import torch
import nlpaug.augmenter.char as nac

aug = nac.OcrAug(aug_char_p =0.4, aug_word_p = 0.6)
corrected_text = "Life is like a box of chocolates"
augmented_text = aug.augment(corrected_text)

model = T5ForConditionalGeneration.from_pretrained('yelpfeast/byt5-base-english-ocr-correction')

input_ids = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + 3  # add 3 for special tokens
labels = torch.tensor([list("La vie est comme une boîte de chocolat.".encode("utf-8"))]) + 3  # add 3 for special tokens

loss = model(input_ids, labels=labels).loss # forward pass

In [24]:
from transformers import T5ForConditionalGeneration, AutoTokenizer
import nlpaug.augmenter.char as nac

aug = nac.OcrAug(aug_char_p =0.4, aug_word_p = 0.6)
corrected_text = "Life is like a box of chocolates"
augmented_text = aug.augment(corrected_text)


model = T5ForConditionalGeneration.from_pretrained('yelpfeast/byt5-base-english-ocr-correction', use_auth_token=True)
tokenizer = AutoTokenizer.from_pretrained("yelpfeast/byt5-base-english-ocr-correction", use_auth_token=True)

inputs = tokenizer(augmented_text, return_tensors="pt", padding=True)

output_sequences = model.generate(

    input_ids=inputs["input_ids"],

    attention_mask=inputs["attention_mask"],

    do_sample=False,  # disable sampling to test if batching affects output

)

print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))

Life i8 like a 6ux uf chocolates
['Life is like a box of chocolates']
