In [1]:
import inspect
import dualcodec
import torch
import os
import numpy as np
import torchaudio
import matplotlib.pyplot as plt
from utils import prepare_data
from IPython.display import Audio

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id = "12hz_v1"
dualcodec_model = dualcodec.get_model(model_id)
dualcodec_inference = dualcodec.Inference(dualcodec_model=dualcodec_model, device="cuda")

Fetching 7 files: 100%|██████████| 7/7 [00:00<00:00, 58603.05it/s]
  WeightNorm.apply(module, name, dim)


Loading model from /home/vansh/.cache/huggingface/hub/models--amphion--dualcodec/snapshots/b5d3158cbd1007441794398435438228f1e80c28/dualcodec_12hz_16384_4096.safetensors
Model loaded


Fetching 7 files: 100%|██████████| 7/7 [00:00<00:00, 69739.02it/s]
Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 56552.41it/s]


In [None]:
dac_model = dualcodec_inference.model.dac

In [None]:
print(inspect.getsource(dualcodec_inference.decode))
print(inspect.getsource(dualcodec_inference.model.decode_from_codes))

In [None]:
ds = prepare_data(max_shards=1)
audios = []
sample_rates = []
for i in range(10):                                          
    audio = torch.from_numpy(ds[i]["mp3"]["array"]).float()  # type: ignore[attr-defined] 
    sample_rates.append(ds[i]["mp3"]["sampling_rate"])       # type: ignore[attr-defined]
    audios.append(audio)

In [None]:
sample, sr = torchaudio.load("tara.wav")
sample = torchaudio.transforms.Resample(orig_freq=sr, new_freq=24000)(sample)
print(sample.shape)

In [None]:
semantic_code, acoustic_code = dualcodec_inference.encode(sample.reshape(1,1,-1), n_quantizers=8)

In [None]:
dualcodec_inference.model.convnext_decoder

In [None]:
print(max(semantic_code[0][0]))
print(max(acoustic_code[0][5]))

In [None]:
print(semantic_code.shape)
print(acoustic_code.shape)

In [None]:
# Give the decoder codes 1-by-1 instead and collect the output samples to see if the decoder works
def calculate_audio_with_receptive(semantic_code, acoustic_code, look_ahead, look_back):
    my_audio = np.array([])
    num_codes = len(semantic_code[0][0])
    assert num_codes == len(acoustic_code[0][0])

    for i in range(0, num_codes):
        l = max(i - look_back, 0)
        r = min(i + look_ahead, num_codes)

        # print(f"looking back {l} tokens, and looking ahead {r} tokens")

        sm = semantic_code[:, :, l:r]
        ac = acoustic_code[:, :, l:r]

        # print(f"num codes given: {r-l}")
        # print(f"num samples generated: {out_audio.shape[0]}")

        out_audio = dualcodec_inference.decode(sm, ac)
        out_audio = out_audio.squeeze(0).squeeze(0).cpu().numpy()

        space = l * 1920

        # print(l, r, len(out_audio), i * 1920 - space)

        my_audio = np.concatenate([my_audio, out_audio[1920 * i - space : 1920 * (i+1) - space]])
        
    return my_audio

In [None]:
# original audio
Audio(sample.squeeze(0).squeeze(0).cpu().numpy(), rate=24000)
print(sample.shape)

In [None]:
non_streamed_audio = dualcodec_inference.decode(semantic_code, acoustic_code).squeeze(0).cpu().numpy()
print(non_streamed_audio.shape)
Audio(non_streamed_audio, rate=24000)

In [None]:
# stream_1 = calculate_audio_with_receptive(semantic_code, acoustic_code, 10, 10)
# stream_2 = calculate_audio_with_receptive(semantic_code, acoustic_code, 20, 20)
# stream_3 = calculate_audio_with_receptive(semantic_code, acoustic_code, 30, 30)


In [None]:
# stream_4 = calculate_audio_with_receptive(semantic_code, acoustic_code, 40, 40)
# stream_5 = calculate_audio_with_receptive(semantic_code, acoustic_code, 50, 50)
# stream_6 = calculate_audio_with_receptive(semantic_code, acoustic_code, 60, 60)


In [None]:
# stream_7 = calculate_audio_with_receptive(semantic_code, acoustic_code, 70, 70)
# stream_8 = calculate_audio_with_receptive(semantic_code, acoustic_code, 80, 80)


In [None]:
stream_9 = calculate_audio_with_receptive(semantic_code, acoustic_code, 5, 5)


In [None]:
stream_8 = calculate_audio_with_receptive(semantic_code, acoustic_code, 5, 25)

In [None]:
plt.plot(stream_8, label="stream_8")
plt.plot(stream_9, label="stream_9")
plt.legend()
plt.show()

In [None]:
Audio(stream_9, rate=24000)


In [None]:
Audio(stream_8, rate=24000)

In [None]:
streams = [
    stream_8,
    stream_9,
]

In [None]:
for j, stream in enumerate(streams):
    differences = stream - non_streamed_audio[0]
    plt.plot(differences, label=f"stream_{j+1}")
    plt.legend()
    plt.show()

### DAC

In [None]:
dac_inputs = torch.randn(1, 1024, 252, device="cuda")

In [None]:
non_stream_dac = dac_model.decoder(dac_inputs).squeeze(0).squeeze(0).cpu().detach().numpy()

In [None]:
dac_model.decoder_rates

In [None]:
def lol(inputs, model_now, look_ahead, look_back, space_cons):
    my_audio = np.array([])
    num_codes = len(inputs[0][0])
    

    for i in range(0, num_codes):
        l = max(i - look_back, 0)
        r = min(i + look_ahead, num_codes)
        
        dac_inputs = inputs[:, :, l:r]

        out_audio = model_now(dac_inputs)
        out_audio = out_audio.squeeze(0).squeeze(0).cpu().detach().numpy()

        space = l * space_cons
        
        # print(f"num samples generated: {out_audio.shape}")
        # print(f"my audio shape: {my_audio.shape}")
        
        my_audio = np.concatenate([
            my_audio, 
            out_audio[space_cons * i - space : space_cons * (i+1) - space]
            ])
        
        # print(f"num samples added so far: {my_audio.shape[0]} at loop {i}")
        
    return my_audio

In [None]:
streamed_dac_audio = lol(dac_inputs, dac_model.decoder, 65, 55, 1920);
streamed_dac_audio_2 = lol(dac_inputs, dac_model.decoder, 70, 55, 1920);
streamed_dac_audio_3 = lol(dac_inputs, dac_model.decoder, 75, 55, 1920);


In [None]:
streamed_dac_audio_4 = lol(dac_inputs, dac_model.decoder, 55, 75, 1920);

In [None]:
print(non_stream_dac.shape)
print(streamed_dac_audio.shape)
print(streamed_dac_audio_2.shape)
print(streamed_dac_audio_3.shape)

In [None]:
plt.plot(streamed_dac_audio_2, label="streamed")
plt.show()
plt.plot(non_stream_dac, label="non-streamed")
plt.legend()
plt.show()

In [None]:
plt.plot(streamed_dac_audio_4 - non_stream_dac, label="diff")
plt.vlines(x=1920 * 75 - 4, ymin=-0.001, ymax=0.001, color="red")
plt.vlines(x=1920 * (252 - 55) - 4, ymin=-0.001, ymax=0.001, color="red")
plt.legend()
plt.show()

In [None]:
layers = dac_model.decoder.model
print(layers[1])

In [None]:
def lol(inputs, model_now, look_ahead, look_back, space_cons, num_frames):
    my_audio = np.zeros((inputs.shape[0], inputs.shape[1], num_frames))
    
    for i in range(0, num_frames):
        l = max(i - look_back, 0)
        r = min(i + look_ahead, num_frames)
        
        dac_inputs = inputs[:, :, l:r]

        out_audio = model_now(dac_inputs)
        out_audio = out_audio.cpu().detach().numpy()
        
        print(out_audio.shape)

        space = l * space_cons
        
        my_audio[:, :, i] = out_audio[:, :, space_cons * i - space : space_cons * (i+1) - space]
        
    return my_audio

In [None]:
dac_outputs_non_stream_1 = layers[0](dac_inputs)
dac_outputs_stream_1 = lol(dac_inputs, layers[0], 1, 1, 1920, dac_inputs.shape[2])



print(dac_outputs_non_stream_1.shape)
print(dac_outputs_stream_1.shape)

### Wav2Vec

In [5]:
new_model = dualcodec_inference.semantic_cfg.feature_extractor

In [13]:
print(new_model)

SeamlessM4TFeatureExtractor {
  "feature_extractor_type": "SeamlessM4TFeatureExtractor",
  "feature_size": 80,
  "num_mel_bins": 80,
  "padding_side": "right",
  "padding_value": 1,
  "processor_class": "Wav2Vec2BertProcessor",
  "return_attention_mask": true,
  "sampling_rate": 16000,
  "stride": 2
}



In [8]:
audio_5s = torch.randn(16000 * 5).numpy()
audio_7s = np.concatenate((audio_5s, torch.randn(16000 * 5).numpy()))

print(audio_5s.shape)
print(audio_7s.shape)

o1 = new_model(audio_5s, sampling_rate=16000)["input_features"][0]
o2 = new_model(audio_7s, sampling_rate=16000)["input_features"][0]

(80000,)
(160000,)


In [9]:
print(o1.shape)
print(o2.shape)

(249, 160)
(499, 160)


In [12]:
print(o1[0][:10])
print("--------------------------------")
print(o2[0][:10])

[-1.5893896  -0.66483647 -0.4504154  -1.3175714  -0.68299586 -0.02689089
  0.27578333  0.03610433 -0.01030014 -0.20770043]
--------------------------------
[-1.4899344  -0.6980322  -0.4099009  -1.320732   -0.7294301  -0.05279103
  0.2647866  -0.01022197 -0.00367124 -0.17570677]
