In [378]:
import torch
import random

def remove_consecutive_repeated_indices_optimized(min_encoding_indices, mask, z_q):
    B, T = min_encoding_indices.shape
    selected_indices_list = []
    max_len = 0

    for b in range(B):
        indices = min_encoding_indices[b]
        selected_indices = []
        start = 0
        for i in range(1, T):
            if mask[b, i] == 0:
                break
            if indices[i] != indices[i - 1]:
                selected_indices.append(random.randint(start, i - 1))
                start = i
        selected_indices.append(random.randint(start, T - 1))

        selected_indices_list.append(torch.tensor(selected_indices))
        max_len = max(max_len, len(selected_indices))

    # Pad and create mask in a vectorized way
    padded_indices = torch.zeros((B, max_len), dtype=min_encoding_indices.dtype)
    masks = torch.zeros((B, max_len), dtype=torch.float)

    for b, indices in enumerate(selected_indices_list):
        length = len(indices)
        padded_indices[b, :length] = indices
        masks[b, :length] = 1

    # Gather and apply mask
    out = padded_indices.unsqueeze(-1).expand(-1, -1, z_q.shape[-1])
    n_z_q = z_q.gather(dim=1, index=out)
    n_z_q *= masks.unsqueeze(-1)

    return n_z_q, masks


mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 0, 0]])
min_encoding_indices = torch.tensor([[1, 1, 2, 2, 3, 4, 3, 5], [2,3, 3, 3, 2, 2, 2, 2]])
z_q = torch.randn(2, 8, 2) # (B,T,C)
remove_consecutive_repeated_indices_optimized(min_encoding_indices, mask, z_q)

(tensor([[[-0.8454,  0.3304],
          [-1.3009,  0.7267],
          [ 2.2446, -0.9767],
          [ 2.5101, -1.4793],
          [ 0.4955, -0.4704],
          [ 1.2795,  0.7289]],
 
         [[-0.2518, -0.3794],
          [-0.9299, -0.1263],
          [-1.2400,  1.1243],
          [-0.0000, -0.0000],
          [-0.0000, -0.0000],
          [-0.0000, -0.0000]]]),
 tensor([[1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 0., 0., 0.]]))

In [379]:
min_encoding_indices, mask

(tensor([[1, 1, 2, 2, 3, 4, 3, 5],
         [2, 3, 3, 3, 2, 2, 2, 2]]),
 tensor([[1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 0, 0]]))

In [377]:
z_q

tensor([[[-0.1182, -1.3978],
         [-0.7669, -2.1116],
         [ 1.4357, -1.0185],
         [ 0.1194,  1.1072],
         [-0.4323, -0.7386],
         [-1.0836,  0.8543],
         [-0.9251,  0.6196],
         [ 1.0231,  1.6800]],

        [[ 1.1913,  0.7126],
         [-0.5653, -0.7836],
         [-0.0749,  0.9504],
         [ 0.7116,  0.1887],
         [-1.0938,  0.9753],
         [ 2.0332,  0.2463],
         [-0.2469, -0.8310],
         [ 0.8509, -0.0595]]])