In [1]:
import os
import requests
import cv2
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from bs4 import BeautifulSoup
from PIL import Image
from io import BytesIO

In [2]:
html_path =  os.getcwd() + '/htmls/nychinaren_url/'
image_src_input_path_train = os.getcwd() + '/image_phone_data/training_data/image_urls.txt'
label_input_path_train = os.getcwd() + '/image_phone_data/training_data/image_labels.txt'
image_src_input_path = os.getcwd() + '/image_phone_data/real_data/image_urls.txt'
label_input_path = os.getcwd() + '/image_phone_data/real_data/image_labels.txt'
model_path = os.getcwd() + '/image_phone_data/model/'
all_files = os.listdir(html_path)

In [3]:
# Configure models
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, EarlyStoppingCallback

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-printed").to(device)

model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id

# Adjust these?
#model.config.num_beams = 4 
#model.config.early_stopping = True


Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-small-printed and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# Save image sources into txt
"""
img_srcs = []

for fn in all_files:
    with open(html_path + fn, "r") as f:
        html_content = f.read()
        soup = BeautifulSoup(html_content, 'html.parser')
            
        contact_img_divs = soup.select('div.frm_rent')
        for div in contact_img_divs:
            if 'frm_phone' in div['class']:
                img_src = div.find('img')['src']
                img_srcs.append(img_src)

with open(image_src_input_path, "w") as f:
    for src in img_srcs:
        f.write(src + "\n")
"""

'\nimg_srcs = []\n\nfor fn in all_files:\n    with open(html_path + fn, "r") as f:\n        html_content = f.read()\n        soup = BeautifulSoup(html_content, \'html.parser\')\n            \n        contact_img_divs = soup.select(\'div.frm_rent\')\n        for div in contact_img_divs:\n            if \'frm_phone\' in div[\'class\']:\n                img_src = div.find(\'img\')[\'src\']\n                img_srcs.append(img_src)\n\nwith open(image_src_input_path, "w") as f:\n    for src in img_srcs:\n        f.write(src + "\n")\n'

In [5]:
def ocr(src, processor, model):
    img = cv2.imdecode(np.asarray(bytearray(requests.get(src, stream=True).content), dtype=np.uint8), cv2.IMREAD_COLOR)
    
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    _, binary = cv2.threshold(img, 200, 255, cv2.THRESH_BINARY)
    img = cv2.bitwise_not(binary)
    img = cv2.dilate(img, np.ones((2, 2), np.uint8), iterations=1)
    img = cv2.bitwise_not(img)
    #img = cv2.copyMakeBorder(img, 20, 20, 0, 0, cv2.BORDER_CONSTANT, value=[255, 255, 255])    
    img = cv2.copyMakeBorder(img, 20, 20, 20, 20, cv2.BORDER_CONSTANT, value=[255, 255, 255])    
    
    image = Image.fromarray(img).convert("RGB")
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

In [6]:
# Form some initial labels to help with manual labeling
"""
with open(image_src_input_path_train, "r") as img_srcs, open(label_input_path_train, "w") as f:  
    for src in img_srcs:        
        generated_text = ocr(src.strip(), processor, model)
        f.write(generated_text + "\n")
"""

'\nwith open(image_src_input_path_train, "r") as img_srcs, open(label_input_path_train, "w") as f:  \n    for src in img_srcs:        \n        generated_text = ocr(src.strip(), processor, model)\n        f.write(generated_text + "\n")\n'

In [7]:
# Convert the txt files into a dataframe
df_rows = []

with open(image_src_input_path_train, "r") as img_src_fn, open(label_input_path_train, "r") as label_fn:
    for line1, line2 in zip(img_src_fn, label_fn):
        df_rows.append((line1.strip(), line2.strip()))

data_df = pd.DataFrame(df_rows, columns=['img_src', 'label'])
data_df

Unnamed: 0,img_src,label
0,https://ny.nychinaren.com/images/topic_info/33...,917-436-9760
1,https://ny.nychinaren.com/images/topic_info/33...,3478635699
2,https://ny.nychinaren.com/images/topic_info/33...,212-672-6486
3,https://ny.nychinaren.com/images/topic_info/33...,6463547367
4,https://ny.nychinaren.com/images/topic_info/33...,6464009707
...,...,...
218,https://ny.nychinaren.com/images/topic_info/33...,631-231-8999
219,https://ny.nychinaren.com/images/topic_info/33...,6173814478
220,https://ny.nychinaren.com/images/topic_info/33...,(347) 884-2610
221,https://ny.nychinaren.com/images/topic_info/33...,347-838-1938


In [8]:
# Pytorch dataset class based on our dataframe
class SkewedDigitsDataset(Dataset):
    def __init__(self, df, processor, max_target_length=20):        
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):        
        img_src = self.df['img_src'][idx]
        text = self.df['label'][idx]

        
        # Image preprocessing for better accuracy (convert to binary, dilate, add border) and get pixel values from processor
        img = cv2.imdecode(np.asarray(bytearray(requests.get(img_src, stream=True).content), dtype=np.uint8), cv2.IMREAD_COLOR)
        
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        _, binary = cv2.threshold(img, 200, 255, cv2.THRESH_BINARY)
        img = cv2.bitwise_not(binary)
        img = cv2.dilate(img, np.ones((2, 2), np.uint8), iterations=1)
        img = cv2.bitwise_not(img)
        #img = cv2.copyMakeBorder(img, 20, 20, 0, 0, cv2.BORDER_CONSTANT, value=[255, 255, 255])
        img = cv2.copyMakeBorder(img, 20, 20, 20, 20, cv2.BORDER_CONSTANT, value=[255, 255, 255])
        
        img = Image.fromarray(img).convert("RGB")
        pixel_values = self.processor(img, return_tensors="pt").pixel_values
        
        # Add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text, 
                                          padding="max_length", 
                                          max_length=self.max_target_length).input_ids
        # Important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

In [9]:
# Setup training and test sets
train_df, test_df = train_test_split(data_df, test_size=0.2, random_state=42)
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

training_set = SkewedDigitsDataset(train_df, processor)
validation_set = SkewedDigitsDataset(test_df, processor)

In [10]:
import torch.optim as optim
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import load_metric

optimizer = optim.AdamW(model.parameters(), lr=3e-4) # Change learning rate?
cer_metric = load_metric('cer')
def compute_cer(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions
 
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
 
    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    return {"cer": cer}

  cer_metric = load_metric('cer')
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [11]:
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy='epoch',
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    fp16=True,
    output_dir=model_path,
    logging_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=5,    
    num_train_epochs=20,
    load_best_model_at_end=True
)

trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_cer,
    train_dataset=training_set,
    eval_dataset=validation_set,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.03)]
)



In [12]:
trainer.train()
# 0.3527, 0.53119, 0.236948 was the best

Epoch,Training Loss,Validation Loss,Cer
1,8.7822,4.516408,0.861446
2,4.1637,3.878091,1.138554
3,3.5649,3.484702,0.835341
4,3.0451,2.618145,0.594378
5,1.5394,1.275967,0.184739
6,0.6571,0.98295,0.301205
7,0.5474,0.939613,0.146586
8,0.4105,0.946572,0.232932
9,0.3116,0.744497,0.138554
10,0.2826,0.69764,0.253012




TrainOutput(global_step=368, training_loss=1.518221045317857, metrics={'train_runtime': 1047.2449, 'train_samples_per_second': 3.399, 'train_steps_per_second': 0.439, 'total_flos': 3.406911097704284e+17, 'train_loss': 1.518221045317857, 'epoch': 16.0})

In [13]:
model.save_pretrained(model_path + "/trained_model_test")

In [14]:
trained_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")
trained_model = VisionEncoderDecoderModel.from_pretrained(model_path + "/trained_model_test").to(device)

In [15]:
with open(image_src_input_path, "r") as img_srcs, open(label_input_path, "w") as f:  
    for src in img_srcs:                     
        generated_text = ocr(src.strip(), trained_processor, trained_model)        
        f.write(generated_text + "\n")        

