In [1]:
from src.training.util.checkpointing import load_checkpoint_for_inference
from src.training.util.config_model import load_config, create_model
from src.data.util.tokenization import detokenize
from src.data.util.codes_interleaving import remove_delay_interleaving
import torch
import torchaudio
from typing import Tuple
from pathlib import Path

device = torch.device("cuda")
config = load_config("../config/model/large.yaml")
model = create_model(config).to(device)
model.eval()
load_checkpoint_for_inference(model, "../../../Downloads/kaggle_output/model/checkpoints/latest.pt", device)

test_ds = torch.load("../data/main/7_datasets/test.pt", weights_only=False)
len(test_ds)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading checkpoint from ../../../Downloads/kaggle_output/model/checkpoints/latest.pt
Checkpoint info:
  Epoch: 390
  Step: 228344
  Metrics: {'train_loss': 2.6828079223632812, 'train_accuracy': 0.3786788582801819, 'val_loss': 4.793875694274902, 'val_accuracy': 0.21112987399101257}


91

In [4]:
SAMPLE_NUMBER = 27

def get_generated_audio(
    model: torch.nn.Module,
    example: Tuple[torch.Tensor, torch.Tensor],
    save_dir: str = "./outputs/",
    greedy: bool = True,
    teacher_forcing: bool = False,
    **sampling_args,
):
    src, tgt = example
    src = src.to(device).unsqueeze(0)
    tgt = tgt.to(device)
    save_dir = Path(save_dir)
    start_tokens = torch.tensor(
        [[2048 for _ in range(4)]],
        dtype=torch.int,
        device=device,
    )

    if teacher_forcing:
        out = model(src, tgt[:, :-1].unsqueeze(0)).argmax(dim=-1, keepdim=False)
        out = torch.cat(
            [torch.tensor([[[2048] for _ in range(4)]]), out],
            dim=-1
        )
    elif greedy:
        out = model.generate_greedy(
            src,
            max_len=1505,
            start_tokens=start_tokens,
        )
    else:
        out = model.generate(
            src,
            max_len=1505,
            start_tokens=start_tokens,
            **sampling_args,
        )
    num_nums = 1
    for i in out.shape:
        num_nums *= i
    print(f"num entries correct: {(out[0] == tgt).sum().item()} / {num_nums - 10}")
    print(out)
    tokens = out
    src = detokenize(remove_delay_interleaving(src[0].cpu()))
    out = detokenize(remove_delay_interleaving(out[0].cpu()))
    tgt = detokenize(remove_delay_interleaving(tgt.cpu()))
    
    torchaudio.save(save_dir / "tgt_model.wav", out, 32000)
    torchaudio.save(save_dir / "src.wav", src, 32000)
    torchaudio.save(save_dir / "tgt_ground_truth.wav", tgt, 32000)

    torchaudio.save(
        save_dir / "combined_model.wav",
        torch.clip(out + src, -1.0, 1.0),
        32000,
    )
    torchaudio.save(
        save_dir / "combined_ground_truth.wav",
        torch.clip(tgt + src, -1.0, 1.0), 
        32000,
    )
    return tokens
    
out = get_generated_audio(model, test_ds[SAMPLE_NUMBER], greedy=False, temperature=0.7)

num entries correct: 97 / 6010
tensor([[[2048,  798,  932,  ...,   83,   83,   83],
         [2048, 2048, 1931,  ..., 1931, 1931, 1931],
         [2048, 2048, 2048,  ..., 2019, 2019, 2019],
         [2048, 2048, 2048,  ..., 1854, 1770, 1770]]], device='cuda:0')


In [83]:
out[:, :, :100]

tensor([[[2048,   83,   83,   83,   83,   83,   83,   83,   83,   83,   83,
            83,   83,   83,   83,   83,   83,   83,   83,   83,   83,   83,
            83,   83,   83,   83,   83,   83,   83,   83,   83,   83,   83,
            83,   83,   83,   83,   83,   83,   83,   83,   83,   83,  778,
           237, 1600, 1126,  457,  457,  457,  457, 1221,  457,  457, 1103,
          1103, 1103, 1111, 1111, 1111,  960, 1257, 1103, 1744, 1744, 1744,
           903, 1744, 1103, 1544, 1544, 1639,  361,  813, 1111,  879,  311,
           624,  311,  624, 1103, 1231, 1472, 1201,  457,  624, 1201, 1701,
           311, 1697, 1201, 1103,  879, 1103, 1103, 1103, 1103, 1103, 1103,
          1103],
         [2048, 2048, 1931, 1931, 1931, 1931, 1931, 2044, 2044, 1931, 2044,
          2044, 2044, 2044, 2044, 1931, 2044, 2044, 2044, 2044, 2044, 2044,
          2044, 2044, 2044, 2044, 2044, 2044, 2044, 2044, 2044, 2044, 2044,
          2044, 1931, 2044, 1931, 2044, 2044, 2044, 2044, 1489, 2044, 2