In [7]:
import torch
from torchsynth.config import SynthConfig
from dataset import chains

config = SynthConfig(batch_size=1, sample_rate=48000, reproducible=False)
synth = chains.SimpleSynth(config)
data = torch.load('/home/ubuntu/Sound2Synth/data/SimpleSynth/train/processed/sound_0.pt')

In [20]:
param, _ = data['label']
unnormalized_param = []
for p ,(k, v) in zip(param,synth.get_parameters().items()):
    unnormalizer = v.parameter_range
    unnormalized = unnormalizer.from_0to1(p)
    unnormalized_param.append(unnormalized)
print(torch.tensor(unnormalized_param))

tensor([ 4.8038e-01,  8.0411e-01,  7.6822e-01,  3.9141e-02,  1.9138e+00,
         6.0000e+01,  3.0000e+00, -7.2531e+00, -5.2643e-04,  2.4909e+00,
         6.3231e-01])


In [2]:
import os
import torch
import torchaudio
from pytorch_lightning import Trainer
from dataset import DATASET_MAPPING, DATASET_PATHS
from interface import INTERFACE_MAPPING
from model import Net, get_backbone, get_classifier
from sound2synth import Sound2SynthModel
from pathlib import Path
from pyheaven import *
from dataset.chains import SimpleSynth, Synplant2
from interface.torchsynth import REG_NCLASS
from torchsynth.config import SynthConfig


def load_checkpoint(checkpoint_path, args):
    interface = INTERFACE_MAPPING[args.synth]
    net = Net(
        backbone=get_backbone(args.backbone, args),
        classifier=get_classifier(args.classifier, interface, args),
    )
    model = Sound2SynthModel.load_from_checkpoint(checkpoint_path, net=net, interface=interface, args=args)
    return model

def save_wavs(audio, path, name):
    os.makedirs(path, exist_ok=True)
    filepath = os.path.join(path, f"{name}.wav")
    torchaudio.save(filepath, audio, 48000)
    
def parse_outputs(outputs, unnormalizer, logits, from_0to1):
    if logits:
        outputs = torch.stack(outputs.squeeze(0).chunk(78)).argmax(dim=1)/(REG_NCLASS-1)
    
    mydict = {}
    assert len(unnormalizer) == len(outputs), f"Length mismatch: {len(unnormalizer)} vs {len(outputs)}"
    for (k, f), v in zip(unnormalizer.items(), outputs):
        if 'keyboard' in k:
            continue # Skip keyboard parameters as they are initialized frozen
        mydict[k] = f.from_0to1(v.unsqueeze(0)) if from_0to1 else v.unsqueeze(0)
    return mydict

def dict_to_audio(mydict, synth):
    synth.set_parameters(mydict)
    audio = synth.output()
    return audio

def main():
    checkpoint_path = './checkpoints/Voice_unnorm/ckpt-epoch=01-valid_celoss=0.82.ckpt'  # Replace with the actual path to the checkpoint
    output_dir = "output_wavs"
    
    args = MemberDict(dict(
        synth="TorchSynth",
        synth_class="Voice",
        backbone="multimodal",
        classifier="parameter",
        dataset_type="my_multimodal",
        dataset="Voice",
        feature_dim=2048,
    ))
    
    dataset = DATASET_MAPPING[args.dataset_type](dir='data/'+args.dataset, chain=args.synth_class, split='train')

    unnormalizer = dataset.unnormalizer

    model = load_checkpoint(checkpoint_path, args)
    model.eval()

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
    
    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            if idx >= 5:  # Iterate through 5 samples
                break
            (features, sample_rate), label = batch
            outputs = model(features)
            true_params = label[0].squeeze(0)
            # print((true_params.max() <= 1).item())
            parsed_outputs = parse_outputs(outputs, unnormalizer, logits=True, from_0to1=True)
            parsed_true_params = parse_outputs(true_params, unnormalizer, logits=False, from_0to1=(true_params.max() <= 1).item())

            reconstructions = dict_to_audio(parsed_outputs, synth)
            ground_truth = dict_to_audio(parsed_true_params, synth)
            
            # Save ground truth and reconstructions as WAV files
            save_wavs(ground_truth, output_dir, f"ground_truth_{idx}")
            save_wavs(reconstructions, output_dir, f"reconstruction_{idx}")
            print(f"Saved sample {idx}")

main()

AttributeError: 'list' object has no attribute 'items'

In [10]:
torch.load('/home/ubuntu/Sound2Synth/data/Synplant2/train/processed/sound_0.pt')['label'][0]

tensor([0.4724, 0.8657, 0.1759, 0.2698, 0.0317, 0.1507, 0.9320, 0.6341, 0.4901,
        0.4556, 0.8964, 0.3074, 0.2081, 0.7231, 0.9298, 0.2437, 0.0332, 0.5263,
        0.7423, 0.5846, 0.6323, 0.4017, 0.3489, 0.2939, 0.6977, 0.1689, 0.0223,
        0.5185, 0.8742, 0.3971, 0.9152, 0.8000, 0.2823, 0.6816, 0.1610, 0.3051,
        0.3734, 0.1852, 0.4194, 0.9527, 0.0362, 0.5529, 0.0885, 0.1320])

In [3]:
from dataset.chains import Voice
from torchsynth.config import SynthConfig

params = []
synth = Voice(SynthConfig(batch_size=1, sample_rate=48000, reproducible=False))
for k, v in synth.named_parameters():
    k = k.split('.')
    k = (k[0], k[-1])
    params.append(k)
# print(params)

[('keyboard', 'midi_f0'), ('keyboard', 'duration'), ('adsr_1', 'attack'), ('adsr_1', 'decay'), ('adsr_1', 'sustain'), ('adsr_1', 'release'), ('adsr_1', 'alpha'), ('adsr_2', 'attack'), ('adsr_2', 'decay'), ('adsr_2', 'sustain'), ('adsr_2', 'release'), ('adsr_2', 'alpha'), ('lfo_1', 'frequency'), ('lfo_1', 'mod_depth'), ('lfo_1', 'initial_phase'), ('lfo_1', 'sin'), ('lfo_1', 'tri'), ('lfo_1', 'saw'), ('lfo_1', 'rsaw'), ('lfo_1', 'sqr'), ('lfo_2', 'frequency'), ('lfo_2', 'mod_depth'), ('lfo_2', 'initial_phase'), ('lfo_2', 'sin'), ('lfo_2', 'tri'), ('lfo_2', 'saw'), ('lfo_2', 'rsaw'), ('lfo_2', 'sqr'), ('lfo_1_amp_adsr', 'attack'), ('lfo_1_amp_adsr', 'decay'), ('lfo_1_amp_adsr', 'sustain'), ('lfo_1_amp_adsr', 'release'), ('lfo_1_amp_adsr', 'alpha'), ('lfo_2_amp_adsr', 'attack'), ('lfo_2_amp_adsr', 'decay'), ('lfo_2_amp_adsr', 'sustain'), ('lfo_2_amp_adsr', 'release'), ('lfo_2_amp_adsr', 'alpha'), ('lfo_1_rate_adsr', 'attack'), ('lfo_1_rate_adsr', 'decay'), ('lfo_1_rate_adsr', 'sustain'), (

In [None]:
output_dir = "output_wavs"
    
args = MemberDict(dict(
    synth="TorchSynth",
    backbone="multimodal",
    classifier="parameter",
    dataset_type="my_multimodal",
    dataset="Synplant2",
    feature_dim=2048,
))

dataset = DATASET_MAPPING[args.dataset_type](dir=DATASET_PATHS[args.dataset], split='test')

unnormalizer = {}
synth = SimpleSynth(SynthConfig(batch_size=1, sample_rate=48000, reproducible=False))
for k, v in synth.named_parameters():
    k = k.split('.')
    k = (k[0], k[-1])
    unnormalizer[k] = v.parameter_range

model = load_checkpoint(checkpoint_path, args)
model.eval()

dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)

for idx, batch in enumerate(dataloader):
    if idx >= 50:  # Iterate through 5 samples
        break
    (features, sample_rate), label = batch

    true_params = label[0]
    parsed_true_params = parse_outputs(true_params, unnormalizer, logits=False, from_0to1=(true_params.max() <= 1).item())
    ground_truth = dict_to_audio(parsed_true_params, synth)
    
    # Save ground truth and reconstructions as WAV files
    save_wavs(ground_truth, output_dir, f"ground_truth_{idx}")
    print(f"Saved sample {idx}")

In [None]:
from dataset.chains import SimpleSynth
from torchsynth.config import SynthConfig

unnormalizer = {}
synth = SimpleSynth(SynthConfig(batch_size=1, sample_rate=48000, reproducible=False))
for k, v in synth.named_parameters():
    k = k.split('.')
    k = (k[0], k[-1])
    unnormalizer[k] = v.parameter_range
unnormalizer
# next(synth.named_parameters())

In [3]:
from torchsynth.module import MonophonicKeyboard
from torchsynth.config import SynthConfig
import torch

params = {"midi_f0": torch.tensor([60.0]), "duration": torch.tensor([3.0])}
keyboard = MonophonicKeyboard(SynthConfig(batch_size=1, sample_rate=48000, reproducible=False), **params)

keyboard.get_parameter("midi_f0").from_0to1()

tensor([60.], grad_fn=<AddBackward0>)

In [19]:
params = []
for n in range(5):
    data = torch.load(f'/home/ubuntu/Sound2Synth/data/Synplant2/train/unprocessed/sound_{n}.pt')
    params.append(data['params'])
torch.stack(params).shape

torch.Size([5, 44])