In [8]:
import argparse
import time

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from utils import bleu1, rouge1_f1
from code.models.downsampled_stacked_autoencoder import EncoderDecoderModel
from data import *
from code.models.modules import AugMask

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [10]:
import argparse
import torch
from torch.utils.data import DataLoader


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load dataset
print("Using device:", device)
ds = PrecomputedASLData(
    "C:/Users/victo/Documents/Brain Growth/Machine Learning/Sign2Text/data/precomputed_train")
print(f"Loaded precomputed dataset with {len(ds)} samples")

vocab = ds.vocab
pad_id = ds.pad_id
print(f"Vocab size: {len(vocab)} | pad_id={pad_id}")

# Dataloader
loader = DataLoader(
    ds,
    batch_size=8,
    shuffle=True,
    num_workers=0,
    pin_memory=torch.cuda.is_available(),
    collate_fn=lambda b: asl_collate_func(b, pad_id=pad_id),
)

# Grab a single batch
batch = next(iter(loader))
features = batch["features"].to(device)         # [B, T, D]
lengths = batch["feature_len"].to(device)       # [B]

B, T, D = features.shape
print(f"Batch features: {features.shape} | lengths: {lengths.shape}")
print(f"Length stats: min={int(lengths.min())} max={int(lengths.max())} (T={T})")

# Instantiate mask layer (force it to actually do something by setting probs=1.0)
aug = AugMask(
    time_probs=1.0,             # force on for smoke test
    body_probs=1.0,             # force on
    num_time_masks=2,
    mask_time_frac=0.10,
    mask_part_frac=0.10,
    mask_method="zero",
).to(device)

# IMPORTANT: augmentation uses self.training. Make sure it's in train mode.
aug.train()

# Run augmentation
with torch.no_grad():
    out = aug(features, lengths)

# --- Quick correctness checks / prints ---

# 1) Shape preserved
assert out.shape == features.shape, "Output shape changed!"

# 2) Padding untouched: for each sample, frames >= length must remain identical
#    (If your precomputed padding is zeros, this also ensures it's still zero.)
pad_ok = True
for b in range(B):
    L = int(lengths[b].item())
    if L < T:
        if not torch.allclose(out[b, L:], features[b, L:]):
            pad_ok = False
            print(f"❌ Padding mismatch in sample {b}: L={L}, T={T}")
            break
print("Padding untouched:", "✓" if pad_ok else "❌")

# 3) Something changed in valid region (since probs=1.0)
valid_mask = (torch.arange(T, device=device).unsqueeze(0) < lengths.unsqueeze(1))  # [B, T]
changed = (out != features) & valid_mask.unsqueeze(-1)
num_changed = int(changed.sum().item())
num_valid = int((valid_mask.sum().item()) * D)
print(f"Changed elements (valid frames only): {num_changed} / {num_valid}")

if num_changed == 0:
    print("❌ Nothing changed. Check that aug.train() is set and probs > 0.")
else:
    print("✓ Augmentation is modifying the batch.")

# 4) Optional: show per-sample % of frames fully zeroed (time masks)
#    This is a rough indicator: if an entire frame vector becomes 0, that's a time mask hit.
frame_energy = out.abs().sum(dim=-1)  # [B, T]
zero_frames = (frame_energy == 0) & valid_mask
for b in range(min(B, 5)):
    z = int(zero_frames[b].sum().item())
    L = int(lengths[b].item())
    print(f"Sample {b}: zeroed frames={z}/{L} ({(z/L*100 if L else 0):.1f}%)")

print("Smoke test complete.")

Using device: cpu
Using device: cpu
Loaded precomputed dataset with 5417 samples
Vocab size: 15127 | pad_id=0
Batch features: torch.Size([8, 300, 225]) | lengths: torch.Size([8])
Length stats: min=36 max=300 (T=300)
Padding untouched: ✓
Changed elements (valid frames only): 28719 / 299475
✓ Augmentation is modifying the batch.
Sample 0: zeroed frames=35/300 (11.7%)
Sample 1: zeroed frames=4/122 (3.3%)
Sample 2: zeroed frames=10/70 (14.3%)
Sample 3: zeroed frames=10/244 (4.1%)
Sample 4: zeroed frames=6/36 (16.7%)
Smoke test complete.
