In [1]:
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from transformers import AutoTokenizer
from torch.nn.utils.rnn import pad_sequence

  from .autonotebook import tqdm as notebook_tqdm


## args


In [2]:
train_size = 10
valid_size = 5

## Data import

In [3]:
with open("image_formula_mapping.json", "r") as f:
    image_formula_mapping = json.load(f)

with open("LaTex_data/230k.json", "r") as f:
    word_to_index = json.load(f)
train_key = list(image_formula_mapping.keys())[:train_size]
valid_key = list(image_formula_mapping.keys())[train_size:train_size+valid_size]
train_formula_map = {k: image_formula_mapping[k] for k in train_key}
valid_formula_map = {k: image_formula_mapping[k] for k in valid_key}
index_to_word = {v: k for k, v in word_to_index.items()}

## Dataset


In [4]:
class LaTexDataset(Dataset):
    def __init__(self, image_formula_mapping, word_to_index, transform=None):
        self.image_formula_mapping = image_formula_mapping
        self.word_to_index = word_to_index
        self.transform = transform
        self.images = list(image_formula_mapping.keys())
        self.formulas = list(image_formula_mapping.values())

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        formula = self.formulas[idx]
        
        # Read the image
        image = Image.open(f'LaTex_data/generated_png_images/{img_path}').convert("RGB")
        if self.transform:
            image = self.transform(image)
        
        # Convert formula to index sequence and ensure it's a LongTensor
        tokenized_formula = [int(self.word_to_index.get(token, self.word_to_index["<P>"])) 
                            for token in formula.split()]
        
        return image, torch.tensor(tokenized_formula).to(torch.long)
    
    

## Pad sequence


In [5]:
def collate_fn(batch):
    images, formulas = zip(*batch)
    
    # Stack images (already tensors of the same size after transform)
    images = torch.stack(images)
    
    # Pad formulas to the same length within the batch
    formulas = pad_sequence(formulas, batch_first=True, padding_value=0)
    
    # Ensure that the padded formulas are LongTensors
    formulas = formulas.long()
    
    return images, formulas


## Transform


In [6]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [7]:
train_dataset = LaTexDataset(train_formula_map, word_to_index, transform=transform)
data_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
valid_dataset = LaTexDataset(valid_formula_map, word_to_index, transform=transform)
val_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

## Model building

In [8]:
class LaTexTransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6):
        super(LaTexTransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = nn.Parameter(torch.zeros(1, 512, d_model))
        
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, 
                                          num_encoder_layers=512, 
                                          num_decoder_layers=512)
        self.fc_out = nn.Linear(d_model, vocab_size)
    
    def forward(self, src, tgt):
        src = self.embedding(src) + self.positional_encoding[:, :src.size(1), :]
        tgt = self.embedding(tgt) + self.positional_encoding[:, :tgt.size(1), :]

        # Transformer expects (sequence length, batch size, embedding size)
        src = src.permute(1, 0, 2)  # (batch_size, seq_len, d_model) -> (seq_len, batch_size, d_model)
        tgt = tgt.permute(1, 0, 2)  # (batch_size, seq_len, d_model) -> (seq_len, batch_size, d_model)

        # Pass through the transformer
        output = self.transformer(src, tgt)
        
        # Output shape will be (seq_len, batch_size, d_model)
        output = self.fc_out(output)  # Fin

## Traning model


In [9]:
vocab_size = len(word_to_index)
model = LaTexTransformerModel(vocab_size)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(10):
    for images, formulas in data_loader:
        optimizer.zero_grad()
        
        # 將圖片向量化處理
        src = images.view(images.size(0), -1, 32).permute(1, 0, 2).long()  # Reshape, permute, and convert to LongTensor
        tgt = formulas[:, :-1] # 去除最後一個 token 作為輸入
        tgt_y = formulas[:, 1:]  # 去除第一個 token 作為標籤
        
        output = model(src, tgt)
        
        loss = criterion(output.view(-1, vocab_size), tgt_y.view(-1))
        loss.backward()
        optimizer.step()
        
        print(f"Epoch [{epoch + 1}/10], Loss: {loss.item():.4f}")



RuntimeError: The size of tensor a (32) must match the size of tensor b (10) at non-singleton dimension 2

## Validate


In [None]:
def validate_model(model, val_loader, criterion):
    model.eval()  # 啟用評估模式
    total_loss = 0.0
    with torch.no_grad():  # 禁用梯度計算
        for images, formulas in val_loader:
            # 圖像向量化
            src = images  # 假設這裡是經過 CNN 編碼的影像特徵向量
            tgt = formulas[:, :-1]  # 去除最後一個 token 作為輸入
            tgt_y = formulas[:, 1:]  # 去除第一個 token 作為標籤
            
            # 前向傳播
            output = model(src, tgt)
            
            # 計算損失
            loss = criterion(output.view(-1, vocab_size), tgt_y.view(-1))
            total_loss += loss.item()
    
    avg_loss = total_loss / len(val_loader)
    print(f"Validation Loss: {avg_loss:.4f}")
    model.train()  # 恢復訓練模式
    return avg_loss


In [None]:
validate_model(model, val_loader, criterion)