In [None]:
import inspect
import dualcodec
import torch
import torchaudio
import matplotlib.pyplot as plt
from utils import prepare_data
from IPython.display import Audio

In [None]:
model_id = "12hz_v1"
dualcodec_model = dualcodec.get_model(model_id)
dualcodec_inference = dualcodec.Inference(dualcodec_model=dualcodec_model, device="cuda")

In [83]:
print(dualcodec_inference)

<dualcodec.infer.dualcodec.inference_with_semantic.Inference object at 0x7fdd76650370>


In [None]:
ds = prepare_data(max_shards=2)
audios = []
sample_rates = []
for i in range(10):                                          
    audio = torch.from_numpy(ds[i]["mp3"]["array"]).float()  # type: ignore[attr-defined] 
    sample_rates.append(ds[i]["mp3"]["sampling_rate"])       # type: ignore[attr-defined]
    audios.append(audio)

In [85]:
# x = torchaudio.load("1.wav")
# sample = x[0][:1, :]
sample = audios[0]
sample = torchaudio.transforms.Resample(orig_freq=sample_rates[0], new_freq=24000)(sample)
print(sample.shape)

torch.Size([150336])


In [86]:
semantic_code, acoustic_code = dualcodec_inference.encode(sample.reshape(1,1,-1), n_quantizers=8)

In [92]:
audio_5s = torch.randn(1920 * 25)
audio_7s = torch.concatenate([audio_5s, torch.randn(1920 * 35)])

audio_5s = audio_5s.reshape(1,1,-1)
audio_7s = audio_7s.reshape(1,1,-1)

print(audio_5s.shape)
print(audio_7s.shape)

torch.Size([1, 1, 48000])
torch.Size([1, 1, 115200])


In [93]:
semantic_codes_5s, acoustic_codes_5s = dualcodec_inference.encode(audio_5s, n_quantizers=8)
semantic_codes_7s, acoustic_codes_7s = dualcodec_inference.encode(audio_7s, n_quantizers=8)

In [94]:
print(semantic_codes_5s.shape)
print(acoustic_codes_5s.shape)
print(semantic_codes_7s.shape)
print(acoustic_codes_7s.shape)

torch.Size([1, 1, 24])
torch.Size([1, 7, 24])
torch.Size([1, 1, 59])
torch.Size([1, 7, 59])


In [95]:
# print all positions where the semantic codes do not match
print(f"Semantic codes do not match at positions:")
for i in range(len(semantic_codes_5s[0][0])):
    if semantic_codes_5s[0][0][i] != semantic_codes_7s[0][0][i]:
        print(i, end=" ")

print("\n")

# print all positions where the acoustic codes do not match
print(f"Acoustic code 1 do not match at positions:")
for i in range(len(acoustic_codes_5s[0][0])):
    if acoustic_codes_5s[0][0][i] != acoustic_codes_7s[0][0][i]:
        print(i, end=" ")
        
print("\n")

# print all positions where the acoustic codes do not match
print(f"Acoustic code 2 do not match at positions:")
for i in range(len(acoustic_codes_5s[0][1])):
    if acoustic_codes_5s[0][1][i] != acoustic_codes_7s[0][1][i]:
        print(i, end=" ")

Semantic codes do not match at positions:
5 6 7 8 9 11 14 15 16 18 20 21 22 23 

Acoustic code 1 do not match at positions:
6 7 18 20 22 23 

Acoustic code 2 do not match at positions:
4 8 11 15 16 20 22 

In [114]:
# Give the decoder codes 1-by-1 instead and collect the output samples to see if the decoder works
import numpy as np
my_audio = np.array([])

j = 10

for i in range(1, len(semantic_code[0][0])):
    
    sm = semantic_code[:, :, :i]
    ac = acoustic_code[:, :, :i]
    
    print(sm.shape)
    print(ac.shape)
    
    out_audio = dualcodec_inference.decode(sm, ac)
    out_audio = out_audio.squeeze(0).squeeze(0).cpu().numpy()
    
    l = len(out_audio)
    if i == 0:
        my_audio = out_audio[:1920 * (j-5)]
    else:
        my_audio = np.concatenate([my_audio, out_audio[l - (1920 * (j-5)) : l - (1920 * (j-6))]])
        
    print(my_audio.shape)
    

torch.Size([1, 1, 1])
torch.Size([1, 7, 1])
(0,)
torch.Size([1, 1, 2])
torch.Size([1, 7, 2])
(0,)
torch.Size([1, 1, 3])
torch.Size([1, 7, 3])
(1920,)
torch.Size([1, 1, 4])
torch.Size([1, 7, 4])
(3840,)
torch.Size([1, 1, 5])
torch.Size([1, 7, 5])
(3840,)
torch.Size([1, 1, 6])
torch.Size([1, 7, 6])
(5760,)
torch.Size([1, 1, 7])
torch.Size([1, 7, 7])
(7680,)
torch.Size([1, 1, 8])
torch.Size([1, 7, 8])
(9600,)
torch.Size([1, 1, 9])
torch.Size([1, 7, 9])
(11520,)
torch.Size([1, 1, 10])
torch.Size([1, 7, 10])
(13440,)
torch.Size([1, 1, 11])
torch.Size([1, 7, 11])
(15360,)
torch.Size([1, 1, 12])
torch.Size([1, 7, 12])
(17280,)
torch.Size([1, 1, 13])
torch.Size([1, 7, 13])
(19200,)
torch.Size([1, 1, 14])
torch.Size([1, 7, 14])
(21120,)
torch.Size([1, 1, 15])
torch.Size([1, 7, 15])


(23040,)
torch.Size([1, 1, 16])
torch.Size([1, 7, 16])
(24960,)
torch.Size([1, 1, 17])
torch.Size([1, 7, 17])
(26880,)
torch.Size([1, 1, 18])
torch.Size([1, 7, 18])
(28800,)
torch.Size([1, 1, 19])
torch.Size([1, 7, 19])
(30720,)
torch.Size([1, 1, 20])
torch.Size([1, 7, 20])
(32640,)
torch.Size([1, 1, 21])
torch.Size([1, 7, 21])
(34560,)
torch.Size([1, 1, 22])
torch.Size([1, 7, 22])
(36480,)
torch.Size([1, 1, 23])
torch.Size([1, 7, 23])
(38400,)
torch.Size([1, 1, 24])
torch.Size([1, 7, 24])
(40320,)
torch.Size([1, 1, 25])
torch.Size([1, 7, 25])
(42240,)
torch.Size([1, 1, 26])
torch.Size([1, 7, 26])
(44160,)
torch.Size([1, 1, 27])
torch.Size([1, 7, 27])
(46080,)
torch.Size([1, 1, 28])
torch.Size([1, 7, 28])
(48000,)
torch.Size([1, 1, 29])
torch.Size([1, 7, 29])
(49920,)
torch.Size([1, 1, 30])
torch.Size([1, 7, 30])
(51840,)
torch.Size([1, 1, 31])
torch.Size([1, 7, 31])
(53760,)
torch.Size([1, 1, 32])
torch.Size([1, 7, 32])
(55680,)
torch.Size([1, 1, 33])
torch.Size([1, 7, 33])
(57600,)
t

In [115]:
Audio(sample.squeeze(0).squeeze(0).cpu().numpy(), rate=24000)

In [116]:
print(my_audio.shape)
Audio(my_audio, rate=24000)

(142080,)


In [None]:
outer_audio = dualcodec_inference.decode(semantic_code, acoustic_code).squeeze(0).cpu().numpy()
print(outer_audio.shape)
Audio(outer_audio, rate=24000)

In [None]:
plt.plot(my_audio, label="streamed")
plt.plot(outer_audio, label="non-streamed")
plt.legend()
plt.show()

In [None]:
import torchaudio
import os

# After your existing code that creates audios and sample_rates
os.makedirs("saved_audio", exist_ok=True)

for i, (audio, sr) in enumerate(zip(audios, sample_rates)):
    if audio.dim() == 1:
        audio = audio.unsqueeze(0)
    
    filename = f"saved_audio/audio_{i:03d}.wav"
    torchaudio.save(filename, audio, sr)
    print(f"Saved: {filename}")

In [None]:
print(inspect.getsource(dualcodec_inference.decode))

In [109]:
import torch
import torch.nn as nn

In [None]:

x = torch.randn(1, 64, 10)  # (batch, channels, length)

# ConvTranspose1d layer
conv_transpose = nn.ConvTranspose1d(64, 32, kernel_size=3, stride=2, padding=1)

# Forward pass
output = conv_transpose(x)

print(output)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print("ConvTranspose1d done!")


In [None]:
print(inspect.getsource(nn.ConvTranspose1d.forward))