In [None]:
import dualcodec
import torchaudio
from IPython.display import Audio

In [None]:
model_id = "12hz_v1"
dualcodec_model_path = "./output_checkpoints/dualcodec_12hz_16384_4096_8vq_scratch/checkpoint/epoch-0000_step-0013800_loss-89.047997-dualcodec_12hz_16384_4096_8vq_scratch"
dualcodec_model = dualcodec.get_model(model_id, dualcodec_model_path)

In [4]:
dualcodec_inference = dualcodec.Inference(dualcodec_model=dualcodec_model, device="cpu")

In [6]:
audio, sr = torchaudio.load("audio_samples/tara.wav")
audio = torchaudio.functional.resample(audio, sr, 24000)
audio = audio.reshape(1,1,-1)
audio = audio.cpu()
# extract codes, for example, using 8 quantizers here:
semantic_codes, acoustic_codes = dualcodec_inference.encode(audio, n_quantizers=8)

In [7]:
out_audio = dualcodec_inference.decode(semantic_codes, acoustic_codes)
Audio(out_audio.cpu().squeeze(0), rate=24000)

In [None]:
group_size = 4
gaussian_std = 0.1  # Controllable parameter for Gaussian standard deviation

import time
import torch
import numpy as np

# Step 1: Crop the tensors so they fit in group size
semantic_codes_cropped = semantic_codes[:, :, :semantic_codes.shape[2] - semantic_codes.shape[2] % group_size]
acoustic_codes_cropped = acoustic_codes[:, :, :acoustic_codes.shape[2] - acoustic_codes.shape[2] % group_size]

# Create Gaussian window
def create_gaussian_window(length, std_ratio=0.5):
    """Create a Gaussian window centered at length/2"""
    x = torch.arange(length, dtype=torch.float32)
    center = (length - 1) / 2
    std = length * std_ratio  # std as a ratio of window length
    window = torch.exp(-0.5 * ((x - center) / std) ** 2)
    return window

# First pass: decode all segments to get their actual sizes
all_decoded = []
print("First pass: decoding all segments...")
for i in range(acoustic_codes.shape[2] - group_size + 1):
    semantic_codes_segment = semantic_codes_cropped[:, :, i:i + group_size]
    acoustic_codes_segment = acoustic_codes_cropped[:, :, i:i + group_size]
    
    start_time = time.monotonic()
    out_audio = dualcodec_inference.decode(semantic_codes_segment, acoustic_codes_segment)
    end_time = time.monotonic()
    print(f"Window {i}/{acoustic_codes.shape[2] - group_size}, decode time: {end_time-start_time:.3f}s, shape: {out_audio.shape}")
    
    all_decoded.append(out_audio)

# Calculate total output length based on actual decoded sizes
# Assuming each window shifts by one sample in the input
if len(all_decoded) > 0:
    # Estimate samples per input step from first few windows
    if len(all_decoded) > 1:
        samples_per_step = all_decoded[0].shape[2] // group_size
    else:
        samples_per_step = all_decoded[0].shape[2] // group_size
    
    # Calculate total length
    total_output_length = sum(audio.shape[2] for audio in all_decoded[:1])  # First chunk full size
    total_output_length += (len(all_decoded) - 1) * samples_per_step  # Remaining chunks shifted
    
    # Initialize output tensor and weight accumulator
    device = all_decoded[0].device
    output_audio = torch.zeros(1, 1, total_output_length, device=device)
    weight_sum = torch.zeros(1, 1, total_output_length, device=device)
    
    # Second pass: apply windows and accumulate
    print("\nSecond pass: applying crossfade...")
    for i, out_audio in enumerate(all_decoded):
        current_samples = out_audio.shape[2]
        
        # Create Gaussian window for this specific chunk size
        gaussian_window = create_gaussian_window(current_samples, std_ratio=gaussian_std)
        gaussian_window = gaussian_window.to(device).unsqueeze(0).unsqueeze(0)
        
        # Apply Gaussian window to the decoded audio
        windowed_audio = out_audio * gaussian_window
        
        # Calculate position in output
        output_start = i * samples_per_step
        output_end = min(output_start + current_samples, total_output_length)
        actual_samples = output_end - output_start
        
        # Handle case where chunk might be larger than remaining space
        if actual_samples < current_samples:
            windowed_audio = windowed_audio[:, :, :actual_samples]
            gaussian_window = gaussian_window[:, :, :actual_samples]
        
        # Add to output with overlap
        output_audio[:, :, output_start:output_end] += windowed_audio
        weight_sum[:, :, output_start:output_end] += gaussian_window
        
        print(f"Window {i}: position {output_start}-{output_end}, chunk size: {current_samples}")
    
    # Normalize by the sum of weights to maintain consistent amplitude
    # Add small epsilon to avoid division by zero
    output_audio = output_audio / (weight_sum + 1e-8)
    
    # Trim any trailing zeros if we overestimated length
    # Find last non-zero sample
    non_zero_indices = torch.nonzero(weight_sum.squeeze() > 0.01)
    if len(non_zero_indices) > 0:
        last_valid_idx = non_zero_indices[-1].item() + 1
        output_audio = output_audio[:, :, :last_valid_idx]
    
    # Convert to numpy and display
    display(Audio(output_audio.cpu().squeeze(0).squeeze(0).numpy(), rate=24000))
else:
    print("No audio decoded!")