In [19]:
from src.training.utils import load_checkpoint_for_inference, load_config, create_model
from src.data.tokenization import detokenize
from src.data.dataset_util import remove_delay_interleaving
import torch
import torchaudio
from typing import Tuple
from pathlib import Path

device = torch.device("cpu")
config = load_config("../config/model/experiment.yaml")
model = create_model(config).to(device)
model.eval()
load_checkpoint_for_inference(model, "../models/experiment/checkpoints/checkpoint_epoch_350.pt", device)

test_ds = torch.load("../data/main/datasets/test.pt", weights_only=False)
len(test_ds)

Loading checkpoint from ../models/experiment/checkpoints/checkpoint_epoch_350.pt
Checkpoint info:
  Epoch: 349
  Step: 1400
  Metrics: {'train_loss': 0.00027132392957961805, 'train_accuracy': 0.9999861121177673, 'val_loss': 9.738258361816406, 'val_accuracy': 0.21627314388751984}


1396

In [21]:
SAMPLE_NUMBER = 2

def get_generated_audio(
    model: torch.nn.Module,
    example: Tuple[torch.Tensor, torch.Tensor],
    save_dir: str = "./outputs/",
    greedy: bool = True,
    **sampling_args,
):
    src, tgt = example
    src = src.to(device).unsqueeze(0)
    tgt = tgt.to(device)
    save_dir = Path(save_dir)
    start_tokens = torch.tensor(
        [[2048 for _ in range(4)]],
        dtype=torch.int,
        device=device,
    )
    if greedy:
        out = model.generate_greedy(
            src,
            max_len=1505,
            start_tokens=start_tokens,
        )
    else:
        out = model.generate(
            src,
            max_len=1505,
            start_tokens=start_tokens,
            **sampling_args,
        )
    num_nums = 1
    for i in out.shape:
        num_nums *= i
    print(f"num entries correct: {(out[0] == tgt).sum().item()} / {num_nums - 10}")
    print(out)
    src = detokenize(remove_delay_interleaving(src[0]))
    out = detokenize(remove_delay_interleaving(out[0]))
    tgt = detokenize(remove_delay_interleaving(tgt))
    
    torchaudio.save(save_dir / "tgt_model.wav", out, 32000)
    torchaudio.save(save_dir / "src.wav", src, 32000)
    torchaudio.save(save_dir / "tgt_ground_truth.wav", tgt, 32000)

    torchaudio.save(
        save_dir / "combined_model.wav",
        torch.clip(out + src, -1.0, 1.0),
        32000,
    )
    torchaudio.save(
        save_dir / "combined_ground_truth.wav",
        torch.clip(tgt + src, -1.0, 1.0),
        32000,
    )
    
get_generated_audio(model, test_ds[SAMPLE_NUMBER])

num entries correct: 77 / 6010
tensor([[[2048,  166,  166,  ...,   83,   83,   83],
         [2048, 2048, 1931,  ..., 1931, 2044, 1931],
         [2048, 2048, 2048,  ..., 2019, 2019, 2019],
         [2048, 2048, 2048,  ..., 1770, 1951,  779]]])


# 