## Part 1: Downloading Dataset and Parameters
Drop kaggle.json, project-50021-415714-5d993bed20f3.json, and data.yaml in /content directory

## Part 1.1: Retrieve Dataset From Kaggle API

In [None]:
# Install required libraries
!pip install -q kaggle

In [None]:
# create .kaggle directory in root direcory
!mkdir -p ~/.kaggle
# copy kaggle.json to ~/.kaggle
!cp kaggle.json ~/.kaggle
# change file permission for kaggle.json
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
# download ocr dataset
!kaggle datasets download -d aidapearson/ocr-data

Downloading ocr-data.zip to /content
 61% 6.16G/10.1G [05:12<03:19, 21.2MB/s]
User cancelled operation


In [None]:
# unzip the retrieved dataset into `raw_train_data`
!unzip ocr-data.zip -d raw_train_data

Archive:  ocr-data.zip
  End-of-central-directory signature not found.  Either this file is not
  a zipfile, or it constitutes one disk of a multi-part archive.  In the
  latter case the central directory and zipfile comment will be found on
  the last disk(s) of this archive.
unzip:  cannot find zipfile directory in one of ocr-data.zip or
        ocr-data.zip.zip, and cannot find ocr-data.zip.ZIP, period.


## Part 1.2 Retrive any model params from Google drive


In [None]:
!pip install google-auth google-auth-oauthlib google-auth-httplib2



In [None]:
# # Authenticate with service account credentials
from google.oauth2 import service_account
from google.auth.transport.requests import Request
# Access Google Drive using the authenticated credentials
from googleapiclient.discovery import build
from googleapiclient.http import MediaFileUpload,MediaIoBaseDownload
import io

In [None]:
credentials = service_account.Credentials.from_service_account_file(
    '/content/project-50021-415714-5d993bed20f3.json',
    scopes=['https://www.googleapis.com/auth/drive']
)
# Authenticate the credentials
credentials.refresh(Request())
# Build a Drive service object
drive_service = build('drive', 'v3', credentials=credentials)

In [None]:
# Example: List files in Drive
results = drive_service.files().list(pageSize=30).execute()
items = results.get('files', [])
files_to_retrieve = []
files_to_delete = []
if not items:
    print('No files found.')
else:
    # print('Files:')
    for item in items:
        print(f"{item['name']} ({item['id']})")
        # if item['name'].startswith("lstm_1_layer_512_with_attention_params_batch_123_epoch_9"):
        #   files_to_retrieve.append(item)
        # if not item["name"].startswith("yolo_params_batch_123_epochs_22"):
        #   files_to_delete.append(item)
        # if not item["name"].startswith("lstm_1_layer_512_with_attention_params_batch_123_epoch_1"):
        #   files_to_delete.append(item)
        if item["name"].startswith("lstm_1_layer_512_with_attention_params_batch_1_7_epoch_4"):
          files_to_retrieve.append(item)
        if item["name"].startswith("yolo_params_batch_1234567_epochs_6"):
          files_to_retrieve.append(item)
        if item['name'].startswith("yolo_data_full"):
          files_to_retrieve.append(item)
        # if item['name'].startswith("ocr-data"):
        #   files_to_delete.append(item)
print("files to retrieve")
print(files_to_retrieve)
print("files to delete")
print(files_to_delete)

yolo_params_batch_1234567_epochs_6.pt (1Z12nGT_dmFPAfSL82qsRvIlcJ8a9tB_B)
yolo_params_batch_1234567_epochs_5.pt (1rBGQjbwns6odzth1xy-Vopnl0lukYmyQ)
yolo_params_batch_1234567_epochs_4.pt (1I9wBPy2v8ysuFhPzeJ8LKtpZvZTFjN6Z)
lstm_1_layer_512_with_attention_params_batch_1_7_epoch_4 (1OzjeQDvOSZw7ay-e8dGuFPiojAo3lVoN)
lstm_1_layer_512_with_attention_params_batch_1_7_epoch_3 (1jR_Utq6VHPzlgr8LHh-ioc5TuZnF23gj)
lstm_1_layer_512_with_attention_params_batch_1_7_epoch_2 (1Q7XIcYvSy5nGN8Pm2qETM4qdpy3pMsus)
lstm_1_layer_512_with_attention_params_batch_1_7_epoch_1 (156drajFS6W5IBjTJA9UBvRG8GqR8wyKC)
lstm_1_layer_512_with_attention_params_batch_1_7_epoch_0 (1ZQqZ6nprkghL4Ces72C5Cz4_c4_EgSFH)
yolo_data_full.zip (1dDyvBGBWYPrUDO9asAiQX1XIB0B_pomS)
yolo_params_batch_123_epochs_22.pt (1-IqgnHsaRL6lUzSQSg3Z6or184Ik_YkY)
lstm_1_layer_512_with_attention_params_batch_2_epoch_2 (1jDF-mJ97GsvdDw-8FrvnY9GbMP1BYDvc)
lstm_1_layer_512_with_attention_params_batch_2_epoch_1 (1UyWVR6JVYsctQPf9JJx3wvDizMafmZeJ)
lstm_

In [None]:
def upload_file_to_drive(file_path, file_name = None):
  if not file_name:
    file_name = os.path.basename(file_path)
  file_metadata = {"name": file_name}
  chunksize = 4000
  media = MediaFileUpload(file_path, mimetype="application/zip", resumable=True)
  uploaded_file = drive_service.files().create(body=file_metadata, media_body=media, fields='id').execute()
def save_state_dict(model, file_name):
  torch.save(model.state_dict(), file_name)

In [None]:
def retrieve_files_from_drive(files_to_retrieve):
  for file_obj in files_to_retrieve:
    # pylint: disable=maybe-no-member
    request = drive_service.files().get_media(fileId=file_obj["id"])
    fh = io.FileIO(f'{file_obj["name"]}', mode='wb')
    downloader = MediaIoBaseDownload(fh, request)
    done = False
    while done is False:
      status, done = downloader.next_chunk()
      print(f"Download {int(status.progress() * 100)}.")


In [None]:
def delete_files_from_drive(files_to_delete):
  for file_obj in files_to_delete:
    drive_service.files().delete(fileId=file_obj["id"]).execute()

In [None]:
retrieve_files_from_drive(files_to_retrieve)

In [None]:
delete_files_from_drive(files_to_delete)

In [None]:
!unzip yolo_data_full.zip -d yolo_data

## Part 1.3 Create Char encoding for visible and full char set.


In [None]:
import torchvision
from PIL import Image
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch
import torchvision.transforms as transforms
from torchvision import io
import json
import os
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN

In [None]:
from tqdm import tqdm
def convert_image_to_binary(image, thresh):
    """Convert image to black and white, which will be referred to as a binary image"""
    fn = lambda x : 1 if x <= thresh else 0
    binary_image = image.convert('L').point(fn, mode='1')
    return binary_image

def create_bounding_box_labels(input_json_file):
    """
    for each image, create a file listing the coordinates of bounding boxes of latex chars of the image
    """
    data = []
    with open(f"{input_json_file}", 'r') as f:
        data = json.load(f)
    data = list(data)
    data.sort(key = lambda x: x["uuid"])
    bounding_box_dict = {}
    for d in data:
        # get output file name
        file_name = f"{d['uuid']}.jpg"
        # extract coordinates from each item in json array
        xmins = d["image_data"]["xmins"]
        ymins = d["image_data"]["ymins"]
        xmaxs = d["image_data"]["xmaxs"]
        ymaxs = d["image_data"]["ymaxs"]
        # make list of bounding box coordinates for each LaTeX character
        bounding_box_dict[file_name] = [[xmin,ymin,xmax,ymax]
                             for xmin,ymin,xmax,ymax in zip(xmins, ymins, xmaxs, ymaxs)]
    return bounding_box_dict

def set_default(obj):
    if isinstance(obj, set):
        return list(obj)
    raise TypeError

def get_visible_chars_in_image(image_data):
    """
    Returns
    ---
    char_set set[str]: set of latex chars for the image corresponding to image_data
    """
    char_set = set(image_data.get("visible_latex_chars"))
    return char_set

def get_full_chars_in_image(image_data):
    """
    Returns
    ---
    char_set set[str]: set of latex chars for the image corresponding to image_data
    """
    char_set = set(image_data.get("full_latex_chars"))
    return char_set

def create_visible_char_labels(input_json_file, subset_start = None, subset_end = None):
    """
    input_json_file list[str]: file path of json ground truth for a batch
    file_subset list[str] | None: list of files that need to be checked
    Returns
    ---
    char_set set[str]: set of latex chars that occur in the files that were checked
    char_dict dict[str,dict[str]]: for each file in the files checked, a dict of visible latex chars for that file is returned
    """
    data = []
    with open(f"{input_json_file}", 'r') as f:
        data = json.load(f)
    data = list(data)
    data.sort(key = lambda x: x["uuid"])
    char_set = set()
    char_dict = {}
    if subset_start is not None and subset_end is not None:
        data = [
            d
            for i,d in enumerate(data)
            if i >= subset_start
            and i < subset_end
        ]
    print('create_char_labels',len(data))
    for i,d in enumerate(tqdm(data)):
        # get output file name
        file_name = f"{d['uuid']}.jpg"
        chars_in_image = get_visible_chars_in_image(d["image_data"])
        char_set = char_set.union(chars_in_image)
        char_dict[file_name] = d["image_data"]["visible_latex_chars"]
    return char_set, char_dict

def create_full_char_labels(input_json_file, subset_start = None, subset_end = None):
    """
    input_json_file list[str]: file path of json ground truth for a batch
    file_subset list[str] | None: list of files that need to be checked
    Returns
    ---
    char_set set[str]: set of latex chars that occur in the files that were checked
    char_dict dict[str,dict[str]]: for each file in the files checked, a dict of visible latex chars for that file is returned
    """
    data = []
    with open(f"{input_json_file}", 'r') as f:
        data = json.load(f)
    data = list(data)
    data.sort(key = lambda x: x["uuid"])
    char_set = set()
    char_dict = {}
    if subset_start is not None and subset_end is not None:
        data = [
            d
            for i,d in enumerate(data)
            if i >= subset_start
            and i < subset_end
        ]
    print('create_char_labels',len(data))
    for i,d in enumerate(tqdm(data)):
        # get output file name
        file_name = f"{d['uuid']}.jpg"
        chars_in_image = get_full_chars_in_image(d["image_data"])
        char_set = char_set.union(chars_in_image)
        char_dict[file_name] = d["image_data"]["full_latex_chars"]
    return char_set, char_dict

In [None]:
def get_complete_visible_char_set():
  batch_char_set = []
  batch_char_set.append([])
  complete_char_set = set()
  for i in range(1,11):
    input_json_file_name = f"raw_train_data/batch_{i}/JSON/kaggle_data_{i}.json"
    batch_dir = ""
    cur_batch_char_set, cur_batch_char_dict = create_visible_char_labels(input_json_file_name,subset_start = 0, subset_end = 5000)
    cur_batch_char_list = list(cur_batch_char_set)
    cur_batch_char_list.sort()
    batch_char_set.append(cur_batch_char_set)
    print(f"batch {i}")
    print(f"number of unique LaTeX chars in batch {i}:{len(cur_batch_char_list)}")
    print(f"unique LaTeX chars",cur_batch_char_list)

  for c in batch_char_set:
    complete_char_set = complete_char_set.union(c)
  complete_char_set = list(complete_char_set)
  complete_char_set.sort()
  print("number of unique LaTeX chars in dataset:",len(complete_char_set))
  print("unique LaTeX chars:",complete_char_set)
  return complete_char_set
def get_complete_full_char_set():
  batch_char_set = []
  batch_char_set.append([])
  complete_char_set = set()
  for i in range(1,11):
    input_json_file_name = f"raw_train_data/batch_{i}/JSON/kaggle_data_{i}.json"
    batch_dir = ""
    cur_batch_char_set, cur_batch_char_dict = create_full_char_labels(input_json_file_name,subset_start = 0, subset_end = 5000)
    cur_batch_char_list = list(cur_batch_char_set)
    cur_batch_char_list.sort()
    batch_char_set.append(cur_batch_char_set)
    print(f"batch {i}")
    print(f"number of unique LaTeX chars in batch {i}:{len(cur_batch_char_list)}")
    print(f"unique LaTeX chars",cur_batch_char_list)

  for c in batch_char_set:
    complete_char_set = complete_char_set.union(c)
  complete_char_set = list(complete_char_set)
  complete_char_set.sort()
  print("number of unique LaTeX chars in dataset:",len(complete_char_set))
  print("unique LaTeX chars:",complete_char_set)
  return complete_char_set
def get_char_encoding(complete_char_set):
  char_encoding = {}
  for i, char in enumerate(complete_char_set):
    char_encoding[char] = i
  return char_encoding
def get_reverse_char_encoding(char_encoding):
  reverse_char_encoding = {}
  for key,val in char_encoding.items():
    reverse_char_encoding[val] = key
  return reverse_char_encoding
def encode_char_list(char_list,char_encoding):
  encoded_char_list = []
  for char in char_list:
    encoded_char_list.append(char_encoding[char])
  return encoded_char_list
def decode_char_list(char_list,reverse_char_encoding):
  decoded_char_list = []
  for char in char_list:
    decoded_char_list.append(reverse_char_encoding[char])
  return decoded_char_list
def convert_full_to_visible_encoding(full_char_encoding,visible_char_encoding):
  full_to_visible_encoding = {}
  visible_to_full_encoding = {}
  for key,val in full_char_encoding.items():
    if key in visible_char_encoding.keys():
      full_to_visible_encoding[val + 3] = visible_char_encoding[key] + 3
  full_to_visible_encoding[0] = 0
  full_to_visible_encoding[1] = 1
  full_to_visible_encoding[2] = 2
  for key,val in visible_char_encoding.items():
    if key in full_char_encoding.keys():
      visible_to_full_encoding[val + 3] = full_char_encoding[key] + 3
  visible_to_full_encoding[0] = 0
  visible_to_full_encoding[1] = 1
  visible_to_full_encoding[2] = 2
  return full_to_visible_encoding, visible_to_full_encoding


In [None]:
complete_visible_char_set = get_complete_visible_char_set()
visible_char_encoding = get_char_encoding(complete_visible_char_set)
reverse_visible_char_encoding = get_reverse_char_encoding(visible_char_encoding)

complete_full_char_set = get_complete_full_char_set()
full_char_encoding = get_char_encoding(complete_full_char_set)
reverse_full_char_encoding = get_reverse_char_encoding(full_char_encoding)


FileNotFoundError: [Errno 2] No such file or directory: 'raw_train_data/batch_1/JSON/kaggle_data_1.json'

In [None]:
full_to_visible_encoding, visible_to_full_encoding = convert_full_to_visible_encoding(full_char_encoding, visible_char_encoding)

In [None]:
print(complete_visible_char_set)
print(reverse_visible_char_encoding)
print(reverse_full_char_encoding)
print(full_to_visible_encoding)
print(visible_to_full_encoding)

## Part 1.4 Create yolo_data directory for yolo training/test data

In [None]:
DatasetTypes = [
    "train",
    "validation",
    "test"
]

In [None]:
import json
import os
import shutil
def create_yolo_data(input_json_file_list, image_subset_list, image_dir_list, dataset_type = "train"):
    """
    Converts JSON annotations to YOLO format label files.
    """
    print(input_json_file_list,image_subset_list,image_dir_list)
    for input_json_file, image_subset, image_dir in zip(input_json_file_list, image_subset_list, image_dir_list):
      print(f"image_dir: {image_dir}")
      print(f"image_subset: {image_subset}")
      with open(input_json_file, 'r') as f:
          data = json.load(f)
      data.sort(key=lambda d: d["uuid"])

      yolo_label_folder_path = f"yolo_data_full/{dataset_type}/labels/"
      yolo_image_folder_path = f"yolo_data_full/{dataset_type}/images/"

      for i,item in enumerate(data):
          # only add images with indices within the subset range
          if i < image_subset[0] or i >= image_subset[1]:
            continue
          image_file_name = item['uuid'] + '.jpg'  # Adjust if your image names differ
          original_image_path = os.path.join(image_dir, image_file_name)
          if not os.path.exists(original_image_path):
              continue  # Skip if image does not exist
          yolo_label_file_name = os.path.splitext(image_file_name)[0] + '.txt'
          # path of label file for yolo
          yolo_label_file_path = yolo_label_folder_path + yolo_label_file_name
          # path of training image for yolo
          yolo_image_file_path = yolo_image_folder_path + image_file_name

          if not os.path.exists(yolo_label_folder_path):
            os.makedirs(yolo_label_folder_path)
          if not os.path.exists(yolo_image_folder_path):
            os.makedirs(yolo_image_folder_path)

          with open(yolo_label_file_path, 'w+') as label_file:
              for char, xmin, ymin, xmax, ymax in zip(item['image_data']['visible_latex_chars'],item['image_data']['xmins'], item['image_data']['ymins'], item['image_data']['xmaxs'], item['image_data']['ymaxs']):
                  # Convert to YOLO format and write to file
                  # Note: This requires image dimensions to normalize coordinates
                  x_center, y_center, width, height = convert_to_yolo_format(xmin, ymin, xmax, ymax)
                  class_id = visible_char_encoding[char]
                  label_file.write(f"{class_id} {x_center} {y_center} {width} {height}\n")
          # copy image file to image_file_path
          shutil.copy2(original_image_path, yolo_image_file_path)
def convert_to_yolo_format(xmin, ymin, xmax, ymax):
    """
    Converts bounding box coordinates to YOLO format.
    """
    x_center = (xmin + xmax) / 2
    y_center = (ymin + ymax) / 2
    width = xmax - xmin
    height = ymax - ymin
    return x_center, y_center, width, height


In [None]:

import os
def get_annotation_file_path(batch_number):
    return f"/content/raw_train_data/batch_{batch_number}/JSON/kaggle_data_{batch_number}.json"
def get_image_path(batch_number):
    return f"/content/raw_train_data/batch_{batch_number}/background_images"
def write_yolo_data():
  batch_numbers = {
      "train":[1,2,3,4,5,6,7],
      "validation":[8,9],
      "test":[9,10]
  }
  # Define the directory where your images are located
  image_dirs = {
      "train":[get_image_path(batch_number) for batch_number in batch_numbers["train"]],
      "validation":[get_image_path(batch_number) for batch_number in batch_numbers["validation"]],
      "test":[get_image_path(batch_number) for batch_number in batch_numbers["test"]]
  }
  # Assuming create_bounding_box_labels function is defined as in your provided code
  annotation_files = {
      "train":[get_annotation_file_path(batch_number) for batch_number in batch_numbers["train"]],
      "validation":[get_annotation_file_path(batch_number) for batch_number in batch_numbers["validation"]],
      "test":[get_annotation_file_path(batch_number) for batch_number in batch_numbers["test"]]
  }
  image_subsets = {
      "train":[(0,10000),(0,10000),(0,10000),(0,10000),(0,10000),(0,10000),(0,10000)], # 70%
      "validation":[(0,10000),(0,5000)], # 15%
      "test":[(5000,10000),(0,10000)] # 15%
  }
  # Corrected function call with all required arguments
  for dataset_type in DatasetTypes:
    print(dataset_type)
    create_yolo_data(annotation_files[dataset_type], image_subsets[dataset_type], image_dirs[dataset_type], dataset_type)

In [None]:
write_yolo_data()

train
['/content/raw_train_data/batch_1/JSON/kaggle_data_1.json', '/content/raw_train_data/batch_2/JSON/kaggle_data_2.json', '/content/raw_train_data/batch_3/JSON/kaggle_data_3.json', '/content/raw_train_data/batch_4/JSON/kaggle_data_4.json', '/content/raw_train_data/batch_5/JSON/kaggle_data_5.json', '/content/raw_train_data/batch_6/JSON/kaggle_data_6.json', '/content/raw_train_data/batch_7/JSON/kaggle_data_7.json'] [(0, 10000), (0, 10000), (0, 10000), (0, 10000), (0, 10000), (0, 10000), (0, 10000)] ['/content/raw_train_data/batch_1/background_images', '/content/raw_train_data/batch_2/background_images', '/content/raw_train_data/batch_3/background_images', '/content/raw_train_data/batch_4/background_images', '/content/raw_train_data/batch_5/background_images', '/content/raw_train_data/batch_6/background_images', '/content/raw_train_data/batch_7/background_images']
image_dir: /content/raw_train_data/batch_1/background_images
image_subset: (0, 10000)
image_dir: /content/raw_train_data/ba

In [None]:

shutil.make_archive("yolo_data_full", 'zip', "yolo_data_full")
upload_file_to_drive("yolo_data_full.zip")

In [None]:
image_names = sorted([filename.strip("jpg") for dirname, _, filenames in os.walk("yolo_data/train/images")
                      for i,filename in enumerate(filenames)])
label_names = sorted([filename.strip("txt") for dirname, _, filenames in os.walk("yolo_data/train/labels")
                      for i,filename in enumerate(filenames)])
print(len(image_names),len(label_names))
for image_name, label_name in zip(image_names,label_names):
  assert(image_name == label_name)

30000 30000


In [None]:
# import locale
# def getpreferredencoding(do_setlocale = True):
#     return "UTF-8"
# locale.getpreferredencoding = getpreferredencoding
# !rm -d -r yolo_data

## Part 2: Object Detection Model

### Part 2.1. YOLO

In [None]:
!pip install yolov8

In [None]:
from ultralytics import YOLO
from PIL import Image
import cv2

In [None]:
# the function takes the original prediction and the iou threshold.
def apply_nms(orig_prediction, iou_thresh=0.05):

    # torchvision returns the indices of the bboxes to keep
    keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh)
    # print('keep',keep)
    final_prediction = orig_prediction
    final_prediction['boxes'] = final_prediction['boxes'][keep]
    final_prediction['scores'] = final_prediction['scores'][keep]
    final_prediction['labels'] = final_prediction['labels'][keep]

    return final_prediction


In [None]:

def yolo_predict(model, dl):
  dataloader_predictions = []
  for data in dl:
    imgs = []
    targets = []
    filenames = []
    filepaths = []
    for filename, filepath, src, trg in data:
      imgs.append(src)
      targets.append(trg)
      filenames.append(filename)
      filepaths.append(filepath)
      prediction = model(filepath)
      prediction = {
            "filename":filename,
            "boxes":prediction[0].boxes.xyxyn,
            "labels":prediction[0].boxes.cls,
            "scores":prediction[0].boxes.conf
      }
      final_prediction = apply_nms(prediction)
      dataloader_predictions.append(final_prediction)
  return dataloader_predictions

In [None]:
image_names = sorted([filename.strip("jpg") for dirname, _, filenames in os.walk("yolo_data_full/train/images")
                                  for i,filename in enumerate(filenames)])
label_names = sorted([filename.strip("txt") for dirname, _, filenames in os.walk("yolo_data_full/train/labels")
                                  for i,filename in enumerate(filenames)])
assert([image_name == label_name for image_name,label_name in zip(image_names,label_names)])

In [None]:
import shutil
def save_and_upload_params(file_path, batch_numbers = [], epochs = 5):
    """
    batch_numbers: list of batch numbers the model was trained on
    """
    batch_number_str = ""
    for n in batch_numbers:
        batch_number_str+=str(n)
    # zip yolo params
    # shutil.make_archive(f"yolo_params_batch_{batch_number_str}_epochs_{epochs}", 'zip', "runs/detect/train/weights")
    # uncomment to upload params to drive
    upload_file_to_drive(file_path = file_path,file_name = f"yolo_params_batch_{batch_number_str}_epochs_{epochs}.pt")

In [None]:
def get_yolo_model(pretrained = False, weight_location = ""):
  if pretrained:
    print("pretrained")
    return YOLO(weight_location)
  print('not pretrained')
  return YOLO('yolov8n.yaml').load('yolov8n.pt')

In [None]:
# def train_and_save_periodically(yolo_model, pretrained = True, starting_epoch = 0, num_epochs= 10, period = 1, batch_numbers = [1,2,3]):
#   data_yaml_path = 'data.yaml'
#   resume = pretrained
#   train_num = 1
#   for epoch in range(starting_epoch, starting_epoch + num_epochs, period):
#     if epoch > starting_epoch:
#       resume = True
#     weight_location = f"runs/detect/train/weights/last.pt"
#     if train_num > 1:
#       weight_location = f"runs/detect/train{train_num}/weights/last.pt"
#     yolo_model.train(data=data_yaml_path, epochs=period, imgsz=416, batch=16, lr0=0.0001, dropout=0.15, pretrained = pretrained)
#     save_and_upload_params(weight_location, batch_numbers = [1,2,3,4,5,6,7], epochs = epoch)
#     yolo_model = get_yolo_model(pretrained = True, weight_location = weight_location)
#     train_num+=1

In [None]:
# specify location of last pretrained weights here
weight_location = "runs/detect/train3/weights/epoch11.pt"
yolo_model = get_yolo_model(pretrained = True, weight_location = weight_location)

pretrained


In [None]:
data_yaml_path = 'data.yaml'
yolo_model.train(data=data_yaml_path, epochs=100, imgsz=416, batch=32, lr0=0.0001, dropout=0.15,resume=True,save_period=1)


Ultralytics YOLOv8.1.43 🚀 Python-3.10.12 torch-2.2.1+cu121 CUDA:0 (Tesla T4, 15102MiB)
[34m[1mengine/trainer: [0mtask=detect, mode=train, model=runs/detect/train3/weights/epoch11.pt, data=data.yaml, epochs=100, time=None, patience=100, batch=32, imgsz=416, save=True, save_period=1, cache=False, device=None, workers=8, project=None, name=train3, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=runs/detect/train3/weights/epoch11.pt, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.15, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, embed=None, show=False, save_frames=False, save_txt=False, save_conf=False,

[34m[1mtrain: [0mScanning /content/yolo_data/train/labels.cache... 70000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 70000/70000 [00:00<?, ?it/s]


[34m[1malbumentations: [0mBlur(p=0.01, blur_limit=(3, 7)), MedianBlur(p=0.01, blur_limit=(3, 7)), ToGray(p=0.01), CLAHE(p=0.01, clip_limit=(1, 4.0), tile_grid_size=(8, 8))


[34m[1mval: [0mScanning /content/yolo_data/validation/labels.cache... 15000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 15000/15000 [00:00<?, ?it/s]


Plotting labels to runs/detect/train3/labels.jpg... 
[34m[1moptimizer:[0m 'optimizer=auto' found, ignoring 'lr0=0.0001' and 'momentum=0.937' and determining best 'optimizer', 'lr0' and 'momentum' automatically... 
[34m[1moptimizer:[0m SGD(lr=0.01, momentum=0.9) with parameter groups 57 weight(decay=0.0), 64 weight(decay=0.0005), 63 bias(decay=0.0)
Resuming training runs/detect/train3/weights/epoch11.pt from epoch 13 to 100 total epochs
[34m[1mTensorBoard: [0mmodel graph visualization added ✅
Image sizes 416 train, 416 val
Using 2 dataloader workers
Logging results to [1mruns/detect/train3[0m
Starting training for 100 epochs...

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


     13/100      3.33G      0.577     0.4377     0.8178       1244        416:  62%|██████▏   | 1358/2188 [14:10<06:26,  2.15it/s]

In [None]:
save_and_upload_params("runs/detect/train3/weights/epoch11.pt", batch_numbers = [1,2,3,4,5,6,7], epochs = 12)

In [None]:
# import locale
# def getpreferredencoding(do_setlocale = True):
#     return "UTF-8"
# locale.getpreferredencoding = getpreferredencoding
# !rm -r -d ocr-data.zip

## Part 3: Seq2Seq Model

In [None]:
import torchvision
from PIL import Image
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch
import torchvision.transforms as transforms
from torchvision import io
import json
import os

### Part 3.1 Datasets

In [None]:
def upload_state_dict_to_drive(file_name):
  file_metadata = {"name": file_name}
  chunksize = 4000
  media = MediaFileUpload(file_name, mimetype="application/zip", resumable=True)
  uploaded_file = drive_service.files().create(body=file_metadata, media_body=media, fields='id').execute()
def save_state_dict(model, file_name):
  torch.save(model.state_dict(), file_name)

In [None]:
def create_sequence_inputs(input_json_file):
    """
    for each image, create a file listing the coordinates of bounding boxes of latex chars of the image
    """
    data = []
    with open(f"{input_json_file}", 'r') as f:
        data = json.load(f)
    data = list(data)
    data.sort(key = lambda x: x["uuid"])
    sequence_input_dict = {}
    for d in data:
        # get output file name
        file_name = f"{d['uuid']}.jpg"
        # extract coordinates from each item in json array
        xmins = d["image_data"]["xmins"]
        ymins = d["image_data"]["ymins"]
        xmaxs = d["image_data"]["xmaxs"]
        ymaxs = d["image_data"]["ymaxs"]
        latex_char_labels = encode_char_list(d["image_data"]["visible_latex_chars"],visible_char_encoding)
        # make list of bounding box coordinates for each LaTeX character
        sequence_input_dict[file_name] = [[latex_char,xmin,ymin,xmax,ymax]
                             for latex_char,xmin,ymin,xmax,ymax in zip(latex_char_labels,xmins, ymins, xmaxs, ymaxs)]
    return sequence_input_dict
def create_sequence_outputs(input_json_file):
    """
    for each image, create a file listing the coordinates of bounding boxes of latex chars of the image
    """
    data = []
    with open(f"{input_json_file}", 'r') as f:
        data = json.load(f)
    data = list(data)
    data.sort(key = lambda x: x["uuid"])
    sequence_label_dict = {}
    for d in data:
        # get output file name
        file_name = f"{d['uuid']}.jpg"
        # extract coordinates from each item in json array
        latex_char_labels = encode_char_list(d["image_data"]["full_latex_chars"],full_char_encoding)
        # make list of bounding box coordinates for each LaTeX character
        sequence_label_dict[file_name] = latex_char_labels
    return sequence_label_dict
def get_yolo_prediction(file_path,model):

  prediction = model(file_path,verbose=False)
  prediction = {
      "boxes":prediction[0].boxes.xyxyn,
      "labels":prediction[0].boxes.cls,
      "scores":prediction[0].boxes.conf
  }
  final_prediction = apply_nms(prediction, iou_thresh = 0.45)
  return final_prediction
def get_fast_rcnn_prediction(file_path, model):
  pass
def create_sequence_input(yolo_prediction):
  labels = [int(label) for label in yolo_prediction["labels"].tolist()]
  boxes = yolo_prediction["boxes"].tolist()
  sequence_input = [[label, box[0], box[1], box[2], box[3]]
                    for label, box in zip(labels, boxes)]
  return sequence_input
def sort_inputs_by_position(encoded_inputs):
  encoded_inputs.sort(key = lambda x: (x[56],x[57],x[58],x[59]))
def add_start_end_tokens(seq,sos_token,eos_token):
  # append one hot encoded sos and eos tokens to start and end of seq
  seq.insert(0,sos_token)
  seq.append(eos_token)
  return seq

#### Part 3.1.1 Dataset Using Output of YOLO Model For Input Sequence

In [None]:
# class SequenceDataset(Dataset):
#     """
#     Dataset object for a single batch in the dataset
#     """
#     def __init__(self, batch_number, input_seq_dim = 60, output_seq_dim = 63, starting_index = 0, length = 800):
#         # directory of the background images
#         self.batch_dir = f"raw_train_data/batch_{str(batch_number)}/background_images"
#         self.file_names = sorted([filename
#                                   for dirname, _, filenames in os.walk(self.batch_dir)
#                                   for i,filename in enumerate(filenames)
#                                   if i - starting_index < length
#                                   and i >= starting_index])
#         self.file_paths = [os.path.join(self.batch_dir, file_name) for file_name in self.file_names]
#         # number of files in the dataset
#         self.no_of_files = len(self.file_names)
#         # dimension of one input sequence element
#         self.input_seq_dim = input_seq_dim
#         # dimension of one output sequence element
#         self.output_seq_dim = output_seq_dim
#         # dict of filename to object detection output
#         self.object_detection_predictions = {}
#         # JSON file name for training labels
#         training_label_file_name = f"raw_train_data/batch_{str(batch_number)}/JSON/kaggle_data_{str(batch_number)}.json"
#         # create input sequence from training label file
#         self.sequence_input_dict = create_sequence_inputs(training_label_file_name)
#         # create output sequence from training label file
#         self.sequence_label_dict = create_sequence_outputs(training_label_file_name)
#         self.char_set,self.char_dict = create_full_char_labels(training_label_file_name)
#     def __getitem__(self, idx):
#         """
#         each item is a tuple of (image: Tensor, target:dict:{boxes:list[list[int]],labels:list[int]} )
#         """
#         file_name = self.file_names[idx]
#         file_path = self.file_paths[idx]
#         encoded_labels = []
#         encoded_inputs = []
#         #create one hot encoding for each label in the output sequence
#         for label in self.sequence_label_dict[file_name]:
#           # one hot encode each label, starting from i = 3, so we can use i = 0 as pad token and i = 1 as sos token and i = 2 as pad token
#           one_hot_encoding = [ 1
#                               if label + 3 == i and i >= 3
#                               else 0
#                               for i in range(63)]
#           encoded_labels.append(one_hot_encoding)
#         # create target object for training
#         target = encoded_labels
#         yolo_prediction = get_yolo_prediction(file_path, yolo_model)
#         yolo_prediction = create_sequence_input(yolo_prediction)
#         encoded_inputs = []
#         for input in yolo_prediction:
#           label = input[0]
#           # one hot encode each label, starting from i = 2, so we can use i = 0 as sos token and i = 1 as eos token
#           one_hot_encoding = [ 1
#                               if label + 3 == i and i >= 3
#                               else 0
#                               for i in range(57)]
#           # append bounding box coordinates to the end of the one_hot_encoding to produce input vector of 60 elements
#           one_hot_encoding = one_hot_encoding + input[1:5]
#           encoded_inputs.append(one_hot_encoding)
#         sort_inputs_by_position(encoded_inputs)
#         source = encoded_inputs
#         item = {
#             "target": target,
#             "source": source
#         }
#         return item
#     def __len__(self):
#         return self.no_of_files
#     def add_object_detection_prediction(self,filename,object_detection_output):
#       self.object_detection_predictions[filename] = object_detection_output

#### Part 3.1.2 Dataset Using Ground Truth Objects as Input Sequence

In [None]:
class SequenceDataset(Dataset):
    """
    Dataset object for a single batch in the dataset
    """
    def __init__(self, batch_number, input_seq_dim = 60, output_seq_dim = 62, starting_index = 0, length = 800):
        # directory of the background images
        self.batch_dir = f"raw_train_data/batch_{str(batch_number)}/background_images"
        self.file_names = sorted([filename
                                  for dirname, _, filenames in os.walk(self.batch_dir)
                                  for i,filename in enumerate(filenames)
                                  if i - starting_index < length
                                  and i >= starting_index])
        self.file_paths = [os.path.join(self.batch_dir, file_name) for file_name in self.file_names]
        # number of files in the dataset
        self.no_of_files = len(self.file_names)
        # dimension of one input sequence element
        self.input_seq_dim = input_seq_dim
        # dimension of one output sequence element
        self.output_seq_dim = output_seq_dim
        # dict of filename to object detection output
        self.object_detection_predictions = {}
        # JSON file name for training labels
        training_label_file_name = f"raw_train_data/batch_{str(batch_number)}/JSON/kaggle_data_{str(batch_number)}.json"
        # create input sequence from training label file
        self.sequence_input_dict = create_sequence_inputs(training_label_file_name)
        # create output sequence from training label file
        self.sequence_label_dict = create_sequence_outputs(training_label_file_name)
        self.char_set,self.char_dict = create_full_char_labels(training_label_file_name)
    def __getitem__(self, idx):
        """
        each item is a tuple of (image: Tensor, target:dict:{boxes:list[list[int]],labels:list[int]} )
        """
        file_name = self.file_names[idx]
        encoded_labels = []
        encoded_inputs = []
        #create one hot encoding for each label in the output sequence
        for label in self.sequence_label_dict[file_name]:
          # one hot encode each label, starting from i = 2, so we can use i = 0 as sos token and i = 1 as eos token
          one_hot_encoding = [ 1
                              if label + 3 == i and i >= 3
                              else 0
                              for i in range(63)]
          encoded_labels.append(one_hot_encoding)
        # create target object for training
        target = encoded_labels
        for input in self.sequence_input_dict[file_name]:
          label = input[0]
          # one hot encode each label, starting from i = 2, so we can use i = 0 as sos token and i = 1 as eos token
          one_hot_encoding = [ 1
                              if label + 3 == i and i >= 3
                              else 0
                              for i in range(57)]
          # append bounding box coordinates to the end of the one_hot_encoding to produce input vector of 60 elements
          one_hot_encoding = one_hot_encoding + input[1:5]
          encoded_inputs.append(one_hot_encoding)
        sort_inputs_by_position(encoded_inputs)
        source = encoded_inputs
        item = {
            "target": target,
            "source": source
        }
        # return encoded_inputs,encoded_labels
        return item
    def __len__(self):
        return self.no_of_files
    def add_object_detection_prediction(self,filename,object_detection_output):
      self.object_detection_predictions[filename] = object_detection_output

### Part 3.2 Initialising Datasets and Dataloaders

In [None]:
seq_train_datasets = [SequenceDataset(batch_number = n, starting_index = 0, length = 10000)
                      for n in range(1,8)]
seq_valid_dataset_1 = SequenceDataset(batch_number = 8, starting_index = 0, length = 10000)
seq_valid_dataset_2 = SequenceDataset(batch_number = 9, starting_index = 0, length = 5000)
seq_test_dataset_1 = SequenceDataset(batch_number = 9, starting_index = 5000, length = 5000)
seq_test_dataset_2 = SequenceDataset(batch_number = 10, starting_index = 0, length = 10000)

create_char_labels 10000


100%|██████████| 10000/10000 [00:00<00:00, 119302.10it/s]


create_char_labels 10000


100%|██████████| 10000/10000 [00:00<00:00, 173427.28it/s]


create_char_labels 10000


100%|██████████| 10000/10000 [00:00<00:00, 158858.30it/s]


create_char_labels 10000


100%|██████████| 10000/10000 [00:00<00:00, 188469.93it/s]


create_char_labels 10000


100%|██████████| 10000/10000 [00:00<00:00, 186826.19it/s]


create_char_labels 10000


100%|██████████| 10000/10000 [00:00<00:00, 163547.40it/s]


create_char_labels 10000


100%|██████████| 10000/10000 [00:00<00:00, 183917.95it/s]


create_char_labels 10000


100%|██████████| 10000/10000 [00:00<00:00, 121758.25it/s]


create_char_labels 10000


100%|██████████| 10000/10000 [00:00<00:00, 156767.11it/s]


create_char_labels 10000


100%|██████████| 10000/10000 [00:00<00:00, 174641.87it/s]


create_char_labels 10000


100%|██████████| 10000/10000 [00:00<00:00, 169050.18it/s]


In [None]:
seq_train_dataset = torch.utils.data.ConcatDataset(seq_train_datasets)
seq_valid_dataset = torch.utils.data.ConcatDataset([seq_valid_dataset_1, seq_valid_dataset_2])
seq_test_dataset = torch.utils.data.ConcatDataset([seq_test_dataset_1, seq_test_dataset_2])

In [None]:
def pad_seq(batch_seq, max_seq_length, pad_token, sos_token,eos_token):
  new_batch_seq = []
  for seq in batch_seq:
    new_seq = seq
    while len(new_seq) < max_seq_length:
      new_seq.append(pad_token)
    new_seq = add_start_end_tokens(new_seq, sos_token,eos_token)
    new_batch_seq.append(torch.tensor(new_seq))
  return new_batch_seq

def collate_fn(batch):
  # define pad char
  trg_pad_token = [1 if i == 0
                   else 0
                   for i in range(63)]
  trg_sos_token = [1 if i == 1
                   else 0
                   for i in range(63)]
  trg_eos_token = [1 if i == 2
                   else 0
                   for i in range(63)]
  src_pad_token = [1 if i == 0
                    else 0
                    for i in range(61)]
  src_sos_token = [1 if i == 1
                    else 0
                    for i in range(61)]
  src_eos_token = [1 if i == 2
                    else 0
                    for i in range(61)]
  # construct a two lists, one for the batch's source objects and one for the atarget objects
  batch_trg = [item["target"] for item in batch]
  batch_src = [item["source"] for item in batch]
  # pad target objects in batch to be same length as longest rarget
  max_trg_length = len(max(batch_trg, key = lambda t: len(t)))
  max_src_length = len(max(batch_src, key = lambda t: len(t)))
  batch_trg = pad_seq(batch_trg, max_trg_length, trg_pad_token, trg_sos_token, trg_eos_token)
  batch_src = pad_seq(batch_src, max_src_length, src_pad_token, src_sos_token, src_eos_token)
  # batch_src = [torch.tensor(item) for item in batch_src]
  # convert batch_src to torch tensor by stacking
  batch_src = torch.stack(batch_src,dim=0)
  # convert batch_trg to torch tensor
  batch_trg = torch.stack(batch_trg,dim=0)
  return batch_src,batch_trg
# def collate_fn(batch):
#   return batch
device = torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu")
torch.cuda.empty_cache()
print(collate_fn)

<function collate_fn at 0x79993111b910>


In [None]:

seq_train_dl = torch.utils.data.DataLoader(
    seq_train_dataset,
    batch_size = 20,
    shuffle = True,
    collate_fn = collate_fn,
    pin_memory = True if torch.cuda.is_available() else False
)
seq_valid_dl = torch.utils.data.DataLoader(
    seq_valid_dataset,
    batch_size = 20,
    shuffle = True,
    collate_fn = collate_fn,
    pin_memory = True if torch.cuda.is_available() else False
)
seq_test_dl = torch.utils.data.DataLoader(
    seq_test_dataset,
    batch_size = 20,
    shuffle = True,
    collate_fn = collate_fn,
    pin_memory = True if torch.cuda.is_available() else False
)

In [None]:

def compare_trg_with_pred(trg,pred,print_output = False):

  if print_output is False:
    return
  trg_copy = trg.transpose(0,1)
  decoded_trg = []
  trg_ls = trg_copy.tolist()
  for ls in trg_ls[0]:
    for i,l in enumerate(ls):
      if l == 1.0:
        decoded_trg.append(i)
        break
  decoded_trg = [i for i in decoded_trg if i != 0]
  pred = [i for i in pred if i != 0]
  print('trg',decoded_trg)
  print('pred',pred)

### Part 3.3.1 LSTM with Attention Model

In [None]:
import random

class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()

    def forward(self, encoder_outputs, decoder_hidden):
        # encoder_outputs: (batch_size, seq_len, hidden_dim)
        # decoder_hidden: (batch_size, hidden_dim)
        encoder_outputs = encoder_outputs.transpose(0,1)
        # Calculate the attention scores.
        scores = torch.bmm(encoder_outputs, decoder_hidden.unsqueeze(2)).squeeze(2)  # (batch_size, seq_len)
        attn_weights = F.softmax(scores, dim=1)  # (batch_size, seq_len)
        context_vector = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)  # (batch_size, hidden_dim)
        return context_vector, attn_weights

class LSTMEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.rnn = nn.LSTM(input_dim, hidden_dim, n_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # src = [src length, batch size, input_dim]
        self.dropout(src.float())
        # print('src shape',src.shape)
        outputs, (hidden, cell) = self.rnn(src)
        # outputs = [src length, batch size, hidden dim * n directions]
        # hidden = [n layers * n directions, batch size, hidden dim]
        # cell = [n layers * n directions, batch size, hidden dim]
        # outputs are always from the top hidden layer
        return outputs, hidden, cell

class LSTMDecoder(nn.Module):
    def __init__(self, output_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.attention = Attention()
        self.n_layers = n_layers
        self.rnn = nn.LSTM(output_dim + hidden_dim, hidden_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, input, encoder_outputs, hidden, cell):
        # input = [batch size, output_dim]
        # hidden = [n layers * n directions, batch size, hidden dim]
        # cell = [n layers * n directions, batch size, hidden dim]
        # n directions in the decoder will both always be 1, therefore:
        # hidden = [n layers, batch size, hidden dim]
        # context = [n layers, batch size, hidden dim]
        # input = input.unsqueeze(0)
        self.dropout(input.float())

        # create attention context vector and get attention weights
        context_vector, attn_weights = self.attention(encoder_outputs, hidden[-1])  # using the last layer's hidden state
        rnn_input = torch.cat([input.transpose(0,1), context_vector.unsqueeze(1)], dim=2)  # (batch_size, 1, emb_dim + hidden_dim)
        rnn_input = rnn_input.transpose(0,1)
        output, (hidden, cell) = self.rnn(rnn_input.float(), (hidden, cell))
        # seq length and n directions will always be 1 in this decoder, therefore:
        # output = [1, batch size, hidden dim]
        # hidden = [n layers, batch size, hidden dim]
        # cell = [n layers, batch size, hidden dim]
        prediction = self.fc_out(output.squeeze(0))
        # prediction = [batch size, output dim]
        return prediction, hidden, cell

class LSTMSeq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        assert (
            encoder.hidden_dim == decoder.hidden_dim
        ), "Hidden dimensions of encoder and decoder must be equal!"
        assert (
            encoder.n_layers == decoder.n_layers
        ), "Encoder and decoder must have equal number of layers!"

    def forward(self, src, trg, print_output = False, teacher_forcing_ratio = 0):
        # src = [src length, batch size]
        # trg = [trg length, batch size]
        # teacher_forcing_ratio is probability to use teacher forcing
        # e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time
        batch_size = trg.shape[1]
        trg_length = trg.shape[0]
        # tensor to store decoder outputs
        outputs = torch.zeros(trg_length, batch_size, self.decoder.output_dim).to(self.device)
        # last hidden state of the encoder is used as the initial hidden state of the decoder
        encoder_outputs, hidden, cell = self.encoder(src)
        # hidden = [n layers * n directions, batch size, hidden dim]
        # cell = [n layers * n directions, batch size, hidden dim]

        # first input to the decoder is the <sos> tokens
        input = trg[0,:]
        input = input[None,:]
        # input = [batch size]
        pred = [[1 for i in range(20)]]
        for t in range(1, trg_length):
            # insert input token embedding, previous hidden and previous cell states
            # receive output tensor (predictions) and new hidden and cell states
            output, hidden, cell = self.decoder(input, encoder_outputs, hidden, cell)
            # output = [batch size, output dim]
            # hidden = [n layers, batch size, hidden dim]
            # cell = [n layers, batch size, hidden dim]
            # place predictions in a tensor holding predictions for each token
            outputs[t] = output
            # decide if we are going to use teacher forcing or not
            teacher_force = random.random() < teacher_forcing_ratio
            # get the highest predicted token from our predictions
            top1 = output.argmax(1)
            top1_vals = top1.tolist()
            one_hot_encoded_top1 = torch.tensor([
                [
                  1 if i == val
                  else 0
                  for i in range(self.decoder.output_dim)
              ] for val in top1_vals]
            )
            # if teacher forcing, use actual next token as next input
            # if not, use predicted token
            pred.append(top1_vals)
            input = trg[t] if teacher_force else one_hot_encoded_top1
            encoded_input = []
            input = torch.tensor(input).float().to(device)
            input = input[None,:]
            # input = [batch size, output dim]
        # print('pred',pred)
        pred_0 = []
        for i in range(len(pred)):
          pred_0.append(pred[i][0])
        compare_trg_with_pred(trg,pred_0,print_output)
        return outputs


### Part 3.3.2 Training LSTM with Attention Model

In [None]:
# num of visible labels = 54, num of coordinates = 4, 1 sos token and 1 eos token, 1 pad token = 61
input_dim = 61
# 60 full labels + 1 sos token + 1 eos token + 1 pad token = 62
output_dim = 63
# size of hidden state vector
hidden_dim = 512
# number of LSTM layers to produce hidden feature state
n_layers = 1
# dropout for regularization
encoder_dropout = 0.2
decoder_dropout = 0.2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

lstm_encoder = LSTMEncoder(
    input_dim,
    hidden_dim,
    n_layers,
    encoder_dropout,
)

lstm_decoder = LSTMDecoder(
    output_dim,
    hidden_dim,
    n_layers,
    decoder_dropout,
)

lstm_model = LSTMSeq2Seq(lstm_encoder, lstm_decoder, device).to(device)
state_dict = torch.load('lstm_1_layer_512_with_attention_params_batch_123_epoch_5')
lstm_model.load_state_dict(state_dict)



Initialise Weights

In [None]:
# def init_weights(m):
#     for name, param in m.named_parameters():
#         nn.init.uniform_(param.data, -0.08, 0.08)
# lstm_model.apply(init_weights)

LSTMSeq2Seq(
  (encoder): LSTMEncoder(
    (rnn): LSTM(61, 512, dropout=0.2)
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (decoder): LSTMDecoder(
    (attention): Attention()
    (rnn): LSTM(575, 512, dropout=0.2)
    (fc_out): Linear(in_features=512, out_features=63, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
  )
)

In [None]:
optimizer = torch.optim.Adam(lstm_model.parameters())
criterion = torch.nn.CrossEntropyLoss()


In [None]:
#### Trying to add eos tokens ###
def train_fn(
    model, data_loader, optimizer, criterion, clip, teacher_forcing_ratio, device
):
  model.train()
  model.to(device)
  epoch_loss = 0
  div = 50
  current_loss = 0
  for i, (src,trg) in enumerate(data_loader):
    optimizer.zero_grad()
    src,trg = src.transpose(0,1).to(device),trg.transpose(0,1).to(device)
    # print(src.shape,trg.shape)
    # src = [src length, batch size, input_dim]
    # trg = [trg length, batch size, output_dim]
    print_output = False
    if i % div == 0 and i > 0:
      print_output = True
    output = model(src, trg, print_output, teacher_forcing_ratio)
    # print('output shape, trg shape',output.shape,trg.shape)
    # output = [trg length, batch size, output_dim]
    output = output[1:]
    # output = [(trg length - 1), batch size, output_dim]
    trg = trg[1:]
    trg, output = trg.transpose(1,2), output.transpose(1,2)

    # print('output shape, trg shape after reshape',output.shape,trg.shape)
    # trg = [(trg length - 1) * batch size]
    loss = criterion(output.float(), trg.float())
    current_loss += loss.item()
    if print_output:
      print('iter',i * 20)
      print('loss',current_loss / div)
      current_loss = 0
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
    optimizer.step()
    epoch_loss += loss.item()
  return epoch_loss / len(data_loader)

In [None]:
def evaluate_fn(model, data_loader, criterion, teacher_forcing_ratio, device):
    model.eval()
    epoch_loss = 0
    div = 50
    current_loss = 0
    with torch.no_grad():
      for i, (src,trg) in enumerate(data_loader):
        src,trg = src.transpose(0,1).to(device),trg.transpose(0,1).to(device)
        # src = [src length, batch size, input_dim]
        # trg = [trg length, batch size, output_dim]
        print_output = False
        if i % div == 0 and i > 0:
          print_output = True
        output = model(src, trg, print_output, teacher_forcing_ratio)
        # print('output shape, trg shape',output.shape,trg.shape)
        # output = [trg length, batch size, output_dim]
        output = output[1:]
        # output = [(trg length - 1) * batch size, output_dim]
        trg = trg[1:]
        trg, output = trg.transpose(1,2), output.transpose(1,2)

        # print('output shape, trg shape after reshape',output.shape,trg.shape)
        # trg = [(trg length - 1) * batch size]
        loss = criterion(output.float(), trg.float())
        current_loss += loss.item()
        if print_output:
          print('iter',i * 20)
          print('loss',current_loss / div)
          current_loss = 0
        epoch_loss += loss.item()
      return epoch_loss / len(data_loader)

In [None]:

# save_state_dict(lstm_model,f"lstm_1_layers_1024_with_attention_params_batch_2_layers_3_epoch_{9}")
# upload_state_dict_to_drive(f"lstm_1_layers_1024_with_attention_params_batch_2_layers_3_epoch_{9}")

Training Loop

In [None]:
import numpy as np
n_epochs = 5
start_epoch = 0
clip = 1.0
teacher_forcing_ratio = 0.2

best_valid_loss = float("inf")

for epoch in tqdm(range(start_epoch,start_epoch+n_epochs)):
    train_loss = train_fn(
        lstm_model,
        seq_train_dl,
        optimizer,
        criterion,
        clip,
        teacher_forcing_ratio,
        device,
    )

    save_state_dict(lstm_model,f"lstm_1_layer_512_with_attention_params_batch_1_7_epoch_{epoch}")
    upload_state_dict_to_drive(f"lstm_1_layer_512_with_attention_params_batch_1_7_epoch_{epoch}")
    valid_loss = evaluate_fn(
        lstm_model,
        seq_test_dl,
        criterion,
        teacher_forcing_ratio,
        device,
    )
    test_loss = evaluate_fn(
        lstm_model,
        seq_test_dl,
        criterion,
        0,
        device,
    )
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
    print('\n')
    print(f"\tTrain Loss: {train_loss:7.3f} | Train PPL: {np.exp(train_loss):7.3f}")
    print(f"\tValid Loss: {valid_loss:7.3f} | Valid PPL: {np.exp(valid_loss):7.3f}")
    print(f"\Test Loss: {test_loss:7.3f} | Valid PPL: {np.exp(test_loss):7.3f}")

  input = torch.tensor(input).float().to(device)


trg [1, 27, 61, 50, 38, 23, 61, 30, 62, 61, 16, 62, 62, 23, 61, 36, 3, 4, 10, 22, 61, 50, 62, 62, 61, 50, 3, 4, 10, 23, 61, 30, 62, 61, 10, 62, 62, 2]
pred [1, 27, 61, 61, 38, 38, 62, 62, 61, 61, 61, 61, 61, 61, 61, 61, 61, 62, 62, 62, 62, 62, 62, 62]
iter 1000
loss 2.265040135383606
trg [1, 27, 61, 50, 38, 30, 6, 16, 62, 23, 61, 9, 20, 40, 61, 11, 62, 61, 50, 62, 3, 4, 15, 22, 40, 61, 11, 62, 61, 50, 62, 62, 61, 10, 62, 2]
pred [1, 27, 61, 55, 38, 62, 62, 61, 61, 61, 61, 61, 61, 61, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62]
iter 2000
loss 1.8828905510902405
trg [1, 27, 61, 59, 38, 24, 62, 23, 61, 29, 41, 61, 8, 8, 62, 61, 59, 62, 62, 61, 29, 41, 61, 11, 15, 62, 61, 14, 62, 62, 23, 61, 29, 41, 61, 9, 8, 62, 61, 9, 62, 62, 61, 29, 41, 61, 8, 7, 62, 61, 59, 62, 62, 2]
pred [1, 27, 61, 58, 38, 24, 62, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 62, 61, 62, 62, 62, 62, 62, 62, 62, 62, 62]
iter 3000
loss 1.7820318961143493
trg [1, 27, 61, 43, 38, 12, 40

 20%|██        | 1/5 [13:47<55:08, 827.14s/it]



	Train Loss:   0.894 | Train PPL:   2.446
	Valid Loss:   0.480 | Valid PPL:   1.616
\Test Loss:   0.587 | Valid PPL:   1.798
trg [1, 27, 61, 58, 38, 24, 62, 23, 61, 23, 61, 29, 41, 61, 13, 15, 62, 61, 58, 62, 62, 61, 29, 41, 61, 16, 14, 62, 61, 9, 62, 62, 62, 61, 23, 61, 29, 41, 61, 16, 8, 62, 61, 58, 62, 62, 61, 29, 41, 61, 16, 13, 62, 61, 15, 62, 62, 62, 2]
pred [1, 27, 61, 58, 38, 24, 62, 23, 61, 23, 61, 29, 41, 61, 14, 14, 62, 61, 58, 62, 62, 61, 29, 41, 61, 9, 9, 62, 61, 8, 62, 62, 62, 61, 23, 61, 29, 41, 61, 13, 62, 62, 61, 58, 62, 62, 61, 29, 41, 61, 8, 9, 62, 61, 16, 62, 62, 62]
iter 1000
loss 0.4986289924383163
trg [1, 17, 27, 61, 50, 38, 15, 40, 61, 3, 62, 62, 23, 61, 28, 3, 4, 11, 25, 50, 4, 10, 31, 62, 61, 28, 61, 25, 50, 4, 8, 31, 62, 62, 2]
pred [1, 17, 27, 61, 50, 38, 15, 40, 61, 3, 62, 62, 23, 61, 28, 61, 50, 62, 3, 4, 25, 50, 3, 4, 31, 62, 28, 61, 25, 25, 50, 3, 50, 31, 62]
iter 2000
loss 0.4737651437520981
trg [1, 27, 61, 53, 38, 14, 62, 23, 61, 16, 62, 61, 53, 62, 

 40%|████      | 2/5 [27:26<41:07, 822.57s/it]



	Train Loss:   0.308 | Train PPL:   1.361
	Valid Loss:   0.211 | Valid PPL:   1.235
\Test Loss:   0.272 | Valid PPL:   1.312
trg [1, 27, 61, 53, 38, 4, 13, 62, 23, 61, 53, 62, 61, 53, 40, 61, 14, 62, 3, 10, 53, 3, 12, 40, 61, 11, 62, 62, 2]
pred [1, 27, 61, 53, 38, 4, 13, 62, 23, 61, 53, 62, 61, 53, 40, 61, 14, 62, 3, 10, 53, 3, 12, 62, 11, 62]
iter 1000
loss 0.19435080096125604
trg [1, 28, 61, 55, 62, 17, 23, 61, 27, 61, 51, 38, 30, 6, 15, 40, 61, 4, 62, 62, 23, 61, 45, 62, 61, 45, 51, 62, 20, 61, 51, 62, 62, 61, 27, 61, 51, 38, 30, 6, 15, 40, 61, 4, 62, 62, 23, 61, 45, 62, 61, 45, 51, 62, 4, 10, 20, 40, 61, 15, 62, 61, 51, 62, 62, 2]
pred [1, 28, 61, 58, 62, 17, 23, 61, 27, 61, 51, 38, 30, 6, 15, 40, 61, 4, 62, 62, 23, 61, 45, 62, 61, 45, 51, 62, 20, 61, 51, 62, 62, 61, 27, 61, 51, 38, 30, 6, 15, 40, 61, 4, 62, 62, 23, 61, 45, 62, 61, 45, 51, 62, 4, 10, 20, 40, 61, 15, 62, 61, 51, 62, 62, 2]
iter 2000
loss 0.17365456253290176
trg [1, 27, 61, 54, 38, 7, 40, 61, 3, 62, 62, 23, 61, 4,

 60%|██████    | 3/5 [41:08<27:24, 822.37s/it]



	Train Loss:   0.160 | Train PPL:   1.174
	Valid Loss:   0.213 | Valid PPL:   1.238
\Test Loss:   0.288 | Valid PPL:   1.334
trg [1, 27, 61, 58, 38, 24, 62, 23, 61, 28, 58, 40, 61, 11, 62, 62, 61, 28, 58, 40, 61, 8, 62, 3, 13, 62, 2]
pred [1, 27, 61, 58, 38, 24, 62, 23, 61, 28, 58, 40, 61, 11, 62, 62, 61, 28, 61, 58, 62, 3, 13, 62, 62]
iter 1000
loss 0.1709039730578661
trg [1, 27, 61, 42, 38, 4, 24, 62, 14, 2]
pred [1, 27, 61, 42, 38, 4, 24, 62, 14]
iter 2000
loss 0.14673392757773399
trg [1, 27, 61, 57, 38, 8, 62, 23, 61, 15, 3, 36, 61, 57, 62, 62, 61, 11, 3, 4, 13, 25, 13, 3, 4, 15, 33, 40, 61, 15, 62, 61, 57, 62, 31, 62, 2]
pred [1, 27, 61, 57, 38, 8, 62, 23, 61, 15, 3, 36, 61, 57, 62, 62, 61, 11, 3, 4, 13, 25, 13, 3, 4, 15, 33, 40, 61, 15, 62, 61, 57, 62, 31, 62]
iter 3000
loss 0.1223685747385025
trg [1, 27, 61, 47, 38, 24, 62, 23, 61, 29, 41, 61, 14, 8, 62, 61, 47, 62, 62, 61, 29, 41, 61, 8, 15, 62, 61, 16, 62, 62, 23, 61, 29, 41, 61, 16, 8, 62, 61, 10, 62, 62, 61, 29, 41, 61, 8,

 80%|████████  | 4/5 [54:43<13:39, 819.34s/it]



	Train Loss:   0.115 | Train PPL:   1.122
	Valid Loss:   0.115 | Valid PPL:   1.122
\Test Loss:   0.150 | Valid PPL:   1.161
trg [1, 27, 61, 52, 38, 8, 62, 23, 61, 15, 34, 61, 25, 15, 52, 31, 62, 62, 61, 11, 12, 52, 62, 2]
pred [1, 27, 61, 52, 38, 8, 62, 23, 61, 15, 34, 61, 25, 15, 52, 31, 62, 62, 61, 11, 12, 52, 62]
iter 1000
loss 0.11029723405838013
trg [1, 17, 27, 61, 58, 38, 8, 62, 23, 61, 20, 61, 58, 62, 4, 16, 62, 61, 25, 15, 3, 4, 13, 36, 61, 58, 62, 31, 25, 10, 3, 34, 61, 58, 62, 31, 62, 2]
pred [1, 17, 27, 61, 58, 38, 8, 62, 23, 61, 20, 61, 58, 62, 4, 16, 62, 61, 25, 15, 3, 4, 13, 36, 61, 58, 62, 31, 25, 10, 3, 34, 61, 58, 62, 31, 62]
iter 2000
loss 0.1142772065848112
trg [1, 27, 61, 43, 38, 15, 40, 61, 3, 62, 62, 46, 40, 61, 20, 61, 43, 62, 28, 61, 25, 9, 3, 43, 31, 62, 62, 2]
pred [1, 27, 61, 43, 38, 15, 40, 61, 3, 62, 62, 46, 40, 61, 20, 61, 43, 62, 28, 61, 25, 9, 3, 43, 31, 62, 62]
iter 3000
loss 0.11335420198738574
trg [1, 27, 61, 55, 38, 8, 40, 61, 3, 62, 62, 23, 61, 2

100%|██████████| 5/5 [1:08:24<00:00, 820.81s/it]



	Train Loss:   0.096 | Train PPL:   1.101
	Valid Loss:   0.097 | Valid PPL:   1.102
\Test Loss:   0.131 | Valid PPL:   1.140





### Part 3.4 Evaluating Expression Generation
Code to generate expression using LSTM with attention model.

In [None]:
import numpy as np
from copy import deepcopy
def get_char_cnt(src):
   cnt_chars = {}
   for i in range(0, 57):
     cnt_chars[i] = 0
   possible_chars = []
   src_ls = src.tolist()
   src_nums = []
   for char_list in src_ls:
     for one_hot_encoded_char in char_list:
        for i in range(57):
          if one_hot_encoded_char[i] == 1:
            src_nums.append(i)
   for num in src_nums:
     possible_chars.append(num)
     cnt_chars[num] += 1
   cnt_chars[1] = 0
   return possible_chars,cnt_chars
def update_char_cnt(output,possible_chars,cnt_chars,top1):

  hidden_chars = {18,39,40,41,61,62}
  mxp = -5000
  ls_output = output.tolist()
  new_cnt_chars = {}
  for char, cnt_char in cnt_chars.items():
    new_cnt_chars[char] = cnt_char
  for i in range(63):
    if i in hidden_chars:
      if ls_output[0][i] > mxp:
        top1 = torch.tensor(i)
        mxp = ls_output[0][i]
      continue
    visible_char_i = full_to_visible_encoding[i]
    if visible_char_i in possible_chars and ls_output[0][i] > mxp and (cnt_chars[visible_char_i] > 0):
      top1 = torch.tensor(i)
      mxp = ls_output[0][i]
  if top1.item() not in hidden_chars:
    new_cnt_chars[full_to_visible_encoding[top1.item()]] -= 1
  cnt_chars[1] = 0
  return top1, new_cnt_chars
def copy_char_cnt(cnt_chars):
  new_cnt_chars = {}
  for char, cnt_char in cnt_chars.items():
    new_cnt_chars[char] = cnt_char
  return new_cnt_chars

def update_char_cnt_direct(option, cnt_chars):
  new_cnt_chars = deepcopy(cnt_chars)
  new_cnt_chars[option] -= 1
  return new_cnt_chars
def get_sorted_options(output_scores,possible_chars,cnt_chars):
   hidden_chars = {18,39,40,41,61,62}
   next_options = []
   scores = []
   for (i,score) in output_scores:
     if i in hidden_chars:
       next_options.append(i)
       scores.append(score)
       continue
     visible_char_i = full_to_visible_encoding[i]
     if visible_char_i in possible_chars and (cnt_chars[visible_char_i] > 0):
       next_options.append(i)
       scores.append(score)
   return next_options,scores

In [None]:
def add_char_recursive(
    max_expression_length,
    model,
    predicted_expression,
    hidden,
    cell,
    next_input,
    encoder_outputs,
    possible_chars,
    cnt_chars
):
  # print(predicted_expression)
  # print(len(predicted_expression))
  # if len(predicted_expression) == 5:
  #   return predicted_expression, True
  next_input = next_input.unsqueeze(0)
  hidden_chars = {18,39,40,41,57,58,59,60,61,62}

  # print('next input shape',next_input.shape)
  # next_input = torch.tensor(option).float().to(device)
  output, hidden, cell = model.decoder(next_input, encoder_outputs, hidden, cell)
  top1 = output.argmax(1)
  output_scores = [(index,score) for index,score in enumerate(output.tolist()[0])]
  sorted_output_scores = sorted(output_scores, key = lambda o: o[1], reverse=True)
  next_options,scores = get_sorted_options(sorted_output_scores, possible_chars, cnt_chars)
  next_options = next_options[:2]
  scores = scores[:2]
  if next_options[0] in hidden_chars:
    next_options[0],next_options[1] = next_options[1],next_options[0]
  # print(next_options)
  n = 0
  for i,option in enumerate(next_options):

    visible_char_option = option
    if visible_char_option not in hidden_chars:
      visible_char_option = full_to_visible_encoding[option]
    if i == 2:
      break
    all_visible_chars_used = False
    next_cnt_chars = deepcopy(cnt_chars)
    if visible_char_option not in hidden_chars:
      next_cnt_chars = update_char_cnt_direct(visible_char_option, cnt_chars)
    # if option == 2 or option == 0:
    all_visible_chars_used = all([next_cnt_char == 0 for next_cnt_char in next_cnt_chars.values()])

    next_predicted_expression = deepcopy(predicted_expression)
    next_predicted_expression.append(option)
    if len(next_predicted_expression) == 67:
      print('next cnt chars')
      print(next_cnt_chars)
      print('all visible chars used')
      print(all_visible_chars_used)
    if len(next_predicted_expression) == max_expression_length or ((visible_char_option == 2 or visible_char_option == 0) and all_visible_chars_used):
      return next_predicted_expression, all_visible_chars_used
    one_hot_encoded_option = torch.tensor([
          1 if i == visible_char_option
          else 0
          for i in range(63)
    ])
    next_input = torch.tensor(one_hot_encoded_option).float().to(device)
    next_input = next_input.unsqueeze(0)
    next_predicted_expression, all_visible_chars_used = add_char_recursive(
        max_expression_length,
        model,
        next_predicted_expression,
        hidden,
        cell,
        next_input,
        encoder_outputs,
        possible_chars,
        next_cnt_chars
    )
    # if all_visible_chars_used:
    # return next_predicted_expression, all_visible_chars_used

  return next_predicted_expression, all_visible_chars_used

def create_expression_recursive(
    input_seq,
    max_expression_length,
    model,
    label_eos_token,
    label_sos_token
):
  possible_chars, cnt_chars = get_char_cnt(input_seq)
  model.eval()
  with torch.no_grad():
    encoder_outputs, hidden, cell = model.encoder(input_seq)
    next_input = label_sos_token
    next_input = next_input[None,:]
    predicted_expression = [1]
    predicted_expression, all_visible_chars_used = add_char_recursive(
        max_expression_length,
        model,
        predicted_expression,
        hidden,
        cell,
        next_input,
        encoder_outputs,
        possible_chars,
        cnt_chars
    )
    return predicted_expression
def create_expression(
    input_seq,
    max_expression_length,
    model,
    label_eos_token,
    label_sos_token,
):
  # print(input_seq)
  # generate counts for each visible character
  possible_chars, cnt_chars = get_char_cnt(input_seq)
  model.eval()
  with torch.no_grad():
    encoder_outputs, hidden, cell = model.encoder(input_seq)
    next_input = label_sos_token
    next_input = next_input[None,:]
    predicted_expression = [1]
    for _ in range(max_expression_length):
      next_input = next_input.unsqueeze(0)
      # print('next input shape',next_input.shape)
      output, hidden, cell = model.decoder(next_input, encoder_outputs, hidden, cell)
      top1 = output.argmax(1)
      # top1 ,cnt_chars = update_char_cnt(output, possible_chars, cnt_chars, top1)
      one_hot_encoded_top1 = torch.tensor([
        1 if i == top1.item()
        else 0
        for i in range(63)
      ])
      predicted_expression.append(top1.item())
      next_input = torch.tensor(one_hot_encoded_top1).float().to(device)
      next_input = next_input.unsqueeze(0)
      if next_input[0][2] == 1 or next_input[0][0] == 1:
        break
    return predicted_expression
def evaluate_create_expression(model, data_loader):
    model.eval()
    epoch_loss = 0

    current_loss = 0
    acc = 0
    total_acc = 0
    batch_acc = 0
    max_acc = -1
    min_acc = 101
    cnt_100_percent = 0
    accs = []
    with torch.no_grad():
        for j, (src,trg) in enumerate(data_loader):
          src = torch.transpose(src,0,1)
          src = src.to(device)
          trg = torch.tensor(trg).to(device)
          trg = torch.transpose(trg,0,1)
          decoded_trg = []
          trg_ls = trg.tolist()
          for ls in trg_ls:
            for i,l in enumerate(ls[0]):
              if l == 1:
                decoded_trg.append(i)
                break
          # print(len(decoded_trg))
          # print(decoded_trg)
          label_sos_token = torch.tensor([1 if i == 1 else 0 for i in range(63)]).to(device)
          label_eos_token = torch.tensor([1 if i == 2 else 0 for i in range(63)]).to(device)
          predicted_expression = create_expression(src,100,model,label_sos_token,label_eos_token)

          if len(predicted_expression) < len(decoded_trg) - 1:
            while len(predicted_expression) < len(decoded_trg) - 1:
              predicted_expression.append(0)
            predicted_expression.append(2)
          elif len(predicted_expression) < len(decoded_trg):
            predicted_expression.append(2)
          elif len(predicted_expression) == len(decoded_trg):
            predicted_expression[-1] = 2
          elif len(predicted_expression) > len(decoded_trg) - 1:
            decoded_trg[-1] = 0
            while len(predicted_expression) -1 > len(decoded_trg):
              decoded_trg.append(0)
            decoded_trg.append(2)
            predicted_expression[-1]= 2
          elif len(predicted_expression) > len(decoded_trg):
            decoded_trg[-1] = 0
            decoded_trg.append(2)
            predicted_expression[-1]= 2
          p,t = np.array(predicted_expression),np.array(decoded_trg)
          curr_acc = np.mean(p == t) * 100
          curr_acc = int(curr_acc)
          if curr_acc == 100:
            cnt_100_percent+=1
          total_acc += curr_acc
          max_acc = max(max_acc, curr_acc)
          min_acc = min(min_acc,curr_acc)
          accs.append(curr_acc)
          if j % 20 == 0:
            print('iteration')
            print(j)
            print('predicted expression')
            print(predicted_expression)
            print('ground truth expression')
            print(decoded_trg)
            print('current accuracy',curr_acc)
            print('current ave accuracy',total_acc // (j+1))
            print('min accuracy in batch',min_acc)
            print('max accuracy in batch',max_acc)
            print('number of perfect predictions',cnt_100_percent)
            max_acc = -1
            min_acc = 101
    return np.array(accs)
seq_train_dataloader = torch.utils.data.DataLoader(
    seq_train_dataset,
    batch_size = 1,
    shuffle = False,
    collate_fn = collate_fn,
    pin_memory = True if torch.cuda.is_available() else False
)
accs = evaluate_create_expression(lstm_model, seq_train_dataloader)

  trg = torch.tensor(trg).to(device)
  next_input = torch.tensor(one_hot_encoded_top1).float().to(device)


iteration
0
predicted expression
[1, 27, 61, 56, 38, 10, 62, 23, 61, 23, 61, 45, 62, 61, 45, 56, 62, 25, 46, 40, 61, 56, 62, 3, 4, 9, 34, 61, 56, 62, 4, 16, 31, 62, 61, 23, 61, 45, 62, 61, 45, 56, 62, 25, 56, 40, 61, 8, 62, 3, 8, 56, 40, 61, 16, 62, 3, 16, 15, 56, 40, 61, 11, 62, 31, 62, 2]
ground truth expression
[1, 27, 61, 56, 38, 10, 62, 23, 61, 23, 61, 45, 62, 61, 45, 56, 62, 25, 46, 40, 61, 56, 62, 3, 4, 9, 34, 61, 56, 62, 4, 16, 31, 62, 61, 23, 61, 45, 62, 61, 45, 56, 62, 25, 56, 40, 61, 8, 62, 3, 8, 56, 40, 61, 16, 62, 3, 16, 15, 56, 40, 61, 11, 62, 31, 62, 2]
current accuracy 100
current ave accuracy 100
min accuracy in batch 100
max accuracy in batch 100
number of perfect predictions 1
iteration
20
predicted expression
[1, 17, 27, 61, 50, 38, 24, 62, 23, 61, 23, 61, 29, 41, 61, 8, 12, 62, 61, 50, 62, 62, 61, 29, 41, 61, 12, 12, 62, 61, 13, 62, 62, 62, 61, 23, 61, 29, 41, 61, 14, 16, 62, 61, 50, 62, 62, 61, 29, 41, 61, 9, 9, 62, 61, 9, 62, 62, 62, 2]
ground truth expression
[1

KeyboardInterrupt: 

## Part 4: CRNN End-to-End Equation Generation

In [None]:
import torchvision
from PIL import Image
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import rnn
import torch
import torchvision.transforms as transforms
from torchvision import io
import os

In [None]:
def upload_state_dict_to_drive(file_name):
  file_metadata = {"name": file_name}
  chunksize = 4000
  media = MediaFileUpload(file_name, mimetype="application/zip", resumable=True)
  uploaded_file = drive_service.files().create(body=file_metadata, media_body=media, fields='id').execute()
def save_state_dict(model, file_name):
  torch.save(model.state_dict(), file_name)

In [None]:
def create_sequence_inputs(input_json_file):
    """
    for each image, create a file listing the coordinates of bounding boxes of latex chars of the image
    """
    data = []
    with open(f"{input_json_file}", 'r') as f:
        data = json.load(f)
    data = list(data)
    data.sort(key = lambda x: x["uuid"])
    sequence_input_dict = {}
    for d in data:
        # get output file name
        file_name = f"{d['uuid']}.jpg"
        # extract coordinates from each item in json array
        xmins = d["image_data"]["xmins"]
        ymins = d["image_data"]["ymins"]
        xmaxs = d["image_data"]["xmaxs"]
        ymaxs = d["image_data"]["ymaxs"]
        latex_char_labels = encode_char_list(d["image_data"]["visible_latex_chars"],visible_char_encoding)
        # make list of bounding box coordinates for each LaTeX character
        sequence_input_dict[file_name] = [[latex_char,xmin,ymin,xmax,ymax]
                             for latex_char,xmin,ymin,xmax,ymax in zip(latex_char_labels,xmins, ymins, xmaxs, ymaxs)]
    return sequence_input_dict
def create_sequence_outputs(input_json_file):
    """
    for each image, create a file listing the coordinates of bounding boxes of latex chars of the image
    """
    data = []
    with open(f"{input_json_file}", 'r') as f:
        data = json.load(f)
    data = list(data)
    data.sort(key = lambda x: x["uuid"])
    sequence_label_dict = {}
    for d in data:
        # get output file name
        file_name = f"{d['uuid']}.jpg"
        # extract coordinates from each item in json array
        latex_char_labels = encode_char_list(d["image_data"]["full_latex_chars"],full_char_encoding)
        # make list of bounding box coordinates for each LaTeX character
        sequence_label_dict[file_name] = latex_char_labels
    return sequence_label_dict
def sort_inputs_by_position(encoded_inputs):
  # coord_tuples = []
  # for i in range(len(input_labels)):
  #   coord_tuples.append((input_labels[i],encoded_positions[i][0],encoded_positions[i][1]))
  encoded_inputs.sort(key = lambda x: (x[54],x[55],x[56],x[57]))

In [None]:
def add_start_end_tokens(encoded_labels):
  # append one hot encoded sos and eos tokens
  # label_sos_token = [0,1,0,0.....]
  # label_eos_token = [0,0,1,0.....]
  label_sos_token = [1 if i == 1
                      else 0
                      for i in range(63)]
  label_eos_token = [1 if i == 2
                      else 0
                      for i in range(63)]
  encoded_labels.insert(0,label_sos_token)
  encoded_labels.append(label_eos_token)
  return encoded_labels

In [None]:
class CRNNDataset(Dataset):
    """
    Dataset object for a single batch in the dataset
    """
    def __init__(self, batch_number, starting_index = 0, length = 800):
        self.batch_dir = f"raw_train_data/batch_{str(batch_number)}/background_images"
        self.file_names = sorted([filename
                                  for dirname, _, filenames in os.walk(self.batch_dir)
                                  for i,filename in enumerate(filenames)
                                  if i - starting_index < length
                                  and i >= starting_index])
        self.no_of_files = len(self.file_names)

        training_label_file_name = f"raw_train_data/batch_{str(batch_number)}/JSON/kaggle_data_{str(batch_number)}.json"

        self.sequence_label_dict = create_sequence_outputs(training_label_file_name)
        self.char_set,self.char_dict = create_full_char_labels(training_label_file_name)
    def __getitem__(self, idx):
        """
        each item is a tuple of (image: Tensor, target:dict:{boxes:list[list[int]],labels:list[int]} )
        """
        file_name = self.file_names[idx]
        image = Image.open(f"{self.batch_dir}/{file_name}") # open colour image
        binary_image = convert_image_to_binary(image, thresh = 127) # convert colour image to black and white image
        # preprocessing for the binary_image object
        process = transforms.Compose([
                                transforms.PILToTensor(), # convert it to a tensor
                                transforms.Resize((640,640),antialias = True) # convert it to 600 x 600
                                ])

        # apply preprocessing to the binary_image
        final_image = process(binary_image).float()
        # final_image = final_image.to(device)
        encoded_labels = []
        for label in self.sequence_label_dict[file_name]:
          # one hot encode each label, starting from i = 2, so we can use i = 0 as sos token and i = 1 as eos token
          one_hot_encoding = [ 1
                              if label + 3 == i and i >= 3
                              else 0
                              for i in range(63)]
          encoded_labels.append(one_hot_encoding)
        # create target object for training
        target = encoded_labels
        # return final_image,target
        item = {
            "image":final_image,
            "target":target
        }
        return item
    def __len__(self):
        return self.no_of_files

In [None]:
!pip install einops

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [None]:
def compare_trg_with_pred(trg,pred,print_output = False):

  if print_output is False:
    return
  trg_copy = trg.transpose(0,1)
  decoded_trg = []
  trg_ls = trg_copy.tolist()
  for ls in trg_ls[0]:
    for i,l in enumerate(ls):
      if l == 1.0:
        decoded_trg.append(i)
        break
  decoded_trg = [i for i in decoded_trg if i != 0]
  pred = [i for i in pred if i != 0]
  print('trg',decoded_trg)
  print('pred',pred)

In [None]:
import random
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.cnn = nn.Sequential(
          nn.Conv2d(1, 64,kernel_size = 4, stride = 2, padding = 1),
          nn.BatchNorm2d(64),
          nn.ReLU(),
          nn.MaxPool2d((2,2)),
          nn.Conv2d(64, 128, kernel_size = 4, stride = 2, padding = 1),
          nn.BatchNorm2d(128),
          nn.ReLU(),
          nn.MaxPool2d((2,2)),
          nn.Conv2d(128, 256, kernel_size = 4, stride = 2, padding = 1),
          nn.BatchNorm2d(256),
          nn.ReLU(),
          nn.MaxPool2d((2,2)),
          nn.Conv2d(256, 512, kernel_size = 4, stride = 2, padding = 1),
          nn.BatchNorm2d(512),
          nn.ReLU(),
          nn.MaxPool2d((2,2)),
          Rearrange('b c h w -> (h w) b c')
        )
        self.rnn = nn.LSTM(input_dim, hidden_dim, n_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # src = [src length, batch size, input_dim]
        self.dropout(src.float())

        # print('src shape',src.shape)
        src = self.cnn(src)
        # print('src shape after conv',src.shape)
        output, (hidden, cell) = self.rnn(src)
        # print('outputs shape',output.shape)
        # print('hidden shape',hidden.shape)
        # print('cell shape', cell.shape)
        # outputs = [src length, batch size, hidden dim * n directions]
        # hidden = [n layers * n directions, batch size, hidden dim]
        # cell = [n layers * n directions, batch size, hidden dim]
        # outputs are always from the top hidden layer
        return hidden, cell

class Decoder(nn.Module):
    def __init__(self, output_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.rnn = nn.LSTM(output_dim, hidden_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell):
        # input = [batch size]
        # hidden = [n layers * n directions, batch size, hidden dim]
        # cell = [n layers * n directions, batch size, hidden dim]
        # n directions in the decoder will both always be 1, therefore:
        # hidden = [n layers, batch size, hidden dim]
        # context = [n layers, batch size, hidden dim]
        # input = [1, batch size]
        # embedded = self.dropout(self.embedding(input))
        # embedded = [1, batch size, embedding dim]
        # input = input[None,:,:]
        self.dropout(input.float())
        output, (hidden, cell) = self.rnn(input.float(), (hidden, cell))
        # output = [seq length, batch size, hidden dim * n directions]
        # hidden = [n layers * n directions, batch size, hidden dim]
        # cell = [n layers * n directions, batch size, hidden dim]
        # seq length and n directions will always be 1 in this decoder, therefore:
        # output = [1, batch size, hidden dim]
        # hidden = [n layers, batch size, hidden dim]
        # cell = [n layers, batch size, hidden dim]
        prediction = self.fc_out(output.squeeze(0))
        # prediction = [batch size, output dim]
        return prediction, hidden, cell

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        assert (
            encoder.hidden_dim == decoder.hidden_dim
        ), "Hidden dimensions of encoder and decoder must be equal!"
        assert (
            encoder.n_layers == decoder.n_layers
        ), "Encoder and decoder must have equal number of layers!"

    def forward(self, src, trg, print_output = False, teacher_forcing_ratio = 0):
        # src = [src length, batch size]
        # trg = [trg length, batch size]
        # teacher_forcing_ratio is probability to use teacher forcing
        # e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time
        batch_size = trg.shape[1]
        trg_length = trg.shape[0]
        # tensor to store decoder outputs
        outputs = torch.zeros(trg_length, batch_size, self.decoder.output_dim).to(self.device)
        # last hidden state of the encoder is used as the initial hidden state of the decoder
        hidden, cell = self.encoder(src)
        # hidden = [n layers * n directions, batch size, hidden dim]
        # cell = [n layers * n directions, batch size, hidden dim]
        # first input to the decoder is the <sos> tokens
        input = trg[0,:]
        input = input[None,:]
        # input = [batch size]
        pred = [[1 for i in range(10)]]
        for t in range(1, trg_length):
            # insert input token embedding, previous hidden and previous cell states
            # receive output tensor (predictions) and new hidden and cell states
            output, hidden, cell = self.decoder(input, hidden, cell)
            # output = [batch size, output dim]
            # hidden = [n layers, batch size, hidden dim]
            # cell = [n layers, batch size, hidden dim]
            # place predictions in a tensor holding predictions for each token
            outputs[t] = output
            # decide if we are going to use teacher forcing or not
            teacher_force = random.random() < teacher_forcing_ratio
            # get the highest predicted token from our predictions
            top1 = output.argmax(1)
            top1_vals = top1.tolist()
            one_hot_encoded_top1 = torch.tensor([
                [
                  1 if i == val
                  else 0
                  for i in range(self.decoder.output_dim)
              ] for val in top1_vals]
            )
            # if teacher forcing, use actual next token as next input
            # if not, use predicted token
            pred.append(top1_vals)
            input = trg[t] if teacher_force else one_hot_encoded_top1
            encoded_input = []
            input = torch.tensor(input).float().to(device)
            input = input[None,:]
            # input = [batch size]
        # print('pred',pred)
        pred_0 = []
        for i in range(len(pred)):
          pred_0.append(pred[i][0])
        compare_trg_with_pred(trg,pred_0,print_output)
        return outputs

In [None]:
# number of channels of convolution output
input_dim = 512
# 60 full latex char labels + 1 sos token + 1 eos token + 1 pad token= 63 total vocab size
output_dim = 63
# size of hidden state vector
hidden_dim = 512
# number of LSTM layers to produce hidden feature state
n_layers = 2
# dropout for regularization
encoder_dropout = 0.2
decoder_dropout = 0.2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = Encoder(
    input_dim,
    hidden_dim,
    n_layers,
    encoder_dropout,
)

decoder = Decoder(
    output_dim,
    hidden_dim,
    n_layers,
    decoder_dropout,
)

model = Seq2Seq(encoder, decoder, device).to(device)
# state_dict = torch.load(f"crnn_params_batch_2_conv_512_lstm_layers_1_with_pad_epoch_4")
# model.load_state_dict(state_dict)

In [None]:
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()

In [None]:
def pad_trg(batch_trg, max_trg_length, encoded_pad_char):
  new_batch_trg = []
  for trg in batch_trg:
    new_trg = trg
    while len(new_trg) < max_trg_length:
      new_trg.append(encoded_pad_char)
    new_trg = add_start_end_tokens(new_trg)
    new_batch_trg.append(new_trg)
  return new_batch_trg
def collate_fn(batch):
  encoded_pad_char = [1 if i == 0
                   else 0
                   for i in range(63)]
  batch_src = [item["image"] for item in batch]
  batch_trg = [item["target"] for item in batch]
  max_trg_length = len(max(batch_trg, key = lambda t: len(t)))
  batch_trg = pad_trg(batch_trg, max_trg_length, encoded_pad_char)

  batch_src = torch.stack(batch_src,dim=0)
  batch_trg = torch.tensor(batch_trg)

  return batch_src,batch_trg
# def collate_fn(batch):
#   return batch
device = torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu")
torch.cuda.empty_cache()

In [None]:
crnn_train_dataset = CRNNDataset(batch_number = 2, starting_index = 0, length = 10000)
crnn_valid_dataset = CRNNDataset(batch_number = 3, starting_index = 0, length = 2000)
crnn_test_dataset = CRNNDataset(batch_number = 4, starting_index = 0, length = 2000)

create_char_labels 10000


100%|██████████| 10000/10000 [00:00<00:00, 83057.00it/s]


create_char_labels 10000


100%|██████████| 10000/10000 [00:00<00:00, 173164.50it/s]


create_char_labels 10000


100%|██████████| 10000/10000 [00:00<00:00, 118460.74it/s]


In [None]:
crnn_train_dl = torch.utils.data.DataLoader(
    crnn_train_dataset,
    batch_size = 20,
    shuffle = True,
    collate_fn = collate_fn,
    pin_memory = True if torch.cuda.is_available() else False
)
crnn_valid_dl = torch.utils.data.DataLoader(
    crnn_valid_dataset,
    batch_size = 20,
    shuffle = True,
    collate_fn = collate_fn,
    pin_memory = True if torch.cuda.is_available() else False
)
crnn_test_dl = torch.utils.data.DataLoader(
    crnn_test_dataset,
    batch_size = 20,
    shuffle = True,
    collate_fn = collate_fn,
    pin_memory = True if torch.cuda.is_available() else False
)

In [None]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)
model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (cnn): Sequential(
      (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      (4): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
      (7): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      (8): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU()
      (11): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      (12): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (13): Ba

In [None]:
#### Trtrging to add eos tokens ###
def crnn_train_fn(
    model, data_loader, optimizer, criterion, clip, teacher_forcing_ratio, device
):
  model.train()
  model.to(device)
  epoch_loss = 0
  batch_loss = 0
  current_loss = 0
  for i, (src,trg) in enumerate(data_loader):
    optimizer.zero_grad()
    src,trg = src.to(device),trg.to(device)
    # print("src shape, trg shape", src.shape, trg.shape)
    # src = [src length, batch size, input_dim]
    # trg = [trg length, batch size, output_dim]
    print_output = False
    if i % 20 == 0 and i > 0:
      print_output = True
    trg = torch.transpose(trg,0,1)
    output = model(src.float(), trg.float(), print_output, teacher_forcing_ratio)
    # output = [trg length, batch size, output_dim]
    output_dim = output.shape[-1]
    output = output[1:]
    # output = [(trg length - 1) * batch size, output_dim]
    trg = trg[1:]
    trg, output = trg.transpose(1,2), output.transpose(1,2)
    # print(output.shape,trg.shape)
    # trg = [(trg length - 1) * batch size]
    loss = criterion(output.float(), trg.float())
    # trg = [(trg length - 1) * batch size]
    batch_loss += loss.item()
    if print_output:
      current_loss += batch_loss
      print('iter',i)
      print('batch loss',loss.item())
      print('20 batch loss',batch_loss/20)
      # print('current loss', current_loss /(20))
      batch_loss = 0
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
    optimizer.step()
    epoch_loss += loss.item()
  return epoch_loss / len(data_loader)

In [None]:
def crnn_evaluate_fn(model, data_loader, criterion, teacher_forcing_ratio, device):
    model.eval()
    epoch_loss = 0
    batch_loss = 0
    current_loss = 0
    with torch.no_grad():
        for i, (src,trg)  in enumerate(data_loader):
          src,trg = src.to(device),trg.to(device)
          # print("src shape, trg shape", src.shape, trg.shape)
          # src = [src length, batch size, input_dim]
          # trg = [trg length, batch size, output_dim]
          print_output = False
          if i % 20 == 0 and i > 0:
            print_output = True
          trg = torch.transpose(trg,0,1)
          output = model(src.float(), trg.float(), print_output, teacher_forcing_ratio)
          # output = [trg length, batch size, output_dim]
          output = output[1:]
          # output = [(trg length - 1) * batch size, output_dim]
          trg = trg[1:]
          trg, output = trg.transpose(1,2), output.transpose(1,2)
          # trg = [(trg length - 1) * batch size]
          loss = criterion(output.float(), trg.float())
          # trg = [(trg length - 1) * batch size]
          batch_loss += loss.item()
          current_loss += loss.item()
          if print_output:
            print('iter',i)
            print('batch loss',loss.item())
            print('20 batch loss',batch_loss/20)
            # print('current loss', current_loss /(5*i))
            batch_loss = 0
          epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

In [None]:
import numpy as np
n_epochs = 10
clip = 1.0
teacher_forcing_ratio = 1.0

best_valid_loss = float("inf")
for epoch in tqdm(range(1,n_epochs)):
    train_loss = crnn_train_fn(
        model,
        crnn_train_dl,
        optimizer,
        criterion,
        clip,
        teacher_forcing_ratio = teacher_forcing_ratio,
        device = device,
    )

    # save_state_dict(model,f"crnn_params_batch_2_conv_512_lstm_layers_2_with_pad_epoch_{epoch}")
    # upload_state_dict_to_drive(f"crnn_params_batch_2_conv_512_lstm_layers_2_with_pad_epoch_{epoch}")

    valid_loss = crnn_evaluate_fn(
        model,
        crnn_valid_dl,
        criterion,
        teacher_forcing_ratio = teacher_forcing_ratio,
        device = device
    )
    test_loss = crnn_evaluate_fn(
        model,
        crnn_test_dl,
        criterion,
        teacher_forcing_ratio = 0,
        device = device
    )
    print('\n')
    print(f"\tTrain Loss: {train_loss:7.3f} | Train PPL: {np.exp(train_loss):7.3f}")
    print(f"\tValid Loss: {valid_loss:7.3f} | Valid PPL: {np.exp(valid_loss):7.3f}")
    print(f"\tTest Loss: {test_loss:7.3f} | Test PPL: {np.exp(test_loss):7.3f}")

  input = torch.tensor(input).float().to(device)


trg [1, 27, 61, 52, 38, 24, 62, 23, 61, 4, 15, 25, 36, 61, 23, 61, 15, 62, 61, 52, 62, 62, 3, 4, 15, 23, 61, 16, 62, 61, 52, 62, 36, 61, 23, 61, 9, 62, 61, 52, 62, 62, 31, 62, 61, 4, 13, 52, 40, 61, 4, 10, 62, 62, 2]
pred [1, 27, 61, 61, 38, 38, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 62, 62, 62, 62, 62, 61, 61, 61, 61, 61, 61, 61, 62, 61, 62, 62, 62, 62, 61]
iter 20
batch loss 1.614637017250061
20 batch loss 2.3017165184021
trg [1, 17, 27, 61, 57, 38, 24, 62, 23, 61, 13, 62, 61, 14, 62, 2]
pred [1, 27, 61, 61, 38, 38, 62, 61, 61, 61, 61, 61, 61, 61, 62, 61, 61, 61]
iter 40
batch loss 1.7223601341247559
20 batch loss 1.7315911412239076
trg [1, 27, 61, 57, 38, 30, 6, 9, 62, 23, 61, 34, 40, 61, 12, 62, 61, 57, 62, 3, 34, 40, 61, 9, 62, 61, 57, 62, 62, 61, 16, 62, 2]
pred [1, 27, 61, 58, 38, 38, 62, 62, 61, 61, 61, 61, 61, 61, 62, 62, 62, 62, 62, 61, 61, 61, 61, 62, 62, 62, 62, 62, 62, 61]
iter 60
batch loss 1

 11%|█         | 1/9 [07:01<56:11, 421.40s/it]



	Train Loss:   1.016 | Train PPL:   2.761
	Valid Loss:   0.674 | Valid PPL:   1.962
	Test Loss:   5.825 | Test PPL: 338.571
trg [1, 27, 61, 47, 38, 8, 62, 23, 61, 12, 3, 4, 13, 34, 61, 47, 62, 62, 61, 47, 22, 61, 47, 62, 62, 2]
pred [1, 27, 61, 58, 38, 9, 62, 23, 61, 23, 62, 4, 9, 34, 61, 25, 62, 62, 61, 9, 40, 61, 58, 62, 62]
iter 20
batch loss 0.5604763031005859
20 batch loss 0.6787390619516372
trg [1, 27, 61, 59, 38, 4, 8, 40, 61, 3, 62, 62, 23, 61, 59, 40, 61, 15, 62, 3, 4, 14, 59, 40, 61, 9, 62, 3, 11, 59, 62, 61, 59, 40, 61, 14, 62, 3, 4, 9, 59, 40, 61, 15, 62, 62, 2]
pred [1, 27, 61, 58, 38, 9, 24, 62, 61, 3, 62, 62, 23, 61, 23, 40, 61, 10, 62, 3, 4, 10, 20, 40, 61, 9, 62, 62, 4, 57, 40, 4, 3, 61, 9, 62, 3, 4, 10, 59, 40, 61, 9, 62, 62]
iter 40
batch loss 0.6644503474235535
20 batch loss 0.6742243781685829
trg [1, 27, 61, 59, 38, 15, 62, 23, 61, 10, 62, 61, 36, 61, 25, 16, 59, 31, 62, 62, 21, 61, 25, 9, 59, 31, 62, 2]
pred [1, 27, 61, 58, 38, 24, 62, 23, 61, 9, 62, 61, 35, 61,

 22%|██▏       | 2/9 [12:10<41:26, 355.27s/it]



	Train Loss:   0.573 | Train PPL:   1.773
	Valid Loss:   0.569 | Valid PPL:   1.766
	Test Loss:   4.850 | Test PPL: 127.734
trg [1, 27, 61, 54, 38, 23, 61, 30, 62, 61, 10, 62, 62, 33, 40, 61, 16, 62, 61, 54, 62, 3, 20, 40, 61, 9, 62, 61, 54, 62, 2]
pred [1, 27, 61, 58, 38, 23, 61, 30, 62, 61, 9, 62, 62, 23, 40, 61, 10, 62, 61, 54, 62, 3, 34, 40, 61, 11, 62, 61, 54, 62]
iter 20
batch loss 0.493679016828537
20 batch loss 0.5053339600563049
trg [1, 28, 61, 53, 62, 17, 23, 61, 27, 61, 54, 38, 13, 40, 61, 3, 62, 62, 23, 61, 45, 62, 61, 45, 54, 62, 26, 28, 61, 54, 62, 32, 62, 61, 27, 61, 54, 38, 13, 40, 61, 3, 62, 62, 23, 61, 45, 62, 61, 45, 54, 62, 11, 6, 54, 62, 2]
pred [1, 28, 61, 57, 62, 17, 23, 61, 27, 61, 58, 38, 23, 40, 61, 3, 62, 62, 23, 61, 45, 62, 61, 45, 54, 62, 26, 28, 61, 54, 62, 32, 62, 61, 27, 61, 54, 38, 9, 40, 61, 3, 62, 62, 23, 61, 45, 62, 61, 45, 54, 62, 23, 6, 54, 62]
iter 40
batch loss 0.4641617238521576
20 batch loss 0.5008790761232376
trg [1, 27, 61, 48, 38, 30, 6, 9

 33%|███▎      | 3/9 [17:10<33:01, 330.26s/it]



	Train Loss:   0.458 | Train PPL:   1.582
	Valid Loss:   0.519 | Valid PPL:   1.681
	Test Loss:   5.929 | Test PPL: 375.713
trg [1, 27, 61, 58, 38, 9, 40, 61, 3, 62, 62, 20, 61, 58, 62, 3, 7, 23, 61, 16, 62, 61, 58, 62, 2]
pred [1, 27, 61, 58, 38, 23, 40, 61, 3, 62, 62, 58, 61, 58, 62, 3, 4, 23, 61, 15, 62, 61, 58, 62]
iter 20
batch loss 0.4137965440750122
20 batch loss 0.41773948520421983
trg [1, 27, 61, 58, 38, 24, 62, 23, 61, 10, 12, 58, 40, 61, 7, 62, 62, 61, 15, 26, 58, 40, 61, 8, 62, 32, 62, 2]
pred [1, 27, 61, 58, 38, 24, 62, 23, 61, 58, 58, 58, 40, 61, 9, 62, 62, 61, 8, 26, 58, 40, 61, 10, 62, 32, 62]
iter 40
batch loss 0.4144754707813263
20 batch loss 0.40635797530412676
trg [1, 27, 61, 54, 38, 24, 62, 23, 61, 23, 61, 29, 41, 61, 11, 10, 62, 61, 54, 62, 62, 61, 29, 41, 61, 10, 8, 62, 61, 11, 62, 62, 62, 61, 23, 61, 29, 41, 61, 13, 16, 62, 61, 54, 62, 62, 61, 29, 41, 61, 15, 11, 62, 61, 13, 62, 62, 62, 2]
pred [1, 27, 61, 59, 38, 24, 62, 23, 61, 23, 61, 29, 41, 61, 9, 7, 62, 

 44%|████▍     | 4/9 [22:09<26:29, 317.82s/it]



	Train Loss:   0.411 | Train PPL:   1.508
	Valid Loss:   0.500 | Valid PPL:   1.649
	Test Loss:   5.210 | Test PPL: 183.104
trg [1, 27, 61, 54, 38, 4, 24, 62, 23, 61, 10, 54, 40, 61, 10, 62, 3, 10, 62, 61, 10, 26, 54, 40, 61, 12, 62, 32, 62, 2]
pred [1, 27, 61, 59, 38, 24, 24, 62, 23, 61, 54, 54, 40, 61, 8, 62, 3, 4, 62, 61, 35, 26, 54, 40, 61, 9, 62, 32, 62]
iter 20
batch loss 0.4289206266403198
20 batch loss 0.3861353829503059
trg [1, 27, 61, 54, 38, 24, 62, 23, 61, 46, 40, 61, 54, 62, 62, 61, 28, 61, 54, 62, 62, 2]
pred [1, 27, 61, 58, 38, 24, 62, 23, 61, 9, 40, 61, 54, 62, 62, 61, 46, 61, 54, 62, 62]
iter 40
batch loss 0.3126627206802368
20 batch loss 0.3727542653679848
trg [1, 27, 61, 58, 38, 24, 62, 23, 61, 12, 11, 28, 61, 58, 62, 15, 6, 58, 62, 61, 23, 61, 12, 62, 61, 16, 35, 61, 58, 62, 62, 62, 23, 61, 11, 35, 61, 58, 62, 62, 61, 9, 35, 61, 58, 62, 62, 2]
pred [1, 27, 61, 58, 38, 24, 62, 23, 61, 23, 58, 58, 61, 58, 62, 23, 6, 58, 62, 61, 23, 61, 9, 62, 61, 58, 35, 61, 58, 62,

 56%|█████▌    | 5/9 [27:06<20:40, 310.20s/it]



	Train Loss:   0.385 | Train PPL:   1.470
	Valid Loss:   0.516 | Valid PPL:   1.675
	Test Loss:   4.999 | Test PPL: 148.209
trg [1, 27, 61, 51, 38, 24, 62, 26, 51, 4, 14, 32, 3, 51, 2]
pred [1, 27, 61, 58, 38, 24, 62, 26, 51, 4, 9, 32, 3, 51]
iter 20
batch loss 0.40965646505355835
20 batch loss 0.3833142340183258
trg [1, 27, 61, 54, 38, 23, 61, 30, 62, 61, 10, 62, 62, 33, 40, 61, 16, 62, 61, 54, 62, 3, 20, 40, 61, 9, 62, 61, 54, 62, 2]
pred [1, 27, 61, 58, 38, 23, 61, 30, 62, 61, 9, 62, 62, 20, 40, 61, 9, 62, 61, 54, 62, 3, 36, 40, 61, 9, 62, 61, 54, 62]
iter 40
batch loss 0.3556564152240753
20 batch loss 0.3661136835813522
trg [1, 27, 61, 48, 38, 23, 61, 30, 62, 61, 9, 62, 62, 36, 40, 61, 16, 62, 61, 48, 62, 3, 34, 40, 61, 14, 62, 61, 48, 62, 2]
pred [1, 27, 61, 58, 38, 23, 61, 30, 62, 61, 9, 62, 62, 34, 40, 61, 9, 62, 61, 48, 62, 3, 34, 40, 61, 9, 62, 61, 48, 62]
iter 60
batch loss 0.41415348649024963
20 batch loss 0.37207387536764147
trg [1, 27, 61, 60, 38, 23, 61, 30, 62, 61, 9, 

 67%|██████▋   | 6/9 [32:04<15:18, 306.02s/it]



	Train Loss:   0.370 | Train PPL:   1.448
	Valid Loss:   0.506 | Valid PPL:   1.659
	Test Loss:   4.737 | Test PPL: 114.116
trg [1, 12, 6, 13, 27, 61, 54, 38, 30, 6, 10, 40, 61, 4, 62, 62, 23, 61, 9, 54, 3, 4, 11, 30, 62, 61, 34, 61, 54, 62, 22, 61, 54, 62, 62, 2]
pred [1, 10, 6, 9, 27, 61, 58, 38, 30, 6, 9, 40, 61, 4, 62, 62, 23, 61, 9, 62, 3, 4, 9, 30, 62, 61, 34, 61, 54, 62, 36, 61, 54, 62, 62]
iter 20
batch loss 0.40972205996513367
20 batch loss 0.37137451469898225
trg [1, 27, 61, 49, 38, 30, 6, 15, 62, 36, 40, 61, 15, 62, 61, 49, 62, 3, 21, 40, 61, 16, 62, 61, 49, 62, 2]
pred [1, 27, 61, 58, 38, 30, 6, 9, 62, 20, 40, 61, 9, 62, 61, 49, 62, 3, 34, 40, 61, 9, 62, 61, 49, 62]
iter 40
batch loss 0.3290197253227234
20 batch loss 0.34488980621099474
trg [1, 27, 61, 59, 38, 9, 40, 61, 3, 62, 62, 23, 61, 9, 6, 59, 62, 61, 4, 13, 34, 61, 59, 62, 33, 61, 59, 62, 62, 2]
pred [1, 27, 61, 58, 38, 9, 40, 61, 3, 62, 62, 23, 61, 23, 6, 59, 62, 61, 4, 9, 36, 61, 59, 62, 36, 61, 59, 62, 62]
iter 

 78%|███████▊  | 7/9 [37:01<10:06, 303.26s/it]



	Train Loss:   0.357 | Train PPL:   1.428
	Valid Loss:   0.507 | Valid PPL:   1.660
	Test Loss:   4.926 | Test PPL: 137.797
trg [1, 27, 61, 60, 38, 23, 61, 30, 62, 61, 12, 62, 62, 23, 61, 10, 34, 40, 61, 12, 62, 61, 60, 62, 3, 4, 9, 22, 40, 61, 11, 62, 61, 60, 62, 62, 61, 13, 62, 2]
pred [1, 27, 61, 54, 38, 23, 61, 30, 62, 61, 10, 62, 62, 23, 61, 9, 20, 40, 61, 10, 62, 61, 60, 62, 3, 4, 10, 36, 40, 61, 12, 62, 61, 60, 62, 62, 61, 15, 62]
iter 20
batch loss 0.3205740451812744
20 batch loss 0.3526390343904495
trg [1, 27, 61, 51, 38, 24, 62, 23, 61, 23, 61, 29, 41, 61, 10, 12, 62, 61, 51, 62, 62, 61, 29, 41, 61, 8, 7, 62, 61, 15, 62, 62, 62, 61, 23, 61, 29, 41, 61, 8, 12, 62, 61, 51, 62, 62, 61, 29, 41, 61, 14, 13, 62, 61, 13, 62, 62, 62, 2]
pred [1, 27, 61, 55, 38, 24, 62, 23, 61, 23, 61, 29, 41, 61, 8, 8, 62, 61, 51, 62, 62, 61, 29, 41, 61, 8, 7, 62, 61, 8, 62, 62, 62, 61, 23, 61, 29, 41, 61, 8, 7, 62, 61, 51, 62, 62, 61, 29, 41, 61, 8, 9, 62, 61, 8, 62, 62, 62, 2]
iter 40
batch loss 

 89%|████████▉ | 8/9 [42:01<05:02, 302.26s/it]



	Train Loss:   0.347 | Train PPL:   1.414
	Valid Loss:   0.507 | Valid PPL:   1.660
	Test Loss:   5.233 | Test PPL: 187.355
trg [1, 27, 61, 53, 38, 24, 62, 23, 61, 16, 53, 40, 61, 9, 62, 3, 53, 34, 61, 53, 62, 62, 61, 53, 40, 61, 14, 62, 62, 2]
pred [1, 27, 61, 57, 38, 24, 62, 23, 61, 25, 53, 40, 61, 8, 62, 3, 53, 34, 61, 53, 62, 62, 61, 53, 40, 61, 9, 62, 62]
iter 20
batch loss 0.27997252345085144
20 batch loss 0.3502157524228096
trg [1, 27, 61, 55, 38, 15, 40, 61, 4, 62, 62, 23, 61, 36, 61, 55, 62, 62, 61, 36, 40, 61, 8, 62, 61, 55, 62, 4, 11, 62, 2]
pred [1, 27, 61, 57, 38, 4, 62, 61, 3, 62, 62, 23, 61, 55, 61, 55, 62, 62, 61, 36, 40, 61, 9, 62, 61, 55, 62, 4, 9, 62]
iter 40
batch loss 0.3586333692073822
20 batch loss 0.3197669818997383
trg [1, 27, 61, 48, 38, 10, 40, 61, 3, 62, 62, 21, 61, 48, 62, 3, 4, 16, 23, 61, 16, 62, 61, 48, 62, 2]
pred [1, 27, 61, 58, 38, 9, 40, 61, 3, 62, 62, 34, 61, 48, 62, 3, 4, 10, 23, 61, 9, 62, 61, 48, 62]
iter 60
batch loss 0.2996523082256317
20 bat

100%|██████████| 9/9 [47:05<00:00, 313.94s/it]



	Train Loss:   0.337 | Train PPL:   1.400
	Valid Loss:   0.502 | Valid PPL:   1.652
	Test Loss:   5.454 | Test PPL: 233.635





In [None]:
def create_expression(
    input_seq,
    max_expression_length,
    model,
    label_eos_token,
    label_sos_token,
):
  model.eval()
  with torch.no_grad():
    hidden, cell = model.encoder(input_seq)
    next_input = label_sos_token
    next_input = next_input[None,None,:]
    outputs = [label_sos_token]
    predicted_expression = []
    for _ in range(max_expression_length):
      # next_input = next_input[None,:]
      if _ > 0:
        next_input = next_input.unsqueeze(0)
      output, hidden, cell = model.decoder(next_input, hidden, cell)
      top1 = output.argmax(1)
      one_hot_encoded_top1 = torch.tensor([
        1 if i == top1.item()
        else 0
        for i in range(63)
      ])
      predicted_expression.append(top1.item())
      next_input = torch.tensor(one_hot_encoded_top1).float().to(device)
      outputs.append(next_input)
      next_input = next_input.unsqueeze(0)
      if next_input[0][2] == 1 or next_input[0][0] == 1:
        break
    return predicted_expression

In [None]:
def evaluate_create_expression(model, data_loader):
    model.eval()
    epoch_loss = 0

    current_loss = 0
    acc = 0
    total_acc = 0
    batch_acc = 0
    max_acc = -1
    min_acc = 101
    accs = []
    with torch.no_grad():
        for j, d in enumerate(data_loader):
          src = []
          trg = []
          for p in d:
            src.append(p["image"])
            # print('image shape',p["image"].shape)
            trg.append(p["target"])
          src = torch.stack(src,dim=0).to(device)
          # print('src shape',src.shape)
          # src = torch.transpose(src,0,1)
          trg = torch.tensor(trg).to(device)
          trg = torch.transpose(trg,0,1)
          label_sos_token = torch.tensor([1 if i == 1 else 0 for i in range(63)]).to(device)
          label_eos_token = torch.tensor([1 if i == 2 else 0 for i in range(63)]).to(device)
          predicted_expression = create_expression(src,100,model,label_sos_token,label_eos_token)
          decoded_trg = []
          trg_ls = trg.tolist()
          for ls in trg_ls:
            for i,l in enumerate(ls[0]):
              if l == 1:
                decoded_trg.append(i)
                break
          if len(predicted_expression) < len(decoded_trg) - 1:
            while len(predicted_expression) < len(decoded_trg) - 1:
              predicted_expression.append(0)
            predicted_expression.append(2)
          elif len(predicted_expression) < len(decoded_trg):
            predicted_expression.append(2)
          elif len(predicted_expression) == len(decoded_trg):
            predicted_expression[-1] = 2
          elif len(predicted_expression) > len(decoded_trg) - 1:
            decoded_trg[-1] = 0
            while len(predicted_expression) -1 > len(decoded_trg):
              decoded_trg.append(0)
            decoded_trg.append(2)
            predicted_expression[-1]= 2
          elif len(predicted_expression) > len(decoded_trg):
            decoded_trg[-1] = 0
            decoded_trg.append(2)
            predicted_expression[-1]= 2
          p,t = np.array(predicted_expression),np.array(decoded_trg)
          curr_acc = np.mean(p == t) * 100
          total_acc += curr_acc
          max_acc = max(max_acc, curr_acc)
          min_acc = min(min_acc,curr_acc)
          accs.append(curr_acc)
          if j % 20 == 0:
            print('predicted expression')
            print(predicted_expression)
            print('ground truth expression')
            print(decoded_trg)
            print('current accuracy',curr_acc)
            print('current ave accuracy',total_acc / (j+1))
            print('min accuracy in batch',min_acc)
            print('max accuracy in batch',max_acc)
            max_acc = -1
            min_acc = 101
    return np.array(accs)

In [None]:
def collate_fn(batch):
  return batch
crnn_translate_dl = torch.utils.data.DataLoader(
    crnn_test_dataset,
    batch_size = 1,
    shuffle = True,
    collate_fn = collate_fn,
    pin_memory = True if torch.cuda.is_available() else False
)

In [None]:
accs = evaluate_create_expression(model, crnn_translate_dl)

  next_input = torch.tensor(one_hot_encoded_top1).float().to(device)


predicted expression
[27, 61, 58, 38, 4, 24, 62, 9, 58, 40, 61, 4, 8, 62, 2]
ground truth expression
[27, 61, 49, 38, 4, 24, 62, 11, 49, 40, 61, 16, 0, 0, 2]
current accuracy 60.0
current ave accuracy 60.0
min accuracy in batch 60.0
max accuracy in batch 60.0
predicted expression
[27, 61, 57, 38, 24, 62, 23, 61, 4, 9, 25, 34, 61, 23, 61, 10, 62, 61, 57, 62, 62, 3, 4, 10, 23, 61, 10, 62, 61, 57, 62, 20, 61, 23, 61, 13, 62, 61, 57, 62, 62, 31, 62, 61, 4, 10, 57, 40, 61, 4, 9, 62, 62, 2]
ground truth expression
[10, 6, 10, 27, 61, 43, 38, 24, 62, 23, 61, 20, 61, 9, 6, 43, 62, 3, 4, 14, 9, 6, 43, 34, 61, 13, 6, 43, 62, 62, 61, 43, 40, 61, 4, 15, 62, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]
current accuracy 7.4074074074074066
current ave accuracy 30.28490857717399
min accuracy in batch 3.225806451612903
max accuracy in batch 79.59183673469387
predicted expression
[27, 61, 57, 38, 30, 6, 9, 62, 23, 61, 20, 40, 61, 10, 62, 61, 57, 62, 3, 20, 40, 61, 9, 62, 61, 57, 62, 62, 61, 10, 62

In [None]:
print(np.count_nonzero(accs>10))

1868
