In [26]:
import os
from typing import List, Tuple
from PIL import Image
from einops import rearrange
from torchmetrics import Metric
import editdistance
import math

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
import torch.nn.functional as F

from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
import wandb

In [27]:
class CROHMEVocab:
    PAD_IDX = 0
    SOS_IDX = 1
    EOS_IDX = 2

    def __init__(self, dict_path: str = "/kaggle/input/crohme/dictionary.txt") -> None:
        self.word2idx = {"<pad>": self.PAD_IDX, "<sos>": self.SOS_IDX, "<eos>": self.EOS_IDX}
        with open(dict_path, "r", encoding="utf-8") as f:
            for line in f:
                w = line.strip()
                self.word2idx[w] = len(self.word2idx)
        self.idx2word: Dict[int, str] = {v: k for k, v in self.word2idx.items()}

    def words2indices(self, words: List[str]) -> List[int]:
        return [self.word2idx[w] for w in words]

    def indices2words(self, id_list: List[int]) -> List[str]:
        return [self.idx2word[i] for i in id_list]

    def indices2label(self, id_list: List[int]) -> str:
        return " ".join(self.indices2words(id_list))

    def __len__(self):
        return len(self.word2idx)

In [28]:
class CROHMEDataset(Dataset):
    def __init__(self, root_dir: str):
        self.img_dir = os.path.join(root_dir, "img")
        caption_path = os.path.join(root_dir, "caption.txt")
        self.data = self._load_captions(caption_path)
        self.vocab = CROHMEVocab()
        self.to_tensor = ToTensor()

    def _load_captions(self, caption_path: str) -> List[Tuple[str, List[str]]]:
        with open(caption_path, "r", encoding="utf-8") as f:
            return [ (line.strip().split()[0], line.strip().split()[1:]) for line in f ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name, formula = self.data[idx]
        img_path = os.path.join(self.img_dir, f"{img_name}.bmp")
        image = Image.open(img_path).convert("L")
        image_tensor = self.to_tensor(image)
        formula_indices = self.vocab.words2indices(formula)
        return img_name, image_tensor, formula_indices

def collate_fn(batch):
    fnames, images, formulas = zip(*batch)
    heights = [img.shape[1] for img in images]
    widths = [img.shape[2] for img in images]
    max_height, max_width = max(heights), max(widths)

    batch_size = len(images)
    imgs = torch.zeros(batch_size, 1, max_height, max_width)
    masks = torch.ones(batch_size, max_height, max_width, dtype=torch.bool)

    for i, img in enumerate(images):
        h, w = img.shape[1:]
        imgs[i, :, :h, :w] = img
        masks[i, :h, :w] = 0

    return fnames, imgs, masks, formulas

In [29]:
class BTTR(nn.Module):
    def __init__(self, d_model, growth_rate, num_layers, nhead, num_decoder_layers, dim_feedforward, dropout):
        super().__init__()
        self.encoder = Encoder(d_model, growth_rate, num_layers)
        self.decoder = Decoder(d_model, nhead, num_decoder_layers, dim_feedforward, dropout)

    def forward(self, img, img_mask, tgt):
        features, mask = self.encoder(img, img_mask)
        features = torch.cat([features, features], dim=0)
        mask = torch.cat([mask, mask], dim=0)
        return self.decoder(features, mask, tgt)

    def beam_search(self, img, img_mask, beam_size, max_len):
        features, mask = self.encoder(img, img_mask)
        return self.decoder.beam_search(features, mask, beam_size, max_len)

In [30]:
class WordPosEnc(nn.Module):
    def __init__(self, d_model, max_len=500, temperature=10000.0):
        super().__init__()
        pos = torch.arange(0, max_len).float()
        dim_t = torch.arange(0, d_model, 2).float()
        div_term = 1.0 / (temperature ** (dim_t / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(pos[:, None] * div_term)
        pe[:, 1::2] = torch.cos(pos[:, None] * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x):
        _, seq_len, _ = x.size()
        return x + self.pe[:seq_len, :].unsqueeze(0)

class ImgPosEnc(nn.Module):
    def __init__(self, d_model, temperature=10000.0, normalize=False):
        super().__init__()
        assert d_model % 4 == 0, "d_model must be divisible by 4 for 2D encoding"
        self.half_d_model = d_model // 2
        self.temperature = temperature
        self.normalize = normalize

    def forward(self, x, mask):
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)

        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * 2 * math.pi
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * 2 * math.pi

        dim_t = torch.arange(self.half_d_model // 2, dtype=torch.float32, device=x.device)
        inv_freq = 1.0 / (self.temperature ** (dim_t / (self.half_d_model // 2)))

        pos_x = torch.einsum('b h w, d -> b h w d', x_embed, inv_freq)
        pos_y = torch.einsum('b h w, d -> b h w d', y_embed, inv_freq)

        pos_x = torch.cat([pos_x.sin(), pos_x.cos()], dim=-1)
        pos_y = torch.cat([pos_y.sin(), pos_y.cos()], dim=-1)
        pos = torch.cat([pos_x, pos_y], dim=-1)  # final shape: [b, h, w, d_model]

        assert pos.shape[-1] == x.shape[-1], f"PosEnc shape mismatch: {pos.shape[-1]} != {x.shape[-1]}"
        return x + pos

In [31]:
class _Bottleneck(nn.Module):
    def __init__(self, in_ch, growth_rate, use_dropout):
        super().__init__()
        inter_ch = 4 * growth_rate
        self.conv1 = nn.Conv2d(in_ch, inter_ch, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(inter_ch)
        self.conv2 = nn.Conv2d(inter_ch, growth_rate, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(growth_rate)
        self.dropout = nn.Dropout(0.2) if use_dropout else nn.Identity()

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.dropout(out)
        return torch.cat([x, out], dim=1)

class _Transition(nn.Module):
    def __init__(self, in_ch, out_ch, use_dropout):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.dropout = nn.Dropout(0.2) if use_dropout else nn.Identity()

    def forward(self, x):
        out = F.relu(self.bn(self.conv(x)))
        out = self.dropout(out)
        out = F.avg_pool2d(out, 2, ceil_mode=True)
        return out
    
class _SingleLayer(nn.Module):
    def __init__(self, in_ch, growth_rate, use_dropout):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, growth_rate, kernel_size=3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(growth_rate)
        self.dropout = nn.Dropout(0.2) if use_dropout else nn.Identity()

    def forward(self, x):
        out = F.relu(self.bn(self.conv(x)))
        out = self.dropout(out)
        return torch.cat([x, out], dim=1)


class DenseNet(nn.Module):
    def __init__(self, growth_rate, num_layers, reduction=0.5, bottleneck=True, use_dropout=True):
        super().__init__()
        n_ch = 2 * growth_rate
        self.conv1 = nn.Conv2d(1, n_ch, kernel_size=7, stride=2, padding=3, bias=False)
        self.norm1 = nn.BatchNorm2d(n_ch)

        self.dense1 = self._make_dense(n_ch, growth_rate, num_layers, bottleneck, use_dropout)
        n_ch += num_layers * growth_rate
        out_ch = int(n_ch * reduction)
        self.trans1 = _Transition(n_ch, out_ch, use_dropout)

        n_ch = out_ch
        self.dense2 = self._make_dense(n_ch, growth_rate, num_layers, bottleneck, use_dropout)
        n_ch += num_layers * growth_rate
        out_ch = int(n_ch * reduction)
        self.trans2 = _Transition(n_ch, out_ch, use_dropout)

        n_ch = out_ch
        self.dense3 = self._make_dense(n_ch, growth_rate, num_layers, bottleneck, use_dropout)
        n_ch += num_layers * growth_rate

        self.post_norm = nn.BatchNorm2d(n_ch)
        self.out_channels = n_ch

    def _make_dense(self, in_ch, growth_rate, num_layers, bottleneck, use_dropout):
        layers = []
        for _ in range(num_layers):
            layers.append(_Bottleneck(in_ch, growth_rate, use_dropout) if bottleneck else _SingleLayer(in_ch, growth_rate, use_dropout))
            in_ch += growth_rate
        return nn.Sequential(*layers)

    def forward(self, x, mask):
        out = self.conv1(x)
        out = self.norm1(out)
        out_mask = mask[:, ::2, ::2]

        out = F.relu(out)
        out = F.max_pool2d(out, 2, ceil_mode=True)
        out_mask = out_mask[:, ::2, ::2]

        out = self.trans1(self.dense1(out))
        out_mask = out_mask[:, ::2, ::2]
        out = self.trans2(self.dense2(out))
        out_mask = out_mask[:, ::2, ::2]
        out = self.dense3(out)
        out = self.post_norm(out)
        return out, out_mask

class Encoder(nn.Module):
    def __init__(self, d_model, growth_rate, num_layers):
        super().__init__()
        self.densenet = DenseNet(growth_rate, num_layers)
        self.feature_proj = nn.Conv2d(self.densenet.out_channels, d_model, kernel_size=1)
        self.norm = nn.LayerNorm(d_model)
        self.pos_enc = ImgPosEnc(d_model, normalize=True)

    def forward(self, img, mask):
        features, mask = self.densenet(img, mask)
        features = self.feature_proj(features)
        features = rearrange(features, 'b d h w -> b h w d')
        features = self.norm(features)
        features = self.pos_enc(features, mask)
        features = rearrange(features, 'b h w d -> b (h w) d')
        mask = rearrange(mask, 'b h w -> b (h w)')
        return features, mask

In [34]:
class Hypothesis:
    def __init__(self, seq_tensor: torch.Tensor, score: float, direction: str):
        assert direction in {"l2r", "r2l"}
        raw_seq = seq_tensor.tolist()
        self.seq = raw_seq[::-1] if direction == "r2l" else raw_seq
        self.score = score

    def __len__(self):
        return len(self.seq) if self.seq else 1

    def __str__(self):
        return f"seq: {self.seq}, score: {self.score}"

In [32]:
class Decoder(nn.Module):
    def __init__(self, d_model, nhead, num_decoder_layers, dim_feedforward, dropout):
        super().__init__()
        self.word_embed = nn.Sequential(
            nn.Embedding(vocab_size, d_model),
            nn.LayerNorm(d_model)
        )
        self.pos_enc = WordPosEnc(d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.transformer = nn.TransformerDecoder(decoder_layer, num_decoder_layers)
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, src, src_mask, tgt):
        b, l = tgt.size()
        tgt_mask = torch.triu(torch.ones(l, l, device=tgt.device), diagonal=1).bool()
        tgt_pad_mask = tgt == vocab.PAD_IDX

        tgt_emb = self.word_embed(tgt)
        tgt_emb = self.pos_enc(tgt_emb)

        src = rearrange(src, 'b s d -> s b d')
        tgt_emb = rearrange(tgt_emb, 'b l d -> l b d')

        out = self.transformer(tgt_emb, src, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_pad_mask, memory_key_padding_mask=src_mask)
        out = rearrange(out, 'l b d -> b l d')
        out = self.proj(out)
        return out
    def beam_search(self, src, mask, beam_size, max_len):
        assert src.size(0) == 1, "beam_search expects batch size 1"

        start_token = vocab.SOS_IDX
        stop_token = vocab.EOS_IDX

        hypotheses = torch.full((1, max_len + 1), vocab.PAD_IDX, dtype=torch.long, device=src.device)
        hypotheses[:, 0] = start_token
        hyp_scores = torch.zeros(1, device=src.device)
        completed_hyps = []

        t = 0
        while len(completed_hyps) < beam_size and t < max_len:
            hyp_num = hypotheses.size(0)
            exp_src = src.expand(hyp_num, -1, -1)
            exp_mask = mask.expand(hyp_num, -1)

            out = self.forward(exp_src, exp_mask, hypotheses)  # logits: [b, l, vocab_size]
            log_probs = torch.log_softmax(out[:, t, :], dim=-1)  # [b, vocab_size]

            new_hyp_scores = hyp_scores.unsqueeze(1) + log_probs  # [b, vocab_size]
            flat_scores = new_hyp_scores.view(-1)
            top_scores, top_indices = flat_scores.topk(beam_size - len(completed_hyps))

            prev_hyp_ids = top_indices // log_probs.size(1)
            next_token_ids = top_indices % log_probs.size(1)

            new_hypotheses = []
            new_hyp_scores = []

            for prev_hyp_id, next_token_id, score in zip(prev_hyp_ids, next_token_ids, top_scores):
                next_token_id = next_token_id.item()
                score = score.item()

                new_hyp = hypotheses[prev_hyp_id].clone()
                new_hyp[t + 1] = next_token_id

                if next_token_id == stop_token:
                    completed_hyps.append(Hypothesis(new_hyp[1:t+1], score, direction="l2r"))
                else:
                    new_hypotheses.append(new_hyp)
                    new_hyp_scores.append(score)

            if len(new_hypotheses) == 0:
                break

            hypotheses = torch.stack(new_hypotheses, dim=0)
            hyp_scores = torch.tensor(new_hyp_scores, device=src.device)

            t += 1

        if len(completed_hyps) == 0:
            completed_hyps.append(Hypothesis(hypotheses[0][1:], hyp_scores[0].item(), direction="l2r"))

        return completed_hyps



In [33]:
def beam_search_batch(model, imgs, masks, beam_size, max_len, alpha, vocab):
    batch_size = imgs.size(0)
    results = []
    for i in range(batch_size):
        img, mask = imgs[i].unsqueeze(0), masks[i].unsqueeze(0)
        hyps = model.beam_search(img, mask, beam_size, max_len)
        best = max(hyps, key=lambda h: h.score / (len(h) ** alpha))
        results.append(best.seq)
    return results

def ensemble_beam_search_batch(models, imgs, masks, beam_size, max_len, alpha, vocab):
    batch_size = imgs.size(0)
    results = []
    for i in range(batch_size):
        img, mask = imgs[i].unsqueeze(0), masks[i].unsqueeze(0)
        all_hyps = []
        for model in models:
            hyps = model.beam_search(img, mask, beam_size, max_len)
            all_hyps.extend(hyps)
        best = max(all_hyps, key=lambda h: h.score / (len(h) ** alpha))
        results.append(best.seq)
    return results


In [35]:
def ce_loss(output_hat: torch.Tensor, output: torch.Tensor) -> torch.Tensor:
    flat_hat = rearrange(output_hat, "b l e -> (b l) e")
    flat = rearrange(output, "b l -> (b l)")
    return F.cross_entropy(flat_hat, flat, ignore_index=vocab.PAD_IDX)

def to_tgt_output(tokens, direction, device):
    assert direction in {"l2r", "r2l"}

    tokens = [torch.tensor(t, dtype=torch.long) for t in tokens]
    if direction == "l2r":
        start_w, stop_w = vocab.SOS_IDX, vocab.EOS_IDX
    else:
        tokens = [torch.flip(t, dims=[0]) for t in tokens]
        start_w, stop_w = vocab.EOS_IDX, vocab.SOS_IDX

    batch_size = len(tokens)
    max_len = max(len(t) for t in tokens)
    tgt = torch.full((batch_size, max_len + 1), vocab.PAD_IDX, dtype=torch.long, device=device)
    out = torch.full((batch_size, max_len + 1), vocab.PAD_IDX, dtype=torch.long, device=device)

    for i, t in enumerate(tokens):
        tgt[i, 0] = start_w
        tgt[i, 1:1+len(t)] = t
        out[i, :len(t)] = t
        out[i, len(t)] = stop_w

    return tgt, out

def to_bi_tgt_out(tokens, device):
    l2r_tgt, l2r_out = to_tgt_output(tokens, "l2r", device)
    r2l_tgt, r2l_out = to_tgt_output(tokens, "r2l", device)
    tgt = torch.cat((l2r_tgt, r2l_tgt), dim=0)
    out = torch.cat((l2r_out, r2l_out), dim=0)
    return tgt, out

In [36]:
class ExpRateRecorder(Metric):
    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum")

    def update(self, indices_hat, indices):
        dist = editdistance.eval(indices_hat, indices)
        if dist == 0:
            self.correct += 1
        self.total += 1

    def compute(self):
        return (self.correct / self.total).item() if self.total > 0 else 0.0

In [37]:
vocab = CROHMEVocab()
vocab_size = len(vocab)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wandb.login(
    key = "9bce7f20a794219664f78217ccc84b283dbb0cea",
)
wandb.init(
    project = "BTTR_offline_hmer"
)



In [38]:
def main():

    # Configs
    best_val_loss = float('inf')
    seed = 1337
    train_root = "/kaggle/input/crohme/resources/resources/CROHME/train"
    val_root = "/kaggle/input/crohme/resources/resources/CROHME/val"
    batch_size = 32
    num_epochs = 100
    learning_rate = 0.01
    patience = 10
    checkpoint_dir = "/kaggle/working/checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)

    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # Data Loaders
    train_dataset = CROHMEDataset(train_root)
    val_dataset = CROHMEDataset(val_root)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

    # Model
    model = BTTR(
        d_model=256,
        growth_rate=16,
        num_layers=3,
        nhead=8,
        num_decoder_layers=3,
        dim_feedforward=1024,
        dropout=0.1
    ).to(DEVICE)

    optimizer = optim.Adadelta(model.parameters(), lr=learning_rate, eps=1e-6, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.1, patience=patience)

    exp_rate_recorder = ExpRateRecorder()

    # Training Loop
    for epoch in range(1, num_epochs + 1):
        model.train()
        train_loss = 0

        for fnames, imgs, masks, formulas in tqdm(train_loader, desc=f"Epoch {epoch}"):
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            tgt, out = to_bi_tgt_out(formulas, DEVICE)

            optimizer.zero_grad()
            out_hat = model(imgs, masks, tgt)
            loss = ce_loss(out_hat, out)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        print(f"Epoch {epoch}: Train Loss = {avg_train_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0
        exp_rate_recorder.reset()

        with torch.no_grad():
            for fnames, imgs, masks, formulas in val_loader:
                imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
                tgt, out = to_bi_tgt_out(formulas, DEVICE)

                out_hat = model(imgs, masks, tgt)
                loss = ce_loss(out_hat, out)
                val_loss += loss.item()

                # Optional: Add beam search for validation exprate if needed
                preds = beam_search_batch(model, imgs, masks, beam_size=10, max_len=200, alpha=1.0, vocab=vocab)
                gt_indices = vocab.words2indices(formulas[0])
                exp_rate_recorder.update(preds[0], gt_indices)

        avg_val_loss = val_loss / len(val_loader)
        val_exprate = exp_rate_recorder.compute()
        print(f"Epoch {epoch}: Val Loss = {avg_val_loss:.4f} | Val Exprate = {val_exprate:.4f}")

        torch.cuda.empty_cache()

        # Scheduler step based on ExpRate (set to dummy for now)
        scheduler.step(val_exprate)

        # wandb.log({
        #     "epoch": epoch,
        #     "train_loss": avg_train_loss,
        #     "val_loss": avg_val_loss
        # })
        wandb.log({"Epoch": epoch, "Train loss": avg_train_loss, "Valid loss": avg_val_loss, "Val exprate": val_exprate})

        # Save checkpoint
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            checkpoint = {
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'val_loss': val_loss,
                'vocab': vocab
            }
            torch.save(checkpoint, os.path.join(checkpoint_dir, 'bttr_best.pth'))
            print('model saved!')
        torch.cuda.empty_cache()
        
# if __name__ == "__main__":
#     main()
main()

Epoch 1: 100%|██████████| 305/305 [01:17<00:00,  3.93it/s]


Epoch 1: Train Loss = 3.6521


KeyError: 53

In [None]:
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# def main():
#     # Config
#     checkpoint_path = "checkpoints/epoch_50.pth"
#     data_root = "resources/CROHME/test"
#     batch_size = 4  # batch-wise testing for speed

#     # Data
#     vocab = CROHMEVocab()
#     test_dataset = CROHMEDataset(data_root)
#     test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

#     # Model
#     model = BTTR(
#         d_model=256,
#         growth_rate=16,
#         num_layers=3,
#         nhead=8,
#         num_decoder_layers=3,
#         dim_feedforward=1024,
#         dropout=0.1
#     ).to(DEVICE)

#     state_dict = torch.load(checkpoint_path, map_location=DEVICE)
#     model.load_state_dict(state_dict)
#     model.eval()

#     # Metrics
#     recorder = ExpRateRecorder()

#     # Inference Loop
#     os.makedirs("results", exist_ok=True)
#     with torch.no_grad():
#         for fnames, imgs, masks, formulas in tqdm(test_loader, desc="Testing"):
#             imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)

#             hypotheses_batch = beam_search_batch(model, imgs, masks, beam_size=10, max_len=200, alpha=1.0, vocab=vocab)

#             for fname, hyp, gt_formula in zip(fnames, hypotheses_batch, formulas):
#                 pred_latex = vocab.indices2label(hyp)
#                 gt_indices = vocab.words2indices(gt_formula)

#                 recorder.update(hyp, gt_indices)

#                 # Save result file
#                 with open(f"results/{fname}.txt", "w", encoding="utf-8") as f:
#                     f.write(f"%{fname}\n${pred_latex}$")

#     # Final Metrics
#     exprate = recorder.compute()
#     print(f"Expression Recognition Rate: {exprate:.4f}")

# # if __name__ == "__main__":
# #     main()
# main()