## INIT

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import json
from PIL import Image
import os

# Check if GPU is available, otherwise use CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
print(torch.cuda.get_device_name(0))
print(device)
# print(len(os.listdir('LaTex_data/split_1')))
# print(image_formula_mapping['0002475406d9932.png'])

NVIDIA GeForce RTX 4060 Ti
cuda:0


## Load Data


In [3]:
transform = transforms.Compose([
    # do not resize
	# transforms.ColorJitter(brightness=0.2, contrast=0.2),
	# transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 1.0)),
	transforms.ToTensor(),
	transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    
])

In [4]:
import json
from torch.nn.utils.rnn import pad_sequence

label_to_index_file = './230k.json'
with open(label_to_index_file, 'r') as f:
	sign2id = json.load(f)

id2sign = [0] * 650
for k, v in sign2id.items():
	id2sign[int(v)] = k

def collate_fn(batch):
	# filter the pictures that have different weight or height
	size = batch[0][0].size()
	batch = [img_formula for img_formula in batch
			if img_formula[0].size() == size]
	
	# # sort by the length of formula
	# batch.sort(key=lambda img_formula: len(img_formula[1].split()),
	# 		reverse=True)

	imgs, formulas = zip(*batch)
	formulas = pad_sequence(formulas, batch_first=True, padding_value=2)
	
	imgs = torch.stack(imgs, dim=0)
	return imgs.to(device), formulas.to(device)


In [5]:
from torch.utils.data import Dataset, DataLoader
from os.path import join

class Im2LatexDataset(Dataset):
	def __init__(self, data_dir, split, max_len=30000):
		"""args:
		data_dir: root dir storing the prepoccessed data
		split: train, validate or test
		"""
		assert split in ["train", "validate", "test"]
		self.data_dir = data_dir
		self.split = split
		self.max_len = max_len
		self.pairs = self._load_pairs()

	def _load_pairs(self):
		pairs = torch.load(join(self.data_dir, "{}.pkl".format(self.split)))

		finite_pairs = []
		for i, (img, formula) in enumerate(pairs):
			pair = (img, " ".join(formula.split()))
			finite_pairs.append(pair)

			if i >= self.max_len:
				break
		
		return finite_pairs

	def __getitem__(self, idx):
		image, formula = self.pairs[idx]
		
		formula_tokens = '<S> ' + formula + ' <E> <E> '
		formula_tokens = formula.split()
		
		formula_indices = []
		for token in formula_tokens:
			# Map each token to its index; if not found, use a default index (e.g., 0)
			index = sign2id.get(token, 0)  # Assuming 0 is for unknown tokens
			formula_indices.append(int(index))
		
		return image, torch.tensor(formula_indices, dtype=torch.long)

	def __len__(self):
		return len(self.pairs)

In [6]:
batch_size = 32

train_loader = DataLoader(
	Im2LatexDataset('./100k/', 'train'),
	batch_size=batch_size,
	collate_fn=collate_fn)

val_loader = DataLoader(
	Im2LatexDataset('./100k/', 'validate'),
	batch_size=batch_size,
	collate_fn=collate_fn)


In [7]:
for i, data in enumerate(train_loader):
    img, label = data
    print(data)
    break

(tensor([[[[0.6816, 0.6816, 0.6816,  ..., 0.6816, 0.6816, 0.6816],
          [0.6816, 0.6816, 0.6816,  ..., 0.6816, 0.6816, 0.6816],
          [0.6816, 0.6816, 0.6816,  ..., 0.6816, 0.6816, 0.6816],
          ...,
          [0.6816, 0.6816, 0.6816,  ..., 0.6816, 0.6816, 0.6816],
          [0.6816, 0.6816, 0.6816,  ..., 0.6816, 0.6816, 0.6816],
          [0.6816, 0.6816, 0.6816,  ..., 0.6816, 0.6816, 0.6816]],

         [[0.9911, 0.9911, 0.9911,  ..., 0.9911, 0.9911, 0.9911],
          [0.9911, 0.9911, 0.9911,  ..., 0.9911, 0.9911, 0.9911],
          [0.9911, 0.9911, 0.9911,  ..., 0.9911, 0.9911, 0.9911],
          ...,
          [0.9911, 0.9911, 0.9911,  ..., 0.9911, 0.9911, 0.9911],
          [0.9911, 0.9911, 0.9911,  ..., 0.9911, 0.9911, 0.9911],
          [0.9911, 0.9911, 0.9911,  ..., 0.9911, 0.9911, 0.9911]],

         [[0.7352, 0.7352, 0.7352,  ..., 0.7352, 0.7352, 0.7352],
          [0.7352, 0.7352, 0.7352,  ..., 0.7352, 0.7352, 0.7352],
          [0.7352, 0.7352, 0.7352,  ..., 

In [8]:
def beam_search(self, images, beam_width=3, max_len=100, start_token=0, end_token=1):
    features = self.encoder(images)
    B = features.size(0)
    hidden = torch.zeros(1, B, self.decoder.gru.hidden_size, device=features.device)

    sequences = [[(start_token, 0.0)]] * B  # Each batch starts with <SOS>
    completed_sequences = [[] for _ in range(B)]

    for _ in range(max_len):
        all_candidates = [[] for _ in range(B)]
        for b in range(B):
            if sequences[b][-1][0] == end_token:  # If sequence ends
                completed_sequences[b].append(sequences[b])
                continue

            input_token = torch.tensor([sequences[b][-1][0]], device=features.device).unsqueeze(0)
            embedding = self.decoder.embedding(input_token).squeeze(1)
            context, _ = self.decoder.attention(features[b:b+1], hidden[:, b:b+1, :].squeeze(1))
            gru_input = torch.cat((embedding, context), dim=1).unsqueeze(1)
            gru_out, hidden[:, b:b+1, :] = self.decoder.gru(gru_input, hidden[:, b:b+1, :])
            logits = self.decoder.fc(gru_out.squeeze(1))
            probs = torch.log_softmax(logits, dim=1)

            for token_idx in range(len(probs[0])):
                prob = probs[0, token_idx].item()
                new_seq = sequences[b] + [(token_idx, prob)]
                all_candidates[b].append((new_seq, sum(prob for _, prob in new_seq)))

        # Prune sequences
        for b in range(B):
            all_candidates[b].sort(key=lambda x: x[1], reverse=True)
            sequences[b] = all_candidates[b][:beam_width]

    # Return best sequences
    return [[token for token, _ in seq[:-1]] for seq in completed_sequences]


## Encoder / Decoder

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Attention(nn.Module):
    def __init__(self, enc_out_dim, dec_hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(enc_out_dim + dec_hidden_dim, dec_hidden_dim)
        self.v = nn.Parameter(torch.rand(dec_hidden_dim))

    def forward(self, encoder_outputs, hidden):
        B, seq_len, enc_dim = encoder_outputs.shape
        hidden = hidden.unsqueeze(1).repeat(1, seq_len, 1)
        energy = torch.tanh(self.attn(torch.cat((encoder_outputs, hidden), dim=2)))
        energy = energy @ (self.v / torch.sqrt(torch.tensor(enc_dim, dtype=torch.float)))
        attn_weights = F.softmax(energy, dim=1)
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)
        return context, attn_weights


class EncoderCNN(nn.Module):
    def __init__(self, enc_out_dim=512, dropout_prob=0):
        super(EncoderCNN, self).__init__()
        self.cnn_encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 1),


            nn.Conv2d(64, 128, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 1),


            nn.Conv2d(128, 256, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d((2, 1), (2, 1), 0),


            nn.Conv2d(256, enc_out_dim, 3, 1, 0),
            nn.ReLU(),

        )

    def forward(self, images):
        features = self.cnn_encoder(images)
        features = features.permute(0, 2, 3, 1)
        B, H, W, C = features.shape
        features = features.contiguous().view(B, H * W, C)
        return features


class DecoderRNN(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size, enc_out_dim, dropout_prob=0.3):
        super(DecoderRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.attention = Attention(enc_out_dim, hidden_dim)
        self.lstm = nn.LSTM(embedding_dim + enc_out_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, features, formulas):
        embeddings = self.dropout(self.embedding(formulas))
        
        # Initialize LSTM hidden and cell states
        hidden = torch.zeros(1, features.size(0), self.lstm.hidden_size, device=features.device)
        cell = torch.zeros(1, features.size(0), self.lstm.hidden_size, device=features.device)
        outputs = []

        for t in range(embeddings.size(1)):
            context, _ = self.attention(features, hidden.squeeze(0))
            lstm_input = torch.cat((embeddings[:, t, :], context), dim=1).unsqueeze(1)
            lstm_out, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
            outputs.append(self.fc(lstm_out.squeeze(1)))

        outputs = torch.stack(outputs, dim=1)
        return outputs



class ImageToLaTeXModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(ImageToLaTeXModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, images, formulas):
        features = self.encoder(images)
        outputs = self.decoder(features, formulas[:, :-1])
        return outputs

    def beam_search(self, images, beam_width=3, max_len=100, start_token=0, end_token=2):
        features = self.encoder(images)
        B = features.size(0)
        hidden = torch.zeros(1, B, self.decoder.gru.hidden_size, device=features.device)

        sequences = [[(start_token, 0.0)]] * B  
        completed_sequences = [[] for _ in range(B)]

        for _ in range(max_len):
            all_candidates = [[] for _ in range(B)]
            for b in range(B):
                if sequences[b][-1][0] == end_token:  # If sequence ends
                    completed_sequences[b].append(sequences[b])
                    continue

                input_token = torch.tensor([sequences[b][-1][0]], device=features.device).unsqueeze(0)
                embedding = self.decoder.embedding(input_token).squeeze(1)
                context, _ = self.decoder.attention(features[b:b+1], hidden[:, b:b+1, :].squeeze(1))
                gru_input = torch.cat((embedding, context), dim=1).unsqueeze(1)
                gru_out, hidden[:, b:b+1, :] = self.decoder.gru(gru_input, hidden[:, b:b+1, :])
                logits = self.decoder.fc(gru_out.squeeze(1))
                probs = torch.log_softmax(logits, dim=1)

                for token_idx in range(len(probs[0])):
                    prob = probs[0, token_idx].item()
                    new_seq = sequences[b] + [(token_idx, prob)]
                    all_candidates[b].append((new_seq, sum(prob for _, prob in new_seq)))

            # Prune sequences
            for b in range(B):
                all_candidates[b].sort(key=lambda x: x[1], reverse=True)
                sequences[b] = all_candidates[b][:beam_width]

        # Return best sequences
        return [[token for token, _ in seq[:-1]] for seq in completed_sequences]



## Save / Load model

In [10]:
# Hyperparameters
EMBED_SIZE = 512 # direct output dim from cv_tiny

hidden_size = 1024
num_epochs = 200
learning_rate = 0.003
# batch_size = 32

# Load dataset and dataloader
# dataset = LaTeXDataset("LaTex_data/split_1" , mapping_path, label_to_index_path, transform)

vocab_size = len(sign2id)

# Model, loss, and optimizer
encoder = EncoderCNN(EMBED_SIZE).to(device)
decoder = DecoderRNN(EMBED_SIZE, hidden_size, vocab_size,512).to(device)
model = ImageToLaTeXModel(encoder, decoder).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=1)
optimizer = optim.Adam(model.parameters())
start_epoch = 0
# augementeation = transforms.Compose([
#         # Geometric transformations
#         transforms.RandomResizedCrop(scale=(0.9, 1.1)),
#         transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=None, shear=None),

#         # Photometric transformations
#         transforms.ColorJitter(brightness=0.2, contrast=0.2),
#         transforms.RandomApply([transforms.Grayscale(num_output_channels=3)], p=1),

#         # To tensor (if images are PIL) and normalize
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.5], std=[0.5]),
#     ])
# Function to save model state
def save_training_state(model, optimizer, epoch, loss):
	state = {
		'model_state_dict': model.state_dict(),
		'optimizer_state_dict': optimizer.state_dict(),
		'epoch': epoch,
		'loss': loss
	}
	torch.save(state, 'model_checkpoint.pth')

# Function to load model state
def load_training_state(model, optimizer):
	checkpoint = torch.load('model_checkpoint.pth')
	model.load_state_dict(checkpoint['model_state_dict'])
	optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
	return checkpoint['epoch'], checkpoint['loss']

# Try to resume from a checkpoint
try:
	start_epoch, last_loss = load_training_state(model, optimizer)
	print(f"Resuming training from epoch {start_epoch}, with loss {last_loss:.4f}")
except FileNotFoundError:
	print("No saved model found, starting fresh.")
	start_epoch = 0


for epoch in range(start_epoch, num_epochs):

	
	for i, data in enumerate(train_loader):
		images, formulas = data
		# print(type(images))
		# Pad sequences to the same length
		formulas_padded = nn.utils.rnn.pad_sequence(formulas, batch_first=True, padding_value=1)
		targets = formulas_padded[:, 1:].contiguous()

		outputs = model(images, formulas_padded[:, :-1].contiguous())
		# Match target size with output size
		targets = targets[:, :outputs.size(1)].contiguous()

		# print(targets.size(), outputs.size())  # Ensure they match
		loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
		
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		if i  % 100 == 0:
			print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(train_loader)}], Loss: {loss.item():.4f}")

	save_training_state(model, optimizer, epoch, loss.item())




# Training loop
# for i in range(6):
# 	for folder_idx in range(start_folder_idx, len(folders)):
# 		print(f"Training on folder: {folders[folder_idx]}")
# 		dataset = LaTeXDataset("LaTex_data/" + folders[folder_idx], mapping_path, label_to_index_path, transform)
# 		dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# 		for epoch in range(start_epoch, num_epochs):
# 			for i, data in enumerate(dataloader):
# 				images, formulas = data
# 				targets = formulas[:, 1:]

# 				outputs = model(images, formulas[:, :-1])
# 				outputs = outputs.log_softmax(2)  # Apply log_softmax for CTC Loss

# 				loss = criterion(outputs, targets, input_lengths, target_lengths)
# 				# loss = criterion(outputs.view(-1, dataset.vocab_size), targets.contiguous().view(-1))

# 				optimizer.zero_grad()
# 				loss.backward()
# 				optimizer.step()

# 				if i % 100 == 0:
# 					print(f"Folder [{folder_idx+1}/{len(folders)}], Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}")

# 				# Save model periodically and at the end of each folder
# 				if i % 200 == 0 or (i == len(dataloader) - 1):
# 					save_training_state(model, optimizer, epoch, folder_idx, loss.item())
			
# 			# Reset start_epoch for next folder
# 			start_epoch = 0
# 		start_folder_idx = 0
# 		start_epoch = 0


Resuming training from epoch 7, with loss 0.3368
Epoch [8/200], Step [0/938], Loss: 0.4480
Epoch [8/200], Step [100/938], Loss: 0.3170
Epoch [8/200], Step [200/938], Loss: 0.4736
Epoch [8/200], Step [300/938], Loss: 0.5339
Epoch [8/200], Step [400/938], Loss: 0.5216
Epoch [8/200], Step [500/938], Loss: 0.4589
Epoch [8/200], Step [600/938], Loss: 0.6162
Epoch [8/200], Step [700/938], Loss: 0.4908
Epoch [8/200], Step [800/938], Loss: 0.5997
Epoch [8/200], Step [900/938], Loss: 0.4242
Epoch [9/200], Step [0/938], Loss: 0.4028
Epoch [9/200], Step [100/938], Loss: 0.2465
Epoch [9/200], Step [200/938], Loss: 0.3848
Epoch [9/200], Step [300/938], Loss: 0.4518
Epoch [9/200], Step [400/938], Loss: 0.4238
Epoch [9/200], Step [500/938], Loss: 0.3878
Epoch [9/200], Step [600/938], Loss: 0.5160
Epoch [9/200], Step [700/938], Loss: 0.4035
Epoch [9/200], Step [800/938], Loss: 0.5317
Epoch [9/200], Step [900/938], Loss: 0.3575
Epoch [10/200], Step [0/938], Loss: 0.3058
Epoch [10/200], Step [100/938], 

KeyboardInterrupt: 

In [None]:
asdfljasd;lkfasdf.sdf

In [11]:
def decode_formula(indices, id2sign):
    return ' '.join([id2sign[i.item()] for i in indices if i.item() and i.item() != 2])  # Skip padding


def validate_model(model, criterion, device):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():  # Disable gradient calculation
        for images, formulas in val_loader:
            images, formulas = images.to(device), formulas.to(device)
            formulas_padded = nn.utils.rnn.pad_sequence(formulas, batch_first=True, padding_value=2)
            targets = formulas_padded[:, 1:].contiguous()

            outputs = model(images, formulas_padded[:, :-1].contiguous())
            # Match target size with output size
            targets = targets[:, :outputs.size(1)].contiguous()

            # print(targets.size(), outputs.size())  # Ensure they match
            loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))

            # Calculate accuracy (if applicable)
            predicted_indices = torch.argmax(outputs, dim=2)  # Get the index of the max log-probability
            # correct_predictions += (predicted_indices == formulas[:, 1:].contiguous()).sum().item()
            total_samples += formulas[:, 1:].numel()  # Total number of tokens in the validation batch

            # Print images and predictions
            for i in range(len(images)):
                # Decode the actual and predicted formulas
                actual_formula = decode_formula(formulas[i, 1:], id2sign)  # Skip <S> token
                predicted_formula = decode_formula(predicted_indices[i, 1:], id2sign)  # Skip <S> token
                # print(f'Image: {images[i]}')  # This will print the tensor, consider using visualization instead
                print(f'Actual Formula: {actual_formula}')
                print(f'Predicted Formula: {predicted_formula}')
                print('-' * 50)

    avg_loss = total_loss / len(val_loader)
    accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0

    return avg_loss, accuracy
# Assuming you have your model, dataloader, criterion, and device set up
# Assuming 230k.json is loaded as label_to_index

val_loss, val_accuracy = validate_model(model, criterion, device)
print(val_accuracy)


Actual Formula: _ { 1 } ^ { k } = \omega _ { 1 } ^ { k - 2 } \subseteq \omega _ { 1 } ^ { k }
Predicted Formula: { 1 } ^ { k } = \omega _ { 1 } ^ { k - 2 } \wedge \omega _ { 1 } ^ { k - 2
--------------------------------------------------
Actual Formula: _ { i j } = \bar { g } _ { i j } + h _ { i j } ,
Predicted Formula: { i j } = \hat { g } _ { i j } + h _ { i j } , h
--------------------------------------------------
Actual Formula: \theta _ { n } ^ { \alpha \Lambda } } { ^ \dagger } = \theta _ { - n } ^ { \alpha \Lambda } ,
Predicted Formula: _ { \kappa } ^ { \alpha N } = = \theta { _ = \theta _ { - \alpha } ^ { \alpha \lambda } ,
--------------------------------------------------
Actual Formula: _ { \mathrm { S G } } = 2 g _ { \mathrm { e f f } } .
Predicted Formula: { S { S G } } = 2 g _ { \mathrm { e f f } } =
--------------------------------------------------
Actual Formula: \chi _ { D } ^ { \prime \prime } , \chi _ { D } ^ { \prime \prime } \} = i \chi ^ { \prime \prime } .
Pre

KeyboardInterrupt: 