In [None]:
from os.path import join

from torch.utils.data import Dataset
import torch

class Im2LatexDataset(Dataset):
    def __init__(self, data_dir, split, max_len):
        """args:
        data_dir: root dir storing the prepoccessed data
        split: train, validate or test
        """
        assert split in ["train", "validate", "test"]
        self.data_dir = data_dir
        self.split = split
        self.max_len = max_len
        self.pairs = self._load_pairs()

    def _load_pairs(self):
        pairs = torch.load(join(self.data_dir, "{}.pkl".format(self.split)))
        for i, (img, formula) in enumerate(pairs):
            pair = (img, " ".join(formula.split()[:self.max_len]))
            pairs[i] = pair
        return pairs

    def __getitem__(self, index):
        return self.pairs[index]

    def __len__(self):
        return len(self.pairs)

In [20]:

from os.path import join
import argparse

from PIL import Image
from torchvision import transforms
import torch

import pandas as pd

def preprocess(data_dir, split):
	assert split in ["train", "validate", "test"]

	print("Process {} dataset...".format(split))
	images_dir = join(data_dir, "formula_images_processed")

	split_file = join(data_dir, f"im2latex_{split}.csv")
	pairs = []
	transform = transforms.ToTensor()
	
	df = pd.read_csv(split_file)

	df = df.dropna(subset=['formula', 'image'])
	
	# Create a dictionary from the DataFrame with 'img_name' as keys and 'formula' as values
	data = pd.Series(df['formula'].values, index=df['image']).to_dict()
	
	for k, v in data.items():
		img_name, formula = k, v
		# load img and its corresponding formula
		img_path = join(images_dir, img_name)
		img = Image.open(img_path)
		img_tensor = transform(img)
		# formula = formulas[int(formula_id)]
		pair = (img_tensor, formula)
		pairs.append(pair)
	pairs.sort(key=img_size)

	out_file = join(data_dir, "{}.pkl".format(split))
	torch.save(pairs, out_file)
	print("Save {} dataset to {}".format(split, out_file))


def img_size(pair):
	img, formula = pair
	return tuple(img.size())


In [21]:
preprocess('./100k/', 'train')

Process train dataset...
Save train dataset to ./100k/train.pkl


In [22]:
preprocess('./100k/', 'validate')

Process validate dataset...
Save validate dataset to ./100k/validate.pkl
