In [7]:
from src.training.util.training import load_checkpoint_for_inference, load_config, create_model
from src.data.util.tokenization import detokenize
from src.data.util.codes_interleaving import remove_delay_interleaving
import torch
import torchaudio
from typing import Tuple
from pathlib import Path

device = torch.device("cpu")
config = load_config("../config/model/medium.yaml")
model = create_model(config).to(device)
model.eval()
load_checkpoint_for_inference(model, "../../../Downloads/model/checkpoints/latest.pt", device)

test_ds = torch.load("../data/main/7_datasets/train.pt", weights_only=False)
len(test_ds)

Loading checkpoint from ../../../Downloads/model/checkpoints/latest.pt
Checkpoint info:
  Epoch: 248
  Step: 31125
  Metrics: {'train_loss': 2.151264190673828, 'train_accuracy': 0.4707900881767273, 'val_loss': 6.298453330993652, 'val_accuracy': 0.11502083390951157}


18688

In [39]:
SAMPLE_NUMBER = 100

def get_generated_audio(
    model: torch.nn.Module,
    example: Tuple[torch.Tensor, torch.Tensor],
    save_dir: str = "./outputs/",
    greedy: bool = True,
    teacher_forcing: bool = False,
    **sampling_args,
):
    src, tgt = example
    src = src.to(device).unsqueeze(0)
    tgt = tgt.to(device)
    save_dir = Path(save_dir)
    start_tokens = torch.tensor(
        [[2048 for _ in range(4)]],
        dtype=torch.int,
        device=device,
    )

    if teacher_forcing:
        out = model(src, tgt[:, :-1].unsqueeze(0)).argmax(dim=-1, keepdim=False)
        out = torch.cat(
            [torch.tensor([[[2048] for _ in range(4)]]), out],
            dim=-1
        )
    elif greedy:
        out = model.generate_greedy(
            src,
            max_len=1505,
            start_tokens=start_tokens,
        )
    else:
        out = model.generate(
            src,
            max_len=1505,
            start_tokens=start_tokens,
            **sampling_args,
        )
    num_nums = 1
    for i in out.shape:
        num_nums *= i
    print(f"num entries correct: {(out[0] == tgt).sum().item()} / {num_nums - 10}")
    print(out)
    tokens = out
    src = detokenize(remove_delay_interleaving(src[0]))
    out = detokenize(remove_delay_interleaving(out[0]))
    tgt = detokenize(remove_delay_interleaving(tgt))
    
    torchaudio.save(save_dir / "tgt_model.wav", out, 32000)
    torchaudio.save(save_dir / "src.wav", src, 32000)
    torchaudio.save(save_dir / "tgt_ground_truth.wav", tgt, 32000)

    torchaudio.save(
        save_dir / "combined_model.wav",
        torch.clip(out + src, -1.0, 1.0),
        32000,
    )
    torchaudio.save(
        save_dir / "combined_ground_truth.wav",
        torch.clip(tgt + src, -1.0, 1.0), 
        32000,
    )
    return tokens
    
out = get_generated_audio(model, test_ds[SAMPLE_NUMBER], greedy=False)

num entries correct: 301 / 6010
tensor([[[2048, 1667, 1667,  ...,   83,   83,   83],
         [2048, 2048, 1758,  ...,  938, 1955, 1497],
         [2048, 2048, 2048,  ...,  924, 1938, 1938],
         [2048, 2048, 2048,  ..., 1595, 1854, 2040]]])


In [40]:
out[:, :, :100]

tensor([[[2048, 1667, 1667, 1318,  793,   66,   83, 1131,  248,   50,  457,
          1913, 1364,  932,  840,  865, 1627,  685, 2036, 2031,  494, 1479,
           520, 1667,  965, 1702,  520,  494,   66,   83,   83,  166,  166,
           166,  166,  166,  166,   83,  965, 1388, 2036,  954, 1667,  443,
            83, 1757,    8, 1373, 1044, 1389, 1479,  237, 1506,  974, 1648,
           573, 1814, 1685,  629, 2016, 1685, 1900,  894,  875, 1771,  730,
          1877, 1429,  178,  178, 1778, 1390,  277, 1188,  952, 1188,  277,
           178, 1119, 1074,  941,  624, 1303,  457, 1132,  484,  952,  178,
          1188, 1550,  634, 1989,  378, 1557, 1557, 1104,   35,  443,  506,
          1074],
         [2048, 2048, 1758, 1469, 1671, 1583,  654, 1394,  826, 1889, 1113,
           902,  334, 1143,  789, 1583, 1336,  235, 1482, 1394, 1955, 1137,
          1583, 1275, 1941, 1102, 1793, 1583, 1085, 1394, 1931, 1793, 2044,
           154, 2044, 2044, 1497, 2044, 1931, 1881, 1931, 1938, 1394, 1