In [1]:
import torch
from transformers import EncodecModel, AutoProcessor, EncodecConfig

# Subclass EncodecModel to modify the forward method
class CustomEncodecModel(EncodecModel):
    def forward(self, input_values: torch.Tensor,
                padding_mask=None, bandwidth=None, audio_codes=None,
                audio_scales=None, return_dict=True):
        # Perform encoding only (no decoding)
        audio_codes, audio_scales = self.encode(input_values, padding_mask, bandwidth, False)
        return audio_codes  # Return only the encoded audio codes

def convert_encodec_to_onnx(model_name="facebook/encodec_24khz", output_path="../public/encodec_24khz.onnx"):
    # Load the custom Encodec model
    model = CustomEncodecModel.from_pretrained(model_name)
    model.eval()


    # Create dummy inputs with dynamic axes
    dummy_input = torch.randn(1, 1, 24000, dtype=torch.float32)  # 1 second of audio (for batch size 1)
    dummy_padding_mask = torch.zeros(1, 24000, dtype=torch.bool)  # Padding mask for sequence length

    print('Preparing to export model to ONNX') 

    # Export the model with dynamic batch size and sequence length
    torch.onnx.export(
        model,
        (dummy_input, dummy_padding_mask),
        output_path,
        export_params=True,
        opset_version=16,
        do_constant_folding=True,  # Apply constant folding optimization
        input_names=["input_values", "padding_mask",'bandwidth'],
        output_names=["audio_codes"],  # Only output the encoded audio codes
        dynamic_axes={
            "input_values": {0: "batch_size", 2: "sequence_length"},  # Dynamic axes for input
            "padding_mask": {0: "batch_size", 1: "sequence_length"},  # Dynamic axes for padding mask
            "audio_codes": {0: "batch_size"}  # Dynamic batch size for output
        }
    )
    print(f"ONNX model saved to {output_path}")

# Usage
convert_encodec_to_onnx()


  from .autonotebook import tqdm as notebook_tqdm
  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)


Preparing to export model to ONNX


  if channels < 1 or channels > 2:
  if (input_length % stride) - step != 0:
  max_pad = max(padding_left, padding_right)
  if length <= max_pad:


verbose: False, log level: Level.ERROR

ONNX model saved to ../public/encodec_24khz.onnx
