In [None]:
from collections import Counter

import matplotlib.pylab as plt
import numpy as np
import pandas as pd

import torch
from transformers import VisionEncoderDecoderModel, default_data_collator, TrOCRProcessor
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import load_metric

from sequence_mnist.model import SequenceMNIST
from tests.test_sequence_mnist import test_sample

In [None]:
test_sample()

In [None]:
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-printed', num_labels=10)
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-printed')
train_dataset = SequenceMNIST(train=True, processor=processor, root="/tmp/data", download=True)
test_dataset = SequenceMNIST(train=False, processor=processor, root="/tmp/data", download=True)

In [None]:
# Tokens for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# Vocab size is the number of integers 0-9
model.config.vocab_size = 10

In [None]:

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    output_dir="results/",
    logging_steps=2,
    save_steps=1000,
    num_train_epochs=2,
    eval_steps=200,
)

In [None]:
### HUGGING FACE

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=default_data_collator,
)

In [None]:
trainer.train()

In [None]:
def predict(model, pixel_values, **kwargs):
    generated_ids = model.generate(pixel_values.unsqueeze(dim=0))
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    print(f'Prediction: {generated_text}, Sample: {kwargs}')
    return generated_text

In [None]:
# Predict a single instance
predict(model=model, **test_dataset[0])

In [None]:
torch.save(model.state_dict(), 'results/mnist_trained_trocr.pth')

Similarity metric isn't perfect as it doesn't account for order - but good enough proxy for quick eval.

In [None]:
def string_similarity(base_string: str, comp_string: str) -> float:
    base_dict = dict(Counter(base_string))
    for key in base_dict:
        comp_count = comp_string.count(key)
        gt_count = base_dict[key]
        base_dict[key] -= comp_count if comp_count <= gt_count else gt_count

    return 1 - sum(base_dict.values()) / (len(base_string) + 1e-9)

In [None]:
def eval_metrics(pred: str, gt: str) -> dict:

    correct = True if pred == gt else False
    similarity = string_similarity(base_string=gt, comp_string=pred) if len(pred) > 0 else 0.

    return {'Correct' : correct, 'Similarity' : similarity}

In [None]:
test_results = []
for sample in test_dataset:
    pred = predict(model=model, pixel_values=sample['pixel_values'])
    gt = sample['text']
    result = eval_metrics(pred, gt)
    result['Prediction'] = pred
    result['Ground Truth'] = gt
    test_results.append(result)
    #print(f'Label: {gt}, Prediction: {pred}, Result: {result}')

In [None]:
test_results_df = pd.DataFrame(test_results)

# Data Analysis

In [None]:
train_counter = Counter()
for sample in train_dataset:
    train_counter.update(sample['text'])


In [None]:
dict(train_counter)

In [None]:
images = torch.hstack([torch.Tensor(img) for img in train_dataset.data[0:5]])

In [None]:
plt.imshow(images, cmap='gray')

In [None]:
# Checking out the images
try:
    i += 1
except:
    i = 0

plt.imshow(train_dataset.data[i], cmap="gray")

In [None]:
test_results_df.groupby(['Correct', 'Similarity']).count()

In [None]:
test_results_df[(test_results_df['Similarity'] == 0.0) & (test_results_df['Prediction'] != '')]

All similarity 0 results were sequences for which there was no prediction. If a prediction was made it is likely to be very similar (4/5 letters or 0.8 similarity) 490/573 = 86%

Interesting edge case, spaces sometimes detected:

In [None]:
test_results_df[(test_results_df['Correct'] == False) & (test_results_df['Similarity'] == 1.0)]

In [None]:
accuracy = sum(test_results_df['Correct'] == True) / len(test_results_df)
print(f'Complete Match Accuracy: {accuracy*100:.1f}%')
avg_sim = test_results_df['Similarity'].mean()
print(f'Mean Similarity Measure: {avg_sim:.2f}')