In [2]:
import dualcodec
import torch
import numpy as np
import torchaudio
import matplotlib.pyplot as plotter
from IPython.display import Audio

#### Emilia

In [None]:
from datasets import load_dataset

def prepare_data(load_from="/mnt/disks/emilia/emilia_dataset/Emilia/EN", max_shards=1000, num_proc=200):
    tar_paths = sorted([filename for filename in os.listdir(load_from) if filename.endswith(".tar")])
    language = "en"

    selected_tar_paths = tar_paths[:max_shards]
    data_files = {language: selected_tar_paths}

    ds = load_dataset(  
        load_from,
        data_files=data_files,
        split=language,
        num_proc=num_proc,
        cache_dir="/mnt/disks/emilia/emilia_cache/"
    )
    
    return ds.remove_columns([c for c in ds.column_names if c not in ["mp3", "json"]])  # type: ignore[attr-defined]

#### Streaming

In [None]:
model_id = "12hz_v1"

path_causal = "output_checkpoints_2/dualcodec_experiments_fully_causal/checkpoint/epoch-0000_step-0118000_loss-118.932007-dualcodec_experiments_fully_causal"

path_look_ahead = "output_checkpoints_3/dualcodec_experiments_look_ahead/checkpoint/epoch-0000_step-0097000_loss-109.211716-dualcodec_experiments_look_ahead"

In [None]:
dualcodec_model = dualcodec.get_model(model_id, path_causal, is_checkpoint=True)
dualcodec_inference = dualcodec.Inference(dualcodec_model=dualcodec_model, device="cuda")

In [None]:
dac_model = dualcodec_inference.model.dac

In [None]:
from dualcodec.model_codec.dac_model import DAC

dac_model = DAC(
    encoder_rates=[2, 4, 8, 8],
    latent_dim=1024,
    decoder_dim=1536,
    decoder_rates=[8, 8, 4, 2],
    n_codebooks=9,
    make_dac_causal=True,
    add_dac_look_ahead=False,
).to("cuda").to(torch.float64)

dac_model.eval()
print(dac_model.decoder)

In [None]:
# ds = prepare_data(max_shards=1)
# audios = []
# sample_rates = []
# for i in range(10):                                          
#     audio = torch.from_numpy(ds[i]["mp3"]["array"]).float()  # type: ignore[attr-defined] 
#     sample_rates.append(ds[i]["mp3"]["sampling_rate"])       # type: ignore[attr-defined]
#     audios.append(audio)

In [None]:
sample, sr = torchaudio.load("audio_samples/tara.wav")
sample = torchaudio.transforms.Resample(orig_freq=sr, new_freq=24000)(sample)
print(sample.shape)

In [None]:
semantic_code, acoustic_code = dualcodec_inference.encode(sample.reshape(1,1,-1), n_quantizers=8)

In [None]:
print(max(semantic_code[0][0]))
print(max(acoustic_code[0][5]))

In [None]:
print(semantic_code.shape)
print(acoustic_code.shape)

In [None]:
# Give the decoder codes 1-by-1 instead and collect the output samples to see if the decoder works
def calculate_audio_with_receptive(semantic_code, acoustic_code, look_ahead, look_back):
    my_audio = np.array([])
    num_codes = len(semantic_code[0][0])
    assert num_codes == len(acoustic_code[0][0])

    for i in range(0, num_codes):
        l = max(i - look_back, 0)
        r = min(i + look_ahead, num_codes)

        # print(f"looking back {l} tokens, and looking ahead {r} tokens")

        sm = semantic_code[:, :, l:r]
        ac = acoustic_code[:, :, l:r]

        # print(f"num codes given: {r-l}")
        # print(f"num samples generated: {out_audio.shape[0]}")

        out_audio = dualcodec_inference.decode(sm, ac)
        out_audio = out_audio.squeeze(0).squeeze(0).cpu().numpy()

        space = l * 1920

        # print(l, r, len(out_audio), i * 1920 - space)

        my_audio = np.concatenate([my_audio, out_audio[1920 * i - space : 1920 * (i+1) - space]])
        
    return my_audio

In [None]:
# original audio
Audio(sample.squeeze(0).squeeze(0).cpu().numpy(), rate=24000)
print(sample.shape)

In [None]:
non_streamed_audio = dualcodec_inference.decode(semantic_code, acoustic_code).squeeze(0).cpu().numpy()
print(non_streamed_audio.shape)
Audio(non_streamed_audio, rate=24000)

In [None]:
stream_1 = calculate_audio_with_receptive(semantic_code, acoustic_code, 10, 10)
stream_2 = calculate_audio_with_receptive(semantic_code, acoustic_code, 20, 20)
stream_3 = calculate_audio_with_receptive(semantic_code, acoustic_code, 30, 30)


In [None]:
plotter.plot(stream_1, label="stream_8")
plotter.plot(stream_2, label="stream_9")
plotter.plot(non_streamed_audio[0], label="non-stream")
plotter.legend()
plotter.show()

In [None]:
streams = [
    stream_1,
    stream_2,
    stream_3
]

In [None]:
for j, stream in enumerate(streams):
    differences = stream - non_streamed_audio[0]
    plotter.plot(differences, label=f"stream_{j+1}")
    plotter.legend()
    plotter.show()

#### DAC Causal

In [None]:
k = 4
k_prime = 1
inputs_dac = torch.randn(1, 1024, k, device="cuda", dtype=torch.float64)

In [None]:
inputs_dac_longer = torch.concatenate([
    inputs_dac,
    torch.randn(1, 1024, 1, device="cuda", dtype=torch.float64)
], dim=2)

In [None]:
print(inputs_dac_longer.shape)
print(inputs_dac.shape)

with torch.no_grad():
    outputs = dac_model.decoder(inputs_dac)
    outputs_longer = dac_model.decoder(inputs_dac_longer)

print(outputs[0][0][:k])
print(outputs_longer[0][0][:k])

print("-----------")

print(outputs[0][0][:k] - outputs_longer[0][0][:k])


### DAC Stream

In [None]:
non_streamed_dac_outputs = dac_model.decoder(inputs_dac).squeeze(0).squeeze(0).cpu().detach().numpy()

In [None]:
def dac_streamer(inputs, model, look_ahead, look_back, space_cons=1920):
    my_audio = np.array([])
    num_codes = len(inputs[0][0])
    

    for i in range(0, num_codes):
        l = max(i - look_back, 0)
        r = min(i + look_ahead, num_codes)
        
        dac_inputs = inputs[:, :, l:r]

        out_audio = model(dac_inputs)
        out_audio = out_audio.squeeze(0).squeeze(0).cpu().detach().numpy()

        space = l * space_cons
        
        print(f"num samples generated: {out_audio.shape}")
        print(f"my audio shape: {my_audio.shape}")
        
        my_audio = np.concatenate([
            my_audio, 
            out_audio[space_cons * i - space : space_cons * (i+1) - space]
            ])
        
        print(f"num samples added so far: {my_audio.shape[0]} at loop {i}")
        
    return my_audio

In [None]:
streamed_dac_outputs = dac_streamer(inputs_dac, dac_model.decoder, 65, 0);

In [None]:
print(non_streamed_dac_outputs.shape)
print(streamed_dac_outputs.shape)

In [None]:
plotter.plot(streamed_dac_outputs, label="streamed")
plotter.show()
plotter.plot(non_streamed_dac_outputs, label="non-streamed")
plotter.legend()
plotter.show()