In [3]:
from poi.rqvae import RQVAE
from poi.train.config import RQVAEConfig
from poi.settings import DEVICE
from poi.dataset.rqvae import get_dataloader

config = RQVAEConfig()

train_loader = get_dataloader(config.dataset_path, batch_size=config.batch_size, num_workers=config.num_dataloader_workers, device=DEVICE)

model = RQVAE(
    embedding_dim=config.embedding_dim,
    vae_hidden_dims=config.vae_hidden_dims,
    vector_dim=config.vector_dim,
    vector_num=config.vector_num,
    codebook_num=config.codebook_num,
    commitment_weight=config.commitment_weight,
    random_state=config.random_state,
).to(DEVICE)

model.eval()



[POIDataset] Loaded features from /home/vislab/poi/datasets/GWL/poi_features.pt, shape: torch.Size([16368, 8706])
[DataLoader] Created with batch_size=128, shuffle=True, num_workers=4, pin_memory=True, drop_last=False


RQVAE(
  (encoder): Encoder(
    (layers): Sequential(
      (0): Linear(in_features=8706, out_features=1024, bias=True)
      (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Linear(in_features=1024, out_features=512, bias=True)
      (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Linear(in_features=512, out_features=128, bias=True)
      (7): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU(inplace=True)
    )
    (proj): Linear(in_features=128, out_features=64, bias=True)
  )
  (decoder): Decoder(
    (layers): Sequential(
      (0): Linear(in_features=64, out_features=128, bias=True)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Linear(in_features=128, out_features=512, bias=True)
      (4): B

In [37]:
import torch

ckp = torch.load(config.checkpoint_best_path)

# config.checkpoint_best_path

In [40]:
res = model.load_state_dict(ckp["model_state_dict"])

# ckp.keys()

In [41]:
res

<All keys matched successfully>

In [24]:
len(train_loader.dataset)

16368

In [17]:
x = next(iter(train_loader)).to(DEVICE)

# 

In [18]:
x.device

device(type='cuda', index=0)

In [21]:
x.size()

torch.Size([128, 8706])

In [42]:
quantized, cur_loss, all_indices = model.forward(x[:2])
all_indices[0].size()


torch.Size([2])

In [43]:
all_indices

[tensor([31, 51], device='cuda:0'),
 tensor([28, 44], device='cuda:0'),
 tensor([ 1, 23], device='cuda:0')]

In [57]:
import string

def encode_poi_sid(model: RQVAE, batch: torch.Tensor):
    letters = string.ascii_lowercase
    batch_size = batch.size(0)
    all_indices:list[torch.Tensor]
    
    _, _, all_indices = model.forward(batch)

    codebook_num = len(all_indices)
    assert codebook_num <= len(letters)
    
    sids = [
        "".join([
            f"<{letters[level]}_{all_indices[level][i]}>"
            for level in range(codebook_num)
        ])
        for i in range(batch_size)
    ]

    return sids

encode_poi_sid(model, x)


['<a_31><b_28><c_1>',
 '<a_51><b_44><c_23>',
 '<a_2><b_19><c_3>',
 '<a_2><b_32><c_32>',
 '<a_46><b_8><c_7>',
 '<a_28><b_46><c_28>',
 '<a_23><b_61><c_35>',
 '<a_52><b_0><c_39>',
 '<a_19><b_60><c_45>',
 '<a_50><b_48><c_45>',
 '<a_4><b_10><c_11>',
 '<a_40><b_36><c_45>',
 '<a_15><b_63><c_55>',
 '<a_51><b_14><c_34>',
 '<a_45><b_12><c_50>',
 '<a_36><b_58><c_3>',
 '<a_2><b_54><c_9>',
 '<a_1><b_6><c_38>',
 '<a_50><b_57><c_19>',
 '<a_30><b_24><c_2>',
 '<a_22><b_14><c_15>',
 '<a_55><b_5><c_40>',
 '<a_2><b_14><c_17>',
 '<a_35><b_59><c_42>',
 '<a_51><b_42><c_56>',
 '<a_44><b_61><c_29>',
 '<a_6><b_56><c_54>',
 '<a_15><b_57><c_49>',
 '<a_12><b_11><c_58>',
 '<a_44><b_56><c_62>',
 '<a_59><b_39><c_11>',
 '<a_9><b_49><c_1>',
 '<a_0><b_14><c_48>',
 '<a_34><b_16><c_43>',
 '<a_44><b_61><c_23>',
 '<a_20><b_0><c_26>',
 '<a_6><b_46><c_34>',
 '<a_60><b_56><c_55>',
 '<a_60><b_50><c_61>',
 '<a_24><b_63><c_16>',
 '<a_59><b_31><c_35>',
 '<a_57><b_30><c_34>',
 '<a_30><b_58><c_43>',
 '<a_15><b_18><c_41>',
 '<a_13><b

In [54]:
import string

string.ascii_lowercase

'abcdefghijklmnopqrstuvwxyz'