In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm

import cached_conv as cc

cc.use_cached_conv(True)


def WNConv1d(*args, **kwargs):
    return weight_norm(nn.Conv1d(*args, **kwargs))

def WNConv1dCached(*args, **kwargs):
    return weight_norm(cc.Conv1d(*args, **kwargs))

wn_conv1d = WNConv1d(5, 10, 7, padding=3)
wn_conv1d_cached = WNConv1dCached(5, 10, 7, padding=3)

torch.manual_seed(0)

weight_g_data = torch.randn(10, 1, 1)
weight_v_data = torch.randn(10, 5, 7)
bias_data = torch.randn(10)
wn_conv1d.weight_g.data = weight_g_data
wn_conv1d.weight_v.data = weight_v_data
wn_conv1d.bias.data = bias_data

data = torch.randn(20, 5, 1000)
out = wn_conv1d(data)

wn_conv1d_cached.weight_g.data = weight_g_data
wn_conv1d_cached.weight_v.data = weight_v_data
wn_conv1d_cached.bias.data = bias_data

out_cached = wn_conv1d_cached(data)

print(torch.allclose(out[..., :-wn_conv1d_cached.cumulative_delay], out_cached[..., wn_conv1d_cached.cumulative_delay:], atol=1e-6))
print(out[..., :-wn_conv1d_cached.cumulative_delay])
print(out_cached[..., wn_conv1d_cached.cumulative_delay:])

True
tensor([[[ 0.2391,  0.2759,  0.8915,  ...,  2.2475,  1.3730,  2.1704],
         [ 0.5690,  0.3283,  0.4935,  ...,  0.8048,  0.7730,  0.4228],
         [-2.4467,  0.4928, -3.4065,  ..., -1.7856, -1.5958, -0.1326],
         ...,
         [-2.7717, -0.6683, -1.0477,  ..., -0.8070, -1.1878, -0.1473],
         [ 0.9553,  0.9808,  0.7493,  ...,  0.8611,  1.7939,  1.8238],
         [ 0.8837,  0.9023,  1.0396,  ...,  1.5872,  0.4577,  1.1960]],

        [[ 0.7317,  2.8882,  1.5236,  ...,  3.4878,  4.8952,  0.2520],
         [ 0.4875,  0.8073,  0.5325,  ...,  0.0925, -0.2831,  0.4675],
         [-1.6555, -2.3700, -6.8651,  ..., -0.8240, -1.1611,  0.9850],
         ...,
         [-2.3175,  0.3974, -1.1541,  ..., -2.0387, -2.5608, -2.1537],
         [ 0.1785,  0.1062,  0.9603,  ...,  1.4923,  0.2642,  1.1658],
         [ 0.6244,  0.9748,  0.4891,  ...,  1.1547,  2.0384,  0.3623]],

        [[ 1.5471,  2.3142, -0.0321,  ...,  3.3907,  1.0151,  1.6523],
         [ 0.4586,  0.5339,  0.6181,  ..

In [12]:
cc.use_cached_conv(True)

wn_conv1d = WNConv1d(1, 1, 3, padding=1)
wn_conv1d_cached = WNConv1dCached(1, 1, 3, padding=cc.get_padding(3))

torch.manual_seed(0)

data = torch.randn(1, 1, 6)

weight_g_data = torch.randn(1, 1, 1)
weight_v_data = torch.randn(1, 1, 3)
bias_data = torch.randn(1)
wn_conv1d.weight_g.data = weight_g_data
wn_conv1d.weight_v.data = weight_v_data
wn_conv1d.bias.data = bias_data

wn_conv1d_cached.weight_g.data = weight_g_data
wn_conv1d_cached.weight_v.data = weight_v_data
wn_conv1d_cached.bias.data = bias_data

chunk_size = 3
res = []
for i in range(0, data.shape[-1], chunk_size):
    # Slice along the sequence dimension and apply conv1d to each chunk
    tok = wn_conv1d_cached(data[..., i:i + chunk_size])
    res.append(tok)

# Concatenate all the chunks along the sequence dimension
chunked_output = torch.cat(res, dim=-1)

# Direct (non-chunked) computation for comparison
non_chunked_output = wn_conv1d(data)

print(chunked_output.shape, non_chunked_output.shape)
print(chunked_output)
print(non_chunked_output)

wn_conv1d_cached.cumulative_delay

torch.Size([1, 1, 6]) torch.Size([1, 1, 6])
tensor([[[-0.8099, -0.9363,  0.2203, -0.2220, -1.2132,  0.0279]]],
       grad_fn=<CatBackward0>)
tensor([[[-0.9363,  0.2203, -0.2220, -1.2132,  0.0279, -0.5633]]],
       grad_fn=<ConvolutionBackward0>)


  WeightNorm.apply(module, name, dim)


1

In [1]:
import dac

# Monkey patching the DAC class to use cc.Conv1d instead of nn.Conv1d

# Download a model
model_path = dac.utils.download(model_type="44khz")

import numpy as np
import torch
import time

dac.DAC.enable_streaming(True)
model = dac.DAC.load(model_path).to("cpu")
delay = model.encoder_cumulative_delay

torch.set_printoptions(precision=5, sci_mode=False)

# set numpy random seed
np.random.seed(0)

# Load audio signal file
silence = np.random.randn(*(1, 1, 51200)).astype(np.float32)
data = torch.tensor(silence).to("cpu")

res = []

for i in range(0, data.shape[-1], 5120):
    # Slice along the sequence dimension and apply conv1d to each chunk
    tok = model.encode(data[..., i:i + 5120])
    res.append(tok)

# Concatenate all the chunks along the sequence dimension
out = torch.cat(res, dim=-1)

out = model.encode(data)
print(model.encoder_cumulative_delay)
out = out[..., delay:]
print(out.shape)
print(out[..., delay:-delay])

print("=============================")
print("=============================")


dac.DAC.enable_streaming(False)
model = dac.DAC.load(model_path).to("cpu")

# #print all model parameters
# for name, param in model.named_parameters():
#     print(name, param.shape)

# set numpy random seed
np.random.seed(0)

# Load audio signal file
silence = np.random.randn(*(1, 1, 51200)).astype(np.float32)
data = torch.tensor(silence).to("cpu")

out = model.encode(data)
out = out[..., :-delay]
print(model.encoder_cumulative_delay)
print(out.shape)
print(out[..., delay:-delay])

  model_dict = torch.load(location, "cpu")
  WeightNorm.apply(module, name, dim)


8
torch.Size([1, 1024, 92])
tensor([[[  8.98867, -14.47900, -52.20496,  ...,   5.15977, -28.10490,
           21.93532],
         [ -4.08285,  84.29540, 104.61837,  ...,  25.47849,  59.95752,
            1.54037],
         [-40.37833, -17.11209, -24.56766,  ...,  23.47040, -51.99257,
          -25.59319],
         ...,
         [-29.15557,  18.42965,  50.19261,  ..., -16.10726,  -0.41827,
          -19.61252],
         [ 43.76177, -21.42388,   0.79145,  ...,  -2.60167,  12.37717,
           71.11745],
         [-40.73623,  23.79416,  74.62721,  ..., -23.58209,  28.23978,
          -25.59929]]], grad_fn=<SliceBackward0>)
0
torch.Size([1, 1024, 92])
tensor([[[  8.98873, -14.47901, -52.20497,  ...,   5.15975, -28.10491,
           21.93531],
         [ -4.08285,  84.29539, 104.61831,  ...,  25.47848,  59.95754,
            1.54039],
         [-40.37836, -17.11202, -24.56760,  ...,  23.47041, -51.99262,
          -25.59317],
         ...,
         [-29.15559,  18.42960,  50.19261,  ..., -1