In [1]:
import os
import sys
import glob
import torch
import numpy as np
import pandas as pd
import open3d as o3d

In [2]:
base_dir = os.path.dirname(os.getcwd())
data_dir = os.path.join(base_dir, "data")
src_dir = os.path.join(base_dir, "src")
sys.path.append(os.path.join(src_dir))

In [3]:
train_files = glob.glob(os.path.join(data_dir, "original", "train", "*", "*.obj"))
valid_files = glob.glob(os.path.join(data_dir, "original", "val", "*", "*.obj"))
len(train_files), len(valid_files)

(7003, 1088)

In [4]:
from utils import load_pipeline

In [5]:
class Tokenizer(object):
    
    def _padding(self, ids_tensor, pad_token):
        max_length = max([len(ids) for ids in ids_tensor])
        ids_tensor = [
            torch.cat([
                ids, pad_token.repeat(max_length - len(ids) + 1)
            ])
            for ids in ids_tensor
        ]
        return ids_tensor
    
    def _make_padding_mask(self, ids_tensor, pad_id):
        mask = torch.where(
            ids_tensor==pad_id,
            torch.ones_like(ids_tensor),
            torch.zeros_like(ids_tensor)
        ).type(torch.bool)
        return mask

    def _make_future_mask(self, ids_tensor):
        batch, length = ids_tensor.shape
        arange = torch.arange(length)
        mask = torch.where(
            arange[None, :] <= arange[:, None],
            torch.zeros((length, length)),
            torch.ones((length, length))*(-np.inf)
        ).type(torch.float32)
        return mask

In [6]:
v_batch, f_batch = [], []
for i in range(3):
    v, f = load_pipeline(train_files[i])
    v = torch.tensor(v)
    f = torch.tensor(f)
    v_batch.append(v)
    f_batch.append(f)
    print(v.shape, f.shape)
    print("="*60)

torch.Size([655, 3]) torch.Size([1317, 3])
torch.Size([310, 3]) torch.Size([676, 3])
torch.Size([396, 3]) torch.Size([1184, 3])


## 頂点エンコーダ

In [7]:
class EncodeVertexTokenizer(Tokenizer):
    
    def __init__(self, pad_id=0):
        self.pad_token = torch.tensor([pad_id])
        self.pad_id = pad_id
        
    def tokenize(self, vertices, padding=True):
        vertices = [v.reshape(-1,) + 1 for v in vertices]
        coord_type_tokens = [torch.arange(len(v)) % 3 + 1 for v in vertices]
        position_ids = [torch.arange(len(v)) // 3 + 1 for v in vertices]
        
        if padding:
            vertices = torch.stack(self._padding(vertices, self.pad_token))
            coord_type_tokens = torch.stack(self._padding(coord_type_tokens, self.pad_token))
            position_ids = torch.stack(self._padding(position_ids, self.pad_token))
            padding_mask = self._make_padding_mask(vertices, self.pad_id)
            
            outputs = {
                "value_tokens": vertices,
                "coord_type_tokens": coord_type_tokens,
                "position_ids": position_ids,
                "padding_mask": padding_mask,
            }
        else:
            outputs = {
                "value_token": vertices,
                "coord_type_tokens": coord_type_tokens,
                "position_ids": position_ids,
            }
            
        return outputs

In [8]:
enc_vtk = EncodeVertexTokenizer()

In [9]:
enc_vtk.tokenize(v_batch)

{'value_tokens': tensor([[161, 137, 133,  ..., 137, 123,   0],
         [135, 162, 130,  ...,   0,   0,   0],
         [163,  99, 134,  ...,   0,   0,   0]]),
 'coord_type_tokens': tensor([[1, 2, 3,  ..., 2, 3, 0],
         [1, 2, 3,  ..., 0, 0, 0],
         [1, 2, 3,  ..., 0, 0, 0]]),
 'position_ids': tensor([[  1,   1,   1,  ..., 655, 655,   0],
         [  1,   1,   1,  ...,   0,   0,   0],
         [  1,   1,   1,  ...,   0,   0,   0]]),
 'padding_mask': tensor([[False, False, False,  ..., False, False,  True],
         [False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True]])}

In [67]:
for k, vs in enc_vtk.tokenize(v_batch).items():
    print(k, ":")
    print(vs.shape)
    print(torch.stack([v[:10] for v in vs]))
    print(torch.stack([v[-10:] for v in vs]))
    print("="*60)

value_tokens :
torch.Size([3, 1966])
tensor([[161, 137, 133, 161, 137, 131, 161, 137, 127, 161],
        [135, 162, 130, 135, 162, 126, 135, 153, 130, 135],
        [163,  99, 134, 163,  99, 130, 163,  99, 125, 163]])
tensor([[ 95, 137, 128,  95, 137, 125,  95, 137, 123,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0]])
coord_type_tokens :
torch.Size([3, 1966])
tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3, 1],
        [1, 2, 3, 1, 2, 3, 1, 2, 3, 1],
        [1, 2, 3, 1, 2, 3, 1, 2, 3, 1]])
tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
position_ids :
torch.Size([3, 1966])
tensor([[1, 1, 1, 2, 2, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 2, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 2, 2, 3, 3, 3, 4]])
tensor([[653, 653, 653, 654, 654, 654, 655, 655, 655,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0,   0, 

In [10]:
enc_vtk.tokenize(v_batch, padding=False)

{'value_token': [tensor([161, 137, 133,  ...,  95, 137, 123]),
  tensor([135, 162, 130, 135, 162, 126, 135, 153, 130, 135, 153, 128, 135, 153,
          126, 135, 143, 128, 135, 143, 126, 135, 134, 130, 135, 134, 128, 135,
          134, 126, 135, 124, 128, 135, 124, 126, 135, 115, 130, 135, 115, 128,
          135, 115, 126, 135, 105, 130, 135, 105, 126, 135,  67, 130, 135,  67,
          128, 135,  67, 126, 135,  65, 130, 135,  65, 128, 135,  65, 126, 134,
          181, 130, 134, 181, 126, 134, 172, 131, 134, 172, 130, 134, 172, 128,
          134, 172, 126, 134, 172, 125, 134, 162, 131, 134, 162, 125, 134, 153,
          132, 134, 153, 124, 134, 143, 132, 134, 143, 124, 134, 134, 132, 134,
          134, 124, 134, 124, 132, 134, 124, 124, 134, 115, 132, 134, 115, 124,
          134, 105, 131, 134, 105, 125, 134,  96, 131, 134,  96, 130, 134,  96,
          128, 134,  96, 126, 134,  96, 125, 134,  86, 130, 134,  86, 126, 134,
           67, 132, 134,  67, 124, 134,  65, 132, 134,  6

In [69]:
for k, vs in enc_vtk.tokenize(v_batch, padding=False).items():
    print(k, ":")
    print(torch.stack([v[:10] for v in vs]))
    print(torch.stack([v[-10:] for v in vs]))
    print("="*60)

value_token :
tensor([[161, 137, 133, 161, 137, 131, 161, 137, 127, 161],
        [135, 162, 130, 135, 162, 126, 135, 153, 130, 135],
        [163,  99, 134, 163,  99, 130, 163,  99, 125, 163]])
tensor([[132,  95, 137, 128,  95, 137, 125,  95, 137, 123],
        [126, 121,  65, 130, 121,  65, 128, 121,  65, 126],
        [122,  93,  88, 131,  93,  88, 126,  93,  88, 122]])
coord_type_tokens :
tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3, 1],
        [1, 2, 3, 1, 2, 3, 1, 2, 3, 1],
        [1, 2, 3, 1, 2, 3, 1, 2, 3, 1]])
tensor([[3, 1, 2, 3, 1, 2, 3, 1, 2, 3],
        [3, 1, 2, 3, 1, 2, 3, 1, 2, 3],
        [3, 1, 2, 3, 1, 2, 3, 1, 2, 3]])
position_ids :
tensor([[1, 1, 1, 2, 2, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 2, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 2, 2, 3, 3, 3, 4]])
tensor([[652, 653, 653, 653, 654, 654, 654, 655, 655, 655],
        [307, 308, 308, 308, 309, 309, 309, 310, 310, 310],
        [393, 394, 394, 394, 395, 395, 395, 396, 396, 396]])


In [11]:
[v.reshape(-1, ) for v in v_batch]

[tensor([160, 136, 132,  ...,  94, 136, 122]),
 tensor([134, 161, 129, 134, 161, 125, 134, 152, 129, 134, 152, 127, 134, 152,
         125, 134, 142, 127, 134, 142, 125, 134, 133, 129, 134, 133, 127, 134,
         133, 125, 134, 123, 127, 134, 123, 125, 134, 114, 129, 134, 114, 127,
         134, 114, 125, 134, 104, 129, 134, 104, 125, 134,  66, 129, 134,  66,
         127, 134,  66, 125, 134,  64, 129, 134,  64, 127, 134,  64, 125, 133,
         180, 129, 133, 180, 125, 133, 171, 130, 133, 171, 129, 133, 171, 127,
         133, 171, 125, 133, 171, 124, 133, 161, 130, 133, 161, 124, 133, 152,
         131, 133, 152, 123, 133, 142, 131, 133, 142, 123, 133, 133, 131, 133,
         133, 123, 133, 123, 131, 133, 123, 123, 133, 114, 131, 133, 114, 123,
         133, 104, 130, 133, 104, 124, 133,  95, 130, 133,  95, 129, 133,  95,
         127, 133,  95, 125, 133,  95, 124, 133,  85, 129, 133,  85, 125, 133,
          66, 131, 133,  66, 123, 133,  64, 131, 133,  64, 123, 132, 190, 130,
     

## 頂点デコーダ

In [12]:
class DecodeVertexTokenizer(Tokenizer):
    
    def __init__(self, bos_id=0, eos_id=1, pad_id=2):
        
        self.special_tokens = {
            "bos": torch.tensor([bos_id]),
            "eos": torch.tensor([eos_id]),
            "pad": torch.tensor([pad_id]),
        }
        self.pad_id = pad_id
        self.not_coord_token = torch.tensor([0])
    
    def tokenize(self, vertices, padding=True):
        special_tokens = self.special_tokens
        not_coord_token = self.not_coord_token
        
        vertices = [
            torch.cat([
                special_tokens["bos"], 
                v.reshape(-1,)  + len(special_tokens), 
                special_tokens["eos"]
            ])
            for v in vertices
        ]
        
        coord_type_tokens = [
            torch.cat([
                not_coord_token,
                torch.arange(len(v)-2) % 3 + 1,
                not_coord_token
            ])
            for v in vertices
        ]
        
        position_ids = [
            torch.cat([
                not_coord_token,
                torch.arange(len(v)-2) // 3 + 1,
                not_coord_token
            ])
            for v in vertices
        ]
        
        if padding:
            vertices = torch.stack(self._padding(vertices, special_tokens["pad"]))
            coord_type_tokens = torch.stack(self._padding(coord_type_tokens, not_coord_token))
            position_ids = torch.stack(self._padding(position_ids, not_coord_token))
            padding_mask = self._make_padding_mask(vertices, self.pad_id)
            future_mask = self._make_future_mask(vertices)
            outputs = {
                "value_tokens": vertices,
                "coord_type_tokens": coord_type_tokens,
                "position_ids": position_ids,
                "padding_mask": padding_mask,
                "future_mask": future_mask,
            }
        else:
            outputs = {
                "value_tokens": vertices,
                "coord_type_tokens": coord_type_tokens,
                "position_ids": position_ids,
            }
            
        return outputs
    
    def detokenize(self, vertices):
        special_tokens = self.special_tokens
        
        result = []
        for vertex in vertices:
            vertex = vertex - len(special_tokens)
            result.append(
                vertex[torch.where(vertex >= 0)]
            )
        return result

In [13]:
dec_vtk = DecodeVertexTokenizer()

In [14]:
dec_vtk.tokenize(v_batch)

{'value_tokens': tensor([[  0, 163, 139,  ..., 125,   1,   2],
         [  0, 137, 164,  ...,   2,   2,   2],
         [  0, 165, 101,  ...,   2,   2,   2]]),
 'coord_type_tokens': tensor([[0, 1, 2,  ..., 3, 0, 0],
         [0, 1, 2,  ..., 0, 0, 0],
         [0, 1, 2,  ..., 0, 0, 0]]),
 'position_ids': tensor([[  0,   1,   1,  ..., 655,   0,   0],
         [  0,   1,   1,  ...,   0,   0,   0],
         [  0,   1,   1,  ...,   0,   0,   0]]),
 'padding_mask': tensor([[False, False, False,  ..., False, False,  True],
         [False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True]]),
 'future_mask': tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
         [0., 0., -inf,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         ...,
         [0., 0., 0.,  ..., 0., -inf, -inf],
         [0., 0., 0.,  ..., 0., 0., -inf],
         [0., 0., 0.,  ..., 0., 0., 0.]])}

In [68]:
for k, vs in dec_vtk.tokenize(v_batch).items():
    print(k, ":")
    print(vs.shape)
    print(torch.stack([v[:10] for v in vs]))
    print(torch.stack([v[-10:] for v in vs]))
    print("="*60)

value_tokens :
torch.Size([3, 1968])
tensor([[  0, 163, 139, 135, 163, 139, 133, 163, 139, 129],
        [  0, 137, 164, 132, 137, 164, 128, 137, 155, 132],
        [  0, 165, 101, 136, 165, 101, 132, 165, 101, 127]])
tensor([[139, 130,  97, 139, 127,  97, 139, 125,   1,   2],
        [  2,   2,   2,   2,   2,   2,   2,   2,   2,   2],
        [  2,   2,   2,   2,   2,   2,   2,   2,   2,   2]])
coord_type_tokens :
torch.Size([3, 1968])
tensor([[0, 1, 2, 3, 1, 2, 3, 1, 2, 3],
        [0, 1, 2, 3, 1, 2, 3, 1, 2, 3],
        [0, 1, 2, 3, 1, 2, 3, 1, 2, 3]])
tensor([[2, 3, 1, 2, 3, 1, 2, 3, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
position_ids :
torch.Size([3, 1968])
tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 3],
        [0, 1, 1, 1, 2, 2, 2, 3, 3, 3],
        [0, 1, 1, 1, 2, 2, 2, 3, 3, 3]])
tensor([[653, 653, 654, 654, 654, 655, 655, 655,   0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0,   0, 

In [15]:
dec_vtk.tokenize(v_batch, padding=False)

{'value_tokens': [tensor([  0, 163, 139,  ..., 139, 125,   1]),
  tensor([  0, 137, 164, 132, 137, 164, 128, 137, 155, 132, 137, 155, 130, 137,
          155, 128, 137, 145, 130, 137, 145, 128, 137, 136, 132, 137, 136, 130,
          137, 136, 128, 137, 126, 130, 137, 126, 128, 137, 117, 132, 137, 117,
          130, 137, 117, 128, 137, 107, 132, 137, 107, 128, 137,  69, 132, 137,
           69, 130, 137,  69, 128, 137,  67, 132, 137,  67, 130, 137,  67, 128,
          136, 183, 132, 136, 183, 128, 136, 174, 133, 136, 174, 132, 136, 174,
          130, 136, 174, 128, 136, 174, 127, 136, 164, 133, 136, 164, 127, 136,
          155, 134, 136, 155, 126, 136, 145, 134, 136, 145, 126, 136, 136, 134,
          136, 136, 126, 136, 126, 134, 136, 126, 126, 136, 117, 134, 136, 117,
          126, 136, 107, 133, 136, 107, 127, 136,  98, 133, 136,  98, 132, 136,
           98, 130, 136,  98, 128, 136,  98, 127, 136,  88, 132, 136,  88, 128,
          136,  69, 134, 136,  69, 126, 136,  67, 134, 1

In [71]:
for k, vs in dec_vtk.tokenize(v_batch, padding=False).items():
    print(k, ":")
    print(torch.stack([v[:10] for v in vs]))
    print(torch.stack([v[-10:] for v in vs]))
    print("="*60)

value_tokens :
tensor([[  0, 163, 139, 135, 163, 139, 133, 163, 139, 129],
        [  0, 137, 164, 132, 137, 164, 128, 137, 155, 132],
        [  0, 165, 101, 136, 165, 101, 132, 165, 101, 127]])
tensor([[ 97, 139, 130,  97, 139, 127,  97, 139, 125,   1],
        [123,  67, 132, 123,  67, 130, 123,  67, 128,   1],
        [ 95,  90, 133,  95,  90, 128,  95,  90, 124,   1]])
coord_type_tokens :
tensor([[0, 1, 2, 3, 1, 2, 3, 1, 2, 3],
        [0, 1, 2, 3, 1, 2, 3, 1, 2, 3],
        [0, 1, 2, 3, 1, 2, 3, 1, 2, 3]])
tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3, 0],
        [1, 2, 3, 1, 2, 3, 1, 2, 3, 0],
        [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]])
position_ids :
tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 3],
        [0, 1, 1, 1, 2, 2, 2, 3, 3, 3],
        [0, 1, 1, 1, 2, 2, 2, 3, 3, 3]])
tensor([[653, 653, 653, 654, 654, 654, 655, 655, 655,   0],
        [308, 308, 308, 309, 309, 309, 310, 310, 310,   0],
        [394, 394, 394, 395, 395, 395, 396, 396, 396,   0]])


In [16]:
[v.reshape(-1, ) for v in v_batch]

[tensor([160, 136, 132,  ...,  94, 136, 122]),
 tensor([134, 161, 129, 134, 161, 125, 134, 152, 129, 134, 152, 127, 134, 152,
         125, 134, 142, 127, 134, 142, 125, 134, 133, 129, 134, 133, 127, 134,
         133, 125, 134, 123, 127, 134, 123, 125, 134, 114, 129, 134, 114, 127,
         134, 114, 125, 134, 104, 129, 134, 104, 125, 134,  66, 129, 134,  66,
         127, 134,  66, 125, 134,  64, 129, 134,  64, 127, 134,  64, 125, 133,
         180, 129, 133, 180, 125, 133, 171, 130, 133, 171, 129, 133, 171, 127,
         133, 171, 125, 133, 171, 124, 133, 161, 130, 133, 161, 124, 133, 152,
         131, 133, 152, 123, 133, 142, 131, 133, 142, 123, 133, 133, 131, 133,
         133, 123, 133, 123, 131, 133, 123, 123, 133, 114, 131, 133, 114, 123,
         133, 104, 130, 133, 104, 124, 133,  95, 130, 133,  95, 129, 133,  95,
         127, 133,  95, 125, 133,  95, 124, 133,  85, 129, 133,  85, 125, 133,
          66, 131, 133,  66, 123, 133,  64, 131, 133,  64, 123, 132, 190, 130,
     

In [17]:
outputs = dec_vtk.detokenize(dec_vtk.tokenize(v_batch)["value_tokens"])
outputs

[tensor([160, 136, 132,  ...,  94, 136, 122]),
 tensor([134, 161, 129, 134, 161, 125, 134, 152, 129, 134, 152, 127, 134, 152,
         125, 134, 142, 127, 134, 142, 125, 134, 133, 129, 134, 133, 127, 134,
         133, 125, 134, 123, 127, 134, 123, 125, 134, 114, 129, 134, 114, 127,
         134, 114, 125, 134, 104, 129, 134, 104, 125, 134,  66, 129, 134,  66,
         127, 134,  66, 125, 134,  64, 129, 134,  64, 127, 134,  64, 125, 133,
         180, 129, 133, 180, 125, 133, 171, 130, 133, 171, 129, 133, 171, 127,
         133, 171, 125, 133, 171, 124, 133, 161, 130, 133, 161, 124, 133, 152,
         131, 133, 152, 123, 133, 142, 131, 133, 142, 123, 133, 133, 131, 133,
         133, 123, 133, 123, 131, 133, 123, 123, 133, 114, 131, 133, 114, 123,
         133, 104, 130, 133, 104, 124, 133,  95, 130, 133,  95, 129, 133,  95,
         127, 133,  95, 125, 133,  95, 124, 133,  85, 129, 133,  85, 125, 133,
          66, 131, 133,  66, 123, 133,  64, 131, 133,  64, 123, 132, 190, 130,
     

## 表面デコーダ

In [72]:
class FaceTokenizer(Tokenizer):
    
    def __init__(self, eof_id=0, eos_id=1, pad_id=2):
        self.special_tokens = {
            "eof": torch.tensor([eof_id]),
            "eos": torch.tensor([eos_id]),
            "pad": torch.tensor([pad_id]),
        }
        self.pad_id = pad_id
        self.not_coord_token = torch.tensor([0])
        
    def tokenize(self, faces, target=False, padding=True):
        special_tokens = self.special_tokens
        not_coord_token = self.not_coord_token
        
        
        if padding:
            faces = [
                torch.cat([
                    torch.cat([
                        f + len(special_tokens),
                        special_tokens["eof"].repeat(len(f))[:, None]
                    ], dim=1).reshape(-1,),
                    special_tokens["eos"]
                ]) for f in faces
            ]
            
            coord_type_tokens = [
                torch.cat([
                    torch.arange(len(f)-1) % 4 + 1,
                    not_coord_token
                ])
                for f in faces
            ]

            position_ids = [
                torch.cat([
                    torch.arange(len(f)-1) // 4 + 1,
                    not_coord_token
                ])
                for f in faces
            ]
            
            faces = self._padding(faces, special_tokens["pad"])
            
            
            if target:
                faces = [torch.cat([f, special_tokens["pad"]])[1:] for f in faces]
                outputs = {
                    "value_tokens": torch.stack(faces)
                }
            else: 
                faces = torch.stack(faces)
                coord_type_tokens = torch.stack(self._padding(coord_type_tokens, not_coord_token))
                position_ids = torch.stack(self._padding(position_ids, not_coord_token))
                
                padding_mask = self._make_padding_mask(faces, self.pad_id)
                future_mask = self._make_future_mask(faces)
                
                cond_vertice = faces >= len(special_tokens)
                reference_vertices_mask = torch.where(cond_vertice, 1., 0.)
                reference_vertices_ids = torch.where(cond_vertice, faces-len(special_tokens), 0)
                reference_embed_mask = torch.where(cond_vertice, 0., 1.)
                reference_embed_ids = torch.where(cond_vertice, 0, faces)
                
                outputs = {
                    "value_tokens": faces,
                    "coord_type_tokens": coord_type_tokens,
                    "position_ids": position_ids,
                    "ref_v_mask": reference_vertices_mask,
                    "ref_v_ids": reference_vertices_ids,
                    "ref_e_mask": reference_embed_mask,
                    "ref_e_ids": reference_embed_ids,
                    "padding_mask": padding_mask,
                    "future_mask": future_mask,
                }
            
        else:
            faces_ids = []
            coord_type_tokens = []
            position_ids = []
            reference_vertices_mask = []
            reference_vertices_ids = []
            reference_embed_mask = []
            reference_embed_ids = []

            for f in faces:
                f = torch.cat([
                    f + len(special_tokens),
                    special_tokens["eof"].repeat(len(f))[:, None]
                ], dim=1).reshape(-1, )
                f = torch.cat([f, special_tokens["eos"]])
                
                c_t_tokens = torch.cat([
                    torch.arange(len(f)-1) % 4 + 1,
                    not_coord_token
                ])
                pos_ids = torch.cat([
                    torch.arange(len(f)-1) // 4 + 1,
                    not_coord_token
                ])
                
                if target:
                    f = torch.cat([f, special_tokens["pad"]])[1:]
                
                cond_vertice = f >= len(special_tokens)

                ref_v_mask = torch.where(cond_vertice, 1., 0.)
                ref_e_mask = torch.where(cond_vertice, 0., 1.)
                ref_v_ids = torch.where(cond_vertice, f-len(special_tokens), 0)
                ref_e_ids = torch.where(cond_vertice, 0, f)

                faces_ids.append(f)
                coord_type_tokens.append(c_t_tokens)
                position_ids.append(pos_ids)
                
                reference_vertices_mask.append(ref_v_mask)
                reference_vertices_ids.append(ref_v_ids)
                reference_embed_mask.append(ref_e_mask)
                reference_embed_ids.append(ref_e_ids)
            
            if target:
                faces_ids = [torch.cat([f, special_tokens["pad"]])[1:] for f in faces_ids]
                outputs = {
                    "value_tokens": faces_ids
                }
            else:
                outputs = {
                    "value_tokens": faces_ids,
                    "coord_type_tokens": coord_type_tokens,
                    "position_ids": position_ids,
                    "ref_v_mask": reference_vertices_mask,
                    "ref_v_ids": reference_vertices_ids,
                    "ref_e_mask": reference_embed_mask,
                    "ref_e_ids": reference_embed_ids,
                }
        
        return outputs

    def detokenize(self, faces):
        special_tokens = self.special_tokens
        
        result = []
        for face in faces:
            face = face - len(special_tokens)
            result.append(
                face[torch.where(face >= 0)]
            )
        return result

In [73]:
ftk = FaceTokenizer()

In [74]:
for k, vs in ftk.tokenize(f_batch).items():
    print(k, ":")
    print(vs.shape)
    print(torch.stack([v[:10] for v in vs]))
    print(torch.stack([v[-10:] for v in vs]))
    print("="*60)

value_tokens :
torch.Size([3, 5270])
tensor([[  3,   4, 156,   0,   3,   9, 106,   0,   3,   9],
        [  3,   5,  29,   0,   3,   5,  33,   0,   3,  29],
        [  3,   3,   4,   0,   3,   3,  12,   0,   3,   4]])
tensor([[651, 655, 657,   0, 655, 656, 657,   0,   1,   2],
        [  2,   2,   2,   2,   2,   2,   2,   2,   2,   2],
        [  2,   2,   2,   2,   2,   2,   2,   2,   2,   2]])
coord_type_tokens :
torch.Size([3, 5270])
tensor([[1, 2, 3, 4, 1, 2, 3, 4, 1, 2],
        [1, 2, 3, 4, 1, 2, 3, 4, 1, 2],
        [1, 2, 3, 4, 1, 2, 3, 4, 1, 2]])
tensor([[1, 2, 3, 4, 1, 2, 3, 4, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
position_ids :
torch.Size([3, 5270])
tensor([[1, 1, 1, 1, 2, 2, 2, 2, 3, 3],
        [1, 1, 1, 1, 2, 2, 2, 2, 3, 3],
        [1, 1, 1, 1, 2, 2, 2, 2, 3, 3]])
tensor([[1316, 1316, 1316, 1316, 1317, 1317, 1317, 1317,    0,    0],
        [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0],
        [   0,    0

In [75]:
ftk.tokenize(f_batch, target=True)

{'value_tokens': tensor([[  4, 156,   0,  ...,   1,   2,   2],
         [  5,  29,   0,  ...,   2,   2,   2],
         [  3,   4,   0,  ...,   2,   2,   2]])}

In [76]:
for k, vs in ftk.tokenize(f_batch, padding=False).items():
    print(k, ":")
    print(torch.stack([v[:10] for v in vs]))
    print(torch.stack([v[-10:] for v in vs]))
    print("="*60)

value_tokens :
tensor([[  3,   4, 156,   0,   3,   9, 106,   0,   3,   9],
        [  3,   5,  29,   0,   3,   5,  33,   0,   3,  29],
        [  3,   3,   4,   0,   3,   3,  12,   0,   3,   4]])
tensor([[  0, 651, 655, 657,   0, 655, 656, 657,   0,   1],
        [  0, 308, 310, 311,   0, 309, 311, 312,   0,   1],
        [  0, 394, 395, 398,   0, 394, 397, 398,   0,   1]])
coord_type_tokens :
tensor([[1, 2, 3, 4, 1, 2, 3, 4, 1, 2],
        [1, 2, 3, 4, 1, 2, 3, 4, 1, 2],
        [1, 2, 3, 4, 1, 2, 3, 4, 1, 2]])
tensor([[4, 1, 2, 3, 4, 1, 2, 3, 4, 0],
        [4, 1, 2, 3, 4, 1, 2, 3, 4, 0],
        [4, 1, 2, 3, 4, 1, 2, 3, 4, 0]])
position_ids :
tensor([[1, 1, 1, 1, 2, 2, 2, 2, 3, 3],
        [1, 1, 1, 1, 2, 2, 2, 2, 3, 3],
        [1, 1, 1, 1, 2, 2, 2, 2, 3, 3]])
tensor([[1315, 1316, 1316, 1316, 1316, 1317, 1317, 1317, 1317,    0],
        [ 674,  675,  675,  675,  675,  676,  676,  676,  676,    0],
        [1182, 1183, 1183, 1183, 1183, 1184, 1184, 1184, 1184,    0]])
ref_v_mask :
t

In [77]:
ftk.tokenize(f_batch, padding=False, target=True)

{'value_tokens': [tensor([156,   0,   3,  ...,   1,   2,   2]),
  tensor([29,  0,  3,  ...,  1,  2,  2]),
  tensor([4, 0, 3,  ..., 1, 2, 2])]}

In [78]:
outputs = ftk.detokenize(ftk.tokenize(f_batch)["value_tokens"])
outputs

[tensor([  0,   1, 153,  ..., 652, 653, 654]),
 tensor([  0,   2,  26,  ..., 306, 308, 309]),
 tensor([  0,   0,   1,  ..., 391, 394, 395])]