In [None]:
import re
import copy
import numpy as np
import torch
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from argparse import Namespace
import fairseq
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.dataclass.configs import GenerationConfig
utils.import_user_module(Namespace(user_dir="/home/ssever/SilentSpeak/ext/av_hubert/avhubert"))

In [None]:
def decode_tokens(tokens, task, gen):
    """
    Handles SP ('▁'), GPT-2 ('Ġ'), subword-nmt ('@@'),
    letter/char labels (single-space between chars, multi-space between words),
    and '|' as the space symbol.
    """
    dictionary = task.target_dictionary
    ignore = set(getattr(gen, "symbols_to_strip_from_output", []))
    ignore.add(dictionary.pad())

    # 1) ids -> interim string of symbols (space-separated)
    s = dictionary.string(tokens.int().cpu(), extra_symbols_to_ignore=ignore)

    # 2) Heuristic detok by scheme
    if "▁" in s:
        # SentencePiece: remove separator spaces, turn '▁' into spaces
        s = s.replace(" ", "")
        s = s.replace("▁", " ").strip()
    elif "@@" " " in s or "@@" in s:
        # subword-nmt: remove continuation markers
        s = s.replace("@@ ", "").replace("@@", "")
        s = re.sub(r"\s{2,}", " ", s).strip()
    elif "Ġ" in s:
        # GPT-2/BPE: 'Ġ' marks a space before the token
        s = s.replace("Ġ", " ")
        s = re.sub(r"\s{2,}", " ", s).strip()
    elif "|" in s:
        # ltr-style vocab where '|' means space
        s = s.replace(" ", "")      # remove char separators
        s = s.replace("|", " ").strip()
    else:
        # Character/letter labels: single spaces inside words, multiple between words.
        # Remove single spaces between word chars, keep multi-spaces as word boundaries, then collapse.
        s = re.sub(r'(?<!\s)\s(?!\s)', '', s)  # kill lone intraword spaces
        s = re.sub(r'\s{2,}', ' ', s).strip()  # collapse multi-spaces to one

    # 3) Punctuation tidy-ups
    s = re.sub(r"\s*(['’`-])\s*", r"\1", s)        # that ' s -> that's ; co - op -> co-op
    s = re.sub(r"\s+([,.?!:;])", r"\1", s)         # remove space before punctuation
    s = re.sub(r"\s{2,}", " ", s).strip()
    
    return s

In [14]:
def predict(model_path, npy_path, fps=25, win_sec=10.0, hop_sec=9.0,
                    beam=5, max_len_b=200, no_repeat_ngram_size=2):
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([model_path])
    model = models[0].eval().to(device)

    # Load npy -> [T, H, W] or [T, C, H, W]
    arr = np.load(npy_path)
    if arr.ndim == 3:
        arr = arr[:, None, :, :]
    elif arr.ndim == 4 and arr.shape[1] != 1:
        arr = arr.mean(axis=1, keepdims=True)  # to 1-channel

    mean, std = 0.421, 0.165    # normalization parameters
    x = torch.from_numpy(arr).float()
    x = (x - mean) / (std + 1e-8)
    x = x.unsqueeze(0)                # [1, T, 1, H, W]
    x = x.permute(0, 2, 1, 3, 4)      # --> [1, 1, T, H, W]  (B, C, T, H, W)
    x = x.contiguous()
    x = x.to(device)

    T = x.shape[2]
    
    # Introspect limits
    try:
        print("Task max positions:", getattr(task, "max_positions", lambda: "unknown")())
    except Exception:
        pass
    print("Saved gen cfg:", saved_cfg.generation)

    # Build generator
    gen_args = copy.deepcopy(saved_cfg.generation)
    gen_args.beam = beam
    # Change some model configuration:
    if hasattr(gen_args, "max_len_b"): gen_args.max_len_b = max_len_b
    if hasattr(gen_args, "max_len_a"): gen_args.max_len_a = 0
    if hasattr(gen_args, "no_repeat_ngram_size"): gen_args.no_repeat_ngram_size = no_repeat_ngram_size
    gen = task.build_generator([model], gen_args)

    # Compute chunk indices
    win = int(round(win_sec * fps))
    hop = int(round(hop_sec * fps))
    if win <= 0 or hop <= 0:
        raise ValueError("win_sec and hop_sec must be > 0")

    pieces = []
    start = 0
    
    # Chunked inference
    while start < T:
        end = min(start + win, T)
        x_chunk = x[:, :, start:end, :, :]
        pad_mask = torch.zeros((1, end - start), dtype=torch.bool, device=device)

        sample = {
            "id": torch.tensor([0], device=device),
            "net_input": {
                "source": {"audio": None, 
                           "video": x_chunk},
                "padding_mask": pad_mask,
            },
        }

        with torch.no_grad():
            hypos = task.inference_step(gen, [model], sample)
            best = hypos[0][0]
            tokens = best.get("tokens", None)
            if tokens is not None:
                txt = decode_tokens(tokens, task, gen)
                pieces.append(txt)
            elif "words" in best:
                # Some checkpoints return words
                pieces.append(" ".join(best["words"]))
            else:
                pieces.append("")

        if end == T:
            break
        start += hop

    # Drop obvious duplicates at chunk seams
    text = " ".join(pieces)
    text = re.sub(r"\s+", " ", text).strip()
    return text

In [None]:
#model_path = "/home/ssever/SilentSpeak/models/self_large_vox_433h.pt"
model_path = "/home/ssever/SilentSpeak/model/base_vox_433h.pt"
#npy_path = "/home/ssever/SilentSpeak/data/preprocessed_files/preproc_out/3255112-uhd_3840_2160_25fps_frames.npy"
#npy_path = "/home/ssever/SilentSpeak/data/preprocessed_files/preproc_out_new/avhubert_demo_video_8s_frames.npy"
npy_path = "/home/ssever/SilentSpeak/data/preprocessed_files/video1/How To Talk To Camera_ The 3 FUNDAMENTALS_frames.npy"

lp_text = predict(model_path, npy_path)

### **Debugging**

In [None]:
for m in models:
    m.eval().cuda()
    if not hasattr(m, "num_updates"):
        m.num_updates = 0
    for sub in m.modules():
        if not hasattr(sub, "num_updates"):
            sub.num_updates = 0
model = models[0]

In [None]:
#video_data = torch.from_numpy(np.load(npy_path)).float()
#video_data = video_data.unsqueeze(0).cuda()
#video_data.shape

arr = np.load(npy_path)
print(arr.shape)
arr = arr[:, None, :, :]  
print(arr.shape)
arr = torch.from_numpy(arr).float()
print(arr.shape)
mean, std = 0.421, 0.165
arr = (arr - mean) / (std + 1e-8)
print(arr.shape)
arr = arr.unsqueeze(0).cuda()
print(arr.shape)
trans = torch.tensor([arr.shape[1]], device='cuda')
print(trans)

In [None]:
data = np.load(npy_path)
print(f"Shape: {data.shape}")
print(f"Type: {type(data)}")
print(f"Dtype: {data.dtype}")
print(f"First few values: {data[:5] if data.ndim == 1 else data[0, :5]}")

In [None]:
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([model_path])
#print("ckpt task in cfg:", saved_cfg.task._name)

models = [model.eval().to("cuda") for model in models]
model = models[0]

In [None]:
print(saved_cfg.generation)  # shows all fields the checkpoint expects

In [None]:
arr.shape

In [None]:
x.shape

In [None]:
# Check the model configuration
print("Model config:", models[0].cfg)
print("Task config:", task.cfg)

# Check if there are any modality-specific settings
if hasattr(models[0], 'modalities'):
    print("Supported modalities:", models[0].modalities)

In [None]:
print(f"Model type: {type(model)}")
print(f"Model class: {model.__class__}")

# Check if model has the methods you're trying to use
print("Available methods:", [method for method in dir(model) if not method.startswith('_')])

In [None]:
print("Available methods:", [method for method in dir(task) if not method.startswith('_')])

In [None]:
print(f"Type of sample: {type(sample)}")
if isinstance(sample, dict):
    for key, value in sample.items():
        print(f"Key: {key}, Type: {type(value)}")

In [None]:
ni = sample["net_input"]
print("net_input keys:", ni.keys())
print("src_tokens type:", type(ni["src_tokens"]))
print("src_tokens keys:", list(ni["src_tokens"].keys()))
print("video shape:", tuple(ni["src_tokens"]["video"].shape) if ni["src_tokens"]["video"] is not None else None)
print("audio:", ni["src_tokens"]["audio"])


### **Backup**

In [None]:
arr = np.load(npy_path)                  
if arr.ndim == 3:
    arr = arr[:, None, :, :]            
elif arr.ndim == 4 and arr.shape[1] != 1:
    arr = arr.mean(axis=1, keepdims=True)

mean, std = 0.421, 0.165
x = torch.from_numpy(arr).float()
x = (x - mean) / (std + 1e-8)
x = x.unsqueeze(0)                # [1, T, 1, H, W]
x = x.permute(0, 2, 1, 3, 4)      # --> [1, 1, T, H, W]  (B, C, T, H, W)
x = x.contiguous()

device = "cuda" if torch.cuda.is_available() else "cpu"
x = x.to(device)

T = x.shape[2]                    # T is now dim 2 after permute
padding_mask = torch.zeros((1, T), dtype=torch.bool, device=x.device)


sample = {
    "id": torch.tensor([0]),
    "net_input": {
        "source": {
            "audio": None,
            "video": x,
        },
        "padding_mask": padding_mask, # some paths use this
    }
}

gen_args = GenerationConfig(beam=20)
generator = task.build_generator(models, gen_args)

with torch.no_grad():
    hypos = task.inference_step(generator, models, sample)
    
# 4) Decode
tgt_dict = getattr(task, "target_dictionary", None) or getattr(models[0].decoder, "dictionary", None)
best = hypos[0][0]
tokens = best["tokens"].tolist()  # tiny list of ints
text = tgt_dict.string(tokens, extra_symbols_to_ignore=set(["<pad>", "<s>", "</s>", "<ctc_blank>"]))
print("VSR text:", text)