In [1]:
!pip install -q transformers datasets evaluate sacrebleu jiwer pandas torch

import random
import string
import numpy as np
import pandas as pd
import torch
from transformers import (
    BartTokenizerFast,
    BartForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)
import evaluate
from sklearn.model_selection import train_test_split

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m105.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m64.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [7]:
from transformers import T5TokenizerFast
from transformers import T5ForConditionalGeneration
from transformers import Seq2SeqTrainingArguments
from datasets import Dataset, DatasetDict

#Cell 2: Load Dataset Files

In [3]:
def load_sentences(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return [line.strip().split('\t')[0] for line in f if line.strip()]

train_sentences = load_sentences('tune.tsv')[:5000]
val_sentences = load_sentences('validation.tsv')[:1000]
test_sentences = load_sentences('test.tsv')[:1000]

print(f"Loaded {len(train_sentences)} training sentences")
print(f"Loaded {len(val_sentences)} validation sentences")
print(f"Loaded {len(test_sentences)} test sentences")

Loaded 5000 training sentences
Loaded 1000 validation sentences
Loaded 1000 test sentences


#Cell 3: Error Generation Functions


In [4]:
def introduce_errors(sentence, min_errors=3, max_errors=5):
    words = sentence.split()
    if not words:
        return sentence

    current_words = list(words)
    num_errors = random.randint(min_errors, max_errors)

    for _ in range(num_errors):
        if not current_words:
            break

        word_idx = random.randint(0, len(current_words) - 1)
        word = current_words[word_idx]
        if len(word) < 1:
            continue

        op = random.choice(["delete", "insert", "substitute", "transpose", "duplicate"])

        if op == "delete" and len(word) > 0:
            pos = random.randint(0, len(word)-1)
            word = word[:pos] + word[pos+1:]
        elif op == "insert":
            pos = random.randint(0, len(word))
            word = word[:pos] + random.choice(string.ascii_lowercase) + word[pos:]
        elif op == "substitute":
            pos = random.randint(0, len(word)-1)
            word = word[:pos] + random.choice(string.ascii_lowercase) + word[pos+1:]
        elif op == "transpose" and len(word) > 1:
            pos = random.randint(0, len(word)-2)
            word = word[:pos] + word[pos+1] + word[pos] + word[pos+2:]
        elif op == "duplicate":
            current_words.insert(word_idx + 1, word)
            continue

        current_words[word_idx] = word

    return ' '.join(current_words)

def generate_pairs(sentences, versions=2):
    pairs = []
    for sent in sentences:
        for _ in range(versions):
            corrupted = introduce_errors(sent)
            if corrupted != sent:
                pairs.append({'input_text': corrupted, 'target_text': sent})
    return pairs

train_pairs = generate_pairs(train_sentences)
val_pairs = generate_pairs(val_sentences, versions=1)
test_pairs = generate_pairs(test_sentences, versions=1)

print(f"Generated {len(train_pairs)} training pairs")
print(f"Generated {len(val_pairs)} validation pairs")
print(f"Generated {len(test_pairs)} test pairs")

Generated 10000 training pairs
Generated 1000 validation pairs
Generated 1000 test pairs


#Cell 4: Create DatasetDict

In [8]:
datasets = DatasetDict({
    'train': Dataset.from_list(train_pairs),
    'validation': Dataset.from_list(val_pairs),
    'test': Dataset.from_list(test_pairs)
})

print(datasets)

DatasetDict({
    train: Dataset({
        features: ['input_text', 'target_text'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['input_text', 'target_text'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['input_text', 'target_text'],
        num_rows: 1000
    })
})


#Cell 5: Tokenization and Data Preparation


In [9]:
MODEL_NAME = "facebook/bart-base"
tokenizer = BartTokenizerFast.from_pretrained(MODEL_NAME)

def preprocess_function(examples):
    inputs = tokenizer(examples["input_text"], max_length=128, truncation=True, padding="max_length")
    labels = tokenizer(text_target=examples["target_text"], max_length=128, truncation=True, padding="max_length")
    inputs["labels"] = labels["input_ids"]
    return inputs

tokenized_datasets = datasets.map(preprocess_function, batched=True)
print(tokenized_datasets)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_text', 'target_text', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['input_text', 'target_text', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['input_text', 'target_text', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 1000
    })
})


#Cell 6: Initialize Model and Training Setup


In [10]:
model = BartForConditionalGeneration.from_pretrained(MODEL_NAME)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

wer_metric = evaluate.load("wer")
bleu_metric = evaluate.load("sacrebleu")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    wer = wer_metric.compute(predictions=decoded_preds, references=decoded_labels)
    bleu = bleu_metric.compute(predictions=decoded_preds, references=[[label] for label in decoded_labels])

    return {"wer": wer, "bleu": bleu["score"]}

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/8.15k [00:00<?, ?B/s]

#Cell 7: Training Configuration


In [11]:
training_args = Seq2SeqTrainingArguments(
    output_dir="spelling_correction_model",
    eval_strategy="steps",
    eval_steps=500,
    save_steps=500,
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    report_to="none"
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

  trainer = Seq2SeqTrainer(


#Cell 8: Training and Evaluation


In [12]:
trainer.train()

test_results = trainer.evaluate(
    eval_dataset=tokenized_datasets["test"],
    metric_key_prefix="test"
)

print("\nTest Results:")
print(f"WER: {test_results['test_wer']:.4f}")
print(f"BLEU: {test_results['test_bleu']:.4f}")

trainer.save_model("final_spelling_correction_model")

Step,Training Loss,Validation Loss,Wer,Bleu
500,0.6048,0.049643,0.561572,27.841936
1000,0.0462,0.04354,0.558474,28.244433
1500,0.0317,0.043014,0.557825,28.379416


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].



Test Results:
WER: 0.5567
BLEU: 28.6030


#Cell 9: Example Corrections






In [13]:
print("Displaying predictions using `trainer.model`...\n")
print("--- Example Predictions on Test Set (first 10 examples) ---")
print("------------------------------------------------------------")

for i in range(10):
    example = datasets["test"][i]
    inputs = tokenizer(example["input_text"], return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_length=128)
    prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)

    print(f"Example {i+1}:")
    print(f"  Input (Misspelled)   : {repr(example['input_text'])}")
    print(f"  Predicted (Corrected): {repr(prediction)}")
    print(f"  Actual (Correct)     : {repr(example['target_text'])}")
    print("------------------------------------------------------------")

Displaying predictions using `trainer.model`...

--- Example Predictions on Test Set (first 10 examples) ---
------------------------------------------------------------
Example 1:
  Input (Misspelled)   : "' Bandolier  Budgie ' , ia free iTunes app for iPad , iPhone and iPod touch , released in December 2011 , tells the story of the making of Bandolier in the band 's own words - including including an extensive audio interview with Burke Shelley "
  Predicted (Corrected): "' Bandolier , Budgie ' , a free iTunes app for iPad , iPhone and iPod touch , released in December 2011 , tells the story of the making of Bandoliers in the band 's own words - including an extensive audio interview with Burke Shelley ."
  Actual (Correct)     : "' Bandolier - Budgie ' , a free iTunes app for iPad , iPhone and iPod touch , released in December 2011 , tells the story of the making of Bandolier in the band 's own words - including an extensive audio interview with Burke Shelley ."
--------------------