In [2]:
from src.training.util.training import load_checkpoint_for_inference, load_config, create_model
from src.data.util.tokenization import detokenize
from src.data.util.codes_interleaving import remove_delay_interleaving
import torch
import torchaudio
from typing import Tuple
from pathlib import Path

device = torch.device("cpu")
config = load_config("../experiments/models/241025_xsmall/xsmall.yaml")
model = create_model(config).to(device)
model.eval()
load_checkpoint_for_inference(model, "../experiments/models/241025_xsmall/checkpoints/checkpoint_interrupt.pt", device)

test_ds = torch.load("../experiments/data/241025/train.pt", weights_only=False)
len(test_ds)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading checkpoint from ../experiments/models/241025_xsmall/checkpoints/checkpoint_interrupt.pt
Checkpoint info:
  Epoch: 799
  Step: 3196
  Metrics: {'train_loss': 9.319572563981637e-05, 'val_loss': 26.893381118774414}


59

In [3]:
SAMPLE_NUMBER = 43

def get_generated_audio(
    model: torch.nn.Module,
    example: Tuple[torch.Tensor, torch.Tensor],
    save_dir: str = "./outputs/",
    greedy: bool = True,
    **sampling_args,
):
    src, tgt = example
    src = src.to(device).unsqueeze(0)
    tgt = tgt.to(device)
    save_dir = Path(save_dir)
    start_tokens = torch.tensor(
        [[2048 for _ in range(4)]],
        dtype=torch.int,
        device=device,
    )
    if greedy:
        out = model.generate_greedy(
            src,
            max_len=1505,
            start_tokens=start_tokens,
        )
    else:
        out = model.generate(
            src,
            max_len=1505,
            start_tokens=start_tokens,
            **sampling_args,
        )
    num_nums = 1
    for i in out.shape:
        num_nums *= i
    print(f"num entries correct: {(out[0] == tgt).sum().item()} / {num_nums - 10}")
    print(out)
    src = detokenize(remove_delay_interleaving(src[0]))
    out = detokenize(remove_delay_interleaving(out[0]))
    tgt = detokenize(remove_delay_interleaving(tgt))
    
    torchaudio.save(save_dir / "tgt_model.wav", out, 32000)
    torchaudio.save(save_dir / "src.wav", src, 32000)
    torchaudio.save(save_dir / "tgt_ground_truth.wav", tgt, 32000)

    torchaudio.save(
        save_dir / "combined_model.wav",
        torch.clip(out + src, -1.0, 1.0),
        32000,
    )
    torchaudio.save(
        save_dir / "combined_ground_truth.wav",
        torch.clip(tgt + src, -1.0, 1.0), 
        32000,
    )
    
get_generated_audio(model, test_ds[SAMPLE_NUMBER])

num entries correct: 6010 / 6010
tensor([[[2048,  289,  289,  ...,   83,   83,   83],
         [2048, 2048,  831,  ..., 2044, 2044, 2044],
         [2048, 2048, 2048,  ..., 2019, 2019, 2019],
         [2048, 2048, 2048,  ..., 1770, 1770, 1770]]])


