In [1]:
import os
import sys
import glob
import torch
import numpy as np
import pandas as pd
import open3d as o3d

In [2]:
base_dir = os.path.dirname(os.getcwd())
data_dir = os.path.join(base_dir, "data")
src_dir = os.path.join(base_dir, "src")
sys.path.append(os.path.join(src_dir))

In [3]:
train_files = glob.glob(os.path.join(data_dir, "original", "train", "*", "*.obj"))
valid_files = glob.glob(os.path.join(data_dir, "original", "val", "*", "*.obj"))
len(train_files), len(valid_files)

(7003, 1088)

In [4]:
from utils_polygen import load_pipeline

In [5]:
class Tokenizer(object):
    
    def _padding(self, ids_tensor, pad_token, max_length=None):
        if max_length is None:
            max_length = max([len(ids) for ids in ids_tensor])
        
        ids_tensor = [
            torch.cat([
                ids, pad_token.repeat(max_length - len(ids) + 1)
            ])
            for ids in ids_tensor
        ]
        return ids_tensor
    
    def _make_padding_mask(self, ids_tensor, pad_id):
        mask = torch.where(
            ids_tensor==pad_id,
            torch.ones_like(ids_tensor),
            torch.zeros_like(ids_tensor)
        ).type(torch.bool)
        return mask

    def _make_future_mask(self, ids_tensor):
        batch, length = ids_tensor.shape
        arange = torch.arange(length)
        mask = torch.where(
            arange[None, :] <= arange[:, None],
            torch.zeros((length, length)),
            torch.ones((length, length))*(-np.inf)
        ).type(torch.float32)
        return mask
    
    def get_pred_start(self, start_token="bos", batch_size=1):
        special_tokens = self.special_tokens
        not_coord_token = self.not_coord_token
        max_seq_len = self.max_seq_len
        
        values = torch.stack(
            self._padding(
                [special_tokens[start_token]] * batch_size, 
                special_tokens["pad"],
                max_seq_len
            )
        )
        coord_type_tokens = torch.stack(
            self._padding(
                [self.not_coord_token] * batch_size,
                not_coord_token,
                max_seq_len
            )
        )
        position_tokens = torch.stack(
            self._padding(
                [self.not_coord_token] * batch_size,
                not_coord_token,
                max_seq_len
            )
        )
        
        padding_mask = self._make_padding_mask(values, self.pad_id)
        
        outputs = {
            "value_tokens": values,
            "coord_type_tokens": coord_type_tokens,
            "position_tokens": position_tokens,
            "padding_mask": padding_mask,
        }
        return outputs

In [6]:
v_batch, f_batch = [], []
for i in range(3):
    vs, _, fs = load_pipeline(train_files[i])
    
    vs = torch.tensor(vs)
    fs = [torch.tensor(f) for f in fs]
    
    v_batch.append(vs)
    f_batch.append(fs)
    print(vs.shape, len(fs))
    print("="*60)

torch.Size([204, 3]) 160
torch.Size([62, 3]) 45
torch.Size([64, 3]) 601


In [7]:
v_batch[0]

tensor([[166, 121, 166],
        [166, 121,  88],
        [166, 108, 166],
        [166, 108,  88],
        [165, 106, 165],
        [165, 106,  89],
        [165, 104, 165],
        [165, 104,  89],
        [165, 103, 165],
        [165, 103,  89],
        [164, 121, 164],
        [164, 121,  90],
        [164, 108, 164],
        [164, 108,  90],
        [164, 106, 164],
        [164, 106,  90],
        [164, 105, 164],
        [164, 105,  90],
        [164, 101, 164],
        [164, 101,  90],
        [163, 103, 163],
        [163, 103,  91],
        [163, 102, 163],
        [163, 102,  91],
        [163,  99, 163],
        [163,  99,  91],
        [162, 100, 162],
        [162, 100,  92],
        [162,  98, 162],
        [162,  98,  92],
        [161,  99, 161],
        [161,  99,  93],
        [160,  97, 160],
        [160,  97,  94],
        [159,  98, 159],
        [159,  98,  95],
        [159,  96, 159],
        [159,  96,  95],
        [158,  97, 158],
        [158,  97,  96],


## tokenizer for vertex model

In [8]:
class EncodeVertexTokenizer(Tokenizer):
    
    def __init__(self, pad_id=0, max_seq_len=None):
        self.pad_token = torch.tensor([pad_id])
        self.pad_id = pad_id
        
        if max_seq_len is not None:
            self.max_seq_len = max_seq_len - 1
        else:
            self.max_seq_len = max_seq_len
        
    def tokenize(self, vertices, padding=True):
        max_seq_len = self.max_seq_len
        vertices = [v.reshape(-1,) + 1 for v in vertices]
        coord_type_tokens = [torch.arange(len(v)) % 3 + 1 for v in vertices]
        position_tokens = [torch.arange(len(v)) // 3 + 1 for v in vertices]
        
        if padding:
            vertices = torch.stack(self._padding(vertices, self.pad_token, max_seq_len))
            coord_type_tokens = torch.stack(self._padding(coord_type_tokens, self.pad_token, max_seq_len))
            position_tokens = torch.stack(self._padding(position_tokens, self.pad_token, max_seq_len))
            padding_mask = self._make_padding_mask(vertices, self.pad_id)
            
            outputs = {
                "value_tokens": vertices,
                "coord_type_tokens": coord_type_tokens,
                "position_tokens": position_tokens,
                "padding_mask": padding_mask,
            }
        else:
            outputs = {
                "value_tokens": vertices,
                "coord_type_tokens": coord_type_tokens,
                "position_tokens": position_tokens,
            }
            
        return outputs

In [9]:
enc_vtk = EncodeVertexTokenizer()

In [10]:
enc_vtk.tokenize(v_batch)

{'value_tokens': tensor([[167, 122, 167,  ..., 109,  89,   0],
         [165, 164, 165,  ...,   0,   0,   0],
         [165, 165, 128,  ...,   0,   0,   0]]),
 'coord_type_tokens': tensor([[1, 2, 3,  ..., 2, 3, 0],
         [1, 2, 3,  ..., 0, 0, 0],
         [1, 2, 3,  ..., 0, 0, 0]]),
 'position_tokens': tensor([[  1,   1,   1,  ..., 204, 204,   0],
         [  1,   1,   1,  ...,   0,   0,   0],
         [  1,   1,   1,  ...,   0,   0,   0]]),
 'padding_mask': tensor([[False, False, False,  ..., False, False,  True],
         [False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True]])}

In [11]:
for k, vs in enc_vtk.tokenize(v_batch).items():
    print(k, ":")
    print(vs.shape)
    print(torch.stack([v[:10] for v in vs]))
    print(torch.stack([v[-10:] for v in vs]))
    print("="*60)

value_tokens :
torch.Size([3, 613])
tensor([[167, 122, 167, 167, 122,  89, 167, 109, 167, 167],
        [165, 164, 165, 165, 164,  91, 165, 155, 165, 165],
        [165, 165, 128, 164, 165, 128, 163, 165, 137, 163]])
tensor([[ 89, 122,  89,  89, 109, 167,  89, 109,  89,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0]])
coord_type_tokens :
torch.Size([3, 613])
tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3, 1],
        [1, 2, 3, 1, 2, 3, 1, 2, 3, 1],
        [1, 2, 3, 1, 2, 3, 1, 2, 3, 1]])
tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
position_tokens :
torch.Size([3, 613])
tensor([[1, 1, 1, 2, 2, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 2, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 2, 2, 3, 3, 3, 4]])
tensor([[202, 202, 202, 203, 203, 203, 204, 204, 204,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0,   0, 

In [12]:
enc_vtk.tokenize(v_batch, padding=False)

{'value_tokens': [tensor([167, 122, 167, 167, 122,  89, 167, 109, 167, 167, 109,  89, 166, 107,
          166, 166, 107,  90, 166, 105, 166, 166, 105,  90, 166, 104, 166, 166,
          104,  90, 165, 122, 165, 165, 122,  91, 165, 109, 165, 165, 109,  91,
          165, 107, 165, 165, 107,  91, 165, 106, 165, 165, 106,  91, 165, 102,
          165, 165, 102,  91, 164, 104, 164, 164, 104,  92, 164, 103, 164, 164,
          103,  92, 164, 100, 164, 164, 100,  92, 163, 101, 163, 163, 101,  93,
          163,  99, 163, 163,  99,  93, 162, 100, 162, 162, 100,  94, 161,  98,
          161, 161,  98,  95, 160,  99, 160, 160,  99,  96, 160,  97, 160, 160,
           97,  96, 159,  98, 159, 159,  98,  97, 158,  97, 158, 158,  97,  98,
          158,  96, 158, 158,  96,  98, 156,  97, 156, 156,  97, 100, 156,  95,
          156, 156,  95, 100, 154,  96, 154, 154,  96, 102, 154,  95, 154, 154,
           95, 102, 153,  96, 153, 153,  96, 103, 153,  95, 153, 153,  95, 103,
          132, 161, 162,

In [13]:
for k, vs in enc_vtk.tokenize(v_batch, padding=False).items():
    print(k, ":")
    print(torch.stack([v[:10] for v in vs]))
    print(torch.stack([v[-10:] for v in vs]))
    print("="*60)

value_tokens :
tensor([[167, 122, 167, 167, 122,  89, 167, 109, 167, 167],
        [165, 164, 165, 165, 164,  91, 165, 155, 165, 165],
        [165, 165, 128, 164, 165, 128, 163, 165, 137, 163]], dtype=torch.int32)
tensor([[167,  89, 122,  89,  89, 109, 167,  89, 109,  89],
        [165,  91, 164,  91,  91, 155, 165,  91, 155,  91],
        [137,  93, 165, 119,  92, 165, 128,  91, 165, 128]], dtype=torch.int32)
coord_type_tokens :
tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3, 1],
        [1, 2, 3, 1, 2, 3, 1, 2, 3, 1],
        [1, 2, 3, 1, 2, 3, 1, 2, 3, 1]])
tensor([[3, 1, 2, 3, 1, 2, 3, 1, 2, 3],
        [3, 1, 2, 3, 1, 2, 3, 1, 2, 3],
        [3, 1, 2, 3, 1, 2, 3, 1, 2, 3]])
position_tokens :
tensor([[1, 1, 1, 2, 2, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 2, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 2, 2, 3, 3, 3, 4]])
tensor([[201, 202, 202, 202, 203, 203, 203, 204, 204, 204],
        [ 59,  60,  60,  60,  61,  61,  61,  62,  62,  62],
        [ 61,  62,  62,  62,  63,  63,  63,  64,  64,  64]])


## tokenizer for face model (encoder)

In [14]:
class DecodeVertexTokenizer(Tokenizer):
    
    def __init__(self, bos_id=0, eos_id=1, pad_id=2, max_seq_len=None):
        
        self.special_tokens = {
            "bos": torch.tensor([bos_id]),
            "eos": torch.tensor([eos_id]),
            "pad": torch.tensor([pad_id]),
        }
        self.pad_id = pad_id
        self.not_coord_token = torch.tensor([0])
        if max_seq_len is not None:
            self.max_seq_len = max_seq_len - 1
        else:
            self.max_seq_len = max_seq_len
        
    
    def tokenize(self, vertices, padding=True):
        special_tokens = self.special_tokens
        not_coord_token = self.not_coord_token
        max_seq_len = self.max_seq_len
        
        vertices = [
            torch.cat([
                special_tokens["bos"], 
                v.reshape(-1,)  + len(special_tokens), 
                special_tokens["eos"]
            ])
            for v in vertices
        ]
        
        coord_type_tokens = [
            torch.cat([
                not_coord_token,
                torch.arange(len(v)-2) % 3 + 1,
                not_coord_token
            ])
            for v in vertices
        ]
        
        position_tokens = [
            torch.cat([
                not_coord_token,
                torch.arange(len(v)-2) // 3 + 1,
                not_coord_token
            ])
            for v in vertices
        ]
        
        vertices_target = [
            torch.cat([v, special_tokens["pad"]])[1:] 
            for v in vertices
        ]
        
        if padding:
            vertices = torch.stack(
                self._padding(vertices, special_tokens["pad"], max_seq_len)
            )
            vertices_target = torch.stack(
                self._padding(vertices_target, special_tokens["pad"], max_seq_len)
            )
            coord_type_tokens = torch.stack(
                self._padding(coord_type_tokens, not_coord_token, max_seq_len)
            )
            position_tokens = torch.stack(
                self._padding(position_tokens, not_coord_token, max_seq_len)
            )
            
            padding_mask = self._make_padding_mask(vertices, self.pad_id)
            # future_mask = self._make_future_mask(vertices)
            outputs = {
                "value_tokens": vertices,
                "target_tokens": vertices_target,
                "coord_type_tokens": coord_type_tokens,
                "position_tokens": position_tokens,
                "padding_mask": padding_mask,
                # "future_mask": future_mask,
            }
        else:
            outputs = {
                "value_tokens": vertices,
                "target_tokens": vertices_target,
                "coord_type_tokens": coord_type_tokens,
                "position_tokens": position_tokens,
            }
            
        return outputs
    
    def detokenize(self, vertices):
        special_tokens = self.special_tokens
        
        result = []
        for vertex in vertices:
            vertex = vertex - len(special_tokens)
            result.append(
                vertex[torch.where(vertex >= 0)]
            )
        return result

In [15]:
dec_vtk = DecodeVertexTokenizer(max_seq_len=2400)

In [16]:
dec_vtk.tokenize(v_batch)

{'value_tokens': tensor([[  0, 169, 124,  ...,   2,   2,   2],
         [  0, 167, 166,  ...,   2,   2,   2],
         [  0, 167, 167,  ...,   2,   2,   2]]),
 'target_tokens': tensor([[169, 124, 169,  ...,   2,   2,   2],
         [167, 166, 167,  ...,   2,   2,   2],
         [167, 167, 130,  ...,   2,   2,   2]]),
 'coord_type_tokens': tensor([[0, 1, 2,  ..., 0, 0, 0],
         [0, 1, 2,  ..., 0, 0, 0],
         [0, 1, 2,  ..., 0, 0, 0]]),
 'position_tokens': tensor([[0, 1, 1,  ..., 0, 0, 0],
         [0, 1, 1,  ..., 0, 0, 0],
         [0, 1, 1,  ..., 0, 0, 0]]),
 'padding_mask': tensor([[False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True]])}

In [17]:
for k, vs in dec_vtk.tokenize(v_batch).items():
    print(k, ":")
    print(vs.shape)
    print(torch.stack([v[:10] for v in vs]))
    print(torch.stack([v[-10:] for v in vs]))
    print("="*60)

value_tokens :
torch.Size([3, 2400])
tensor([[  0, 169, 124, 169, 169, 124,  91, 169, 111, 169],
        [  0, 167, 166, 167, 167, 166,  93, 167, 157, 167],
        [  0, 167, 167, 130, 166, 167, 130, 165, 167, 139]])
tensor([[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]])
target_tokens :
torch.Size([3, 2400])
tensor([[169, 124, 169, 169, 124,  91, 169, 111, 169, 169],
        [167, 166, 167, 167, 166,  93, 167, 157, 167, 167],
        [167, 167, 130, 166, 167, 130, 165, 167, 139, 165]])
tensor([[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]])
coord_type_tokens :
torch.Size([3, 2400])
tensor([[0, 1, 2, 3, 1, 2, 3, 1, 2, 3],
        [0, 1, 2, 3, 1, 2, 3, 1, 2, 3],
        [0, 1, 2, 3, 1, 2, 3, 1, 2, 3]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
position_tokens :
torch.Size([3, 2400

In [18]:
dec_vtk.tokenize(v_batch, padding=False)

{'value_tokens': [tensor([  0, 169, 124, 169, 169, 124,  91, 169, 111, 169, 169, 111,  91, 168,
          109, 168, 168, 109,  92, 168, 107, 168, 168, 107,  92, 168, 106, 168,
          168, 106,  92, 167, 124, 167, 167, 124,  93, 167, 111, 167, 167, 111,
           93, 167, 109, 167, 167, 109,  93, 167, 108, 167, 167, 108,  93, 167,
          104, 167, 167, 104,  93, 166, 106, 166, 166, 106,  94, 166, 105, 166,
          166, 105,  94, 166, 102, 166, 166, 102,  94, 165, 103, 165, 165, 103,
           95, 165, 101, 165, 165, 101,  95, 164, 102, 164, 164, 102,  96, 163,
          100, 163, 163, 100,  97, 162, 101, 162, 162, 101,  98, 162,  99, 162,
          162,  99,  98, 161, 100, 161, 161, 100,  99, 160,  99, 160, 160,  99,
          100, 160,  98, 160, 160,  98, 100, 158,  99, 158, 158,  99, 102, 158,
           97, 158, 158,  97, 102, 156,  98, 156, 156,  98, 104, 156,  97, 156,
          156,  97, 104, 155,  98, 155, 155,  98, 105, 155,  97, 155, 155,  97,
          105, 134, 163,

In [19]:
for k, vs in dec_vtk.tokenize(v_batch, padding=False).items():
    print(k, ":")
    print(torch.stack([v[:10] for v in vs]))
    print(torch.stack([v[-10:] for v in vs]))
    print("="*60)

value_tokens :
tensor([[  0, 169, 124, 169, 169, 124,  91, 169, 111, 169],
        [  0, 167, 166, 167, 167, 166,  93, 167, 157, 167],
        [  0, 167, 167, 130, 166, 167, 130, 165, 167, 139]])
tensor([[ 91, 124,  91,  91, 111, 169,  91, 111,  91,   1],
        [ 93, 166,  93,  93, 157, 167,  93, 157,  93,   1],
        [ 95, 167, 121,  94, 167, 130,  93, 167, 130,   1]])
target_tokens :
tensor([[169, 124, 169, 169, 124,  91, 169, 111, 169, 169],
        [167, 166, 167, 167, 166,  93, 167, 157, 167, 167],
        [167, 167, 130, 166, 167, 130, 165, 167, 139, 165]])
tensor([[124,  91,  91, 111, 169,  91, 111,  91,   1,   2],
        [166,  93,  93, 157, 167,  93, 157,  93,   1,   2],
        [167, 121,  94, 167, 130,  93, 167, 130,   1,   2]])
coord_type_tokens :
tensor([[0, 1, 2, 3, 1, 2, 3, 1, 2, 3],
        [0, 1, 2, 3, 1, 2, 3, 1, 2, 3],
        [0, 1, 2, 3, 1, 2, 3, 1, 2, 3]])
tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3, 0],
        [1, 2, 3, 1, 2, 3, 1, 2, 3, 0],
        [1, 2, 3, 1, 2, 

In [20]:
outputs = dec_vtk.detokenize(dec_vtk.tokenize(v_batch)["value_tokens"])
outputs

[tensor([166, 121, 166, 166, 121,  88, 166, 108, 166, 166, 108,  88, 165, 106,
         165, 165, 106,  89, 165, 104, 165, 165, 104,  89, 165, 103, 165, 165,
         103,  89, 164, 121, 164, 164, 121,  90, 164, 108, 164, 164, 108,  90,
         164, 106, 164, 164, 106,  90, 164, 105, 164, 164, 105,  90, 164, 101,
         164, 164, 101,  90, 163, 103, 163, 163, 103,  91, 163, 102, 163, 163,
         102,  91, 163,  99, 163, 163,  99,  91, 162, 100, 162, 162, 100,  92,
         162,  98, 162, 162,  98,  92, 161,  99, 161, 161,  99,  93, 160,  97,
         160, 160,  97,  94, 159,  98, 159, 159,  98,  95, 159,  96, 159, 159,
          96,  95, 158,  97, 158, 158,  97,  96, 157,  96, 157, 157,  96,  97,
         157,  95, 157, 157,  95,  97, 155,  96, 155, 155,  96,  99, 155,  94,
         155, 155,  94,  99, 153,  95, 153, 153,  95, 101, 153,  94, 153, 153,
          94, 101, 152,  95, 152, 152,  95, 102, 152,  94, 152, 152,  94, 102,
         131, 160, 161, 131, 160, 160, 131, 160, 159

In [21]:
dec_vtk.get_pred_start()

{'value_tokens': tensor([[0, 2, 2,  ..., 2, 2, 2]]),
 'coord_type_tokens': tensor([[0, 0, 0,  ..., 0, 0, 0]]),
 'position_tokens': tensor([[0, 0, 0,  ..., 0, 0, 0]]),
 'padding_mask': tensor([[False,  True,  True,  ...,  True,  True,  True]])}

## tokenizer for face model (decoder)

In [22]:
f_batch[0][:5]

[tensor([203, 202, 200, 201]),
 tensor([203, 201, 147, 143,  97, 101,   1,   3]),
 tensor([203, 195, 194, 202]),
 tensor([203,   3,   5, 195]),
 tensor([202, 194,   4,   2])]

In [23]:
class FaceTokenizer(Tokenizer):
    
    def __init__(self, bof_id=0, eos_id=1, pad_id=2, max_seq_len=None):
        self.special_tokens = {
            "bof": torch.tensor([bof_id]),
            "eos": torch.tensor([eos_id]),
            "pad": torch.tensor([pad_id]),
        }
        self.pad_id = pad_id
        self.not_coord_token = torch.tensor([0])
        if max_seq_len is not None:
            self.max_seq_len = max_seq_len - 1
        else:
            self.max_seq_len = max_seq_len
        
    def tokenize(self, faces, padding=True):
        special_tokens = self.special_tokens
        not_coord_token = self.not_coord_token
        max_seq_len = self.max_seq_len
        
        faces_ids = []
        in_position_tokens = []
        out_position_tokens = []
        faces_target = []

        for face in faces:
            face_with_bof = [
                torch.cat([
                    special_tokens["bof"],
                    f + len(special_tokens)
                ])
                for f in face
            ]
            face = torch.cat([
                torch.cat(face_with_bof),
                special_tokens["eos"]
            ])
            faces_ids.append(face)
            faces_target.append(torch.cat([face, special_tokens["pad"]])[1:])

            in_position_token = torch.cat([
                torch.arange(1, len(f)+1)
                for f in face_with_bof
            ])
            in_position_token = torch.cat([in_position_token, not_coord_token])
            in_position_tokens.append(in_position_token)

            out_position_token = torch.cat([
                torch.ones((len(f), ), dtype=torch.int32) * (idx+1)
                for idx, f in enumerate(face_with_bof)
            ])
            out_position_token = torch.cat([out_position_token, not_coord_token])
            out_position_tokens.append(out_position_token)
        
        
        if padding:
            faces_ids = torch.stack(
                self._padding(faces_ids, special_tokens["pad"], max_seq_len)
            )
            faces_target = torch.stack(
                self._padding(faces_target, special_tokens["pad"], max_seq_len)
            )
            in_position_tokens = torch.stack(
                self._padding(in_position_tokens, not_coord_token, max_seq_len)
            )
            out_position_tokens = torch.stack(
                self._padding(out_position_tokens, not_coord_token, max_seq_len)
            )

            padding_mask = self._make_padding_mask(faces_ids, self.pad_id)
            # future_mask = self._make_future_mask(faces)

            cond_vertice = faces_ids >= len(special_tokens)
            reference_vertices_mask = torch.where(cond_vertice, 1., 0.)
            reference_vertices_ids = torch.where(cond_vertice, faces_ids-len(special_tokens), 0)
            reference_embed_mask = torch.where(cond_vertice, 0., 1.)
            reference_embed_ids = torch.where(cond_vertice, 0, faces_ids)

            outputs = {
                "value_tokens": faces_ids,
                "target_tokens": faces_target,
                "in_position_tokens": in_position_tokens,
                "out_position_tokens": out_position_tokens,
                "ref_v_mask": reference_vertices_mask,
                "ref_v_ids": reference_vertices_ids,
                "ref_e_mask": reference_embed_mask,
                "ref_e_ids": reference_embed_ids,
                "padding_mask": padding_mask,
                # "future_mask": future_mask,
            }
            
        else:
            reference_vertices_mask = []
            reference_vertices_ids = []
            reference_embed_mask = []
            reference_embed_ids = []

            for f in faces_ids:
                cond_vertice = f >= len(special_tokens)

                ref_v_mask = torch.where(cond_vertice, 1., 0.)
                ref_e_mask = torch.where(cond_vertice, 0., 1.)
                ref_v_ids = torch.where(cond_vertice, f-len(special_tokens), 0)
                ref_e_ids = torch.where(cond_vertice, 0, f)
                
                reference_vertices_mask.append(ref_v_mask)
                reference_vertices_ids.append(ref_v_ids)
                reference_embed_mask.append(ref_e_mask)
                reference_embed_ids.append(ref_e_ids)
            
            outputs = {
                "value_tokens": faces_ids,
                "target_tokens": faces_target,
                "in_position_tokens": in_position_tokens,
                "out_position_tokens": out_position_tokens,
                "ref_v_mask": reference_vertices_mask,
                "ref_v_ids": reference_vertices_ids,
                "ref_e_mask": reference_embed_mask,
                "ref_e_ids": reference_embed_ids,
            }
        
        return outputs

    def tokenize_prediction(self, faces):
        special_tokens = self.special_tokens
        not_coord_token = self.not_coord_token
        max_seq_len = self.max_seq_len
        
        faces_ids = []
        in_position_tokens = []
        out_position_tokens = []
        faces_target = []    
        
        for face in faces:
            face = torch.cat([special_tokens["bof"], face])
            faces_ids.append(face)
            faces_target.append(torch.cat([face, special_tokens["pad"]])[1:])
            
            
            bof_indeces = torch.where(face==special_tokens["bof"])[0]
            now_pos_in = 1
            now_pos_out = 0
            in_position_token = []
            out_position_token = []
            
            for idx, point in enumerate(face):
                if idx in bof_indeces:
                    now_pos_out += 1
                    now_pos_in = 1
                
                in_position_token.append(now_pos_in)
                out_position_token.append(now_pos_out)
                now_pos_in += 1
                
            in_position_tokens.append(torch.tensor(in_position_token))
            out_position_tokens.append(torch.tensor(out_position_token))
            

        faces_ids = torch.stack(
            self._padding(faces_ids, special_tokens["pad"], max_seq_len)
        )
        faces_target = torch.stack(
            self._padding(faces_target, special_tokens["pad"], max_seq_len)
        )
        in_position_tokens = torch.stack(
            self._padding(in_position_tokens, not_coord_token, max_seq_len)
        )
        out_position_tokens = torch.stack(
            self._padding(out_position_tokens, not_coord_token, max_seq_len)
        )

        padding_mask = self._make_padding_mask(faces_ids, self.pad_id)
        # future_mask = self._make_future_mask(faces)

        cond_vertice = faces_ids >= len(special_tokens)
        reference_vertices_mask = torch.where(cond_vertice, 1., 0.)
        reference_vertices_ids = torch.where(cond_vertice, faces_ids-len(special_tokens), 0)
        reference_embed_mask = torch.where(cond_vertice, 0., 1.)
        reference_embed_ids = torch.where(cond_vertice, 0, faces_ids)

        outputs = {
            "value_tokens": faces_ids,
            "target_tokens": faces_target,
            "in_position_tokens": in_position_tokens,
            "out_position_tokens": out_position_tokens,
            "ref_v_mask": reference_vertices_mask,
            "ref_v_ids": reference_vertices_ids,
            "ref_e_mask": reference_embed_mask,
            "ref_e_ids": reference_embed_ids,
            "padding_mask": padding_mask,
            # "future_mask": future_mask,
        }
        
        return outputs
    
    
    def detokenize(self, faces):
        special_tokens = self.special_tokens
        
        result = []
        for face in faces:
            face = face - len(special_tokens)
            result.append(
                face[torch.where(face >= 0)]
            )
        return result

In [24]:
ftk = FaceTokenizer(max_seq_len=3800)

In [25]:
for k, vs in ftk.tokenize(f_batch).items():
    print(k, ":")
    print(vs.shape)
    print(torch.stack([v[:10] for v in vs]))
    print(torch.stack([v[-10:] for v in vs]))
    print("="*60)

value_tokens :
torch.Size([3, 3800])
tensor([[  0, 206, 205, 203, 204,   0, 206, 204, 150, 146],
        [  0,  64,  63,  61,  62,   0,  64,  62,   4,   6],
        [  0,  66,  66,  64,  64,   0,  66,  66,  64,  64]])
tensor([[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]])
target_tokens :
torch.Size([3, 3800])
tensor([[206, 205, 203, 204,   0, 206, 204, 150, 146, 100],
        [ 64,  63,  61,  62,   0,  64,  62,   4,   6,  58],
        [ 66,  66,  64,  64,   0,  66,  66,  64,  64,   0]])
tensor([[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]])
in_position_tokens :
torch.Size([3, 3800])
tensor([[1, 2, 3, 4, 5, 1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5, 1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
out_position_tokens :
torch.Size([3,

In [26]:
for k, vs in ftk.tokenize(f_batch, padding=False).items():
    print(k, ":")
    print(torch.stack([v[:10] for v in vs]))
    print(torch.stack([v[-10:] for v in vs]))
    print("="*60)

value_tokens :
tensor([[  0, 206, 205, 203, 204,   0, 206, 204, 150, 146],
        [  0,  64,  63,  61,  62,   0,  64,  62,   4,   6],
        [  0,  66,  66,  64,  64,   0,  66,  66,  64,  64]])
tensor([[ 8,  6,  5,  7,  0,  6,  4,  3,  5,  1],
        [11,  9,  8, 10,  0,  6,  4,  3,  5,  1],
        [ 5,  4,  4,  5,  0,  5,  4,  3,  5,  1]])
target_tokens :
tensor([[206, 205, 203, 204,   0, 206, 204, 150, 146, 100],
        [ 64,  63,  61,  62,   0,  64,  62,   4,   6,  58],
        [ 66,  66,  64,  64,   0,  66,  66,  64,  64,   0]])
tensor([[ 6,  5,  7,  0,  6,  4,  3,  5,  1,  2],
        [ 9,  8, 10,  0,  6,  4,  3,  5,  1,  2],
        [ 4,  4,  5,  0,  5,  4,  3,  5,  1,  2]])
in_position_tokens :
tensor([[1, 2, 3, 4, 5, 1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5, 1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]])
tensor([[2, 3, 4, 5, 1, 2, 3, 4, 5, 0],
        [2, 3, 4, 5, 1, 2, 3, 4, 5, 0],
        [2, 3, 4, 5, 1, 2, 3, 4, 5, 0]])
out_position_tokens :
tensor([[1, 1, 1, 1,

In [27]:
outputs = ftk.detokenize(ftk.tokenize(f_batch)["value_tokens"])
outputs

[tensor([203, 202, 200, 201, 203, 201, 147, 143,  97, 101,   1,   3, 203, 195,
         194, 202, 203,   3,   5, 195, 202, 194,   4,   2, 202,   2,   0,  98,
          94, 140, 144, 200, 201, 200, 144, 145, 184, 185, 146, 147, 199, 198,
         196, 197, 199, 197,   7,   9, 199, 193, 192, 198, 199,   9,  19, 193,
         198, 192,  18,   8, 198,   8,   6, 196, 197, 196, 194, 195, 197, 195,
           5,   7, 196,   6,   4, 194, 193, 183, 182, 192, 193,  19,  25, 183,
         192, 182,  24,  18, 191, 190, 178, 179, 191, 189, 188, 190, 191, 179,
          21,  17, 191,  17,  15, 189, 190, 188,  14,  16, 190,  16,  20, 178,
         189, 187, 186, 188, 189,  15,  13, 187, 188, 186,  12,  14, 187, 185,
         184, 186, 187,  13,  11, 100,  96, 142, 146, 185, 186, 184, 145, 141,
          95,  99,  10,  12, 183, 177, 176, 182, 183,  25,  29, 177, 182, 176,
          28,  24, 181, 180, 174, 175, 181, 179, 178, 180, 181, 175,  27,  23,
         181,  23,  21, 179, 180, 178,  20,  22, 180

In [28]:
ftk.get_pred_start("bof")

{'value_tokens': tensor([[0, 2, 2,  ..., 2, 2, 2]]),
 'coord_type_tokens': tensor([[0, 0, 0,  ..., 0, 0, 0]]),
 'position_tokens': tensor([[0, 0, 0,  ..., 0, 0, 0]]),
 'padding_mask': tensor([[False,  True,  True,  ...,  True,  True,  True]])}

In [29]:
ftk.tokenize_prediction([torch.tensor([589, 423, 0, 30, 21])])

{'value_tokens': tensor([[  0, 589, 423,  ...,   2,   2,   2]]),
 'target_tokens': tensor([[589, 423,   0,  ...,   2,   2,   2]]),
 'in_position_tokens': tensor([[1, 2, 3,  ..., 0, 0, 0]]),
 'out_position_tokens': tensor([[1, 1, 1,  ..., 0, 0, 0]]),
 'ref_v_mask': tensor([[0., 1., 1.,  ..., 0., 0., 0.]]),
 'ref_v_ids': tensor([[  0, 586, 420,  ...,   0,   0,   0]]),
 'ref_e_mask': tensor([[1., 0., 0.,  ..., 1., 1., 1.]]),
 'ref_e_ids': tensor([[0, 0, 0,  ..., 2, 2, 2]]),
 'padding_mask': tensor([[False, False, False,  ...,  True,  True,  True]])}