## INIT

In [82]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import json
from PIL import Image
import os

# Paths
image_folder_path = "LaTex_data/generated_png_images"
mapping_path = "image_formula_mapping.json"
label_to_index_path = "LaTex_data/230k.json"

# Check if GPU is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



## Load Data


In [83]:
class LaTeXDataset(Dataset):
    def __init__(self, image_folder, mapping_file, label_to_index_file, transform=None, max_images=None):
        self.image_folder = image_folder
        self.transform = transform

        # Load mappings and label-to-index dictionary
        with open(mapping_file, 'r') as f:
            self.image_formula_mapping = json.load(f)
        with open(label_to_index_file, 'r') as f:
            self.label_to_index = json.load(f)


        # Apply the image count limit if specified
        if max_images:
            self.image_formula_mapping = dict(list(self.image_formula_mapping.items())[:max_images])

        self.index_to_label = {v: k for k, v in self.label_to_index.items()}
        self.vocab_size = len(self.label_to_index)
        self.formulas = list(self.image_formula_mapping.values())

    def __len__(self):
        return len(self.formulas)

    def __getitem__(self, idx):
        image_name = list(self.image_formula_mapping.keys())[idx]
        formula = self.image_formula_mapping[image_name]

        # Load image
        image_path = os.path.join(self.image_folder, image_name)
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # Tokenize the formula into indices
        formula_tokens = formula.split()  # Splitting the formula string by whitespace
        formula_indices = []
        for token in formula_tokens:
            # Map each token to its index; if not found, use a default index (e.g., 0)
            index = self.label_to_index.get(token, 0)  # Assuming 0 is for unknown tokens
            formula_indices.append(int(index))
        
        # Convert the list of indices to a 1D tensor
        return image, torch.tensor(formula_indices, dtype=torch.long)



## Encoder / Decoder

In [84]:


class EncoderCNN(nn.Module):
    def __init__(self, feature_dim):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)  # Example using ResNet50
        for param in resnet.parameters():
            param.requires_grad = False  # Freeze ResNet layers

        # Replace the final fully connected layer with a custom one
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        self.fc = nn.Linear(resnet.fc.in_features, feature_dim)

    def forward(self, images):
        features = self.resnet(images)  # Shape: [batch_size, feature_dim, 1, 1]
        features = features.view(features.size(0), -1)  # Flatten to [batch_size, feature_dim]
        features = self.fc(features)
        return features

class DecoderRNN(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size):
        super(DecoderRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, features, formulas):
        # Embed the input formula tokens
        embeddings = self.embedding(formulas)

        # Concatenate features and embeddings along the sequence dimension
        embeddings = torch.cat((features.unsqueeze(1), embeddings), dim=1)
        
        # Pass through LSTM and then through the final linear layer
        lstm_out, _ = self.lstm(embeddings)
        outputs = self.fc(lstm_out)
        return outputs

class ImageToLaTeXModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(ImageToLaTeXModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, images, formulas):
        # Encode the images
        features = self.encoder(images)  # Shape: [batch_size, feature_dim]
        
        # Decode to generate the LaTeX expression
        outputs = self.decoder(features, formulas[:, :-1])  # Skip the end token
        return outputs


In [85]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    # Separate images and formulas from the batch
    images, formulas = zip(*batch)

    # Stack images (assumes images are already the same size after transforms)
    images = torch.stack(images)

    # Pad formulas to the length of the longest formula in the batch
    formulas = pad_sequence(formulas, batch_first=True, padding_value=0)  # 0 as the padding index

    return images, formulas


In [None]:
# Hyperparameters
embed_size = 256
hidden_size = 512
num_epochs = 10
learning_rate = 0.001
batch_size = 32

# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((80, 400)),
    transforms.ToTensor()
])

# Load dataset and dataloader
dataset = LaTeXDataset(image_folder_path, mapping_path, label_to_index_path, transform, max_images=5000)
# Create DataLoader with the custom collate function
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)


# Model, loss, and optimizer
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, dataset.vocab_size)
model = ImageToLaTeXModel(encoder, decoder)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # 0 is assumed as <PAD> token
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    for i, (images, formulas) in enumerate(dataloader):
        images, formulas = images.to(device), formulas.to(device)

        # Set targets: shift formula by one for teacher forcing
        targets = formulas[:, 1:]

        # Forward, loss, and optimize
        outputs = model(images, formulas[:, :-1])
        loss = criterion(outputs.view(-1, dataset.vocab_size), targets.contiguous().view(-1))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}")


Epoch [1/10], Step [0/157], Loss: 6.3814
Epoch [1/10], Step [100/157], Loss: 3.1509
Epoch [2/10], Step [0/157], Loss: 3.2230
Epoch [2/10], Step [100/157], Loss: 2.8009
Epoch [3/10], Step [0/157], Loss: 2.9870
Epoch [3/10], Step [100/157], Loss: 2.7815
Epoch [4/10], Step [0/157], Loss: 2.8185
Epoch [4/10], Step [100/157], Loss: 2.6696
Epoch [5/10], Step [0/157], Loss: 2.6798
Epoch [5/10], Step [100/157], Loss: 2.7056


In [None]:
def decode_formula(indices, index_to_label):
    return ' '.join([index_to_label[str(i.item())] for i in indices if i.item() not in [0]])  # Skip padding

def validate_model(model, dataloader, criterion, device, index_to_label):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():  # Disable gradient calculation
        for images, formulas in dataloader:
            images, formulas = images.to(device), formulas.to(device)
            outputs = model(images, formulas[:, :-1])  # Pass images and input sequence

            # Calculate loss
            loss = criterion(outputs.view(-1, outputs.size(-1)), formulas[:, 1:].contiguous().view(-1))
            total_loss += loss.item()

            # Calculate accuracy (if applicable)
            predicted_indices = torch.argmax(outputs, dim=2)  # Get the index of the max log-probability
            correct_predictions += (predicted_indices == formulas[:, 1:].contiguous()).sum().item()
            total_samples += formulas[:, 1:].numel()  # Total number of tokens in the validation batch

            # Print images and predictions
            for i in range(len(images)):
                # Decode the actual and predicted formulas
                actual_formula = decode_formula(formulas[i, 1:], index_to_label)  # Skip <S> token
                predicted_formula = decode_formula(predicted_indices[i, 1:], index_to_label)  # Skip <S> token
                print(f'Image: {images[i]}')  # This will print the tensor, consider using visualization instead
                print(f'Actual Formula: {actual_formula}')
                print(f'Predicted Formula: {predicted_formula}')
                print('-' * 50)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0

    return avg_loss, accuracy
# Assuming you have your model, dataloader, criterion, and device set up
# Assuming 230k.json is loaded as label_to_index
val_dataset = LaTeXDataset(image_folder_path, mapping_path, label_to_index_path, transform, max_images=10)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)  # Set shuffle to False for validation
label_to_index = json.load(open(label_to_index_path, 'r'))
index_to_label = {v: k for k, v in label_to_index.items()}  # Reverse the mapping

val_loss, val_accuracy = validate_model(model, val_dataloader, criterion, device, index_to_label)



Image: tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])
Actual Formula: _ { 1 2 } K _ { 1 } R _ { 2 1 } d K _ { 2 } = d K _ { 2 } R _ { 1 2 } K _ { 1 } R _ { 1 2 } ^ { - 1 } ,
Predicted Formula: { \mu } = = { { \mu } ^ { { 2 } } ^ { _ { 2 } } { _ { { 2 } } { { 2 