In [1]:
import torch

# Sample tensor of shape (32, 243) on GPU
tensor = torch.randint(1, 3, (5, 5), device='cuda:0')  # Example tensor
tensor

tensor([[2, 1, 2, 1, 1],
        [2, 2, 2, 2, 2],
        [2, 2, 1, 2, 2],
        [1, 2, 1, 2, 2],
        [1, 2, 2, 1, 2]], device='cuda:0')

In [2]:
zeros_tensor = torch.zeros((5, 2), device='cuda:0')  # Example tensor
zeros_tensor

tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], device='cuda:0')

In [3]:
tensor = torch.cat((tensor, zeros_tensor), dim=1)  # Concatenate along the second dimensiont
tensor

tensor([[2., 1., 2., 1., 1., 0., 0.],
        [2., 2., 2., 2., 2., 0., 0.],
        [2., 2., 1., 2., 2., 0., 0.],
        [1., 2., 1., 2., 2., 0., 0.],
        [1., 2., 2., 1., 2., 0., 0.]], device='cuda:0')

In [4]:
import torch
import torch.nn as nn
import random

class Tokenizer(nn.Module):
    def __init__(self, num_codebooks, save_dir="histogram_frames"):
        super(Tokenizer, self).__init__()
        self.frame_count = 0  
        self.save_dir = save_dir
        self.num_codebooks = num_codebooks
        self.codebook_usage = torch.zeros(self.num_codebooks)

    def randomly_keep_one_until_zero(self, tensor):
        dtype = tensor.dtype
        device = tensor.device
        result = []
        chosen_indices = []

        for row_idx, row in enumerate(tensor):
            filtered_row = []
            row_indices = []
            i = 0

            while i < len(row):
                value = row[i].item()
                
                # Stop processing when zero is encountered
                if value == 0:
                    break
                
                # Detect consecutive duplicates
                j = i + 1
                while j < len(row) and row[j].item() == value:
                    j += 1
                
                # Randomly select one index from the range
                random_index = random.randint(i, j - 1)
                filtered_row.append(row[random_index].item())
                row_indices.append(random_index)  # Track chosen index
                
                # Move to next unique value
                i = j

            # Add trailing zeros to match original row length
            filtered_row.extend([0] * (row.size(0) - len(filtered_row)))
            result.append(torch.tensor(filtered_row, dtype=dtype, device=device))
            chosen_indices.append(row_indices)
            # Pad chosen_indices to consistent length
            max_len = max(len(idx_list) for idx_list in chosen_indices)
            padded_indices = [
                idx_list + [-1] * (max_len - len(idx_list))
                for idx_list in chosen_indices
            ]


        return torch.stack(result), torch.tensor(padded_indices, device=device)


In [5]:
tokenizer = Tokenizer(num_codebooks=30)
result, indices = tokenizer.randomly_keep_one_until_zero(tensor)
print(result, indices)


tensor([[2., 1., 2., 1., 0., 0., 0.],
        [2., 0., 0., 0., 0., 0., 0.],
        [2., 1., 2., 0., 0., 0., 0.],
        [1., 2., 1., 2., 0., 0., 0.],
        [1., 2., 1., 2., 0., 0., 0.]], device='cuda:0') tensor([[ 0,  1,  2,  4],
        [ 2, -1, -1, -1],
        [ 0,  2,  3, -1],
        [ 0,  1,  2,  3],
        [ 0,  1,  3,  4]], device='cuda:0')


In [6]:
# Prepare result tensor
batch_size, max_indices = indices.shape
extracted = torch.full((batch_size, max_indices), -1.0, device=tensor.device)  # Fill with -1

# Loop through each batch
for i in range(batch_size):
    valid_indices = indices[i][indices[i] != -1]  # Ignore invalid indices
    if valid_indices.numel() > 0:
        extracted[i, :valid_indices.numel()] = tensor[i, valid_indices]

print("Extracted Tensor:")
print(extracted)

Extracted Tensor:
tensor([[ 2.,  1.,  2.,  1.],
        [ 2., -1., -1., -1.],
        [ 2.,  1.,  2., -1.],
        [ 1.,  2.,  1.,  2.],
        [ 1.,  2.,  1.,  2.]], device='cuda:0')


In [7]:
tensor

tensor([[2., 1., 2., 1., 1., 0., 0.],
        [2., 2., 2., 2., 2., 0., 0.],
        [2., 2., 1., 2., 2., 0., 0.],
        [1., 2., 1., 2., 2., 0., 0.],
        [1., 2., 2., 1., 2., 0., 0.]], device='cuda:0')