In [1]:
from src.training.util.checkpointing import load_checkpoint_for_inference
from src.training.util.config_model import load_config, create_model
from src.data.util.tokenization import detokenize
from src.data.util.codes_interleaving import remove_delay_interleaving
import torch
import torchaudio
from typing import Tuple
from pathlib import Path

device = torch.device("cpu")
config = load_config("../config/model/medium.yaml")
model = create_model(config).to(device)
model.eval()
load_checkpoint_for_inference(model, "../../../Downloads/model/checkpoints/latest.pt", device)

test_ds = torch.load("../data/main/7_datasets/train.pt", weights_only=False)
len(test_ds)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading checkpoint from ../../../Downloads/model/checkpoints/latest.pt
Checkpoint info:
  Epoch: 945
  Step: 118250
  Metrics: {'train_loss': 1.6159923076629639, 'train_accuracy': 0.5756251215934753, 'val_loss': 7.692844390869141, 'val_accuracy': 0.10429166257381439}


18688

In [10]:
SAMPLE_NUMBER = 80

def get_generated_audio(
    model: torch.nn.Module,
    example: Tuple[torch.Tensor, torch.Tensor],
    save_dir: str = "./outputs/",
    greedy: bool = True,
    teacher_forcing: bool = False,
    **sampling_args,
):
    src, tgt = example
    src = src.to(device).unsqueeze(0)
    tgt = tgt.to(device)
    save_dir = Path(save_dir)
    start_tokens = torch.tensor(
        [[2048 for _ in range(4)]],
        dtype=torch.int,
        device=device,
    )

    if teacher_forcing:
        out = model(src, tgt[:, :-1].unsqueeze(0)).argmax(dim=-1, keepdim=False)
        out = torch.cat(
            [torch.tensor([[[2048] for _ in range(4)]]), out],
            dim=-1
        )
    elif greedy:
        out = model.generate_greedy(
            src,
            max_len=1505,
            start_tokens=start_tokens,
        )
    else:
        out = model.generate(
            src,
            max_len=1505,
            start_tokens=start_tokens,
            **sampling_args,
        )
    num_nums = 1
    for i in out.shape:
        num_nums *= i
    print(f"num entries correct: {(out[0] == tgt).sum().item()} / {num_nums - 10}")
    print(out)
    tokens = out
    src = detokenize(remove_delay_interleaving(src[0]))
    out = detokenize(remove_delay_interleaving(out[0]))
    tgt = detokenize(remove_delay_interleaving(tgt))
    
    torchaudio.save(save_dir / "tgt_model.wav", out, 32000)
    torchaudio.save(save_dir / "src.wav", src, 32000)
    torchaudio.save(save_dir / "tgt_ground_truth.wav", tgt, 32000)

    torchaudio.save(
        save_dir / "combined_model.wav",
        torch.clip(out + src, -1.0, 1.0),
        32000,
    )
    torchaudio.save(
        save_dir / "combined_ground_truth.wav",
        torch.clip(tgt + src, -1.0, 1.0), 
        32000,
    )
    return tokens
    
out = get_generated_audio(model, test_ds[SAMPLE_NUMBER], greedy=False)

num entries correct: 265 / 6010
tensor([[[2048,  297,  289,  ...,   83,   83,   83],
         [2048, 2048, 1672,  ..., 2044, 2044, 2044],
         [2048, 2048, 2048,  ..., 2019, 2019, 2019],
         [2048, 2048, 2048,  ..., 1770, 1770,   11]]])


In [11]:
out[:, :, :100]

tensor([[[2048,  297,  289,  289,  609,  248,    8,  609,  443,   83,  289,
            83,   83,   83,   83,   83,   83,  166,   83,   83,   83,   83,
            83,   83,   83,   83,   83,   83,   83,   83,   83,   83,   83,
            83,   83,   83,   83,   83,   83,   83,   83,   83,   83,   83,
            83, 1436, 1600,  967, 1915, 1729, 1194, 1044,  658, 1869, 1019,
          1652,  472,  289,  233,  233, 1883, 1322,  624, 1701, 1100,  624,
          1697, 1458, 1827, 1106, 1126, 1303, 1688,  782, 2047,  918,  310,
           113,  310, 1712,  310,  100,  881,  654,  810, 1518,  632, 1279,
           632, 1615,  551, 1976, 1424, 1389,   74,  904,  974,  974, 1680,
          1322],
         [2048, 2048, 1672, 1595, 1972,  959, 2044,  449,  923, 1600, 2044,
          1987, 2044, 2044, 1931, 2044, 1931, 1931, 2044, 1931, 1931, 1931,
          1931, 1931, 1931, 1931, 1931, 1931, 1931, 1931, 1931, 1931, 1931,
          1931, 1931, 2044, 2044, 2044, 2044, 2044, 2044, 2044, 2044, 2