In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:

#!pip install -q transformers
!pip install -q easyocr

from transformers import AutoModelForObjectDetection
from transformers import TableTransformerForObjectDetection
import torch
import os
from torchvision import transforms
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Patch
import numpy as np
import csv
import easyocr
from tqdm.auto import tqdm
import csv
from PIL import Image

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.9/2.9 MB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m908.3/908.3 kB[0m [31m21.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m307.2/307.2 kB[0m [31m21.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:

# class to prepare image
class MaxResize(object):
    def __init__(self, max_size=800):
        self.max_size = max_size

    def __call__(self, image):
        width, height = image.size
        current_max_size = max(width, height)
        scale = self.max_size / current_max_size
        resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))

        return resized_image

In [4]:
def get_image_urls(directory_path):
    image_urls = []

    # Iterate through all files in the directory
    for filename in os.listdir(directory_path):
        # Check if the file is an image (you can customize the list of valid extensions)
        valid_extensions = ['.jpg', '.jpeg', '.png']
        for i in valid_extensions:
          if filename.lower().endswith(i):
            # Create the URL by joining the directory path and filename
            image_url = os.path.join(directory_path, filename)
            # Append the URL to the list
            image_urls.append(image_url)

    return image_urls

In [5]:
# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)


def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b


In [6]:
def outputs_to_objects(outputs, img_size, id2label):
    m = outputs.logits.softmax(-1).max(-1)
    pred_labels = list(m.indices.detach().cpu().numpy())[0]
    pred_scores = list(m.values.detach().cpu().numpy())[0]
    pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
    pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]

    objects = []
    for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
        class_label = id2label[int(label)]
        if not class_label == 'no object':
            objects.append({'label': class_label, 'score': float(score),
                            'bbox': [float(elem) for elem in bbox]})

    return objects

In [7]:
def fig2img(fig):
    """Convert a Matplotlib figure to a PIL Image and return it"""
    import io
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img

In [8]:
def visualize_detected_tables(img, det_tables, out_path=None):
    plt.imshow(img, interpolation="lanczos")
    fig = plt.gcf()
    fig.set_size_inches(20, 20)
    ax = plt.gca()

    for det_table in det_tables:
        bbox = det_table['bbox']

        if det_table['label'] == 'table':
            facecolor = (1, 0, 0.45)
            edgecolor = (1, 0, 0.45)
            alpha = 0.3
            linewidth = 2
            hatch='//////'
        elif det_table['label'] == 'table rotated':
            facecolor = (0.95, 0.6, 0.1)
            edgecolor = (0.95, 0.6, 0.1)
            alpha = 0.3
            linewidth = 2
            hatch='//////'
        else:
            continue

        rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
                                    edgecolor='none',facecolor=facecolor, alpha=0.1)
        ax.add_patch(rect)
        rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
                                    edgecolor=edgecolor,facecolor='none',linestyle='-', alpha=alpha)
        ax.add_patch(rect)
        rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0,
                                    edgecolor=edgecolor,facecolor='none',linestyle='-', hatch=hatch, alpha=0.2)
        ax.add_patch(rect)

    plt.xticks([], [])
    plt.yticks([], [])

    legend_elements = [Patch(facecolor=(1, 0, 0.45), edgecolor=(1, 0, 0.45),
                                label='Table', hatch='//////', alpha=0.3),
                        Patch(facecolor=(0.95, 0.6, 0.1), edgecolor=(0.95, 0.6, 0.1),
                                label='Table (rotated)', hatch='//////', alpha=0.3)]
    plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0,
                    fontsize=10, ncol=2)
    plt.gcf().set_size_inches(10, 10)
    plt.axis('off')

    if out_path is not None:
      plt.savefig(out_path, bbox_inches='tight', dpi=150)

    return fig

In [9]:
def objects_to_crops(img, tokens, objects, class_thresholds, padding=10):
    """
    Process the bounding boxes produced by the table detection model into
    cropped table images and cropped tokens.
    """

    table_crops = []
    for obj in objects:
        if obj['score'] < class_thresholds[obj['label']]:
            continue

        cropped_table = {}

        bbox = obj['bbox']
        bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding]

        cropped_img = img.crop(bbox)

        table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5]
        for token in table_tokens:
            token['bbox'] = [token['bbox'][0]-bbox[0],
                             token['bbox'][1]-bbox[1],
                             token['bbox'][2]-bbox[0],
                             token['bbox'][3]-bbox[1]]

        # If table is predicted to be rotated, rotate cropped image and tokens/words:
        if obj['label'] == 'table rotated':
            cropped_img = cropped_img.rotate(270, expand=True)
            for token in table_tokens:
                bbox = token['bbox']
                bbox = [cropped_img.size[0]-bbox[3]-1,
                        bbox[0],
                        cropped_img.size[0]-bbox[1]-1,
                        bbox[2]]
                token['bbox'] = bbox

        cropped_table['image'] = cropped_img
        cropped_table['tokens'] = table_tokens

        table_crops.append(cropped_table)

    return table_crops

In [16]:
# Function to find cell coordinates
def find_cell_coordinates(row, column):
        cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]]
        return cell_bbox

In [19]:
def get_cell_coordinates_by_row(table_data):
    # Extract rows and columns
    rows = [entry for entry in table_data if entry['label'] == 'table row']
    columns = [entry for entry in table_data if entry['label'] == 'table column']
    print("=================================>>>>> rows ", rows, columns)
    # Sort rows and columns by their Y and X coordinates, respectively
    rows.sort(key=lambda x: x['bbox'][1])
    columns.sort(key=lambda x: x['bbox'][0])

    # Generate cell coordinates and count cells in each row
    cell_coordinates = []

    for row in rows:
        row_cells = []
        for column in columns:
            cell_bbox = find_cell_coordinates(row, column)
            row_cells.append({'column': column['bbox'], 'cell': cell_bbox})

        # Sort cells in the row by X coordinate
        row_cells.sort(key=lambda x: x['column'][0])

        # Append row information to cell_coordinates
        cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)})

    # Sort rows from top to bottom
    cell_coordinates.sort(key=lambda x: x['row'][1])
    print("==================================>>>>>>>>>>>>>>>>>>>", cell_coordinates)
    return cell_coordinates

In [11]:
def apply_ocr(cell_coordinates):
    # let's OCR row by row
    data = dict()
    max_num_columns = 0
    for idx, row in enumerate(tqdm(cell_coordinates)):
      row_text = []
      for cell in row["cells"]:
        # crop cell out of image
        cell_image = np.array(cropped_table.crop(cell["cell"]))
        # apply OCR
        result = reader.readtext(np.array(cell_image))
        if len(result) > 0:
          # print([x[1] for x in list(result)])
          text = " ".join([x[1] for x in result])
          row_text.append(text)

      if len(row_text) > max_num_columns:
          max_num_columns = len(row_text)

      data[idx] = row_text

    print("Max number of columns:", max_num_columns)

    # pad rows which don't have max_num_columns elements
    # to make sure all rows have the same number of columns
    for row, row_data in data.copy().items():
        if len(row_data) != max_num_columns:
          row_data = row_data + ["" for _ in range(max_num_columns - len(row_data))]
        data[row] = row_data

    return data

In [12]:
def save_as_csv(file_name,data):
    with open(file_name,'w') as result_file:
        wr = csv.writer(result_file, dialect='excel')
        for row, row_text in data.items():
            wr.writerow(row_text)

In [26]:
if __name__=="__main__":
    # load model
    model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
    print(model.config.id2label)

    # set to gpu
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    print("")

    # update id2label to include "no object"
    id2label = model.config.id2label
    id2label[len(model.config.id2label)] = "no object"

    # extract all images from directory
    dir_path=r"/content/drive/MyDrive/PAGES WITH TABLES (4)/PAGES WITH TABLES"
    all_images_in_dir=get_image_urls(dir_path)


    # do detection on each image
    for file_path in all_images_in_dir:

        image = Image.open(file_path).convert("RGB")
        print(image)

        # prepare image
        detection_transform=transforms.Compose([
            MaxResize(800),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        pixel_values = detection_transform(image).unsqueeze(0)
        pixel_values = pixel_values.to(device)
        print(f"Pixel value for image : {os.path.basename(file_path)} is ==> {pixel_values.shape}")


        # froward pass
        with torch.no_grad():
            outputs = model(pixel_values)

        # postProcessing : Next, we take the prediction that has an actual class (i.e. not "no object").
        objects = outputs_to_objects(outputs, image.size, id2label)
        print(f"Objects for image : {os.path.basename(file_path)} is ==> {objects}")

        # to visualize image with detection run this in collab for each images
        #fig = visualize_detected_tables(image, objects)
        #visualized_image = fig2img(fig)

        # crop table



        """Process the bounding boxes produced by the table detection model into
        cropped table images and cropped tokens."""

        tokens = []
        detection_class_thresholds = {
                "table": 0.5,
                    "table rotated": 0.5,
                        "no object": 10
        }
        crop_padding = 10

        tables_crops = objects_to_crops(image, tokens, objects, detection_class_thresholds, padding=0)
        if tables_crops:
          cropped_table = tables_crops[0]['image'].convert("RGB")
        else:
          print("no columns detected")
        #you can visualize the crop table ==> cropped_table
        # save cropped table ==> cropped_table.save(f"{image}_table.jpg")
        #cropped_table.save(f"{image}_table.jpg")



        # load structure recognition model
        #new v1.1 checkpoints require no timm anymore
        structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
        structure_model.to(device)



        structure_transform=transforms.Compose([
            MaxResize(1000),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        pixel_values = structure_transform(cropped_table).unsqueeze(0)
        pixel_values = pixel_values.to(device)
        print(f"structure_transform Pixel value for image : {image} is ==> {pixel_values.shape}")

        # forward pass
        with torch.no_grad():
            outputs = structure_model(pixel_values)

        # update id2label to include "no object"
        structure_id2label = structure_model.config.id2label
        structure_id2label[len(structure_id2label)] = "no object"
        cells = outputs_to_objects(outputs, cropped_table.size, structure_id2label)
        # print(cells)


        # SKIIPING FINALIZED TABLE VISUALIZATION PART.


        # apply ocr row by row
        cell_coordinates = get_cell_coordinates_by_row(cells)
        print("======================>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ",len(cell_coordinates[0]["cells"]))
        # you can see row wise data of coordinates
        for row in cell_coordinates:
            print(row["cells"])

        reader = easyocr.Reader(['en'])
        data = apply_ocr(cell_coordinates)
        #for row, row_data in data.items():
           # print(row_data)


        save_as_csv(f"{os.path.basename(file_path)}_data.csv",data)


{0: 'table', 1: 'table rotated'}

<PIL.Image.Image image mode=RGB size=1275x1650 at 0x7C8C8866C8E0>
Pixel value for image : Pages with tables (1)_page-0003.jpg is ==> torch.Size([1, 3, 800, 618])
Objects for image : Pages with tables (1)_page-0003.jpg is ==> [{'label': 'table', 'score': 0.9999123811721802, 'bbox': [92.92005920410156, 190.97616577148438, 1195.224853515625, 1497.059326171875]}]
structure_transform Pixel value for image : <PIL.Image.Image image mode=RGB size=1275x1650 at 0x7C8C8866C8E0> is ==> torch.Size([1, 3, 1000, 844])
[{'column': [0.6660800576210022, -0.05546361207962036, 143.2706298828125, 1299.9361572265625], 'cell': [0.6660800576210022, -0.1298116147518158, 143.2706298828125, 21.068477630615234]}, {'column': [145.40216064453125, 0.5712168216705322, 364.3544921875, 1300.24169921875], 'cell': [145.40216064453125, -0.1298116147518158, 364.3544921875, 21.068477630615234]}, {'column': [368.1749267578125, 0.48029541969299316, 709.7206420898438, 1299.770751953125], 'cell

  0%|          | 0/10 [00:00<?, ?it/s]

Max number of columns: 5
<PIL.Image.Image image mode=RGB size=1275x1650 at 0x7C8C885C5AB0>
Pixel value for image : PAGE 5.jpg is ==> torch.Size([1, 3, 800, 618])
Objects for image : PAGE 5.jpg is ==> [{'label': 'table', 'score': 0.9968369007110596, 'bbox': [122.39686584472656, 153.7803955078125, 720.2481689453125, 429.9388122558594]}, {'label': 'table', 'score': 0.9756752252578735, 'bbox': [127.56649017333984, 468.0892028808594, 1099.3411865234375, 1506.0799560546875]}]
structure_transform Pixel value for image : <PIL.Image.Image image mode=RGB size=1275x1650 at 0x7C8C885C5AB0> is ==> torch.Size([1, 3, 462, 1000])
[{'column': [6.194310188293457, 20.092769622802734, 131.72174072265625, 272.7504577636719], 'cell': [6.194310188293457, 20.171566009521484, 131.72174072265625, 38.60599136352539]}, {'column': [125.48196411132812, 20.479522705078125, 364.268798828125, 273.01422119140625], 'cell': [125.48196411132812, 20.171566009521484, 364.268798828125, 38.60599136352539]}, {'column': [224.83

  0%|          | 0/12 [00:00<?, ?it/s]

Max number of columns: 4
<PIL.Image.Image image mode=RGB size=1275x1650 at 0x7C8C88762A10>
Pixel value for image : Pages with tables (1)_page-0004.jpg is ==> torch.Size([1, 3, 800, 618])
Objects for image : Pages with tables (1)_page-0004.jpg is ==> [{'label': 'table', 'score': 0.9217093586921692, 'bbox': [120.583984375, 347.51068115234375, 1131.6683349609375, 1413.169189453125]}]
structure_transform Pixel value for image : <PIL.Image.Image image mode=RGB size=1275x1650 at 0x7C8C88762A10> is ==> torch.Size([1, 3, 1000, 949])
[{'column': [1.144584059715271, 2.235030174255371, 89.37885284423828, 1055.2603759765625], 'cell': [1.144584059715271, 0.9796060919761658, 89.37885284423828, 17.8302001953125]}, {'column': [92.5902328491211, 2.207226514816284, 308.4300231933594, 1055.3536376953125], 'cell': [92.5902328491211, 0.9796060919761658, 308.4300231933594, 17.8302001953125]}, {'column': [349.0235290527344, 1.6596570014953613, 514.3118896484375, 1055.87353515625], 'cell': [349.0235290527344,

  0%|          | 0/48 [00:00<?, ?it/s]

Max number of columns: 7
<PIL.Image.Image image mode=RGB size=1275x1650 at 0x7C8C8843D660>
Pixel value for image : Pages with tables (1)_page-0001.jpg is ==> torch.Size([1, 3, 800, 618])
Objects for image : Pages with tables (1)_page-0001.jpg is ==> [{'label': 'table', 'score': 0.9872811436653137, 'bbox': [78.83558654785156, 195.83523559570312, 1162.3287353515625, 1288.5723876953125]}]
structure_transform Pixel value for image : <PIL.Image.Image image mode=RGB size=1275x1650 at 0x7C8C8843D660> is ==> torch.Size([1, 3, 1000, 991])
[{'column': [5.556395530700684, 3.6884121894836426, 281.561767578125, 1091.8695068359375], 'cell': [5.556395530700684, 3.8791143894195557, 281.561767578125, 29.424598693847656]}, {'column': [283.5521240234375, 4.280248165130615, 576.5831298828125, 1092.0594482421875], 'cell': [283.5521240234375, 3.8791143894195557, 576.5831298828125, 29.424598693847656]}, {'column': [574.8310546875, 3.5371389389038086, 704.173828125, 1091.7586669921875], 'cell': [574.831054687

  0%|          | 0/12 [00:00<?, ?it/s]

Max number of columns: 6
<PIL.Image.Image image mode=RGB size=1275x1650 at 0x7C8C887E1ED0>
Pixel value for image : PAGE 8.jpg is ==> torch.Size([1, 3, 800, 618])
Objects for image : PAGE 8.jpg is ==> [{'label': 'table', 'score': 0.8436331748962402, 'bbox': [87.88920593261719, 293.9248352050781, 715.8790893554688, 1280.296142578125]}, {'label': 'table', 'score': 0.8794835209846497, 'bbox': [87.57617950439453, 75.25721740722656, 915.622802734375, 211.6537628173828]}]
structure_transform Pixel value for image : <PIL.Image.Image image mode=RGB size=1275x1650 at 0x7C8C887E1ED0> is ==> torch.Size([1, 3, 1000, 637])
[{'column': [0.2657792270183563, 32.88214874267578, 120.77088928222656, 971.7088623046875], 'cell': [0.2657792270183563, 136.0955810546875, 120.77088928222656, 157.95973205566406]}, {'column': [118.84996032714844, 32.44651794433594, 307.8067626953125, 971.5819091796875], 'cell': [118.84996032714844, 136.0955810546875, 307.8067626953125, 157.95973205566406]}, {'column': [308.225982

  0%|          | 0/20 [00:00<?, ?it/s]

Max number of columns: 4
<PIL.Image.Image image mode=RGB size=1275x1650 at 0x7C8C8866E830>
Pixel value for image : PAGE 4.jpg is ==> torch.Size([1, 3, 800, 618])
Objects for image : PAGE 4.jpg is ==> [{'label': 'table', 'score': 0.9217093586921692, 'bbox': [120.583984375, 347.51068115234375, 1131.6683349609375, 1413.169189453125]}]
structure_transform Pixel value for image : <PIL.Image.Image image mode=RGB size=1275x1650 at 0x7C8C8866E830> is ==> torch.Size([1, 3, 1000, 949])
[{'column': [1.144584059715271, 2.235030174255371, 89.37885284423828, 1055.2603759765625], 'cell': [1.144584059715271, 0.9796060919761658, 89.37885284423828, 17.8302001953125]}, {'column': [92.5902328491211, 2.207226514816284, 308.4300231933594, 1055.3536376953125], 'cell': [92.5902328491211, 0.9796060919761658, 308.4300231933594, 17.8302001953125]}, {'column': [349.0235290527344, 1.6596570014953613, 514.3118896484375, 1055.87353515625], 'cell': [349.0235290527344, 0.9796060919761658, 514.3118896484375, 17.8302001

  0%|          | 0/48 [00:00<?, ?it/s]

KeyboardInterrupt: 