# Part 1a: Downloading Dataset
Download the `ocr-data.zip` file from either kaggle or google drive, currently this notebook using kaggle API. Before running the cells below, please drop the `kaggle.json` credentials file to you colab directory.

## Method 1: Using Kaggle API
For some reason, downloading the dataset using Kaggle API is much faster than using google drive API.

In [1]:
# Install required libraries
!pip install -q kaggle

In [7]:
# create .kaggle directory in root direcory
!mkdir -p ~/.kaggle
# copy kaggle.json to ~/.kaggle
!cp kaggle.json ~/.kaggle
# change file permission for kaggle.json
!chmod 600 ~/.kaggle/kaggle.json

In [8]:
# download ocr dataset
!kaggle datasets download -d aidapearson/ocr-data

Downloading ocr-data.zip to /content
100% 10.1G/10.1G [01:44<00:00, 156MB/s]
100% 10.1G/10.1G [01:44<00:00, 103MB/s]


In [9]:
# unzip the retrieved dataset into `raw_train_data`
!unzip ocr-data.zip -d raw_train_data

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: raw_train_data/batch_9/background_images/8128a706-c4ff-4564-81b8-215831bee2e3.jpg  
  inflating: raw_train_data/batch_9/background_images/81293de5-cb9d-4b17-93fb-458c659a69d8.jpg  
  inflating: raw_train_data/batch_9/background_images/813e8c33-cd4c-45ae-a604-6c0c36578735.jpg  
  inflating: raw_train_data/batch_9/background_images/81408cb0-cb35-4c82-b1c0-64da5542e03e.jpg  
  inflating: raw_train_data/batch_9/background_images/814dafdc-2844-4959-9dac-bffe6c2bfb34.jpg  
  inflating: raw_train_data/batch_9/background_images/814f9682-eaa7-4268-9aa8-ed9f80cdff9b.jpg  
  inflating: raw_train_data/batch_9/background_images/81536501-8bbc-4080-be92-0facff01bd63.jpg  
  inflating: raw_train_data/batch_9/background_images/81586069-55ea-4952-888d-e309a2725791.jpg  
  inflating: raw_train_data/batch_9/background_images/8159f16a-2fcf-4a0d-8ff9-e3ffc084c47b.jpg  
  inflating: raw_train_data/batch_9/background_images/81663a86

## Method 2: Using Google Drive API
Slower download speed than Kaggle API so currently not used.

In [None]:

# !pip install google-auth google-auth-oauthlib google-auth-httplib2

In [None]:
# # Authenticate with service account credentials
# from google.oauth2 import service_account
# from google.auth.transport.requests import Request
# # Access Google Drive using the authenticated credentials
# from googleapiclient.discovery import build
# from googleapiclient.http import MediaFileUpload,MediaIoBaseDownload
# import io

In [None]:
# credentials = service_account.Credentials.from_service_account_file(
#     '/content/project-50021-415714-5d993bed20f3.json',
#     scopes=['https://www.googleapis.com/auth/drive']
# )
# # Authenticate the credentials
# credentials.refresh(Request())
# # Build a Drive service object
# drive_service = build('drive', 'v3', credentials=credentials)

In [None]:
# Example: List files in Drive
# results = drive_service.files().list(pageSize=10).execute()
# items = results.get('files', [])

# if not items:
#     print('No files found.')
# else:
#     print('Files:')
#     for item in items:
#         print(f"{item['name']} ({item['id']})")

In [None]:
# file_name = "ocr-data.zip"

# file_metadata = {"name": "ocr-data.zip"}
# media = MediaFileUpload("/content/ocr-data.zip", mimetype="application/zip")
# uploaded_file = drive_service.files().create(body=file_metadata, media_body=media, fields='id').execute()


In [None]:

# file_id = "1uiNTe5CL3USzw2N-ShVWyD9IxdUlicYe"
# # pylint: disable=maybe-no-member
# request = drive_service.files().get_media(fileId=file_id)
# fh = io.FileIO('ocr-data.zip', mode='wb')
# downloader = MediaIoBaseDownload(fh, request)
# done = False
# while done is False:
#   status, done = downloader.next_chunk()
#   print(f"Download {int(status.progress() * 100)}.")


## Part 2: A custom Dataset object
A dataset object that should be modified according to the model being trained.

### Helper function to return a list of bounding box coordinates and the corresponding label for the object in each bounding box


In [10]:
import torchvision
from PIL import Image
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch
import torchvision.transforms as transforms
from torchvision import io
import json
import os
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN

In [40]:
from tqdm import tqdm
# Convert Image to Black And White.
def convert_image_to_binary(image, thresh):
    """Convert image to black and white, which will be referred to as a binary image"""
    fn = lambda x : 1 if x <= thresh else 0
    binary_image = image.convert('L').point(fn, mode='1')
    return binary_image
def create_bounding_box_labels(input_json_file):
    """
    for each image, create a file listing the coordinates of bounding boxes of latex chars of the image
    """
    data = []
    with open(f"{input_json_file}", 'r') as f:
        data = json.load(f)
    data = list(data)
    data.sort(key = lambda x: x["uuid"])
    bounding_box_dict = {}
    for d in data:
        # get output file name
        file_name = f"{d['uuid']}.jpg"
        # extract coordinates from each item in json array
        xmins = d["image_data"]["xmins"]
        ymins = d["image_data"]["ymins"]
        xmaxs = d["image_data"]["xmaxs"]
        ymaxs = d["image_data"]["ymaxs"]
        # make list of bounding box coordinates for each LaTeX character
        bounding_box_dict[file_name] = [[xmin,ymin,xmax,ymax]
                             for xmin,ymin,xmax,ymax in zip(xmins, ymins, xmaxs, ymaxs)]
    return bounding_box_dict

def set_default(obj):
    if isinstance(obj, set):
        return list(obj)
    raise TypeError

def get_image_classes(image_data):
    """
    get all the latex chars in one image
    """
    char_set = set(image_data.get("visible_latex_chars"))
    return char_set

def create_char_labels(input_json_file):
    """
    for each image, create a file listing the coordinates of bounding boxes of latex chars of the image
    """
    data = []
    with open(f"{input_json_file}", 'r') as f:
        data = json.load(f)
    data = list(data)
    data.sort(key = lambda x: x["uuid"])
    char_set = set()
    char_dict = {}
    for i,d in enumerate(tqdm(data)):
        # get output file name
        file_name = f"{d['uuid']}.jpg"
        chars_in_image = get_image_classes(d["image_data"])
        char_set = char_set.union(chars_in_image)
        char_dict[file_name] = d["image_data"]["visible_latex_chars"]
    return char_set, char_dict
def create_areas(bounding_boxes):
    areas = [
        (bounding_box[2] - bounding_box[0]) * (bounding_box[3]- bounding_box[1])
        for bounding_box in bounding_boxes
    ]

### Dataset object for the model to be trained with.

Currently `MathDataset` only handles two class labels:
- a latex object
- not a latex object

TODO

Add more classes such as:
- number
- operator ($+$,$-$,$\div$,$\times$, etc.)
- symbol (arrow, fraction, parenthesis)
- functions ($lim$, $tan$, $sin$, $cos$, etc.)
- mathematical variables ($x$, $y$, $z$, $\alpha$, $\beta$, etc.)

In [24]:
def get_complete_char_set():
  batch_char_set = []
  batch_char_set.append([])
  complete_char_set = set()
  for i in range(1,11):
    input_json_file_name = f"raw_train_data/batch_{i}/JSON/kaggle_data_{i}.json"
    batch_dir = ""
    cur_batch_char_set, cur_batch_char_dict = create_char_labels(input_json_file_name,batch_dir)
    batch_char_set.append(cur_batch_char_set)
    print(f"batch {i}")
    print(f"unique chars",batch_char_set[i])

  for c in batch_char_set:
    complete_char_set = complete_char_set.union(c)
  complete_char_set = list(complete_char_set)
  complete_char_set.sort()
  print("number of unique LaTeX chars in dataset:",len(complete_char_set))
  print("unique LaTeX chars:",complete_char_set)
  return complete_char_set

def get_char_encoding(complete_char_set):
  char_encoding = {}
  for i, char in enumerate(complete_char_set):
    char_encoding[char] = i
  return char_encoding

complete_char_set = get_complete_char_set()
char_encoding = get_char_encoding(complete_char_set)

100%|██████████| 10000/10000 [00:00<00:00, 165951.34it/s]


batch 1
unique chars {'\\cos', '/', '9', 's', 'z', 'r', '\\sin', 'e', 'h', '6', '\\pi', 'u', '7', 'b', '\\right|', '=', 'a', '3', '5', '\\left|', 'p', '\\to', '\\frac', '\\tan', '\\ln', '0', 'k', 'y', '\\lim_', 'd', '8', 't', 'v', '+', 'x', '\\sqrt', '-', '\\theta', '\\left(', '\\infty', '\\cot', 'n', '\\log', '4', 'w', '\\right)', '\\csc', '\\cdot', '\\sec', 'g', '2', '.', 'c', '1'}


100%|██████████| 10000/10000 [00:00<00:00, 152927.77it/s]


batch 2
unique chars {'\\cos', '/', '9', 's', 'z', 'r', '\\sin', 'h', 'e', '6', '7', 'b', 'u', '\\pi', '\\right|', '=', '5', '3', '\\left|', 'a', 'p', '\\to', '\\frac', '\\tan', '0', '\\ln', 'k', 'y', '\\lim_', '8', 'd', 't', 'v', '+', 'x', '\\sqrt', '-', '\\theta', '\\left(', '\\infty', '\\cot', 'n', '\\log', '4', 'w', '\\right)', '\\csc', '\\cdot', '\\sec', 'g', '2', '.', 'c', '1'}


100%|██████████| 10000/10000 [00:00<00:00, 160065.95it/s]


batch 3
unique chars {'\\cos', '/', '9', 's', 'z', 'r', '\\sin', 'h', 'e', '6', '\\pi', 'b', '7', 'u', '\\right|', '=', '5', '3', 'a', 'p', '\\left|', '\\to', '\\frac', '\\tan', '\\ln', '0', 'k', 'y', '\\lim_', '8', 'd', 't', 'v', 'x', '+', '\\sqrt', '-', '\\theta', '\\left(', '\\infty', '\\cot', 'n', '\\log', '4', 'w', '\\right)', '\\csc', '\\cdot', '\\sec', 'g', '2', 'c', '1'}


100%|██████████| 10000/10000 [00:00<00:00, 121930.63it/s]


batch 4
unique chars {'\\cos', '/', 's', '9', 'z', 'r', '\\sin', 'h', 'c', 'e', '6', '7', '\\pi', 'b', 'u', '\\right|', '=', '5', 'a', '3', '\\left|', 'p', '\\to', '\\frac', '\\tan', '0', 'k', '\\ln', 'y', '\\lim_', 'd', '8', 't', 'v', '+', 'x', '\\sqrt', '-', '\\theta', '\\left(', '\\infty', '\\cot', 'n', '\\log', '4', 'w', '\\right)', '\\csc', '\\cdot', '\\sec', 'g', '.', '2', '1'}


100%|██████████| 10000/10000 [00:00<00:00, 157401.30it/s]


batch 5
unique chars {'\\cos', '/', '9', 's', 'z', 'r', '\\sin', 'e', 'h', 'c', '6', '\\pi', '7', 'u', 'b', '\\right|', '=', '5', '3', 'a', '\\left|', 'p', '\\to', '\\frac', '\\tan', '0', 'k', '\\ln', 'y', '\\lim_', '8', 'd', 't', 'v', '+', 'x', '\\sqrt', '-', '\\theta', '\\left(', '\\infty', '\\cot', 'n', '\\log', '4', 'w', '\\right)', '\\csc', '\\cdot', '\\sec', 'g', '.', '2', '1'}


100%|██████████| 10000/10000 [00:00<00:00, 169378.55it/s]


batch 6
unique chars {'\\cos', '/', '9', 's', 'z', 'r', '\\sin', 'e', 'h', 'c', '6', '7', '\\pi', 'u', 'b', '\\right|', '=', '5', '3', 'a', '\\left|', 'p', '\\to', '\\frac', '\\tan', '0', '\\ln', 'k', 'y', '\\lim_', '8', 'd', 't', 'v', '+', 'x', '\\sqrt', '-', '\\theta', '\\left(', '\\infty', '\\cot', 'n', '\\log', '4', 'w', '\\right)', '\\csc', '\\cdot', '\\sec', 'g', '2', '1'}


100%|██████████| 10000/10000 [00:00<00:00, 157495.87it/s]


batch 7
unique chars {'\\cos', '/', '9', 's', 'z', 'r', '\\sin', 'e', 'h', '6', '7', '\\pi', 'u', 'b', '\\right|', '=', '5', '3', 'p', 'a', '\\left|', '\\to', '\\frac', '\\tan', '0', '\\ln', 'k', 'y', '\\lim_', '8', 'd', 't', 'v', '+', 'x', '\\sqrt', '-', '\\theta', '\\left(', '\\infty', '\\cot', 'n', '\\log', '4', 'w', '\\right)', '\\csc', '\\cdot', '\\sec', 'g', '2', '.', 'c', '1'}


100%|██████████| 10000/10000 [00:00<00:00, 160500.83it/s]


batch 8
unique chars {'\\cos', '/', '9', 's', 'z', 'r', 'h', '\\sin', 'e', '6', '7', '\\pi', 'b', 'u', '\\right|', '=', '5', '3', '\\left|', 'p', 'a', '\\to', '\\frac', '\\tan', '0', 'k', '\\ln', 'y', '\\lim_', '8', 'd', 't', 'v', '+', 'x', '\\sqrt', '-', '\\theta', '\\left(', '\\infty', '\\cot', 'n', '\\log', '4', 'w', '\\right)', '\\csc', '\\cdot', '\\sec', 'g', '2', '.', 'c', '1'}


100%|██████████| 10000/10000 [00:00<00:00, 151527.78it/s]


batch 9
unique chars {'\\cos', '/', '9', 's', 'z', 'r', '\\sin', 'h', 'e', '6', '\\pi', '7', 'b', 'u', '\\right|', '=', '5', '3', 'p', 'a', '\\left|', '\\to', '\\frac', '\\tan', '0', '\\ln', 'k', 'y', '\\lim_', '8', 'd', 't', 'v', '+', 'x', '\\sqrt', '-', '\\theta', '\\left(', '\\infty', '\\cot', 'n', '\\log', '4', 'w', '\\right)', '\\csc', '\\cdot', '\\sec', 'g', '2', '.', 'c', '1'}


100%|██████████| 10000/10000 [00:00<00:00, 165487.77it/s]


batch 10
unique chars {'\\cos', '/', '9', 's', 'z', 'r', '\\sin', 'h', 'e', '6', '\\pi', '7', 'u', 'b', '\\right|', '=', '5', '3', '\\left|', 'p', 'a', '\\to', '\\tan', '\\frac', '\\ln', 'k', '0', 'y', '\\lim_', 'd', '8', 't', 'v', '+', 'x', '\\sqrt', '-', '\\theta', '\\left(', '\\infty', '\\cot', 'n', '\\log', 'w', '4', '\\right)', '\\cdot', '\\csc', '\\sec', 'g', '2', '.', 'c', '1'}
number of unique LaTeX chars in dataset: 54
unique LaTeX chars: ['+', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '=', '\\cdot', '\\cos', '\\cot', '\\csc', '\\frac', '\\infty', '\\left(', '\\left|', '\\lim_', '\\ln', '\\log', '\\pi', '\\right)', '\\right|', '\\sec', '\\sin', '\\sqrt', '\\tan', '\\theta', '\\to', 'a', 'b', 'c', 'd', 'e', 'g', 'h', 'k', 'n', 'p', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [42]:
class MathDataset(Dataset):
    """
    Dataset object for a single batch in the dataset
    """
    def __init__(self, batch_number, starting_index = 0, length = 800):
        self.batch_dir = f"raw_train_data/batch_{str(batch_number)}/background_images"
        self.file_names = sorted([filename
                                  for dirname, _, filenames in os.walk(self.batch_dir)
                                  for i,filename in enumerate(filenames)
                                  if i - starting_index < length
                                  and i >= starting_index])
        self.no_of_files = len(self.file_names)
        training_label_file_name = f"raw_train_data/batch_{str(batch_number)}/JSON/kaggle_data_{str(batch_number)}.json"
        self.bounding_box_dict = create_bounding_box_labels(training_label_file_name)
        self.char_set,self.char_dict = create_char_labels(training_label_file_name)

    def __getitem__(self, idx):
        """
        each item is a tuple of (image: Tensor, target:dict:{boxes:list[list[int]],labels:list[int]} )
        """
        file_name = self.file_names[idx]
        image = Image.open(f"{self.batch_dir}/{file_name}") # open colour image
        binary_image = convert_image_to_binary(image, thresh = 127) # convert colour image to black and white image
        # preprocessing for the binary_image object
        process = transforms.Compose([
                                transforms.PILToTensor(), # convert it to a tensor
                                transforms.Resize((600,600),antialias = True) # convert it to 600 x 600
                                ])
        # apply preprocessing to the binary_image
        final_image = process(binary_image).float()
        # create target object for training
        target = self._create_target(file_name)
        return final_image, target

    def _create_target(self,file_name):

        # bounding_boxes is a list of coordinates for each detected object's bounding box. Each bounding box is just a list of
        # (xmin,ymin,xmax,ymax)
        bounding_boxes = self.bounding_box_dict.get(file_name)
        areas = [
            (bounding_box[2] - bounding_box[0]) * (bounding_box[3]- bounding_box[1])
            for bounding_box in bounding_boxes
        ]
        labels = []
        for char in self.char_dict[file_name]:
          labels.append(char_encoding[char])
        is_crowd = torch.tensor([False] * len(areas), dtype = torch.bool)
        target = {}
        target["boxes"] = torch.tensor(bounding_boxes)
        target["image_id"] = file_name
        target["area"] = torch.tensor(areas)
        target["iscrowd"] = torch.tensor([False] * len(areas), dtype = torch.bool)
        target["labels"] = torch.tensor(labels, dtype=torch.int64)
        return target
    def __len__(self):
        return self.no_of_files

## Part 3: Object Detection Model
TODO:
Code for training an object detection model.
We should probably try different object detection models

some built in models in pytorch include:
1. Fast R-CNN
2. Mask R-CNN
3. YOLO (You Only Look Once)
4. RetinaNet
and many more.

Then evaluate their performance and choose the best one.