In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import cProfile, pstats, io

# One WaveNet-style block with causal, dilated convolution
class NSynthWaveNetBlock(nn.Module):
    def __init__(self, channels, dilation):
        super().__init__()
        self.conv_filter = nn.Conv1d(channels, channels, kernel_size=2, dilation=dilation)
        self.conv_gate = nn.Conv1d(channels, channels, kernel_size=2, dilation=dilation)
        self.residual = nn.Conv1d(channels, channels, kernel_size=1)
        self.skip = nn.Conv1d(channels, channels, kernel_size=1)
        self.dilation = dilation

    def forward(self, x):
        x_padded = F.pad(x, (self.dilation, 0))  # causal
        z = torch.tanh(self.conv_filter(x_padded)) * torch.sigmoid(self.conv_gate(x_padded))

        res = self.residual(z)
        skip = self.skip(z)

        # Trim input to match res size
        x_trimmed = x[..., -res.shape[-1]:]
        return x_trimmed + res, skip

# Full NSynth-style decoder
class NSynthDecoder(nn.Module):
    def __init__(self, in_channels=256, hidden_channels=128, num_layers=30):
        super().__init__()
        self.input_proj = nn.Conv1d(in_channels, hidden_channels, 1)
        dilations = [2 ** (i % 10) for i in range(num_layers)]  # [1, 2, ..., 512] × 3
        self.blocks = nn.ModuleList([
            NSynthWaveNetBlock(hidden_channels, d) for d in dilations
        ])
        self.output_proj = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(hidden_channels, hidden_channels, 1),
            nn.ReLU(),
            nn.Conv1d(hidden_channels, in_channels, 1)
        )
        self.receptive_field = sum(dilations) + 1

    def forward(self, x):
        x = self.input_proj(x)
        skip_total = None
        for block in self.blocks:
            x, skip = block(x)
            skip_total = skip if skip_total is None else skip_total + skip
        return self.output_proj(skip_total)

# -----------------------------
# Run a single forward pass with profiling
# -----------------------------
# Dummy input: batch=1, time=1024, one-hot encoded to 256 channels
x = F.one_hot(torch.randint(0, 256, (1, 1024)), num_classes=256).float()
x = x.permute(0, 2, 1)  # (B, C, T)

model = NSynthDecoder()

# Profile the forward pass
pr = cProfile.Profile()
pr.enable()

output = model(x)

pr.disable()
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
ps.print_stats(20)  # Top 20 slowest functions
print(s.getvalue())


         1555 function calls (1243 primitive calls) in 0.215 seconds

   Ordered by: cumulative time
   List reduced from 51 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000    0.215    0.107 /usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py:3512(run_code)
        2    0.000    0.000    0.215    0.107 {built-in method builtins.exec}
        1    0.000    0.000    0.215    0.215 <ipython-input-21-02187589da75>:1(<cell line: 0>)
    157/1    0.001    0.000    0.215    0.215 /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1735(_wrapped_call_impl)
    157/1    0.002    0.000    0.215    0.215 /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1743(_call_impl)
        1    0.005    0.005    0.215    0.215 <ipython-input-21-02187589da75>:45(forward)
       30    0.015    0.001    0.203    0.007 <ipython-input-21-02187589da75>:17(forward)
      123    0.0