In [1]:
import torch
from quantization import GroupVQ

In [2]:
x = torch.randn(2, 2*3, 4)
x

tensor([[[-0.5565, -1.1160,  1.5750,  1.2781],
         [ 0.4280,  0.4349, -0.8033, -0.0299],
         [ 1.1778,  0.9756, -0.1200,  0.1954],
         [ 1.0845, -0.9608,  1.3249,  1.0767],
         [-0.2449,  0.6976,  0.8688, -1.6953],
         [-0.0492,  2.4842, -1.4148,  0.0741]],

        [[ 0.8033, -0.1597, -0.6752,  1.3199],
         [-0.0881, -0.0349,  0.4298, -0.9522],
         [ 0.9399,  0.8171, -0.0028, -0.9964],
         [-0.5293, -1.7170,  0.9610,  0.4428],
         [-0.1163,  2.4159, -0.2126, -1.6512],
         [-0.4930, -1.7803,  0.7488, -0.5361]]])

In [3]:
vq1 = GroupVQ(
    in_dim=4, H=3, overlap=2, num_vqs=3, 
    codebook_dim=2, codebook_size=2, 
)
vq2 = GroupVQ(
    in_dim=4, H=3, overlap=2, num_vqs=3, 
    codebook_dim=2, codebook_size=2, 
)

for (mod1, mod2) in zip(vq1.vqs, vq2.vqs):
    
    mod1.embedding.weight.data = mod2.embedding.weight.clone()
    mod1.proj_down.weight.data = mod2.proj_down.weight.clone()
    mod1.proj_up.weight.data = mod2.proj_up.weight.clone()


vq1.eval()

GroupVQ(
  (vqs): ModuleList(
    (0-2): 3 x Codebook(
      (embedding): Embedding(2, 2)
      (proj_down): Linear(in_features=8, out_features=2, bias=False)
      (proj_up): Linear(in_features=2, out_features=8, bias=False)
    )
  )
)

In [4]:
with torch.no_grad():
    print(vq1(x))

(tensor([[[-0.0850,  0.5452, -0.1945, -0.5143],
         [ 0.5460, -0.2685,  0.2954, -0.3716],
         [ 0.2048, -0.0037, -0.0631, -0.3721],
         [ 0.0064, -0.0079, -0.2214,  0.2875],
         [-0.9251, -0.5298,  0.3773, -0.0345],
         [-0.1411, -0.2568, -0.1273,  0.5143]],

        [[ 0.0126,  0.1211, -0.0347, -0.5143],
         [ 0.5460, -0.2685,  0.3106,  0.7087],
         [ 0.0625, -0.0085, -0.0595, -0.3721],
         [ 0.0064,  1.1988, -0.9671,  0.6533],
         [-0.1729, -0.0988,  0.3773, -0.0345],
         [-0.1411,  0.1365,  1.2267,  0.1263]]]), tensor([0.6005, 1.0182]), tensor([0.6005, 1.0182]))


In [7]:
with torch.no_grad():
    codes = vq1.encode(x)
    print(codes)
    print( vq1.decode(codes, dim=3))

tensor([[[1],
         [0],
         [1]],

        [[0],
         [0],
         [0]]])
tensor([[[-0.0850,  0.5452, -0.1945, -0.5143],
         [ 0.5460, -0.2685,  0.2954, -0.3716],
         [ 0.2048, -0.0037, -0.0631, -0.3721],
         [ 0.0064, -0.0079, -0.2214,  0.2875],
         [-0.9251, -0.5298,  0.3773, -0.0345],
         [-0.1411, -0.2568, -0.1273,  0.5143]],

        [[ 0.0126,  0.1211, -0.0347, -0.5143],
         [ 0.5460, -0.2685,  0.3106,  0.7087],
         [ 0.0625, -0.0085, -0.0595, -0.3721],
         [ 0.0064,  1.1988, -0.9671,  0.6533],
         [-0.1729, -0.0988,  0.3773, -0.0345],
         [-0.1411,  0.1365,  1.2267,  0.1263]]])


In [1]:
sum((1,2))

3

In [1]:
num_device = 4
device_ids = [i for i in range(num_device)]

In [2]:
device_ids

[0, 1, 2, 3]

In [8]:
", ".join([str(ids) for ids in device_ids])

'0, 1, 2, 3'

In [9]:
import torch.nn as nn
x = nn.ModuleList()
x.append(nn.Linear(10,10))
x

ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
)

In [11]:
x = nn.ModuleList([nn.Linear(10,10)])
x

ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
)

In [12]:
class A:
    def __init__(self) -> None:
        pass

class B(A):
    def __init__(self) -> None:
        super().__init__()
        super().m = 6

    def train(self, ):
        print(self.m)

In [14]:
codebook_size=2
max_streams=4
num_vq=3
vq_stats_dict = {
        f"stream_{S}_group_{G}": {i+1:0 for i in range(codebook_size)} for S in range(max_streams) for G in range(num_vq)
    }

In [15]:
vq_stats_dict

{'stream_0_group_0': {1: 0, 2: 0},
 'stream_0_group_1': {1: 0, 2: 0},
 'stream_0_group_2': {1: 0, 2: 0},
 'stream_1_group_0': {1: 0, 2: 0},
 'stream_1_group_1': {1: 0, 2: 0},
 'stream_1_group_2': {1: 0, 2: 0},
 'stream_2_group_0': {1: 0, 2: 0},
 'stream_2_group_1': {1: 0, 2: 0},
 'stream_2_group_2': {1: 0, 2: 0},
 'stream_3_group_0': {1: 0, 2: 0},
 'stream_3_group_1': {1: 0, 2: 0},
 'stream_3_group_2': {1: 0, 2: 0}}