In [162]:
from typing import Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm

def WNConv1d(*args, **kwargs):
    return weight_norm(nn.Conv1d(*args, **kwargs))

class VectorQuantize(nn.Module):
    """
    Implementation of VQ similar to Karpathy's repo:
    https://github.com/karpathy/deep-vector-quantization
    Additionally uses following tricks from Improved VQGAN
    (https://arxiv.org/pdf/2110.04627.pdf):
        1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
            for improved codebook usage
        2. l2-normalized codes: Converts euclidean distance to cosine similarity which
            improves training stability
    """

    def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
        super().__init__()
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim

        self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
        self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
        self.codebook = nn.Embedding(codebook_size, codebook_dim)

    def forward(self, z):
        """Quantized the input tensor using a fixed codebook and returns
        the corresponding codebook vectors

        Parameters
        ----------
        z : Tensor[B x D x T]

        Returns
        -------
        Tensor[B x D x T]
            Quantized continuous representation of input
        Tensor[1]
            Commitment loss to train encoder to predict vectors closer to codebook
            entries
        Tensor[1]
            Codebook loss to update the codebook
        Tensor[B x T]
            Codebook indices (quantized discrete representation of input)
        Tensor[B x D x T]
            Projected latents (continuous representation of input before quantization)
        """

        # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
        z_e = self.in_proj(z)  # z_e : (B x D x T)
        z_q, indices = self.decode_latents(z_e)

        commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
        codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])

        z_q = (
            z_e + (z_q - z_e).detach()
        )  # noop in forward pass, straight-through gradient estimator in backward pass

        z_q = self.out_proj(z_q)

        return z_q, commitment_loss, codebook_loss, indices, z_e

    def embed_code(self, embed_id):
        ans = F.embedding(embed_id, self.codebook.weight)
        return ans

    def decode_code(self, embed_id):
        return self.embed_code(embed_id).transpose(1, 2)

    def decode_latents(self, latents):
        encodings = rearrange(latents, "b d t -> (b t) d")
        codebook = self.codebook.weight  # codebook: (N x D)

        # L2 normalize encodings and codebook (ViT-VQGAN)
        encodings = F.normalize(encodings)
        codebook = F.normalize(codebook)

        # Compute euclidean distance with codebook
        dist = (
            encodings.pow(2).sum(1, keepdim=True)
            - 2 * encodings @ codebook.t()
            + codebook.pow(2).sum(1, keepdim=True).t()
        )
        indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
        z_q = self.decode_code(indices)
        return z_q, indices

torch.manual_seed(0)
vq = VectorQuantize(1024, 1024, 8)
vq.in_proj.weight_v.data = torch.randn(8, 1024, 1)
vq.in_proj.weight_g.data = torch.randn(8, 1, 1)
vq.in_proj.bias.data = torch.randn(8)
vq.out_proj.weight_v.data = torch.randn(1024, 8, 1)
vq.out_proj.weight_g.data = torch.randn(1024, 1, 1)
vq.out_proj.bias.data = torch.randn(1024)
vq.codebook.weight.data = torch.randn(1024, 8)

z = torch.randn(32, 1024, 500)
z_q, commitment_loss, codebook_loss, indices, z_e = vq(z)
print(z_q, indices)

  WeightNorm.apply(module, name, dim)


tensor([[[ 0.9988,  0.9286,  0.8645,  ...,  0.8869,  0.7973,  0.8903],
         [-0.1705, -0.7351,  0.4765,  ..., -0.1090, -0.1078,  0.0302],
         [-0.4289, -1.6530, -0.5824,  ..., -0.9791, -0.4523, -0.6140],
         ...,
         [-3.6225, -1.5333, -2.9127,  ..., -1.9451, -3.5091, -1.4743],
         [ 0.4280,  0.5835,  0.6545,  ...,  0.5954,  0.4389,  0.8003],
         [-1.1396, -1.8115, -0.7024,  ..., -1.7644, -1.3501, -0.7024]],

        [[ 0.9899,  0.7915,  0.8645,  ...,  0.9607,  0.8546,  0.9808],
         [-0.8788, -0.0732,  0.4765,  ..., -0.2098, -1.0007,  0.0928],
         [-0.5386, -0.6403, -0.5824,  ..., -1.4164, -1.9763, -0.4482],
         ...,
         [-3.0648, -2.4138, -2.9127,  ..., -3.3440, -3.2999, -3.0438],
         [ 0.1935,  0.4268,  0.6545,  ...,  0.6232,  0.8056,  0.6058],
         [-1.5807, -1.2421, -0.7024,  ..., -2.1177, -1.9800, -1.4128]],

        [[ 0.8903,  0.9075,  0.8546,  ...,  0.8645,  0.8645,  0.8451],
         [ 0.0302, -0.5262, -1.0007,  ...,  0

In [163]:
from typing import Union

import numpy as np
import torch
import torch.nn as nn

class ResidualVectorQuantize(nn.Module):
    """
    Introduced in SoundStream: An end2end neural audio codec
    https://arxiv.org/abs/2107.03312
    """

    def __init__(
        self,
        input_dim: int = 512,
        n_codebooks: int = 9,
        codebook_size: int = 1024,
        codebook_dim: Union[int, list] = 8,
        quantizer_dropout: float = 0.0,
    ):
        super().__init__()
        if isinstance(codebook_dim, int):
            codebook_dim = [codebook_dim for _ in range(n_codebooks)]

        self.n_codebooks = n_codebooks
        self.codebook_dim = codebook_dim
        self.codebook_size = codebook_size

        self.quantizers = nn.ModuleList(
            [
                VectorQuantize(input_dim, codebook_size, codebook_dim[i])
                for i in range(n_codebooks)
            ]
        )
        self.quantizer_dropout = quantizer_dropout

    def forward(self, z, n_quantizers: int = None):
        """Quantized the input tensor using a fixed set of `n` codebooks and returns
        the corresponding codebook vectors
        Parameters
        ----------
        z : Tensor[B x D x T]
        n_quantizers : int, optional
            No. of quantizers to use
            (n_quantizers < self.n_codebooks ex: for quantizer dropout)
            Note: if `self.quantizer_dropout` is True, this argument is ignored
                when in training mode, and a random number of quantizers is used.
        Returns
        -------
        dict
            A dictionary with the following keys:

            "z" : Tensor[B x D x T]
                Quantized continuous representation of input
            "codes" : Tensor[B x N x T]
                Codebook indices for each codebook
                (quantized discrete representation of input)
            "latents" : Tensor[B x N*D x T]
                Projected latents (continuous representation of input before quantization)
            "vq/commitment_loss" : Tensor[1]
                Commitment loss to train encoder to predict vectors closer to codebook
                entries
            "vq/codebook_loss" : Tensor[1]
                Codebook loss to update the codebook
        """
        z_q = 0
        residual = z
        commitment_loss = 0
        codebook_loss = 0

        codebook_indices = []
        latents = []

        if n_quantizers is None:
            n_quantizers = self.n_codebooks
        if self.training:
            n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
            dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
            n_dropout = int(z.shape[0] * self.quantizer_dropout)
            n_quantizers[:n_dropout] = dropout[:n_dropout]
            n_quantizers = n_quantizers.to(z.device)

        for i, quantizer in enumerate(self.quantizers):
            if self.training is False and i >= n_quantizers:
                break

            z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
                residual
            )

            # Create mask to apply quantizer dropout
            mask = (
                torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
            )
            z_q = z_q + z_q_i * mask[:, None, None]
            residual = residual - z_q_i

            # Sum losses
            commitment_loss += (commitment_loss_i * mask).mean()
            codebook_loss += (codebook_loss_i * mask).mean()

            codebook_indices.append(indices_i)
            latents.append(z_e_i)

        codes = torch.stack(codebook_indices, dim=1)
        latents = torch.cat(latents, dim=1)

        return z_q, codes, latents, commitment_loss, codebook_loss

    def from_codes(self, codes: torch.Tensor):
        """Given the quantized codes, reconstruct the continuous representation
        Parameters
        ----------
        codes : Tensor[B x N x T]
            Quantized discrete representation of input
        Returns
        -------
        Tensor[B x D x T]
            Quantized continuous representation of input
        """
        z_q = 0.0
        z_p = []
        n_codebooks = codes.shape[1]
        for i in range(n_codebooks):
            z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
            z_p.append(z_p_i)

            z_q_i = self.quantizers[i].out_proj(z_p_i)
            z_q = z_q + z_q_i
        return z_q, torch.cat(z_p, dim=1), codes

    def from_latents(self, latents: torch.Tensor):
        """Given the unquantized latents, reconstruct the
        continuous representation after quantization.

        Parameters
        ----------
        latents : Tensor[B x N x T]
            Continuous representation of input after projection

        Returns
        -------
        Tensor[B x D x T]
            Quantized representation of full-projected space
        Tensor[B x D x T]
            Quantized representation of latent space
        """
        z_q = 0
        z_p = []
        codes = []
        dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])

        n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
            0
        ]
        for i in range(n_codebooks):
            j, k = dims[i], dims[i + 1]
            z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
            z_p.append(z_p_i)
            codes.append(codes_i)

            z_q_i = self.quantizers[i].out_proj(z_p_i)
            z_q = z_q + z_q_i

        return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)

torch.manual_seed(0)
rvq = ResidualVectorQuantize(1024, 2, 1024, 8)
rvq.quantizers[0].in_proj.weight_v.data = torch.randn(8, 1024, 1)
rvq.quantizers[0].in_proj.weight_g.data = torch.randn(8, 1, 1)
rvq.quantizers[0].in_proj.bias.data = torch.randn(8)
rvq.quantizers[0].out_proj.weight_v.data = torch.randn(1024, 8, 1)
rvq.quantizers[0].out_proj.weight_g.data = torch.randn(1024, 1, 1)
rvq.quantizers[0].out_proj.bias.data = torch.randn(1024)
rvq.quantizers[0].codebook.weight.data = torch.randn(1024, 8)

rvq.quantizers[1].in_proj.weight_v.data = torch.randn(8, 1024, 1)
rvq.quantizers[1].in_proj.weight_g.data = torch.randn(8, 1, 1)
rvq.quantizers[1].in_proj.bias.data = torch.randn(8)
rvq.quantizers[1].out_proj.weight_v.data = torch.randn(1024, 8, 1)
rvq.quantizers[1].out_proj.weight_g.data = torch.randn(1024, 1, 1)
rvq.quantizers[1].out_proj.bias.data = torch.randn(1024)
rvq.quantizers[1].codebook.weight.data = torch.randn(1024, 8)

rvq_z_q, rvq_indices, _, _, _ = rvq(z)
print(rvq_z_q, rvq_indices)

tensor([[[-3.8284, -3.7353, -3.2694,  ..., -4.2950, -3.4027, -4.8893],
         [-0.5032, -3.1676, -1.6385,  ..., -2.0904, -2.8227, -0.1386],
         [ 0.5786, -0.2436, -0.4937,  ...,  1.3624,  1.1955, -0.2150],
         ...,
         [-2.7138, -2.9814, -3.2482,  ..., -2.6789, -2.8203, -2.7572],
         [-0.4085, -0.2236, -0.1695,  ..., -1.7220, -0.4804, -0.7997],
         [ 0.2923, -0.1119,  0.0231,  ...,  1.4829,  1.1780,  0.3640]],

        [[-1.7458, -4.6101, -2.9354,  ..., -1.5002, -5.4848, -2.7086],
         [-0.8326, -0.8330, -0.9082,  ..., -0.7709, -2.2841, -0.4977],
         [ 3.0759,  0.1574,  0.9528,  ...,  1.3162,  0.3702,  1.2299],
         ...,
         [-2.7903, -3.0200, -2.9933,  ..., -2.3979, -2.5885, -2.9738],
         [-0.3413, -0.1408, -0.0683,  ...,  0.7651, -0.9241, -0.0609],
         [ 0.1140, -0.3490,  0.1787,  ...,  0.0482,  0.8981, -0.1373]],

        [[-3.4563, -3.7545, -5.5988,  ..., -6.0925, -4.1387, -3.4020],
         [-2.0591, -2.2037, -2.6103,  ..., -0

In [164]:
from typing import Optional
import tvm
from tvm import relax
from tvm.relax import op as _op
from tvm.relax.frontend import nn
from tvm import te
from tvm import dlight as dl
from tvm.target import Target
import numpy as np


def normalize(x: nn.Tensor, axis: Optional[int] = 1, eps: float = 1e-12):
    denom = nn.op.sum(nn.op.square(x), axis=axis, keepdims=True)
    denom = nn.op.sqrt(nn.op.maximum(nn.op.broadcast_to(denom, x.shape), nn.Tensor.from_const(eps)))
    return x / denom


class WNConv1d(nn.Module):
    """
    Module for conv1d layer.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        dtype: Optional[str] = None,
    ) -> None:
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.dtype = dtype

        self.weight_g = nn.Parameter(
            (
                self.out_channels,
                1,
                1,
            ),
            dtype,
        )

        self.weight_v = nn.Parameter(
            (
                self.out_channels,
                self.in_channels // self.groups,
                self.kernel_size,
            ),
            dtype,
        )
        if bias:
            self.bias = nn.Parameter((self.out_channels,), dtype)
        else:
            self.bias = None

    def forward(self, x: nn.Tensor) -> nn.Tensor:
        dim = [i for i in range(1, x.ndim)]
        norm_v = _op.sqrt(
            _op.sum(_op.square(self.weight_v._expr), axis=dim, keepdims=True),
        )
        weight = nn.wrap_nested(
            self.weight_g._expr * (self.weight_v._expr / norm_v), name="wnconv1d"
        )
        return nn.op.conv1d(
            x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )


class VectorQuantize(nn.Module):
    def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim

        self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
        self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
        self.codebook = nn.Embedding(codebook_size, codebook_dim)

    def forward(self, z: nn.Tensor):
        z_e = self.in_proj(z)
        z_q, indices = self.decode_latents(z_e)
        z_q = self.out_proj(z_q)
        return z_q, indices

    def decode_latents(self, latents: nn.Tensor):    
        encodings = nn.op.permute_dims(latents, [0, 2, 1])  # (b, t, d)
        encodings = nn.op.reshape(encodings, [-1, encodings.shape[2]])  # (b*t, d)
        codebook = self.codebook.weight

        encodings = normalize(encodings)  # (b*t, d)
        codebook = normalize(codebook)  # (N, d)

        dist = (
            nn.op.sum(nn.op.square(encodings), axis=1, keepdims=True)  # (b*t, 1)
            - 2
            * nn.op.matmul(encodings, nn.op.permute_dims(codebook, [1, 0]))  # (b*t, N)
            + nn.op.permute_dims(
                nn.op.sum(nn.op.square(codebook), axis=1, keepdims=True), [1, 0]
            )  # (1, N)
        )  # (b*t, N)

        indices = nn.op.argsort(dist, axis=1)  # (b*t, N)
        indices = nn.op.take(indices, nn.Tensor.from_const([0]), axis=1)  # (b*t, 1)
        indices = nn.op.reshape(indices, [latents.shape[0], latents.shape[2]])  # (b, t)

        z_q = self.codebook(indices) # (b, t, d)
        z_q = nn.op.permute_dims(z_q, [0, 2, 1]) # (b, d, t)
        return z_q, indices


mod, params = VectorQuantize(1024, 1024, 8).export_tvm(
    {"forward": {"z": nn.spec.Tensor((32, 1024, 500), "float32")}}, debug=True
)
mod.show()

In [166]:
import time

target = Target.from_device("metal")
print(target)
with target:
    mod = relax.transform.LegalizeOps()(mod)
    mod = dl.ApplyDefaultSchedule(
        # dl.gpu.Matmul(),
        dl.gpu.GEMV(),
        dl.gpu.Reduction(),
        dl.gpu.GeneralReduction(),
        dl.gpu.Fallback(),
    )(mod)
ex = relax.build(mod, target)
device = tvm.metal()

vm = relax.VirtualMachine(ex, device)
tvm_data = tvm.nd.array(z, device=device)
# print(params)

tvm_params = [
    vq.in_proj.weight_g.data.numpy().astype("float32"),
    vq.in_proj.weight_v.data.numpy().astype("float32"),
    vq.in_proj.bias.data.numpy().astype("float32"),
    vq.out_proj.weight_g.data.numpy().astype("float32"),
    vq.out_proj.weight_v.data.numpy().astype("float32"),
    vq.out_proj.bias.data.numpy().astype("float32"),
    vq.codebook.weight.data.numpy().astype("float32"),
]
tvm_params = [tvm.nd.array(param, device=device) for param in tvm_params]
# print(tvm_params)

start = time.time()
effects = vm["_initialize_effect"]()
z_q_tvm, indices_tvm = vm["forward"](tvm_data, *effects, *tvm_params)[0]
print("Elapsed time: ", time.time() - start)

print(z_q_tvm.numpy())
assert np.allclose(z_q.detach().numpy(), z_q_tvm.numpy(), atol=1e-3)
assert np.allclose(indices.detach().numpy(), indices_tvm.numpy(), atol=1e-3)

metal -keys=metal,gpu -max_function_args=31 -max_num_threads=256 -max_shared_memory_per_block=32768 -max_threads_per_block=1024 -thread_warp_size=32
Elapsed time:  0.0020580291748046875
[[[ 0.9988062   0.9286075   0.8645287  ...  0.8869156   0.79730624
    0.8903235 ]
  [-0.17046398 -0.7351244   0.47651774 ... -0.10897435 -0.10779934
    0.03020364]
  [-0.42893085 -1.6530032  -0.5824416  ... -0.97907007 -0.45231423
   -0.6140453 ]
  ...
  [-3.6224961  -1.5332791  -2.9127078  ... -1.9451369  -3.50908
   -1.4742826 ]
  [ 0.42801517  0.5834759   0.65450764 ...  0.59538335  0.43885562
    0.80032337]
  [-1.1395632  -1.8114793  -0.7024274  ... -1.7644299  -1.3500917
   -0.7024373 ]]

 [[ 0.9899451   0.7914734   0.8645287  ...  0.96066576  0.85459185
    0.9807827 ]
  [-0.87882215 -0.07316126  0.47651774 ... -0.20981178 -1.0006511
    0.09284002]
  [-0.5385559  -0.64032966 -0.5824416  ... -1.4164237  -1.976311
   -0.44821826]
  ...
  [-3.064777   -2.413759   -2.9127078  ... -3.3440204  -3.29

In [167]:
class ResidualVectorQuantize(nn.Module):
    def __init__(
        self,
        input_dim: int = 512,
        n_codebooks: int = 9,
        codebook_size: int = 1024,
        codebook_dim: Union[int, list] = 8,
        quantizer_dropout: float = 0.0,
    ):
        if isinstance(codebook_dim, int):
            codebook_dim = [codebook_dim for _ in range(n_codebooks)]

        self.n_codebooks = n_codebooks
        self.codebook_dim = codebook_dim
        self.codebook_size = codebook_size

        self.quantizers = nn.ModuleList(
            [
                VectorQuantize(input_dim, codebook_size, codebook_dim[i])
                for i in range(n_codebooks)
            ]
        )
        self.quantizer_dropout = quantizer_dropout

    def forward(self, z: nn.Tensor):
        z_q = nn.zeros(z.shape, dtype=z.dtype)
        residual = z
        codebook_indices = []

        for quantizer in self.quantizers:
            z_q_i, indices = quantizer(residual)
            z_q = z_q + z_q_i
            residual = residual - z_q_i
            codebook_indices.append(indices)
        
        return z_q, codebook_indices

mod, params = ResidualVectorQuantize(1024, 2, 1024, 8).export_tvm(
    {"forward": {"z": nn.spec.Tensor((32, 1024, 500), "float32")}}, debug=True
)
mod.show()

In [168]:
import time

target = Target.from_device("metal")
print(target)
with target:
    mod = relax.transform.LegalizeOps()(mod)
    mod = dl.ApplyDefaultSchedule(
        # dl.gpu.Matmul(),
        dl.gpu.GEMV(),
        dl.gpu.Reduction(),
        dl.gpu.GeneralReduction(),
        dl.gpu.Fallback(),
    )(mod)
ex = relax.build(mod, target)
device = tvm.metal()

vm = relax.VirtualMachine(ex, device)
tvm_data = tvm.nd.array(z, device=device)
print(params)

tvm_params = [
    rvq.quantizers[0].in_proj.weight_g.data.numpy().astype("float32"),
    rvq.quantizers[0].in_proj.weight_v.data.numpy().astype("float32"),
    rvq.quantizers[0].in_proj.bias.data.numpy().astype("float32"),
    rvq.quantizers[0].out_proj.weight_g.data.numpy().astype("float32"),
    rvq.quantizers[0].out_proj.weight_v.data.numpy().astype("float32"),
    rvq.quantizers[0].out_proj.bias.data.numpy().astype("float32"),
    rvq.quantizers[0].codebook.weight.data.numpy().astype("float32"),
    rvq.quantizers[1].in_proj.weight_g.data.numpy().astype("float32"),
    rvq.quantizers[1].in_proj.weight_v.data.numpy().astype("float32"),
    rvq.quantizers[1].in_proj.bias.data.numpy().astype("float32"),
    rvq.quantizers[1].out_proj.weight_g.data.numpy().astype("float32"),
    rvq.quantizers[1].out_proj.weight_v.data.numpy().astype("float32"),
    rvq.quantizers[1].out_proj.bias.data.numpy().astype("float32"),
    rvq.quantizers[1].codebook.weight.data.numpy().astype("float32"),
]
tvm_params = [tvm.nd.array(param, device=device) for param in tvm_params]
# print(tvm_params)

start = time.time()
effects = vm["_initialize_effect"]()
rvq_z_q_tvm, rvq_indices_tvm = vm["forward"](tvm_data, *effects, *tvm_params)[0]
rvq_indices_tvm = map(lambda x: x.numpy(), rvq_indices_tvm)
rvq_indices_tvm = np.stack(list(rvq_indices_tvm), axis=1)
print(rvq_indices.shape, rvq_indices_tvm.shape)
print("Elapsed time: ", time.time() - start)

print(rvq_z_q_tvm.numpy())
assert np.allclose(rvq_z_q.detach().numpy(), rvq_z_q_tvm.numpy(), atol=1e-3)
assert np.allclose(rvq_indices.detach().numpy(), rvq_indices_tvm)

metal -keys=metal,gpu -max_function_args=31 -max_num_threads=256 -max_shared_memory_per_block=32768 -max_threads_per_block=1024 -thread_warp_size=32
[('quantizers.0.in_proj.weight_g', Tensor([8, 1, 1], "float32")), ('quantizers.0.in_proj.weight_v', Tensor([8, 1024, 1], "float32")), ('quantizers.0.in_proj.bias', Tensor([8], "float32")), ('quantizers.0.out_proj.weight_g', Tensor([1024, 1, 1], "float32")), ('quantizers.0.out_proj.weight_v', Tensor([1024, 8, 1], "float32")), ('quantizers.0.out_proj.bias', Tensor([1024], "float32")), ('quantizers.0.codebook.weight', Tensor([1024, 8], "float32")), ('quantizers.1.in_proj.weight_g', Tensor([8, 1, 1], "float32")), ('quantizers.1.in_proj.weight_v', Tensor([8, 1024, 1], "float32")), ('quantizers.1.in_proj.bias', Tensor([8], "float32")), ('quantizers.1.out_proj.weight_g', Tensor([1024, 1, 1], "float32")), ('quantizers.1.out_proj.weight_v', Tensor([1024, 8, 1], "float32")), ('quantizers.1.out_proj.bias', Tensor([1024], "float32")), ('quantizers.1.c