In [1]:

from os.path import join
import argparse

from PIL import Image
from torchvision import transforms
import torch

import pandas as pd

import json

def preprocess(data_dir, split):
	assert split in ["train", "validate", "test"]

	print("Process {} dataset...".format(split))
	images_dir = join(data_dir, "generated_png_images")

	# split_file = join(data_dir, f"im2latex_{split}.csv")
	pairs = []
	transform = transforms.ToTensor()
	
	# df = pd.read_csv(split_file)

	# df = df.dropna(subset=['formula', 'image'])
	
	# # Create a dictionary from the DataFrame with 'img_name' as keys and 'formula' as values
	# data = pd.Series(df['formula'].values, index=df['image']).to_dict()

	with open('./image_formula_mapping_crt.json', 'r') as f:
		data: dict = json.load(f)

	k_220 = 120 * 1000
	k_230 = 130 * 1000

	split_index = int(220 * 1000 * 0.8)
	if split == 'train':
		data = dict(list(data.items())[:split_index])
	elif split == 'validate':
		data = dict(list(data.items())[split_index:k_220])
	elif split == 'test':
		data = dict(list(data.items())[k_220:k_230])
	
	for (k, v) in data.items():
		img_name, formula = k, v
		# load img and its corresponding formula
		img_path = join(images_dir, img_name)
		img = Image.open(img_path)
		
		x, y = img.size
		if y < 32:
			continue

		img_tensor = transform(img)
		pair = (img_tensor, formula)
		pairs.append(pair)
	pairs.sort(key=img_size)
	
	out_file = join(data_dir, "{}.pkl".format(split))
	torch.save(pairs, out_file)
	print("Save {} dataset to {}".format(split, out_file))


def img_size(pair):
	img, formula = pair
	return tuple(img.size())


In [2]:
from concurrent.futures import ThreadPoolExecutor

def preprocess(data_dir, split):
	assert split in ["train", "validate", "test"]

	print(f"Processing {split} dataset...")
	images_dir = join(data_dir, "generated_png_images")

	with open('./image_formula_mapping_crt.json', 'r') as f:
		data: dict = json.load(f)

	k_220 = 100 * 1000
	k_230 = 110 * 1000
	split_index = int(k_220 * 0.8)

	# Split data for training, validation, and test
	if split == 'train':
		data = dict(list(data.items())[:split_index])
	elif split == 'validate':
		data = dict(list(data.items())[split_index:k_220])
	elif split == 'test':
		data = dict(list(data.items())[k_220:k_230])

	transform = transforms.ToTensor()
	pairs = []

	def process_image(item):
		img_name, formula = item
		
		if len(formula.split()) > 150:
			return None
		
		img_path = join(images_dir, img_name)
		try:
			with Image.open(img_path) as img:
				x, y = img.size

				if y > 64:
					return None

				if y < 32:
					# x/y = x'/32 => x'= 32*x/y
					img = img.resize((int(32 * x / y), 32), Image.Resampling.LANCZOS)

				img_tensor = transform(img)
				return (img_tensor, formula)
		except Exception as e:
			print(f"Failed to process {img_name}: {e}")
			return None

	# Use multithreading to speed up image processing
	with ThreadPoolExecutor() as executor:
		results = executor.map(process_image, data.items())

	# Filter out None results
	pairs = [pair for pair in results if pair is not None]

	# Sort pairs by image size
	pairs.sort(key=lambda pair: pair[0].shape[1:])

	# Save to a file
	out_file = join(data_dir, f"{split}.pkl")
	torch.save(pairs, out_file)
	print(f"Saved {split} dataset to {out_file}")


In [3]:
preprocess('./archive/PRINTED_TEX_230k/', 'train')

Processing train dataset...
Saved train dataset to ./archive/PRINTED_TEX_230k/train.pkl


In [4]:
preprocess('./archive/PRINTED_TEX_230k/', 'validate')

Processing validate dataset...
Saved validate dataset to ./archive/PRINTED_TEX_230k/validate.pkl


In [5]:
preprocess('./archive/PRINTED_TEX_230k/', 'test')

Processing test dataset...
Saved test dataset to ./archive/PRINTED_TEX_230k/test.pkl
