## INIT

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import json
from PIL import Image
import os

# Check if GPU is available, otherwise use CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
print(torch.cuda.get_device_name(0))
print(device)
# print(len(os.listdir('LaTex_data/split_1')))
# print(image_formula_mapping['0002475406d9932.png'])

NVIDIA GeForce RTX 4060 Ti
cuda:0


## Load Data


In [3]:
transform = transforms.Compose([
    # do not resize
	transforms.ToTensor(),
	transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [4]:
import json
from torch.nn.utils.rnn import pad_sequence

label_to_index_file = './230k.json'
with open(label_to_index_file, 'r') as f:
	sign2id = json.load(f)

id2sign = [0] * 650
for k, v in sign2id.items():
	id2sign[int(v)] = k

def collate_fn(batch):
	# filter the pictures that have different weight or height
	size = batch[0][0].size()
	batch = [img_formula for img_formula in batch
			if img_formula[0].size() == size]
	
	# # sort by the length of formula
	# batch.sort(key=lambda img_formula: len(img_formula[1].split()),
	# 		reverse=True)

	imgs, formulas = zip(*batch)
	formulas = pad_sequence(formulas, batch_first=True, padding_value=2)
	
	imgs = torch.stack(imgs, dim=0)
	return imgs.to(device), formulas.to(device)


In [5]:
from torch.utils.data import Dataset, DataLoader
from os.path import join

class Im2LatexDataset(Dataset):
	def __init__(self, data_dir, split, max_len=32 * 750):
		"""args:
		data_dir: root dir storing the prepoccessed data
		split: train, validate or test
		"""
		assert split in ["train", "validate", "test"]
		self.data_dir = data_dir
		self.split = split
		self.max_len = max_len
		self.pairs = self._load_pairs()

	def _load_pairs(self):
		pairs = torch.load(join(self.data_dir, "{}.pkl".format(self.split)))

		finite_pairs = []
		for i, (img, formula) in enumerate(pairs):
			pair = (img, " ".join(formula.split()))
			finite_pairs.append(pair)

			if i >= self.max_len:
				break
		
		return finite_pairs

	def __getitem__(self, idx):
		image, formula = self.pairs[idx]
		
		formula_tokens = '<S> ' + formula + ' <E> '
		formula_tokens = formula.split()
		
		formula_indices = []
		for token in formula_tokens:
			# Map each token to its index; if not found, use a default index (e.g., 0)
			index = sign2id.get(token, 0)  # Assuming 0 is for unknown tokens
			formula_indices.append(int(index))
		
		return image, torch.tensor(formula_indices, dtype=torch.long)

	def __len__(self):
		return len(self.pairs)

In [14]:
batch_size = 32

train_loader = DataLoader(
	Im2LatexDataset('./100k/', 'train'),
	batch_size=batch_size,
	collate_fn=collate_fn)

val_loader = DataLoader(
	Im2LatexDataset('./100k/', 'validate'),
    batch_size=32,
	collate_fn=collate_fn)


In [7]:
for i, data in enumerate(train_loader):
    img, label = data
    print(data)
    break

(tensor([[[[0.6816, 0.6816, 0.6816,  ..., 0.6816, 0.6816, 0.6816],
          [0.6816, 0.6816, 0.6816,  ..., 0.6816, 0.6816, 0.6816],
          [0.6816, 0.6816, 0.6816,  ..., 0.6816, 0.6816, 0.6816],
          ...,
          [0.6816, 0.6816, 0.6816,  ..., 0.6816, 0.6816, 0.6816],
          [0.6816, 0.6816, 0.6816,  ..., 0.6816, 0.6816, 0.6816],
          [0.6816, 0.6816, 0.6816,  ..., 0.6816, 0.6816, 0.6816]],

         [[0.9911, 0.9911, 0.9911,  ..., 0.9911, 0.9911, 0.9911],
          [0.9911, 0.9911, 0.9911,  ..., 0.9911, 0.9911, 0.9911],
          [0.9911, 0.9911, 0.9911,  ..., 0.9911, 0.9911, 0.9911],
          ...,
          [0.9911, 0.9911, 0.9911,  ..., 0.9911, 0.9911, 0.9911],
          [0.9911, 0.9911, 0.9911,  ..., 0.9911, 0.9911, 0.9911],
          [0.9911, 0.9911, 0.9911,  ..., 0.9911, 0.9911, 0.9911]],

         [[0.7352, 0.7352, 0.7352,  ..., 0.7352, 0.7352, 0.7352],
          [0.7352, 0.7352, 0.7352,  ..., 0.7352, 0.7352, 0.7352],
          [0.7352, 0.7352, 0.7352,  ..., 

## Encoder / Decoder

In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Attention(nn.Module):
    def __init__(self, enc_out_dim, dec_hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(enc_out_dim + dec_hidden_dim, dec_hidden_dim)
        self.v = nn.Parameter(torch.rand(dec_hidden_dim))

    def forward(self, encoder_outputs, hidden):
        B, seq_len, enc_dim = encoder_outputs.shape
        hidden = hidden.unsqueeze(1).repeat(1, seq_len, 1)
        energy = torch.tanh(self.attn(torch.cat((encoder_outputs, hidden), dim=2)))
        energy = energy @ (self.v / torch.sqrt(torch.tensor(enc_dim, dtype=torch.float)))
        attn_weights = F.softmax(energy, dim=1)
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)
        return context, attn_weights


class EncoderCNN(nn.Module):
    def __init__(self, enc_out_dim=512, dropout_prob=0):
        super(EncoderCNN, self).__init__()
        self.cnn_encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 1),


            nn.Conv2d(64, 128, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 1),


            nn.Conv2d(128, 256, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d((2, 1), (2, 1), 0),


            nn.Conv2d(256, enc_out_dim, 3, 1, 0),
            nn.ReLU(),

        )

    def forward(self, images):
        features = self.cnn_encoder(images)
        features = features.permute(0, 2, 3, 1)
        B, H, W, C = features.shape
        features = features.contiguous().view(B, H * W, C)
        return features


class DecoderRNN(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size, enc_out_dim, dropout_prob=0.3):
        super(DecoderRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.attention = Attention(enc_out_dim, hidden_dim)
        self.lstm = nn.LSTM(enc_out_dim + embedding_dim, hidden_dim, batch_first=True)  # LSTM only takes features
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, features, formulas, teacher_forcing_ratio=0.7):
        batch_size = features.size(0)
        seq_len = formulas.size(1)
        vocab_size = self.fc.out_features

        hidden = torch.zeros(1, batch_size, self.lstm.hidden_size, device=features.device)
        cell = torch.zeros(1, batch_size, self.lstm.hidden_size, device=features.device)
        
        outputs = torch.zeros(batch_size, seq_len, vocab_size, device=features.device)

        input_token = torch.ones((batch_size, 1), dtype=torch.long, device=features.device) * 0  # Shape: (batch_size, 1)

        for t in range(seq_len):
            input_emb = self.embedding(input_token)  # Shape: (batch_size, 1, embedding_dim)
            
            context, _ = self.attention(features, hidden.squeeze(0))  # Shape: (batch_size, enc_out_dim)
            context = context.unsqueeze(1)  # Add time dimension: (batch_size, 1, enc_out_dim)
            
            lstm_input = torch.cat((context, input_emb), dim=2)  # Shape: (batch_size, 1, enc_out_dim + embedding_dim)
            
            lstm_out, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))  # Shape: (batch_size, 1, hidden_dim)
            
            output = self.fc(lstm_out.squeeze(1))  # Shape: (batch_size, vocab_size)
            outputs[:, t, :] = output

            top1 = output.argmax(1)  # Shape: (batch_size)
            if self.training:
                # Decide whether to use teacher forcing
                teacher_force = torch.rand(1).item() < teacher_forcing_ratio
                input_token = formulas[:, t].unsqueeze(1) if teacher_force and t + 1 < seq_len else top1.unsqueeze(1)
            else:
                input_token = top1.unsqueeze(1)

        return outputs


class ImageToLaTeXModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(ImageToLaTeXModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, images, formulas):
        features = self.encoder(images)
        outputs = self.decoder(features, formulas[:, :-1])
        return outputs




## Save / Load model

In [32]:
def save_training_state(model, optimizer, epoch, loss):
	state = {
		'model_state_dict': model.state_dict(),
		'optimizer_state_dict': optimizer.state_dict(),
		'epoch': epoch,
		'loss': loss
	}
	torch.save(state, 'model_checkpoint.pth')

# Function to load model state
def load_training_state(model, optimizer):
	checkpoint = torch.load('model_checkpoint.pth')
	model.load_state_dict(checkpoint['model_state_dict'])
	optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
	return checkpoint['epoch'], checkpoint['loss']

In [10]:
# Hyperparameters
EMBED_SIZE = 512 # direct output dim from cv_tiny

hidden_size = 1024
num_epochs = 22
learning_rate = 0.003

vocab_size = len(sign2id)

# Model, loss, and optimizer
encoder = EncoderCNN(EMBED_SIZE).to(device)
decoder = DecoderRNN(EMBED_SIZE, hidden_size, vocab_size,512).to(device)
model = ImageToLaTeXModel(encoder, decoder).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=1)
optimizer = optim.Adam(model.parameters())
start_epoch = 0

# Try to resume from a checkpoint
try:
	start_epoch, last_loss = load_training_state(model, optimizer)
	print(f"Get model from epoch {start_epoch}, with loss {last_loss:.4f}")
except FileNotFoundError:
	print("No saved model found, starting fresh.")
	start_epoch = 0

for epoch in range(start_epoch, num_epochs):

	
	for i, data in enumerate(train_loader):
		images, formulas = data
		# Pad sequences to the same length
		formulas_padded = nn.utils.rnn.pad_sequence(formulas, batch_first=True, padding_value=2)
		targets = formulas_padded[:, :].contiguous()

		outputs = model(images, formulas_padded[:, :-1].contiguous())
		# Match target size with output size
		targets = targets[:, :outputs.size(1)].contiguous()

		loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
		
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		if i  % 100 == 0:
			print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(train_loader)}], Loss: {loss.item():.4f}")

	save_training_state(model, optimizer, epoch, loss.item())


Get model from epoch 10, with loss 0.3491
Epoch [11/22], Step [0/751], Loss: 1.6123


KeyboardInterrupt: 

In [None]:
asdfsadf asdlkfj.ds .df.daf

In [33]:
def decode_formula(indices, id2sign):
    return ' '.join([id2sign[i.item()] for i in indices if i.item() and i.item() != 2])  # Skip padding

model.eval()
with torch.no_grad():
    for i, data in enumerate(val_loader):
        images, formulas = data
        # Pad sequences to the same length
        formulas_padded = nn.utils.rnn.pad_sequence(formulas, batch_first=True, padding_value=2)
        targets = formulas_padded[:, :].contiguous()

        outputs = model(images, formulas_padded[:, :-1].contiguous())
        print(outputs.shape)

        predicted_indices = torch.argmax(outputs, dim=2)
        print(predicted_indices.shape)

        for i in range(batch_size):
            actual_formula = decode_formula(formulas[i, 1:], id2sign)  # Skip <S> token
            predicted_formula = decode_formula(predicted_indices[i, 1:], id2sign)  # Skip <S> token

            print('act :', actual_formula)
            print('pred:', predicted_formula)
            print()
        break

        


torch.Size([32, 42, 579])
torch.Size([32, 42])
act : _ { 1 } ^ { k } = \omega _ { 1 } ^ { k - 2 } \subseteq \omega _ { 1 } ^ { k }
pred: _ { \mu } ^ { \mu } = 0 _ { k } ^ { j } 1 } } { _ { k } ^ { k } } _

act : _ { i j } = \bar { g } _ { i j } + h _ { i j } ,
pred: _ { \mu j } = { { g } _ { i j } ^ g _ { i j } g

act : \theta _ { n } ^ { \alpha \Lambda } } { ^ \dagger } = \theta _ { - n } ^ { \alpha \Lambda } ,
pred: _ ^ { \alpha } , { i } } = 0 , = 0 { \alpha } } ^ { \alpha } } {

act : _ { \mathrm { S G } } = 2 g _ { \mathrm { e f f } } .
pred: _ { \mu { e } } } = { } = 0 . 1 } ^ } } =

act : \chi _ { D } ^ { \prime \prime } , \chi _ { D } ^ { \prime \prime } \} = i \chi ^ { \prime \prime } .
pred: Q _ { \alpha } , { i } } , \chi _ { \beta } ^ { \prime } } \} = 0 \delta ^ { \prime } } ,

act : ^ { 2 } f = \bar { \gamma } F ^ { 4 } f \; \;
pred: _ { 2 } { ^ 0 { f } ^ { 2 } \, \, ^ { \;

act : _ { m n p } \, \xi = \psi _ { m n p } \, \xi ,
pred: _ { \mu } ^ _ = = _ 0 _ { m } ^ _ \, , 

## validate


In [None]:
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
def predict_formula(model, image, start_token, end_token, max_length, device):
    """
    Generate a LaTeX formula given an image input.

    Args:
        model: The trained ImageToLaTeXModel.
        image: Input image tensor of shape (1, C, H, W).
        start_token: Index of the <START> token.
        end_token: Index of the <END> token.
        max_length: Maximum length of the generated formula.
        device: Device to perform inference on.

    Returns:
        A list of token indices representing the generated formula.
    """
    model.eval()
    image = image.to(device)

    with torch.no_grad():
        # Extract features using the encoder
        features = model.encoder(image)

        # Initialize the decoder input with the <START> token
        input_token = torch.tensor([[start_token]], device=device)

        # Initialize LSTM hidden and cell states
        hidden = torch.zeros(1, 1, model.decoder.lstm.hidden_size, device=device)
        cell = torch.zeros(1, 1, model.decoder.lstm.hidden_size, device=device)

        # Store generated tokens
        generated_tokens = []

        for _ in range(max_length):
            # Embed the current token
            embedding = model.decoder.embedding(input_token).squeeze(1)

            # Compute attention and context
            context, _ = model.decoder.attention(features, hidden.squeeze(0))

            # Prepare input for the LSTM
            lstm_input = torch.cat((embedding, context), dim=1).unsqueeze(1)

            # Forward through LSTM
            lstm_out, (hidden, cell) = model.decoder.lstm(lstm_input, (hidden, cell))

            # Generate the next token
            output = model.decoder.fc(lstm_out.squeeze(1))
            next_token = output.argmax(dim=1)

            # Append the predicted token
            generated_tokens.append(next_token.item())

            # Break if <END> token is generated
            if next_token.item() == end_token:
                break

            # Update input token for the next time step
            input_token = next_token.unsqueeze(1)

    return generated_tokens

def decode_formula(indices, id2sign):
    return ' '.join([id2sign[i.item()] for i in indices if i.item() and i.item() != 2])  # Skip padding

def validate_model(model, criterion, device):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0.0
    total_bleu = 0
    total_samples = 0
    total_imgs = 0
    smooth_fn = SmoothingFunction().method1
    with torch.no_grad():  # Disable gradient calculation
        for images, formulas in val_loader:
            images, formulas = images.to(device), formulas.to(device)
            print(predict_formula(model, images[0], 0, 2, 100, device))
            formulas_padded = nn.utils.rnn.pad_sequence(formulas, batch_first=True, padding_value=2)
            targets = formulas_padded[:, :].contiguous()

            outputs = model(images, [])
            # Match target size with output size
            targets = targets[:, :outputs.size(1)].contiguous()
            loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
            predicted_indices = torch.argmax(outputs, dim=2)  # Get the index of the max log-probability
            total_samples += formulas[:, 1:].numel()  # Total number of tokens in the validation batch

            # Print images and predictions
            # The batch size is 400, print every 400 images
            total_imgs += len(images)
            for i in range(len(images)):
                # Decode the actual and predicted formulas
                actual_formula = decode_formula(formulas[i, :], id2sign)  
                predicted_formula = decode_formula(predicted_indices[i, :], id2sign) 
                total_bleu += sentence_bleu([actual_formula.split()], predicted_formula.split(), smoothing_function=smooth_fn)
                # print(total_bleu)
                total_loss += loss.item()

                if i == 0:
                    print(f'Actual Formula:    {actual_formula}')
                    print(f'Predicted Formula: {predicted_formula}')
                    print('-' * 50)

    avg_loss = total_loss / total_imgs
    accuracy = total_bleu / total_imgs

    return avg_loss, accuracy

# val_loss, val_accuracy = validate_model(model, criterion, device)
# print(f'val_accuracy: {val_accuracy*100:.2f}% , val_loss: {val_loss:.4f}')


RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 3 is not equal to len(dims) = 4