In [None]:
import os
import torch
import torchaudio
from pytorch_lightning import Trainer
from dataset import DATASET_MAPPING, DATASET_PATHS
from interface import INTERFACE_MAPPING
from model import Net, get_backbone, get_classifier
from sound2synth import Sound2SynthModel
from pathlib import Path
from pyheaven import *
from dataset.chains import SimpleSynth, Synplant2
from interface.torchsynth import REG_NCLASS
from torchsynth.config import SynthConfig


def load_checkpoint(checkpoint_path, args):
    interface = INTERFACE_MAPPING[args.synth]
    net = Net(
        backbone=get_backbone(args.backbone, args),
        classifier=get_classifier(args.classifier, interface, args),
    )
    model = Sound2SynthModel.load_from_checkpoint(checkpoint_path, net=net, interface=interface, args=args)
    return model

def save_wavs(audio, path, name):
    os.makedirs(path, exist_ok=True)
    filepath = os.path.join(path, f"{name}.wav")
    torchaudio.save(filepath, audio, 48000)
    
def parse_outputs(outputs, unnormalizer, logits, from_0to1):
    n_params = len(unnormalizer)
    # print(n_params)
    if logits:
        outputs = torch.stack(outputs.squeeze(0).chunk(n_params)).argmax(dim=1)/(REG_NCLASS-1)
    
    mydict = {}
    assert len(unnormalizer) == len(outputs), f"Length mismatch: {len(unnormalizer)} vs {len(outputs)}"
    for (k, f), v in zip(unnormalizer.items(), outputs):
        if 'keyboard' in k:
            continue # Skip keyboard parameters as they are initialized frozen
        mydict[k] = f.from_0to1(v.unsqueeze(0)) if from_0to1 else v.unsqueeze(0)
    return mydict

def dict_to_audio(mydict, synth):
    synth.set_parameters(mydict)
    # print(synth.keyboard.torchparameters.midi_f0)
    audio = synth.output()
    return audio

# def main():
if True:
    checkpoint_path = './checkpoints/SimpleSynth_audioLoss/last.ckpt'  # Replace with the actual path to the checkpoint
    output_dir = "output_wavs"
    
    args = MemberDict(dict(
        synth="TorchSynth",
        synth_class="SimpleSynth",
        backbone="multimodal",
        classifier="parameter",
        dataset_type="my_multimodal",
        dataset="SimpleSynth",
        feature_dim=2048,
    ))
    
    dataset = DATASET_MAPPING[args.dataset_type](dir='data/'+args.dataset, chain=args.synth_class, split='train')
    
    synth = dataset.synth
    unnormalizer = dataset.unnormalizer

    model = load_checkpoint(checkpoint_path, args)
    model.eval()

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
    
    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            if idx >= 5:  # Iterate through 5 samples
                break
            (features, sample_rate), label = batch
            outputs = model(features)
            true_params = label[0].squeeze(0)
            # print((true_params.max() <= 1).item())
            parsed_outputs = parse_outputs(outputs, unnormalizer, logits=True, from_0to1=True)
            parsed_true_params = parse_outputs(true_params, unnormalizer, logits=False, from_0to1=(true_params.max() <= 1).item())

            print(parsed_outputs)
            print(parsed_true_params)

            reconstructions = dict_to_audio(parsed_outputs, synth)
            ground_truth = dict_to_audio(parsed_true_params, synth)
            
            # Save ground truth and reconstructions as WAV files
            save_wavs(ground_truth, output_dir, f"ground_truth_{idx}")
            save_wavs(reconstructions, output_dir, f"reconstruction_{idx}")
            print(f"Saved sample {idx}")

# main()

In [14]:
%load_ext autoreload
%autoreload 2
from dataset.chains import SimpleSynth
from torchsynth.config import SynthConfig
import torch
import torch.nn.functional as F
from pyheaven import *
from dataset import DATASET_MAPPING, DATASET_PATHS
from interface.torchsynth import REG_NCLASS
from utils.loss_utils import AudioLoss

args = MemberDict(dict(
        synth="TorchSynth",
        synth_class="SimpleSynth",
        backbone="multimodal",
        classifier="parameter",
        dataset_type="my_multimodal",
        dataset="SimpleSynth",
        feature_dim=2048,
))

dataset = DATASET_MAPPING[args.dataset_type](dir='data/'+args.dataset, chain=args.synth_class, split='train')

unnormalizer = dataset.unnormalizer

class DifferentiableArgmax(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        idx = input.argmax(dim=1)
        output = F.one_hot(idx, num_classes=input.size(1)).float()
        ctx.save_for_backward(input, output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, output = ctx.saved_tensors
        grad_input = grad_output * output
        return grad_input

def parse_outputs(outputs, unnormalizer, logits, from_0to1):
    n_params = len(unnormalizer)
    # print(n_params)
    if logits:
        outputs = torch.stack(outputs.squeeze(0).chunk(n_params, dim=-1))
        # print("Grad function after stacking:", outputs.grad_fn)
        
        # Apply DifferentiableArgmax
        outputs_one_hot = DifferentiableArgmax.apply(outputs)
        # print("Grad function after DifferentiableArgmax:", outputs_one_hot.grad_fn)
        
        # Convert one-hot back to index and normalize
        outputs = outputs_one_hot.argmax(dim=1).float() / (REG_NCLASS - 1)
        # print("Grad function after conversion and normalization:", outputs.grad_fn)
        
        # Ensure outputs requires grad
        if not outputs.requires_grad:
            outputs = outputs.detach().requires_grad_()
        
        # print("Final grad function:", outputs.grad_fn)
    
    mydict = {}
    assert len(unnormalizer) == len(outputs), f"Length mismatch: {len(unnormalizer)} vs {len(outputs)}"
    for (k, f), v in zip(unnormalizer.items(), outputs):
        if 'keyboard' in k:
            continue # Skip keyboard parameters as they are initialized frozen
        mydict[k] = f.from_0to1(v.unsqueeze(0)) if from_0to1 else v.unsqueeze(0)
    return mydict

random_params = torch.rand((32, 704), requires_grad=True).to('cuda')
model = torch.nn.Linear(704, 704).to('cuda')
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = AudioLoss(scales=[64, 128, 256, 512, 1024], synth='SimpleSynth')

outputs = torch.rand((32,704), requires_grad=True).to('cuda')
# params = parse_outputs(outputs, unnormalizer, logits=True, from_0to1=True)

# synth = SimpleSynth(SynthConfig(batch_size=1, sample_rate=48000, reproducible=False, no_grad=False))
# print('Params: ', params)
# synth.set_parameters(params)
# set_params = synth.get_parameters()
# for k, v in params.items():
#     print(v)

for i in range(1000):
    # output = synth.output()
    output = model(outputs)
    loss = loss_fn(output, random_params)
    loss.backward()
    optim.step()
    optim.zero_grad()
    
    if i % 1 == 0:
        print(f"Iteration {i}, Loss: {loss.item()}")


[autoreload of utils.loss_utils failed: Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/sound2synth/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/home/ubuntu/anaconda3/envs/sound2synth/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 500, in superreload
    update_generic(old_obj, new_obj)
  File "/home/ubuntu/anaconda3/envs/sound2synth/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 397, in update_generic
    update(a, b)
  File "/home/ubuntu/anaconda3/envs/sound2synth/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 349, in update_class
    if update_generic(old_obj, new_obj):
  File "/home/ubuntu/anaconda3/envs/sound2synth/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 397, in update_generic
    update(a, b)
  File "/home/ubuntu/anaconda3/envs/sound2synth/lib/python3.9/site-packages/IPython/ext

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
start  True
end  True
start  True
end  True
Iteration 0, Loss: 29.13892936706543
start  True
end  True
start  True
end  True
Iteration 1, Loss: 28.324331283569336
start  True
end  True
start  True
end  True
Iteration 2, Loss: 28.363880157470703
start  True
end  True
start  True
end  True
Iteration 3, Loss: 28.853649139404297
start  True
end  True
start  True
end  True
Iteration 4, Loss: 29.501190185546875
start  True
end  True
start  True
end  True


KeyboardInterrupt: 

In [11]:
def softargmax1d(input, beta=100):
    *_, n = input.shape
    input = torch.nn.functional.softmax(beta * input, dim=-1)
    indices = torch.linspace(0, 1, n)
    result = torch.sum((n - 1) * input * indices, dim=-1)
    return result

test = torch.randn((1, 10))
print(test)
print(softargmax1d(test))

tensor([[-0.6636,  0.2846, -1.4506,  1.2309, -0.2987,  0.5017,  1.1112,  0.5043,
          1.6322, -0.2264]])
tensor([8.])


In [3]:
true = torch.rand((1, 704), requires_grad=True).to('cuda')
outputs = torch.rand((1,704), requires_grad=True).to('cuda')
outputs = parse_outputs(outputs, unnormalizer, logits=True, from_0to1=True)
true = parse_outputs(true, unnormalizer, logits=True, from_0to1=True)

synth = SimpleSynth(SynthConfig(batch_size=1, sample_rate=48000, reproducible=False, no_grad=False))
# print('Params: ', params)
synth.set_parameters(outputs)
pred = synth.output()
synth.set_parameters(true)
true = synth.output()

import matplotlib.pyplot as plt
fig, ax = plt.subplots(2, 1, figsize=(10, 10))
ax[0].plot(pred[0].detach().cpu().numpy())
ax[1].plot(true[0].detach().cpu().numpy())

tensor([0.5079], device='cuda:0', grad_fn=<UnsqueezeBackward0>)


AttributeError: 'Tensor' object has no attribute 'data_with_grad'

In [10]:
import torch
import torch.nn as nn

def traverse_grad_fn(grad_fn):
    if grad_fn is None:
        return
    ops = [grad_fn]
    while grad_fn is not None:
        ops.append(grad_fn)
        if hasattr(grad_fn, 'next_functions'):
            for f in grad_fn.next_functions:
                if f[0] is not None:
                    ops.extend(traverse_grad_fn(f[0]))
        grad_fn = None
    return ops

# class CustomLayer(nn.Module):
#     def __init__(self):
#         super(CustomLayer, self).__init__()
#         self.param = nn.Parameter(torch.tensor(1.0, requires_grad=True))

#     def forward(self, x):
#         return self.param * x

# a = torch.tensor(1., requires_grad=True)
# b = torch.tensor(2., requires_grad=True)
# c = a + b

# layer = CustomLayer()
# d = layer(c)

# type(d)
operations = traverse_grad_fn(output.grad_fn)
print(len(operations))
for op in operations:
    print(op)


664
<AliasBackward0 object at 0x7f68ea1b6220>
<AliasBackward0 object at 0x7f68ea1b6220>
<MulBackward0 object at 0x7f68ea1b65e0>
<MulBackward0 object at 0x7f68ea1b65e0>
<AliasBackward0 object at 0x7f68ea1b6250>
<AliasBackward0 object at 0x7f68ea1b6250>
<MulBackward0 object at 0x7f68ea1b6b50>
<MulBackward0 object at 0x7f68ea1b6b50>
<MulBackward0 object at 0x7f68ea1b6af0>
<MulBackward0 object at 0x7f68ea1b6af0>
<RsubBackward1 object at 0x7f68ea1b6340>
<RsubBackward1 object at 0x7f68ea1b6340>
<DivBackward0 object at 0x7f68ea1b68b0>
<DivBackward0 object at 0x7f68ea1b68b0>
<UnsqueezeBackward0 object at 0x7f68ea1b6d60>
<UnsqueezeBackward0 object at 0x7f68ea1b6d60>
<AddBackward0 object at 0x7f68ea1b6880>
<AddBackward0 object at 0x7f68ea1b6880>
<MulBackward0 object at 0x7f68ea1b62e0>
<MulBackward0 object at 0x7f68ea1b62e0>
<AccumulateGrad object at 0x7f68ea1b6970>
<AccumulateGrad object at 0x7f68ea1b6970>
<TanhBackward0 object at 0x7f68ea1b63d0>
<TanhBackward0 object at 0x7f68ea1b63d0>
<DivBack

In [36]:
%load_ext autoreload
%autoreload 2

from utils.loss_utils import AudioLoss

pred_params = torch.randn(1, 64*44)

loss_fn = AudioLoss(scales=[512,1024], synth='Synplant2')

loss, stfts = loss_fn(pred_params, true_params)
# for stft in stfts:
#     print(f'Max: {stft[0].max()}, Min: {stft[0].min()}')
print(loss)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Signal(4.9198, grad_fn=<AliasBackward0>)
