# prototypes

In [1]:
import random
from datasets import load_dataset
from transformers import AutoTokenizer


ds = load_dataset("deepcopy/MathWriting-Human")

def resize_image(example, size=(224, 224)):
    example["image"] = example["image"].resize(size)
    return example

ds["train"] = ds["train"].map(resize_image)
ds["val"] = ds["val"].map(resize_image)
ds["test"] = ds["test"].map(resize_image)

import random
latex_pool = ds["train"]["latex"]

def add_binary_label(example, latex_list):
    if random.random() > 0.5:
        example["label"] = 1
        example["latex_used"] = example["latex"]
    else:
        wrong_latex = random.choice(latex_list)
        while wrong_latex == example["latex"]:
            wrong_latex = random.choice(latex_list)
        example["label"] = 0
        example["latex_used"] = wrong_latex
    return example

ds["train"] = ds["train"].map(lambda x: add_binary_label(x, latex_pool))
ds["val"] = ds["val"].map(lambda x: add_binary_label(x, latex_pool))
ds["test"] = ds["test"].map(lambda x: add_binary_label(x, latex_pool))

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

def tokenize_latex(example):
    tokens = tokenizer(example["latex_used"], truncation=True, padding="max_length", max_length=64)
    example["latex_ids"] = tokens["input_ids"]
    return example

ds["train"] = ds["train"].map(tokenize_latex)
ds["val"] = ds["val"].map(tokenize_latex)
ds["test"] = ds["test"].map(tokenize_latex)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
tf_train = ds["train"].to_tf_dataset(
    columns=["image", "latex_ids"],
    label_cols=["label"],
    shuffle=True,
    batch_size=4
)

Map:   0%|          | 0/15674 [00:00<?, ? examples/s]

2025-07-12 13:14:29.439678: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Old behaviour: columns=['a'], labels=['labels'] -> (tf.Tensor, tf.Tensor)  
             : columns='a', labels='labels' -> (tf.Tensor, tf.Tensor)  
New behaviour: columns=['a'],labels=['labels'] -> ({'a': tf.Tensor}, {'labels': tf.Tensor})  
             : columns='a', labels='labels' -> (tf.Tensor, tf.Tensor) 


In [2]:
from tensorflow import keras
from tensorflow.keras import layers


img_input = keras.Input(shape=(224, 224, 3), name="image", dtype="float32")
x_img = layers.Rescaling(1.0 / 255)(img_input)
x_img = layers.Conv2D(32, 3, activation="relu")(x_img)
x_img = layers.MaxPooling2D()(x_img)
x_img = layers.Conv2D(64, 3, activation="relu")(x_img)
x_img = layers.GlobalAveragePooling2D()(x_img)

txt_input = keras.Input(shape=(64,), name="latex_ids", dtype="int32")
x_txt = layers.Embedding(input_dim=50257, output_dim=32)(txt_input)
x_txt = layers.GlobalAveragePooling1D()(x_txt)

x = layers.Concatenate()([x_img, x_txt])
x = layers.Dense(128, activation="relu")(x)
x = layers.Dropout(0.3)(x)
output = layers.Dense(1, activation="sigmoid")(x)

model = keras.Model(inputs=[img_input, txt_input], outputs=output)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
model.summary()
model.fit(tf_train, epochs=3)

Epoch 1/3
[1m57466/57466[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5626s[0m 98ms/step - accuracy: 0.4994 - loss: 0.6939
Epoch 2/3
[1m57466/57466[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7094s[0m 123ms/step - accuracy: 0.5008 - loss: 0.6933
Epoch 3/3
[1m29682/57466[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m1:10:42[0m 153ms/step - accuracy: 0.4968 - loss: 0.6933

KeyboardInterrupt: 

In [None]:
model.save("version-cv2.keras")

In [10]:
tf_train = ds["train"].to_tf_dataset(
    columns=["image", "latex_ids"],
    label_cols=["label"],
    shuffle=True,
    batch_size=16
)

tf_val = ds["val"].to_tf_dataset(
    columns=["image", "latex_ids"],
    label_cols=["label"],
    shuffle=False,
    batch_size=16
)