In [10]:
import inspect
import dualcodec
import torch
import numpy as np
import torchaudio
import matplotlib.pyplot as plt
from utils import prepare_data
from IPython.display import Audio

In [None]:
model_id = "12hz_v1"
dualcodec_model = dualcodec.get_model(model_id)
dualcodec_inference = dualcodec.Inference(dualcodec_model=dualcodec_model, device="cuda")

In [57]:
print(inspect.getsource(dualcodec_inference.decode))
print(inspect.getsource(dualcodec_inference.model.decode_from_codes))

    @torch.no_grad()
    def decode(self, semantic_codes, acoustic_codes):
        """
        Args:
        - semantic_codes: torch.Tensor, shape=(B, 1, T), dtype=torch.int, semantic codes
        - acoustic_codes: torch.Tensor, shape=(B, num_vq-1, T), dtype=torch.int, acoustic codes
        Returns:
        - audio: torch.Tensor, shape=(B, 1, T), dtype=torch.float32, output audio waveform
        """
        audio = self.model.decode_from_codes(semantic_codes, acoustic_codes).to(
            torch.float32
        )
        return audio

    @torch.no_grad()
    def decode_from_codes(self, semantic_codes, acoustic_codes):
        """both [B, n_q, T]"""
        semantic = self.semantic_vq.from_codes(semantic_codes)[0]
        if self.decode_semantic_for_codec:
            semantic = self.convnext_decoder(semantic)

        audio = self.dac.decode_from_codes(acoustic_codes, semantic)
        return audio



In [None]:
ds = prepare_data(max_shards=2)
audios = []
sample_rates = []
for i in range(10):                                          
    audio = torch.from_numpy(ds[i]["mp3"]["array"]).float()  # type: ignore[attr-defined] 
    sample_rates.append(ds[i]["mp3"]["sampling_rate"])       # type: ignore[attr-defined]
    audios.append(audio)

In [39]:
# x = torchaudio.load("1.wav")
# sample = x[0][:1, :]
sample = audios[1]
sample = torchaudio.transforms.Resample(orig_freq=sample_rates[0], new_freq=24000)(sample)
print(sample.shape)

torch.Size([192744])


In [40]:
semantic_code, acoustic_code = dualcodec_inference.encode(sample.reshape(1,1,-1), n_quantizers=8)

In [60]:
print(semantic_code[0][0])
print(acoustic_code[0][0])

# The semantic codes do not affect the receptive fields. I am quite confident about this somehow.

tensor([10803,  9329,  8499,  3535,  7573, 10698, 16039,   484, 10315,  8992,
         7978,  1962, 10158,   900, 12087,  5951,  4531,  3545, 11430,  9306,
         7911, 12127, 13828, 15167,  8000, 10370, 10785,  9003,  1565, 16233,
        11702,   866, 15143, 15510, 14554,  6484,  2010, 12264, 10199,  8863,
        16057,  3219,   841, 14554, 11374,  7851, 14000,  3190, 11117,  8961,
        11374, 14093, 15583,   199, 14766,  3750,  8009, 11349, 10258, 12202,
         6584,  9954, 15636,  2045, 11434, 11152, 10369, 14052, 15915,  1962,
        14034,  2728, 10215,   182,  5113, 10228,  3800,  5533,  8315, 15666,
        16081,  7087, 15722,  8251, 10186,  8063, 14160, 10390, 14082, 10540,
        10114, 15894, 16168, 10369, 11167,  2563,  1812, 10191, 12204, 16081],
       device='cuda:0')
tensor([2210,  553,  469, 1954, 2926,  751, 3592,  364, 1767, 2926, 3682, 3530,
        4070, 1632, 3225, 1035, 3596, 1632, 3971, 3187, 2572, 3065, 1080, 3337,
        1832, 3368, 1287, 2944,   7

In [None]:
# audio_5s = torch.randn(1920 * 25)
# audio_7s = torch.concatenate([audio_5s, torch.randn(1920 * 35)])

# audio_5s = audio_5s.reshape(1,1,-1)
# audio_7s = audio_7s.reshape(1,1,-1)

# print(audio_5s.shape)
# print(audio_7s.shape)

In [None]:
# semantic_codes_5s, acoustic_codes_5s = dualcodec_inference.encode(audio_5s, n_quantizers=8)
# semantic_codes_7s, acoustic_codes_7s = dualcodec_inference.encode(audio_7s, n_quantizers=8)

In [None]:
# print(semantic_codes_5s.shape)
# print(acoustic_codes_5s.shape)
# print(semantic_codes_7s.shape)
# print(acoustic_codes_7s.shape)

In [None]:
# # print all positions where the semantic codes do not match
# print(f"Semantic codes do not match at positions:")
# for i in range(len(semantic_codes_5s[0][0])):
#     if semantic_codes_5s[0][0][i] != semantic_codes_7s[0][0][i]:
#         print(i, end=" ")

# print("\n")

# # print all positions where the acoustic codes do not match
# print(f"Acoustic code 1 do not match at positions:")
# for i in range(len(acoustic_codes_5s[0][0])):
#     if acoustic_codes_5s[0][0][i] != acoustic_codes_7s[0][0][i]:
#         print(i, end=" ")
        
# print("\n")

# # print all positions where the acoustic codes do not match
# print(f"Acoustic code 2 do not match at positions:")
# for i in range(len(acoustic_codes_5s[0][1])):
#     if acoustic_codes_5s[0][1][i] != acoustic_codes_7s[0][1][i]:
#         print(i, end=" ")

In [52]:
# Give the decoder codes 1-by-1 instead and collect the output samples to see if the decoder works

my_audio = np.array([])

num_codes = len(semantic_code[0][0])
assert num_codes == len(acoustic_code[0][0])

look_ahead = 11
look_back = 12

for i in range(0, num_codes):
    
    l = max(i - look_back, 0)
    r = min(i + look_ahead, num_codes)
    
    # print(f"looking back {l} tokens, and looking ahead {r} tokens")
    
    sm = semantic_code[:, :, l:r]
    ac = acoustic_code[:, :, l:r]
    
    # print(f"num codes given: {r-l}")
    # print(f"num samples generated: {out_audio.shape[0]}")
    
    out_audio = dualcodec_inference.decode(sm, ac)
    out_audio = out_audio.squeeze(0).squeeze(0).cpu().numpy()
    
    space = l * 1920
    
    # print(l, r, len(out_audio), i * 1920 - space)
    
    my_audio = np.concatenate([my_audio, out_audio[1920 * i - space : 1920 * (i+1) - space]])

In [46]:
Audio(sample.squeeze(0).squeeze(0).cpu().numpy(), rate=24000)
print(sample.shape)

torch.Size([192744])


In [51]:
print(my_audio.shape)
Audio(my_audio, rate=24000)

(191996,)


In [48]:
outer_audio = dualcodec_inference.decode(semantic_code, acoustic_code).squeeze(0).cpu().numpy()
print(outer_audio.shape)
Audio(outer_audio, rate=24000)

(1, 191996)


In [None]:
plt.plot(my_audio, label="streamed")
plt.plot(outer_audio, label="non-streamed")
plt.legend()
plt.show()

In [None]:
import torchaudio
import os

# After your existing code that creates audios and sample_rates
os.makedirs("saved_audio", exist_ok=True)

for i, (audio, sr) in enumerate(zip(audios, sample_rates)):
    if audio.dim() == 1:
        audio = audio.unsqueeze(0)
    
    filename = f"saved_audio/audio_{i:03d}.wav"
    torchaudio.save(filename, audio, sr)
    print(f"Saved: {filename}")

In [None]:
print(inspect.getsource(dualcodec_inference.decode))

In [None]:
import torch
import torch.nn as nn

In [None]:

x = torch.randn(1, 64, 10)  # (batch, channels, length)

# ConvTranspose1d layer
conv_transpose = nn.ConvTranspose1d(64, 32, kernel_size=3, stride=2, padding=1)

# Forward pass
output = conv_transpose(x)

print(output)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print("ConvTranspose1d done!")


In [None]:
print(inspect.getsource(nn.ConvTranspose1d.forward))