In [7]:
from data import *
from pretrained_models import *
from torch.utils.data import DataLoader
import os
from mediapipe.tasks.python.vision.hand_landmarker import HandLandmarkerResult

### Sanity Check

In [10]:
# MediaPipeCFG = MediaPipeCfg("pretrained_model/hand_landmarker.task")
# options = MediaPipeCFG.create_options()
# MP_model = MediaPipeCFG.HandLandmarker.create_from_options(options)

In [11]:
# MMPoseCFG = MMPoseCfg(checkpoint_path='pretrained_model/checkpoint/rtmpose-s_simcc-body7_pt-body7_420e-256x192-acd4a1ef_20230504.pth',
#                       config_path='pretrained_model/mmpose_config/rtmpose_m_8xb256-420e_coco-256x192.py')
# body_model = MMPoseCFG.create_model()

In [12]:
# PreTrainDataset = ASLData(video_dir="data/raw_videos",
#                            MP_model=MP_model,
#                            body_cfg=MMPoseCFG,
#                            body_model=body_model,
#                            labels_path="data/how2sign_realigned_train.csv")
# def collate_as_is(batch):
#     return batch
# PreTrainDataLoader = DataLoader(PreTrainDataset, batch_size=4, shuffle=False, collate_fn=collate_as_is)

In [13]:
# OneSample = next(iter(PreTrainDataLoader))

Four batches run in 2 m 45 secs

### Write converted JSON to local dir for training

In [24]:
output_dir = "test_json_8_batches"
os.makedirs(output_dir, exist_ok=True)

In [28]:
def make_json_safe(x):
    """
    Recursively convert tensors, numpy arrays, and other non-serializable
    objects into plain Python lists or scalars.
    """
    if isinstance(x, torch.Tensor):
        return x.cpu().numpy().tolist()
    if isinstance(x, np.ndarray):
        return x.tolist()
    if isinstance(x, dict):
        return {k: make_json_safe(v) for k, v in x.items()}
    if isinstance(x, list):
        return [make_json_safe(v) for v in x]
    if isinstance(x, tuple):
        return tuple(make_json_safe(v) for v in x)
    return x  # scalar, string, or already json-safe

print("Saving dataloader batches to:", output_dir)

for batch_idx, batch in enumerate(PreTrainDataLoader):
    if batch_idx == 8:
        break
    json_safe_batch = make_json_safe(batch)

    output_path = os.path.join(output_dir, f"batch_{batch_idx:05d}.json")
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(json_safe_batch, f, indent=2)

    print(f"[✓] Saved batch {batch_idx} → {output_path}")


print("Finished writing all batches!")

Saving dataloader batches to: test_json_8_batches
[✓] Saved batch 0 → test_json_8_batches\batch_00000.json
[✓] Saved batch 1 → test_json_8_batches\batch_00001.json
[✓] Saved batch 2 → test_json_8_batches\batch_00002.json
[✓] Saved batch 3 → test_json_8_batches\batch_00003.json
[✓] Saved batch 4 → test_json_8_batches\batch_00004.json
[✓] Saved batch 5 → test_json_8_batches\batch_00005.json
[✓] Saved batch 6 → test_json_8_batches\batch_00006.json
[✓] Saved batch 7 → test_json_8_batches\batch_00007.json
Finished writing all batches!


8 batches runs in 13m 52s

### Read in json data into trainable data using Custom DataLoader and Text Tokenizer

In [49]:
from torch.utils.data import DataLoader

# Build tokenizer & vocab from your saved JSONs
tokenizer_fn, vocab, my_pad_id = create_tokenizer(
    json_dir="test_json_8_batches",  # or wherever you dumped them
    min_freq=1                       # or >1 to prune rare tokens
)

print("PAD ID:", my_pad_id)
print("Vocab size:", len(vocab))

# Create dataset
dataset = ASLPoseJSONDataset(
    json_dir="test_json_8_batches",
    tokenizer=tokenizer_fn,
    max_frames=300,
    frame_subsample=2,
)

# Create dataloader
loader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0,
    collate_fn=lambda b: asl_collate_fn(b, pad_id=my_pad_id),
)


PAD ID: 0
Vocab size: 239


In [50]:
next(iter(loader))

{'pose': tensor([[[1435.8334,  310.0000,  398.3333,  ...,    0.0000,    0.0000,
              0.0000],
          [1435.8334,  310.0000,  398.3333,  ...,    0.0000,    0.0000,
              0.0000],
          [1435.8334,  310.0000,  398.3333,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],
 
         [[-160.0000,  310.0000,  398.3333,  ...,    0.0000,    0.0000,
              0.0000],
          [-160.0000,  310.0000,  398.3333,  ...,    0.0000,    0.0000,
              0.0000],
          [-160.0000,  310.0000,  398.3333,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [-160.0000,  310.0000,  398.3333,  ...,    0.0000,    0.0000,
              0.0000],
         

### Testing model

In [51]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader


class PoseToTextModel(nn.Module):
    def __init__(
        self,
        pose_dim: int,          # D: feature dimension of pose per frame
        enc_hidden: int,        # encoder GRU hidden size per direction
        vocab_size: int,        # |V|
        emb_dim: int,           # token embedding dim
        pad_id: int,
        num_enc_layers: int = 1,
        num_dec_layers: int = 1,
    ):
        super().__init__()
        self.pad_id = pad_id
        self.vocab_size = vocab_size

        # Encoder: Bi-GRU over pose sequence
        self.encoder = nn.GRU(
            input_size=pose_dim,
            hidden_size=enc_hidden,
            num_layers=num_enc_layers,
            batch_first=True,
            bidirectional=True,
        )

        # Decoder embedding
        self.emb = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=emb_dim,
            padding_idx=pad_id,
        )

        # Decoder GRU: hidden size = 2 * enc_hidden (concat directions)
        self.decoder = nn.GRU(
            input_size=emb_dim,
            hidden_size=2 * enc_hidden,
            num_layers=num_dec_layers,
            batch_first=True,
        )

        # Final projection to vocab
        self.out = nn.Linear(2 * enc_hidden, vocab_size)

    def encode(self, pose, pose_len):
        """
        pose: [B, T, D]
        pose_len: [B]
        Returns: encoder final hidden state [num_layers*2, B, H]
        """
        # Pack for efficient RNN
        packed = nn.utils.rnn.pack_padded_sequence(
            pose,
            lengths=pose_len.cpu(),
            batch_first=True,
            enforce_sorted=False,
        )
        enc_out, h_n = self.encoder(packed)
        # h_n: [num_layers*2, B, enc_hidden]
        return h_n

    def forward(self, pose, pose_len, labels):
        """
        pose:   [B, T, D]
        pose_len: [B]
        labels: [B, L]  (with <bos> ... <eos> and <pad>)

        We use teacher forcing:
          decoder inputs: labels[:, :-1]
          targets:        labels[:, 1:]
        Returns:
          logits: [B, L-1, vocab_size]
        """
        B, T, D = pose.shape
        B2, L = labels.shape
        assert B == B2

        # ---- Encode ----
        h_n = self.encode(pose, pose_len)  # [num_layers*2, B, enc_hidden]

        # Merge directions for final layer into a single initial hidden state
        # For simplicity, we only use last layer’s forward/backward
        # h_n_last: [2, B, enc_hidden] -> concat -> [1, B, 2*enc_hidden]
        num_layers_times_dir, B_enc, H = h_n.shape
        assert B_enc == B
        h_n_last = h_n[-2:]                 # [2, B, H] (last layer forward/backward)
        h0_dec = torch.cat(
            [h_n_last[0], h_n_last[1]], dim=-1
        ).unsqueeze(0)                      # [1, B, 2H]

        # ---- Decode with teacher forcing ----
        # decoder input is labels shifted right (all but last token)
        dec_inp = labels[:, :-1]            # [B, L-1]
        emb = self.emb(dec_inp)             # [B, L-1, emb_dim]

        dec_out, _ = self.decoder(emb, h0_dec)  # [B, L-1, 2H]
        logits = self.out(dec_out)              # [B, L-1, vocab_size]

        return logits


In [57]:
import math
from collections import Counter

def build_id_to_token(vocab: dict) -> dict:
    """vocab: {token: id} -> {id: token}"""
    return {idx: tok for tok, idx in vocab.items()}


def tokens_to_text(
    ids,
    id_to_token,
    pad_id: int,
    bos_token: str = "<bos>",
    eos_token: str = "<eos>",
):
    """
    Convert a sequence of token IDs into a space-separated string.
    Ignores <pad>, optionally removes <bos>, and stops at <eos>.
    """
    tokens = []
    for i in ids:
        if int(i) == pad_id:
            continue
        tok = id_to_token.get(int(i), "<unk>")
        if tok == bos_token:
            continue
        if tok == eos_token:
            break
        tokens.append(tok)
    return " ".join(tokens)


def bleu1(pred_tokens, ref_tokens):
    """
    Simple BLEU-1 (unigram BLEU) with brevity penalty.
    pred_tokens, ref_tokens: lists of tokens (strings).
    """
    if len(pred_tokens) == 0:
        return 0.0

    pred_counts = Counter(pred_tokens)
    ref_counts = Counter(ref_tokens)
    overlap = sum(min(pred_counts[w], ref_counts[w]) for w in pred_counts)

    precision = overlap / len(pred_tokens)

    # brevity penalty
    ref_len = len(ref_tokens)
    pred_len = len(pred_tokens)
    if pred_len == 0:
        return 0.0
    if pred_len > ref_len:
        bp = 1.0
    else:
        bp = math.exp(1.0 - ref_len / pred_len)

    return bp * precision


def rouge1_f1(pred_tokens, ref_tokens):
    """
    Very simple ROUGE-1 F1 (over unigrams).
    """
    if not pred_tokens or not ref_tokens:
        return 0.0

    pred_counts = Counter(pred_tokens)
    ref_counts = Counter(ref_tokens)

    overlap = sum(min(pred_counts[w], ref_counts[w]) for w in pred_counts)

    precision = overlap / len(pred_tokens)
    recall = overlap / len(ref_tokens)

    if precision + recall == 0:
        return 0.0

    return 2 * precision * recall / (precision + recall)


In [58]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pose_dim = batch["pose"].shape[-1]
vocab_size = len(vocab)

model = PoseToTextModel(
    pose_dim=pose_dim,
    enc_hidden=256,
    vocab_size=vocab_size,
    emb_dim=256,
    pad_id=my_pad_id,
).to(device)

loss_fn = nn.CrossEntropyLoss(ignore_index=my_pad_id)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

id_to_token = build_id_to_token(vocab)

# if you used these special tokens when building vocab:
bos_id = vocab.get("<bos>", None)
eos_id = vocab.get("<eos>", None)


In [59]:
num_epochs = 20

for epoch in range(1, num_epochs + 1):
    model.train()
    total_loss = 0.0
    total_tokens = 0

    total_bleu = 0.0
    total_rouge = 0.0
    total_sentences = 0

    for batch in loader:
        pose = batch["pose"].to(device)          # [B, T, D]
        pose_len = batch["pose_len"].to(device)  # [B]
        labels = batch["labels"].to(device)      # [B, L]

        # ----- forward -----
        logits = model(pose, pose_len, labels)   # [B, L-1, V]

        # Targets are labels shifted left
        target = labels[:, 1:]                   # [B, L-1]

        B, Lm1, V = logits.shape
        loss = loss_fn(
            logits.reshape(B * Lm1, V),
            target.reshape(B * Lm1),
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # ----- accumulate loss (per non-pad token) -----
        with torch.no_grad():
            non_pad = (target != my_pad_id).sum().item()
            non_pad = max(non_pad, 1)
            total_loss += loss.item() * non_pad
            total_tokens += non_pad

            # ----- compute BLEU / ROUGE for this batch -----
            # Greedy predictions: argmax over vocab
            pred_ids_batch = logits.argmax(dim=-1)   # [B, L-1]
            ref_ids_batch = target                   # [B, L-1]

            for b in range(B):
                pred_ids = pred_ids_batch[b].tolist()
                ref_ids = ref_ids_batch[b].tolist()

                # Convert id sequences to token sequences (strings)
                pred_str = tokens_to_text(
                    pred_ids,
                    id_to_token,
                    pad_id=my_pad_id,
                    bos_token="<bos>",
                    eos_token="<eos>",
                )
                ref_str = tokens_to_text(
                    ref_ids,
                    id_to_token,
                    pad_id=my_pad_id,
                    bos_token="<bos>",
                    eos_token="<eos>",
                )

                pred_toks = pred_str.split()
                ref_toks = ref_str.split()

                if len(ref_toks) == 0:
                    continue  # skip entirely empty reference

                b_bleu = bleu1(pred_toks, ref_toks)
                b_rouge = rouge1_f1(pred_toks, ref_toks)

                total_bleu += b_bleu
                total_rouge += b_rouge
                total_sentences += 1

    avg_loss = total_loss / max(total_tokens, 1)
    avg_bleu = total_bleu / max(total_sentences, 1)
    avg_rouge = total_rouge / max(total_sentences, 1)

    print(
        f"Epoch {epoch} | "
        f"avg token loss: {avg_loss:.4f} | "
        f"BLEU-1: {avg_bleu:.4f} | ROUGE-1 F1: {avg_rouge:.4f}"
    )


Epoch 1 | avg token loss: 5.3558 | BLEU-1: 0.0561 | ROUGE-1 F1: 0.0740
Epoch 2 | avg token loss: 4.7113 | BLEU-1: 0.1413 | ROUGE-1 F1: 0.2191
Epoch 3 | avg token loss: 4.1272 | BLEU-1: 0.1299 | ROUGE-1 F1: 0.1901
Epoch 4 | avg token loss: 3.6539 | BLEU-1: 0.2273 | ROUGE-1 F1: 0.2987
Epoch 5 | avg token loss: 3.2051 | BLEU-1: 0.3029 | ROUGE-1 F1: 0.3381
Epoch 6 | avg token loss: 2.8048 | BLEU-1: 0.3726 | ROUGE-1 F1: 0.3973
Epoch 7 | avg token loss: 2.4183 | BLEU-1: 0.4627 | ROUGE-1 F1: 0.4767
Epoch 8 | avg token loss: 2.0040 | BLEU-1: 0.5471 | ROUGE-1 F1: 0.5625
Epoch 9 | avg token loss: 1.6619 | BLEU-1: 0.6141 | ROUGE-1 F1: 0.6295
Epoch 10 | avg token loss: 1.3469 | BLEU-1: 0.7407 | ROUGE-1 F1: 0.7407
Epoch 11 | avg token loss: 1.0729 | BLEU-1: 0.7973 | ROUGE-1 F1: 0.7973
Epoch 12 | avg token loss: 0.8664 | BLEU-1: 0.8447 | ROUGE-1 F1: 0.8447
Epoch 13 | avg token loss: 0.6858 | BLEU-1: 0.8784 | ROUGE-1 F1: 0.8784
Epoch 14 | avg token loss: 0.5726 | BLEU-1: 0.8807 | ROUGE-1 F1: 0.8807
E