In [None]:
import os
import torch
import torchaudio
from pytorch_lightning import Trainer
from dataset import DATASET_MAPPING, DATASET_PATHS
from interface import INTERFACE_MAPPING
from model import Net, get_backbone, get_classifier
from sound2synth import Sound2SynthModel
from pathlib import Path
from pyheaven import *
from dataset.chains import SimpleSynth, Synplant2
from interface.torchsynth import REG_NCLASS
from torchsynth.config import SynthConfig


def load_checkpoint(checkpoint_path, args):
    interface = INTERFACE_MAPPING[args.synth]
    net = Net(
        backbone=get_backbone(args.backbone, args),
        classifier=get_classifier(args.classifier, interface, args),
    )
    model = Sound2SynthModel.load_from_checkpoint(checkpoint_path, net=net, interface=interface, args=args)
    return model

def save_wavs(audio, path, name):
    os.makedirs(path, exist_ok=True)
    filepath = os.path.join(path, f"{name}.wav")
    torchaudio.save(filepath, audio, 48000)
    
def parse_outputs(outputs, unnormalizer, logits, from_0to1):
    n_params = len(unnormalizer)
    # print(n_params)
    if logits:
        outputs = torch.stack(outputs.squeeze(0).chunk(n_params)).argmax(dim=1)/(REG_NCLASS-1)
    
    mydict = {}
    assert len(unnormalizer) == len(outputs), f"Length mismatch: {len(unnormalizer)} vs {len(outputs)}"
    for (k, f), v in zip(unnormalizer.items(), outputs):
        if 'keyboard' in k:
            continue # Skip keyboard parameters as they are initialized frozen
        mydict[k] = f.from_0to1(v.unsqueeze(0)) if from_0to1 else v.unsqueeze(0)
    return mydict

def dict_to_audio(mydict, synth):
    synth.set_parameters(mydict)
    # print(synth.keyboard.torchparameters.midi_f0)
    audio = synth.output()
    return audio

# def main():
if True:
    checkpoint_path = './checkpoints/SimpleSynth_audioLoss/last.ckpt'  # Replace with the actual path to the checkpoint
    output_dir = "output_wavs"
    
    args = MemberDict(dict(
        synth="TorchSynth",
        synth_class="SimpleSynth",
        backbone="multimodal",
        classifier="parameter",
        dataset_type="my_multimodal",
        dataset="SimpleSynth",
        feature_dim=2048,
    ))
    
    dataset = DATASET_MAPPING[args.dataset_type](dir='data/'+args.dataset, chain=args.synth_class, split='train')
    
    synth = dataset.synth
    unnormalizer = dataset.unnormalizer

    model = load_checkpoint(checkpoint_path, args)
    model.eval()

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
    
    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            if idx >= 5:  # Iterate through 5 samples
                break
            (features, sample_rate), label = batch
            outputs = model(features)
            true_params = label[0].squeeze(0)
            # print((true_params.max() <= 1).item())
            parsed_outputs = parse_outputs(outputs, unnormalizer, logits=True, from_0to1=True)
            parsed_true_params = parse_outputs(true_params, unnormalizer, logits=False, from_0to1=(true_params.max() <= 1).item())

            print(parsed_outputs)
            print(parsed_true_params)

            reconstructions = dict_to_audio(parsed_outputs, synth)
            ground_truth = dict_to_audio(parsed_true_params, synth)
            
            # Save ground truth and reconstructions as WAV files
            save_wavs(ground_truth, output_dir, f"ground_truth_{idx}")
            save_wavs(reconstructions, output_dir, f"reconstruction_{idx}")
            print(f"Saved sample {idx}")

# main()

In [2]:
from dataset.chains import SimpleSynth
from torchsynth.config import SynthConfig
import torch
import torch.nn.functional as F
from pyheaven import *
from dataset import DATASET_MAPPING, DATASET_PATHS
from interface.torchsynth import REG_NCLASS

args = MemberDict(dict(
        synth="TorchSynth",
        synth_class="SimpleSynth",
        backbone="multimodal",
        classifier="parameter",
        dataset_type="my_multimodal",
        dataset="SimpleSynth",
        feature_dim=2048,
))

dataset = DATASET_MAPPING[args.dataset_type](dir='data/'+args.dataset, chain=args.synth_class, split='train')

unnormalizer = dataset.unnormalizer

class DifferentiableArgmax(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        idx = input.argmax(dim=1)
        output = F.one_hot(idx, num_classes=input.size(1)).float()
        ctx.save_for_backward(input, output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, output = ctx.saved_tensors
        grad_input = grad_output * output
        return grad_input

def parse_outputs(outputs, unnormalizer, logits, from_0to1):
    n_params = len(unnormalizer)
    # print(n_params)
    if logits:
        outputs = torch.stack(outputs.squeeze(0).chunk(n_params))
        # print("Grad function after stacking:", outputs.grad_fn)
        
        # Apply DifferentiableArgmax
        outputs_one_hot = DifferentiableArgmax.apply(outputs)
        # print("Grad function after DifferentiableArgmax:", outputs_one_hot.grad_fn)
        
        # Convert one-hot back to index and normalize
        outputs = outputs_one_hot.argmax(dim=1).float() / (REG_NCLASS - 1)
        # print("Grad function after conversion and normalization:", outputs.grad_fn)
        
        # Ensure outputs requires grad
        if not outputs.requires_grad:
            outputs = outputs.detach().requires_grad_()
        
        # print("Final grad function:", outputs.grad_fn)
    
    mydict = {}
    assert len(unnormalizer) == len(outputs), f"Length mismatch: {len(unnormalizer)} vs {len(outputs)}"
    for (k, f), v in zip(unnormalizer.items(), outputs):
        if 'keyboard' in k:
            continue # Skip keyboard parameters as they are initialized frozen
        mydict[k] = f.from_0to1(v.unsqueeze(0)) if from_0to1 else v.unsqueeze(0)
    return mydict

outputs = torch.randn((1,704), requires_grad=True)
params = parse_outputs(outputs, unnormalizer, logits=True, from_0to1=True)

synth = SimpleSynth(SynthConfig(batch_size=1, sample_rate=48000, reproducible=False, no_grad=False))

synth.set_parameters(params)
set_params = synth.get_parameters()
for k, v in params.items():
    print(v)


Module grads, tensor([0.4724], requires_grad=True), tensor([0.4724], requires_grad=True)
Module grads, tensor([1.], requires_grad=True), tensor([1.], requires_grad=True)
Module grads, tensor([0.4724], requires_grad=True), tensor([0.4724], requires_grad=True)
Module grads, tensor([1.], requires_grad=True), tensor([1.], requires_grad=True)
Module grads, tensor([0.6508], grad_fn=<CloneBackward0>), tensor([0.6508], grad_fn=<CloneBackward0>)
Module grads, tensor([0.5079], grad_fn=<CloneBackward0>), tensor([0.5079], grad_fn=<CloneBackward0>)
Module grads, tensor([0.8571], grad_fn=<CloneBackward0>), tensor([0.8571], grad_fn=<CloneBackward0>)
Module grads, tensor([0.9683], grad_fn=<CloneBackward0>), tensor([0.9683], grad_fn=<CloneBackward0>)
Module grads, tensor([0.0794], grad_fn=<CloneBackward0>), tensor([0.0794], grad_fn=<CloneBackward0>)
Module grads, tensor([0.5714], grad_fn=<CloneBackward0>), tensor([0.5714], grad_fn=<CloneBackward0>)
Module grads, tensor([0.1587], grad_fn=<CloneBackward0

In [4]:
dataset.unnormalizer

{}

In [15]:
import torch
import torch.nn as nn

def traverse_grad_fn(grad_fn):
    if grad_fn is None:
        return
    ops = [grad_fn]
    while grad_fn is not None:
        ops.append(grad_fn)
        if hasattr(grad_fn, 'next_functions'):
            for f in grad_fn.next_functions:
                if f[0] is not None:
                    ops.extend(traverse_grad_fn(f[0]))
        grad_fn = None
    return ops

class CustomLayer(nn.Module):
    def __init__(self):
        super(CustomLayer, self).__init__()
        self.param = nn.Parameter(torch.tensor(1.0, requires_grad=True))

    def forward(self, x):
        return self.param * x

a = torch.tensor(1., requires_grad=True)
b = torch.tensor(2., requires_grad=True)
c = a + b

layer = CustomLayer()
d = layer(c)

type(d)
# operations = traverse_grad_fn(d.grad_fn)

# for op in operations:
#     print(op)


torch.Tensor

In [36]:
%load_ext autoreload
%autoreload 2

from utils.loss_utils import AudioLoss

pred_params = torch.randn(1, 64*44)

loss_fn = AudioLoss(scales=[512,1024], synth='Synplant2')

loss, stfts = loss_fn(pred_params, true_params)
# for stft in stfts:
#     print(f'Max: {stft[0].max()}, Min: {stft[0].min()}')
print(loss)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Signal(4.9198, grad_fn=<AliasBackward0>)
