## INIT

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import json
from PIL import Image
import os

# Check if GPU is available, otherwise use CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
print(torch.cuda.get_device_name(0))
print(device)
# print(len(os.listdir('LaTex_data/split_1')))
# print(image_formula_mapping['0002475406d9932.png'])

NVIDIA GeForce RTX 3050 Ti Laptop GPU
cuda:0


## Load Data


In [3]:
transform = transforms.Compose([
    # do not resize
	transforms.ToTensor(),
	transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [4]:
import json
from torch.nn.utils.rnn import pad_sequence

label_to_index_file = './230k.json'
with open(label_to_index_file, 'r') as f:
	sign2id = json.load(f)

id2sign = [0] * 650
for k, v in sign2id.items():
	id2sign[int(v)] = k

def collate_fn(batch):
	# filter the pictures that have different weight or height
	size = batch[0][0].size()
	batch = [img_formula for img_formula in batch
			if img_formula[0].size() == size]
	
	# # sort by the length of formula
	# batch.sort(key=lambda img_formula: len(img_formula[1].split()),
	# 		reverse=True)

	imgs, formulas = zip(*batch)
	formulas = pad_sequence(formulas, batch_first=True, padding_value=2)
	
	imgs = torch.stack(imgs, dim=0)
	return imgs.to(device), formulas.to(device)


In [5]:
from torch.utils.data import Dataset, DataLoader
from os.path import join

class Im2LatexDataset(Dataset):
	def __init__(self, data_dir, split, max_len=30000):
		"""args:
		data_dir: root dir storing the prepoccessed data
		split: train, validate or test
		"""
		assert split in ["train", "validate", "test"]
		self.data_dir = data_dir
		self.split = split
		self.max_len = max_len
		self.pairs = self._load_pairs()

	def _load_pairs(self):
		pairs = torch.load(join(self.data_dir, "{}.pkl".format(self.split)))

		finite_pairs = []
		for i, (img, formula) in enumerate(pairs):
			pair = (img, " ".join(formula.split()))
			finite_pairs.append(pair)

			if i >= self.max_len:
				break
		
		return finite_pairs

	def __getitem__(self, idx):
		image, formula = self.pairs[idx]
		
		formula_tokens = '<S> ' + formula + ' <E>'
		formula_tokens = formula.split()
		
		formula_indices = []
		for token in formula_tokens:
			# Map each token to its index; if not found, use a default index (e.g., 0)
			index = sign2id.get(token, 0)  # Assuming 0 is for unknown tokens
			formula_indices.append(int(index))
		
		return image, torch.tensor(formula_indices, dtype=torch.long)

	def __len__(self):
		return len(self.pairs)

In [6]:
batch_size = 32

train_loader = DataLoader(
	Im2LatexDataset('./100k/', 'train'),
	batch_size=batch_size,
	collate_fn=collate_fn)

val_loader = DataLoader(
	Im2LatexDataset('./100k/', 'validate'),
	batch_size=batch_size,
	collate_fn=collate_fn)


  pairs = torch.load(join(self.data_dir, "{}.pkl".format(self.split)))


In [7]:
for i, data in enumerate(train_loader):
    img, label = data
    print(data)
    break

(tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]]],


        [[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
       

## Encoder / Decoder

In [8]:
class EncoderCNN(nn.Module):
	def __init__(self, emb_size, enc_out_dim=512):
		super(EncoderCNN, self).__init__()
		
		self.cnn_encoder = nn.Sequential(
			nn.Conv2d(3, 64, 3, 1, 1),
			nn.ReLU(),
			nn.MaxPool2d(2, 2, 1),

			nn.Conv2d(64, 128, 3, 1, 1),
			nn.ReLU(),
			nn.MaxPool2d(2, 2, 1),

			nn.Conv2d(128, 256, 3, 1, 1),
			nn.ReLU(),
			nn.Conv2d(256, 256, 3, 1, 1),
			nn.ReLU(),
			nn.MaxPool2d((2, 1), (2, 1), 0),

			nn.Conv2d(256, enc_out_dim, 3, 1, 0),
			nn.ReLU()
		)

		self.embedding = nn.Embedding(enc_out_dim, emb_size)

	def forward(self, images):
		
		features = self.cnn_encoder(images)

		# print(features.shape)

		features = features.permute(0, 2, 3, 1)  # [B, H', W', 512]
		B, H, W, _ = features.shape
		features = features.contiguous().view(B, H*W, -1)   # [B, H' * W', 512]

		# print(features.shape)

		features = features.mean(dim=1) # [B, 512]

		# print(features.shape) 

		# features = features.view(features.size(0), -1)  # Flatten to [batch_size, feature_dim]
		# features = self.fc(features)
		return features

class DecoderRNN(nn.Module):
	def __init__(self, embedding_dim, hidden_dim, vocab_size):
		super(DecoderRNN, self).__init__()
		self.embedding = nn.Embedding(vocab_size, embedding_dim)
		self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
		self.fc = nn.Linear(hidden_dim, vocab_size)

	def forward(self, features, formulas):
		# Embed the input formula tokens
		embeddings = self.embedding(formulas)

		# Concatenate features and embeddings along the sequence dimension
		embeddings = torch.cat((features.unsqueeze(1), embeddings), dim=1)
		
		# Pass through GRU and then through the final linear layer
		gru_out, _ = self.gru(embeddings)
		outputs = self.fc(gru_out)
		return outputs

class ImageToLaTeXModel(nn.Module):
	def __init__(self, encoder, decoder):
		super(ImageToLaTeXModel, self).__init__()
		self.encoder = encoder
		self.decoder = decoder

	def forward(self, images, formulas):
		# Encode the images
		features = self.encoder(images)  # Shape: [batch_size, feature_dim]
		
		# Decode to generate the LaTeX expression
		outputs = self.decoder(features, formulas[:, :-1])  # Skip the end token
		return outputs


## Save / Load model

In [None]:
# Hyperparameters
EMBED_SIZE = 512 # direct output dim from cv_tiny

hidden_size = 1024
num_epochs = 60
learning_rate = 0.003
# batch_size = 32

# Load dataset and dataloader
# dataset = LaTeXDataset("LaTex_data/split_1" , mapping_path, label_to_index_path, transform)

vocab_size = len(sign2id)

# Model, loss, and optimizer
encoder = EncoderCNN(EMBED_SIZE).to(device)
decoder = DecoderRNN(EMBED_SIZE, hidden_size, vocab_size).to(device)
model = ImageToLaTeXModel(encoder, decoder).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=2)
optimizer = optim.Adam(model.parameters())
start_epoch = 0

# Function to save model state
def save_training_state(model, optimizer, epoch, loss):
	state = {
		'model_state_dict': model.state_dict(),
		'optimizer_state_dict': optimizer.state_dict(),
		'epoch': epoch,
		'loss': loss
	}
	torch.save(state, 'model_checkpoint.pth')

# Function to load model state
def load_training_state(model, optimizer):
	checkpoint = torch.load('model_checkpoint.pth')
	model.load_state_dict(checkpoint['model_state_dict'])
	optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
	return checkpoint['epoch'], checkpoint['loss']

# Try to resume from a checkpoint
try:
	start_epoch, last_loss = load_training_state(model, optimizer)
	print(f"Resuming training from epoch {start_epoch}, with loss {last_loss:.4f}")
except FileNotFoundError:
	print("No saved model found, starting fresh.")
	start_epoch = 0


for epoch in range(start_epoch, num_epochs):
	for i, data in enumerate(train_loader):
		images, formulas = data
		targets = formulas[:, 1:]

		outputs = model(images, formulas[:, :-1])
		loss = criterion(outputs.view(-1, vocab_size), targets.contiguous().view(-1))

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		if i % 100 == 0:
			print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(train_loader)}], Loss: {loss.item():.4f}")

	save_training_state(model, optimizer, epoch, loss.item())




# Training loop
# for i in range(6):
# 	for folder_idx in range(start_folder_idx, len(folders)):
# 		print(f"Training on folder: {folders[folder_idx]}")
# 		dataset = LaTeXDataset("LaTex_data/" + folders[folder_idx], mapping_path, label_to_index_path, transform)
# 		dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# 		for epoch in range(start_epoch, num_epochs):
# 			for i, data in enumerate(dataloader):
# 				images, formulas = data
# 				targets = formulas[:, 1:]

# 				outputs = model(images, formulas[:, :-1])
# 				outputs = outputs.log_softmax(2)  # Apply log_softmax for CTC Loss

# 				loss = criterion(outputs, targets, input_lengths, target_lengths)
# 				# loss = criterion(outputs.view(-1, dataset.vocab_size), targets.contiguous().view(-1))

# 				optimizer.zero_grad()
# 				loss.backward()
# 				optimizer.step()

# 				if i % 100 == 0:
# 					print(f"Folder [{folder_idx+1}/{len(folders)}], Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}")

# 				# Save model periodically and at the end of each folder
# 				if i % 200 == 0 or (i == len(dataloader) - 1):
# 					save_training_state(model, optimizer, epoch, folder_idx, loss.item())
			
# 			# Reset start_epoch for next folder
# 			start_epoch = 0
# 		start_folder_idx = 0
# 		start_epoch = 0


No saved model found, starting fresh.


  checkpoint = torch.load('model_checkpoint.pth')


Epoch [1/60], Step [0/938], Loss: 6.3636
Epoch [1/60], Step [100/938], Loss: 2.8479
Epoch [1/60], Step [200/938], Loss: 2.8525
Epoch [1/60], Step [300/938], Loss: 2.6944
Epoch [1/60], Step [400/938], Loss: 2.7569
Epoch [1/60], Step [500/938], Loss: 2.4747
Epoch [1/60], Step [600/938], Loss: 2.5451
Epoch [1/60], Step [700/938], Loss: 2.2993
Epoch [1/60], Step [800/938], Loss: 2.3789
Epoch [1/60], Step [900/938], Loss: 2.3424
Epoch [2/60], Step [0/938], Loss: 2.8114
Epoch [2/60], Step [100/938], Loss: 2.2407
Epoch [2/60], Step [200/938], Loss: 2.5074
Epoch [2/60], Step [300/938], Loss: 2.3863
Epoch [2/60], Step [400/938], Loss: 2.4767
Epoch [2/60], Step [500/938], Loss: 2.1841
Epoch [2/60], Step [600/938], Loss: 2.3169
Epoch [2/60], Step [700/938], Loss: 2.1260
Epoch [2/60], Step [800/938], Loss: 2.2076
Epoch [2/60], Step [900/938], Loss: 2.1681
Epoch [3/60], Step [0/938], Loss: 2.5978
Epoch [3/60], Step [100/938], Loss: 2.0596
Epoch [3/60], Step [200/938], Loss: 2.3529
Epoch [3/60], Ste

In [None]:
asdfljasd;lkfasdf.sdf

In [None]:
def decode_formula(indices, id2sign):
    return ' '.join([id2sign[i.item()] for i in indices if i.item() and i.item() != 2])  # Skip padding


def validate_model(model, criterion, device):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():  # Disable gradient calculation
        for images, formulas in val_loader:
            images, formulas = images.to(device), formulas.to(device)
            outputs = model(images, formulas[:, :-1])  # Pass images and input sequence

            # Calculate loss
            loss = criterion(outputs.view(-1, outputs.size(-1)), formulas[:, 1:].contiguous().view(-1))
            total_loss += loss.item()

            # Calculate accuracy (if applicable)
            predicted_indices = torch.argmax(outputs, dim=2)  # Get the index of the max log-probability
            correct_predictions += (predicted_indices == formulas[:, 1:].contiguous()).sum().item()
            total_samples += formulas[:, 1:].numel()  # Total number of tokens in the validation batch

            # Print images and predictions
            for i in range(len(images)):
                # Decode the actual and predicted formulas
                actual_formula = decode_formula(formulas[i, 1:], id2sign)  # Skip <S> token
                predicted_formula = decode_formula(predicted_indices[i, 1:], id2sign)  # Skip <S> token
                # print(f'Image: {images[i]}')  # This will print the tensor, consider using visualization instead
                print(f'Actual Formula: {actual_formula}')
                print(f'Predicted Formula: {predicted_formula}')
                print('-' * 50)

    avg_loss = total_loss / len(val_loader)
    accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0

    return avg_loss, accuracy
# Assuming you have your model, dataloader, criterion, and device set up
# Assuming 230k.json is loaded as label_to_index

val_loss, val_accuracy = validate_model(model, criterion, device)
print(val_accuracy)


Actual Formula: _ { 1 } ^ { k } = \omega _ { 1 } ^ { k - 2 } \subseteq \omega _ { 1 } ^ { k }
Predicted Formula: { 0 } = \mu \mu } ( x ^ { 0 } , \omega a } + } + A ^ { l l l 0 k } g _ \omega \omega , , , 1 , , , , ,
--------------------------------------------------
Actual Formula: _ { i j } = \bar { g } _ { i j } + h _ { i j } ,
Predicted Formula: { \mu j } ^ { _ v } _ _ j j } ^ { \bar \bar j j } } \bar \bar { \gamma \dot . } . } . . . . . . . . . . .
--------------------------------------------------
Actual Formula: \theta _ { n } ^ { \alpha \Lambda } } { ^ \dagger } = \theta _ { - n } ^ { \alpha \Lambda } ,
Predicted Formula: L { r } ^ ^ \prime } } \theta \theta e L } = _ _ { + } } { { \alpha } k k _ { z \gamma { \beta { \beta \beta \beta \beta \beta \beta \beta \beta
--------------------------------------------------
Actual Formula: _ { \mathrm { S G } } = 2 g _ { \mathrm { e f f } } .
Predicted Formula: { n } c s } } ^ { _ + ( \mathrm { T T f } } ^ w = C C 1 1 1 1 . . . . . . . . 

KeyboardInterrupt: 