<a href="https://colab.research.google.com/github/samj786/NCVPRIPG-AutoEval/blob/main/sub_optimized.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**This file will guide you with steps and visualization.**


---


*In case, you are only interested in the final result, please take a look at a more optimized version of this file [here](https://colab.research.google.com/drive/1bEv_Al6JCuTK0-JVQ3nl5TOUZCqjT0OK?usp=sharing).*

# Install Required Packages
First, we need to install the required packages: `pytesseract`, `tesseract-ocr`, `opencv`, and `transformers`.

Note: `opencv` and `transformers` already come installed in Google Colab.


In [None]:
!pip install pytesseract
!apt-get install tesseract-ocr

# Import Libraries
 Import the necessary libraries for Image Processing, OCR, and Visualization.


In [None]:
from PIL import Image, ImageDraw
import pytesseract
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import IPython.display as display
import os
import torch
from transformers import AutoModelForObjectDetection, TableTransformerForObjectDetection, TrOCRProcessor, VisionEncoderDecoderModel
from torchvision import transforms
from matplotlib.patches import Rectangle, Patch
import re
from difflib import SequenceMatcher
from tqdm.auto import tqdm
import glob
import csv

# Detecting and Correcting Image Rotation + Preprocessing for table detection
The below code implements the following:
1. **Rotation** **Correction**: Using pytesseract's `osd` function, we detect the orientation of the image and correct it if necessary.
2. **Noise** **Removal**: We convert the image to grayscale and apply median blurring to remove noise, preparing it for table detection.

*Note: Change the path of the image to your own.*

In [None]:
# Step 1: Open the original image using OpenCV
path = '/content/20240328_160040.jpg'
original_im = Image.open(path)

# Step 2: Enhance the contrast of the original image for better OCR
im = original_im.convert('L')  # Convert to grayscale
im = im.point(lambda x: 0 if x < 128 else 255, '1')  # Binarize the image


# Step 3: Perform OCR to detect the rotation angle
osd = pytesseract.image_to_osd(im, output_type='dict')
print(osd)

# Rotate the original image based on the detected angle
rotate = osd['rotate']
if rotate != 0:
    im_fixed = original_im.rotate(-rotate, expand=True)  # Rotate the original image
else:
    im_fixed = original_im


im_fixed_np = np.array(im_fixed)

# Step 4: Convert the rotated image to grayscale and remove noise to prepare for table detection
gray_img = cv.cvtColor(im_fixed_np, cv.COLOR_BGR2GRAY)
median = cv.medianBlur(gray_img, 5)

# Table Region and Structure detection Using a Pre-trained Model
The below code implements the following main steps:

1. **Model Loading:** We load the `TableTransformerForObjectDetection` model from Hugging Face, which is pretrained on table detection tasks.
2. **Image Transformation:** The image is resized and normalized to match the input requirements of the model.
3. **Object Detection:** The model processes the image and outputs bounding boxes and labels for detected objects, focusing on tables in this case.

In [None]:
# Define the custom resize transformation
class MaxResize(object):
    def __init__(self, max_size=800):
        self.max_size = max_size

    def __call__(self, image):
        width, height = image.size
        current_max_size = max(width, height)
        scale = self.max_size / current_max_size
        resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))
        return resized_image

# Postprocessing functions
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b

def outputs_to_objects(outputs, img_size, id2label):
    score_threshold = 0.5
    m = outputs.logits.softmax(-1).max(-1)
    pred_labels = list(m.indices.detach().cpu().numpy())[0]
    pred_scores = list(m.values.detach().cpu().numpy())[0]
    pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
    pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]

    objects = []
    for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
        class_label = id2label[int(label)]
        if not class_label == 'no object' and score >= score_threshold:
            objects.append({'label': class_label, 'score': float(score), 'bbox': [float(elem) for elem in bbox]})
    return objects

# Visualization function
def fig2img(fig):
    import io
    buf = io.BytesIO()
    fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
    buf.seek(0)
    return Image.open(buf)


def visualize_detected_objects(img, objects, out_path=None):
    # Optimize figure creation
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(img, interpolation="lanczos")

    for obj in objects:
        bbox = obj['bbox']
        label = obj['label']
        facecolor = (1, 0, 0.45) if label == 'table' else (0.95, 0.6, 0.1)
        edgecolor = facecolor
        alpha = 0.3
        linewidth = 2
        hatch = '//////' if label == 'table' else 'xxxx'

        # Add patches in a more efficient manner
        rect_params = {'linewidth': linewidth, 'edgecolor': 'none', 'facecolor': facecolor, 'alpha': 0.1}
        ax.add_patch(Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], **rect_params))

        rect_params.update({'edgecolor': edgecolor, 'facecolor': 'none', 'alpha': alpha})
        ax.add_patch(Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], **rect_params))

        rect_params.update({'linewidth': 0, 'hatch': hatch, 'alpha': 0.2})
        ax.add_patch(Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], **rect_params))

    ax.set_xticks([])
    ax.set_yticks([])
    ax.axis('off')

    legend_elements = [
        Patch(facecolor=(1, 0, 0.45), edgecolor=(1, 0, 0.45), label='Table', hatch='//////', alpha=0.3),
        Patch(facecolor=(0.95, 0.6, 0.1), edgecolor=(0.95, 0.6, 0.1), label='Table (rotated)', hatch='xxxx', alpha=0.3)
    ]
    ax.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0, fontsize=10, ncol=2)

    if out_path is not None:
        fig.savefig(out_path, bbox_inches='tight', dpi=150)

    return fig


# Function to get model outputs
def get_model_outputs(image, model, transform):
    pixel_values = transform(image).unsqueeze(0)
    pixel_values = pixel_values.to(device)

    with torch.no_grad():
        outputs = model(pixel_values)

    return outputs


#Helper function for crop table
def iob(bbox1, bbox2):
    """Calculates Intersection over Union (IoU) for two bounding boxes."""
    x1 = max(bbox1[0], bbox2[0])
    y1 = max(bbox1[1], bbox2[1])
    x2 = min(bbox1[2], bbox2[2])
    y2 = min(bbox1[3], bbox2[3])

    intersection_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
    bbox1_area = (bbox1[2] - bbox1[0] + 1) * (bbox1[3] - bbox1[1] + 1)
    bbox2_area = (bbox2[2] - bbox2[0] + 1) * (bbox2[3] - bbox2[1] + 1)
    union_area = bbox1_area + bbox2_area - intersection_area

    return intersection_area / union_area

# Crop table function
def objects_to_crops(img, tokens, objects, class_thresholds, padding):
    """
    Process the bounding boxes produced by the table detection model into
    cropped table images and cropped tokens.
    """

    table_crops = []
    for obj in objects:
        if obj['score'] < class_thresholds[obj['label']]:
            continue

        cropped_table = {}

        bbox = obj['bbox']
        bbox = [bbox[0]-padding, bbox[1]-10, bbox[2]+(1.2*padding), bbox[3]+padding]

        bbox[1] = max(0, bbox[1])
        #bbox[2] = min(img.size[0], bbox[2])

        cropped_img = img.crop(bbox)

        table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5]
        for token in table_tokens:
            token['bbox'] = [token['bbox'][0]-bbox[0],
                             token['bbox'][1]-bbox[1],
                             token['bbox'][2]-bbox[0],
                             token['bbox'][3]-bbox[1]]

        # If table is predicted to be rotated, rotate cropped image and tokens/words:
        if obj['label'] == 'table rotated':
            cropped_img = cropped_img.rotate(270, expand=True)
            for token in table_tokens:
                bbox = token['bbox']
                bbox = [cropped_img.size[0]-bbox[3]-1,
                        bbox[0],
                        cropped_img.size[0]-bbox[1]-1,
                        bbox[2]]
                token['bbox'] = bbox

        cropped_table['image'] = cropped_img
        cropped_table['tokens'] = table_tokens

        table_crops.append(cropped_table)

    return table_crops

#Visualization of the table structure being detected
def plot_results(cells, class_to_visualize):
    if class_to_visualize not in structure_model.config.id2label.values():
      raise ValueError("Class should be one of the available classes")

    plt.figure(figsize=(16,10))
    plt.imshow(cropped_table)
    ax = plt.gca()

    for cell in cells:
        score = cell["score"]
        bbox = cell["bbox"]
        label = cell["label"]

        if label == class_to_visualize:
          xmin, ymin, xmax, ymax = tuple(bbox)

          ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color="red", linewidth=3))
          text = f'{cell["label"]}: {score:0.2f}'
          ax.text(xmin, ymin, text, fontsize=15,
                  bbox=dict(facecolor='yellow', alpha=0.5))
          plt.axis('off')


# Function to get the last column bbox
def get_last_column_bbox(cells):
    columns = [entry for entry in cells if entry['label'] == 'table column']
    columns.sort(key=lambda x: x['bbox'][0])
    return columns[-1]['bbox']

# Function to get cell coordinates by row within the last column
def get_row_coordinates_within_column(column_image, original_bbox, rows):
    rows.sort(key=lambda x: x['bbox'][1])

    # Adjust row bbox coordinates relative to the original image
    for row in rows:
        row['bbox'] = [original_bbox[0], row['bbox'][1], original_bbox[2], row['bbox'][3]]

    return rows


# Define the transformation pipeline for detection
detection_transform = transforms.Compose([
    MaxResize(800),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Define the transformation pipeline for structure model
structure_transform = transforms.Compose([
    MaxResize(1000),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
# Load detection model
model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Load structure model
structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
structure_model.to(device)

In [None]:
# Convert the image to a PIL image
binary_pil = Image.fromarray(median).convert("RGB")

# Get model outputs
outputs = get_model_outputs(binary_pil, model, detection_transform)

# Convert outputs to objects
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"
objects = outputs_to_objects(outputs, binary_pil.size, id2label)

Visualize the detected table region

In [None]:
# Visualization for table region being detected - to visualize the detected table region uncomment the below lines
'''
fig = visualize_detected_objects(binary_pil, objects)
visualized_image = fig2img(fig)
visualized_image.show()
'''

In [None]:
#Cropping the detected table region
tokens = []
detection_class_thresholds = {
    "table": 0.5,
    "table rotated": 0.5,
    "no object": 10
}
crop_padding = 200

tables_crops = objects_to_crops(binary_pil, tokens, objects, detection_class_thresholds, padding=crop_padding)
cropped_table = tables_crops[0]['image'].convert("RGB")

In [None]:
#to show the cropped image - uncomment the below line
'''
cropped_table
'''

In [None]:
# Get model outputs for the cropped table
new_outputs = get_model_outputs(cropped_table, structure_model, structure_transform)

structure_id2label = structure_model.config.id2label
structure_id2label[len(structure_id2label)] = "no object"
cells = outputs_to_objects(new_outputs, cropped_table.size, structure_id2label)
rows = [entry for entry in cells if entry['label'] == 'table row']

In [None]:
# Check the number of detected rows
if len(rows) <= 10:
    table_bbox = objects[0]['bbox']
    extended_bbox = [table_bbox[0], table_bbox[1], table_bbox[2], table_bbox[3] + 200]
    #extended_cropped_table = binary_pil.crop([extended_bbox[0]- crop_padding, extended_bbox[1]-crop_padding, extended_bbox[2]+crop_padding, extended_bbox[3]+crop_padding])

    # Use the same cropping function for the extended bounding box
    extended_objects = [{'label': 'table', 'score': 1.0, 'bbox': extended_bbox}]
    extended_tables_crops = objects_to_crops(binary_pil, tokens, extended_objects, detection_class_thresholds, padding=crop_padding)
    extended_cropped_table = extended_tables_crops[0]['image'].convert("RGB")

    # Reapply the model on the extended cropped table
    extended_outputs = get_model_outputs(extended_cropped_table, structure_model, structure_transform)
    extended_cells = outputs_to_objects(extended_outputs, extended_cropped_table.size, structure_id2label)
    rows = [entry for entry in extended_cells if entry['label'] == 'table row']
    print(f"Number of detected rows: {len(rows)}")
    final_cropped_table = extended_cropped_table if len(rows) > 10 else cropped_table
else:
    final_cropped_table = cropped_table

#Get the final model outputs
final_outputs = get_model_outputs(final_cropped_table, structure_model, structure_transform)
final_cells = outputs_to_objects(final_outputs, final_cropped_table.size, structure_id2label)

In [None]:
#To visualize the detection of rows, uncomment the below line
'''
plot_results(final_cells, class_to_visualize="table row")
'''

In [None]:
# Get the last column bbox and crop the last column
last_column_bbox = get_last_column_bbox(final_cells)
last_column_image = final_cropped_table.crop(last_column_bbox)

rows_within_last_column = get_row_coordinates_within_column(last_column_image, last_column_bbox, rows)

In [None]:
# Define directory for saving binarized images
binarized_dir = '/content/binarized_cells'
os.makedirs(binarized_dir, exist_ok=True)

In [None]:
# Crop, binarize, and save each cell in the last column
for idx, row in enumerate(rows_within_last_column):
    cell_bbox = [last_column_bbox[0], row['bbox'][1], last_column_bbox[2], row['bbox'][3]]
    cell_image = cropped_table.crop(cell_bbox)

    # Convert to grayscale
    grayscale_image = cell_image.convert('L')

    # Apply Otsu's thresholding
    image_array = np.array(grayscale_image)
    _, binarized_image = cv.threshold(image_array, 0, 255, cv.THRESH_BINARY + cv.THRESH_OTSU)

    # Save the binarized image
    binarized_image_pil = Image.fromarray(binarized_image)
    binarized_image_path = os.path.join(binarized_dir, f'cell_{idx + 1}.png')
    binarized_image_pil.save(binarized_image_path)

    #print(f'Saved binarized image: {binarized_image_path}')

print("All cells have been binarized and saved.")

# OCR on the cells containing handwritten answers

1. **Text Detection:** Using TrOCRProcessor and VisionEncoderDecoderModel, we detect the handwritten text from each cropped out cell within the detected table regions.
2. **Checking Blank Image:** By calculating the number of white pixels, we set a threshold to identify an image as a blank image. *(This is done after observing that trOCR returned garbage values on blank images)*

In [None]:
# Load processor and model
processor_trOCR = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
model_trOCr = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten')

preprocessor_config.json:   0%|          | 0.00/224 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.12k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.13k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.23G [00:00<?, ?B/s]

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-large-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


generation_config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [None]:
def is_blank_image(image, threshold=0.985):
    """
    Check if an image is blank by analyzing the percentage of white pixels.
    """
    #black threshold
    black_threshold = 0.2
    # Convert image to numpy array
    image_array = np.array(image)
    # Calculate the percentage of white pixels
    white_pixels = np.sum(image_array == 255)
    total_pixels = image_array.size
    white_pixel_ratio = white_pixels / total_pixels
    black_pixel_ratio = 1 - white_pixel_ratio
    print(f"White pixel ratio: {white_pixel_ratio:.4f}, Black pixel ratio: {black_pixel_ratio:.4f}")
    if black_pixel_ratio > black_threshold:
        return True
    return white_pixel_ratio > threshold

def ocr(image, processor, model):
    """
    Perform OCR using the TrOCR model.
    """
    # Process image and generate text
    pixel_values = processor(images=image, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

def ocr_image(image_path):
    # Load and preprocess the image
    image = Image.open(image_path).convert("RGB")

    if is_blank_image(image):
        return "Blank Image"

    # Perform OCR on the preprocessed image
    generated_text = ocr(image, processor_trOCR, model_trOCr)
    return generated_text

In [None]:
# Collect OCR results
ocr_results = []

In [None]:
image_paths = [f'/content/binarized_cells/cell_{i}.png' for i in range(2, 12)]

for path in image_paths:
    result = ocr_image(path)
    ocr_results.append(result)
    print(f"OCR results for {path}: {result}")

# ocr_results now contains the OCR results in proper order
print(ocr_results)

## Post-processing of ocr_results

Here, we are cleaning up the *ocr_results* by implementing a series of logics to identify the text as 'TRUE' or 'FALSE'.






In [None]:
def is_similar(a, b, threshold=0.4):
    """
    Check if two strings are similar based on a given threshold.
    """
    return SequenceMatcher(None, a, b).ratio() > threshold

def normalize_text(text):
    """
    Normalize text by converting to lowercase and removing non-alphanumeric characters.
    """
    text = text.lower()
    text = re.sub(r'[^a-z]', '', text)  # Remove non-alphabet characters
    return text

def post_process_ocr_results(ocr_results):
    """
    Post-process OCR results to clean up and correctly identify true and false values.
    """
    cleaned_results = []
    for result in ocr_results:
        #print(f"Original result: {result}")

        # Remove special characters and spaces
        result_no_special = re.sub(r'[^a-zA-Z0-9]', '', result)
        #print(f"Result without special characters: {result_no_special}")

        # Check if the result is just numbers
        if result_no_special.isdigit():
            cleaned_results.append("Blank Image")
            #print("Detected as numbers only, appended 'Blank Image'")
            continue

        # Normalize the result
        normalized_result = normalize_text(result)
        #print(f"Normalized result: {normalized_result}")

        # Check if the result is empty after normalization
        if normalized_result == '':
            cleaned_results.append("Blank Image")
            #print("Result is empty after normalization, appended 'Blank Image'")
            continue

        # Apply the new rules
        if normalized_result.startswith('t'):
            if len(normalized_result) == 1 or (len(normalized_result) > 1 and normalized_result[1] != 'a'):
                cleaned_results.append("true")
                #print("Appended 'true'")
                continue
            elif len(normalized_result) > 1 and normalized_result[1] == 'a':
                cleaned_results.append("false")
                #print("Appended 'false'")
                continue

        if normalized_result.startswith('f'):
            if len(normalized_result) == 1 or (len(normalized_result) > 1 and normalized_result[1] != 'r'):
                cleaned_results.append("false")
                #print("Appended 'false'")
                continue
            elif len(normalized_result) > 1 and normalized_result[1] == 'r':
                cleaned_results.append("true")
                #print("Appended 'true'")
                continue

        if normalized_result == 'blankimage':
            cleaned_results.append("Blank Image")
            #print("Appended 'Blank Image'")
            continue

        if 't' in normalized_result and not 'f' in normalized_result:
            cleaned_results.append("true")
            #print("Appended 'true'")
        elif 'f' in normalized_result and not 't' in normalized_result:
            cleaned_results.append("false")
            #print("Appended 'false'")
        elif 'r' in normalized_result and not 'l' in normalized_result:
            cleaned_results.append("true")
            #print("Appended 'true'")
        elif 'l' in normalized_result and not 'r' in normalized_result:
            cleaned_results.append("false")
            #print("Appended 'false'")
        elif is_similar(normalized_result, "true"):
            cleaned_results.append("true")
            #print("Similarity check: appended 'true'")
        elif is_similar(normalized_result, "false"):
            cleaned_results.append("false")
            #print("Similarity check: appended 'false'")
        else:
            cleaned_results.append("Uncertain")  # Handle uncertain cases
            #print("Appended 'Uncertain'")

    return cleaned_results

In [None]:
cleaned_results = post_process_ocr_results(ocr_results)
print(cleaned_results)

['true', 'true', 'false', 'false', 'true', 'false', 'true', 'Blank Image', 'Blank Image', 'true']


# Score

We finally compare our *cleaned_results* with the *Model-Answers* and calculate the final score of every image.

In [None]:
# Load correct answers from CSV
correct_answers = []
with open('ModelAnswer (1).csv', 'r') as file:
    reader = csv.reader(file)
    next(reader)  # Skip the header
    for row in reader:
        _, answer = row
        correct_answers.append(answer)

In [None]:
# Compare cleaned results with correct answers
score = 0
total = len(cleaned_results)
for i, token in enumerate(cleaned_results):
    if token.lower() == correct_answers[i].lower():
        score += 1
    elif token.lower() == "uncertain":
        if correct_answers[i].lower() in ["true", "false"]:
            score += 1

print(f"Score: {score}/{total}")

Score: 5/10


**END OF FILE**