In [1]:
import os
import sys
import json
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
base_dir = os.path.dirname(os.getcwd())
data_dir = os.path.join(base_dir, "data", "original")
train_files = glob.glob(os.path.join(data_dir, "train", "*", "*.obj"))
valid_files = glob.glob(os.path.join(data_dir, "val", "*", "*.obj"))
print(len(train_files), len(valid_files))

src_dir = os.path.join(base_dir, "src")
sys.path.append(os.path.join(src_dir))

7003 1088


In [3]:
from utils import load_pipeline
from tokenizers import EncodeVertexTokenizer, FaceTokenizer

In [4]:
v_batch, f_batch = [], []
for i in range(3):
    vs, _, fs = load_pipeline(train_files[i])
    
    vs = torch.tensor(vs)
    fs = [torch.tensor(f) for f in fs]
    
    v_batch.append(vs)
    f_batch.append(fs)
    print(vs.shape, len(fs))
    print("="*60)

torch.Size([655, 3]) 588
torch.Size([310, 3]) 220
torch.Size([396, 3]) 304


In [5]:
enc_tokenizer = EncodeVertexTokenizer()
dec_tokenizer = FaceTokenizer()

In [6]:
src_tokens = enc_tokenizer.tokenize(v_batch)
src_tokens

{'value_tokens': tensor([[161, 137, 133,  ..., 137, 123,   0],
         [135, 162, 130,  ...,   0,   0,   0],
         [163,  99, 134,  ...,   0,   0,   0]]),
 'coord_type_tokens': tensor([[1, 2, 3,  ..., 2, 3, 0],
         [1, 2, 3,  ..., 0, 0, 0],
         [1, 2, 3,  ..., 0, 0, 0]]),
 'position_tokens': tensor([[  1,   1,   1,  ..., 655, 655,   0],
         [  1,   1,   1,  ...,   0,   0,   0],
         [  1,   1,   1,  ...,   0,   0,   0]]),
 'padding_mask': tensor([[ True,  True,  True,  ...,  True,  True, False],
         [ True,  True,  True,  ..., False, False, False],
         [ True,  True,  True,  ..., False, False, False]])}

In [7]:
class FaceEncoderEmbedding(nn.Module):
    
    def __init__(self, embed_dim=256,
                 vocab_value=259, pad_idx_value=2, 
                 vocab_coord_type=4, pad_idx_coord_type=0,
                 vocab_position=1000, pad_idx_position=0):
        
        super().__init__()
        
        self.value_embed = nn.Embedding(
            vocab_value, embed_dim, padding_idx=pad_idx_value
        )
        self.coord_type_embed = nn.Embedding(
            vocab_coord_type, embed_dim, padding_idx=pad_idx_coord_type
        )
        self.position_embed = nn.Embedding(
            vocab_position, embed_dim, padding_idx=pad_idx_position
        )
        
    def forward(self, value_tokens, coord_type_tokens, position_tokens):
        embed = self.value_embed(value_tokens)
        embed = embed + self.coord_type_embed(coord_type_tokens)
        embed = embed + self.position_embed(position_tokens)
        embed = embed[:, :-1]
        embed = torch.cat([
            e.sum(dim=1).unsqueeze(dim=1) for e in embed.split(3, dim=1)
        ], dim=1)
        return embed

In [8]:
src_embed = FaceEncoderEmbedding(embed_dim=128)

In [9]:
print(
    src_tokens["value_tokens"].shape,
    src_tokens["coord_type_tokens"].shape,
    src_tokens["position_tokens"].shape,
    src_tokens["padding_mask"].shape
)

torch.Size([3, 1966]) torch.Size([3, 1966]) torch.Size([3, 1966]) torch.Size([3, 1966])


In [10]:
src_emb = src_embed(
    src_tokens["value_tokens"], 
    src_tokens["coord_type_tokens"], 
    src_tokens["position_tokens"]
)
src_emb.shape

torch.Size([3, 655, 128])

In [11]:
tgt_tokens = dec_tokenizer.tokenize(f_batch)
tgt_tokens

{'value_tokens': tensor([[657, 651, 652,  ...,   0,   1,   2],
         [312, 311, 310,  ...,   2,   2,   2],
         [398, 397, 394,  ...,   2,   2,   2]]),
 'target_tokens': tensor([[651, 652, 648,  ...,   1,   2,   2],
         [311, 310, 296,  ...,   2,   2,   2],
         [397, 394, 395,  ...,   2,   2,   2]]),
 'in_position_tokens': tensor([[1, 2, 3,  ..., 5, 0, 0],
         [1, 2, 3,  ..., 0, 0, 0],
         [1, 2, 3,  ..., 0, 0, 0]]),
 'out_position_tokens': tensor([[  1.,   1.,   1.,  ..., 588.,   0.,   0.],
         [  1.,   1.,   1.,  ...,   0.,   0.,   0.],
         [  1.,   1.,   1.,  ...,   0.,   0.,   0.]]),
 'ref_v_mask': tensor([[1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.]]),
 'ref_v_ids': tensor([[654, 648, 649,  ...,   0,   0,   0],
         [309, 308, 307,  ...,   0,   0,   0],
         [395, 394, 391,  ...,   0,   0,   0]]),
 'ref_e_mask': tensor([[0., 0., 0.,  ..., 1., 1., 1.],
         [0., 0., 

In [12]:
class FaceDecoderEmbedding(nn.Module):
    
    def __init__(self, embed_dim=256,
                 vocab_value=3, pad_idx_value=2, 
                 vocab_in_position=100, pad_idx_in_position=0,
                 vocab_out_position=1000, pad_idx_out_position=0):
        
        super().__init__()
        
        self.value_embed = nn.Embedding(
            vocab_value, embed_dim, padding_idx=pad_idx_value
        )
        self.in_position_embed = nn.Embedding(
            vocab_in_position, embed_dim, padding_idx=pad_idx_in_position
        )
        self.out_position_embed = nn.Embedding(
            vocab_out_position, embed_dim, padding_idx=pad_idx_out_position
        )
        
    def forward(self, encoder_embed, tokens):
        embed = torch.cat([
            encoder_embed[b_idx, ids].unsqueeze(dim=0) 
            for b_idx, ids in enumerate(tokens["ref_v_ids"])
        ], dim=0)
        embed = embed * tokens["ref_v_mask"].unsqueeze(dim=2)
        
        embed = (embed + \
                (self.value_embed(tokens["ref_e_ids"]) * 
                 tokens["ref_e_mask"].unsqueeze(dim=2)))
        
        embed = embed + self.in_position_embed(tokens["in_position_tokens"])
        embed = embed + self.out_position_embed(tokens["in_position_tokens"])
        return embed

In [13]:
tgt_embed = FaceDecoderEmbedding(embed_dim=128)

In [15]:
tgt_emb = tgt_embed(src_emb, tgt_tokens)
tgt_emb.shape

torch.Size([3, 3083, 128])

In [16]:
class Config(object):
    
    def write_to_json(self, out_path):
        with open(out_path, "w") as fw:
            json.dump(self.config, fw, indent=4)
            
    def load_from_json(self, file_path):
        with open(file_path) as fr:
            self.config = json.load(fr)
        
    def __getitem__(self, key):
        return self.config[key]

In [None]:
class FacePolyGenConfig(Config):
    
    def __init__(self,
                 embed_dim=256, 
                 src__max_seq_len=2400, 
                 src__tokenizer__bos_id=0,
                 src__tokenizer__eos_id=1,
                 src__tokenizer__pad_id=2,
                 tgt__max_seq_len=2400,
                 tgt__tokenizer__eof_id=0,
                 tgt__tokenizer__eos_id=1, 
                 tgt__tokenizer__pad_id=2,
                 src__embedding__vocab_value=256 + 3, 
                 src__embedding__vocab_coord_type=4, 
                 src__embedding__vocab_position=1000, 
                 src__embedding__pad_idx_value=2,
                 src__embedding__pad_idx_coord_type=0,
                 src__embedding__pad_idx_position=0,
                 tgt__embedding__vocab_value=3,
                 tgt__embedding__vocab_in_position=100,
                 tgt__embedding__vocab_out_position=1000,
                 tgt__embedding__pad_idx_value=2,
                 tgt__embedding__pad_idx_in_position=0,
                 tgt__embedding__pad_idx_out_position=0,
                 reformer__depth=12,
                 reformer__heads=8,
                 reformer__n_hashes=8,
                 reformer__bucket_size=48,
                 reformer__causal=True,
                 reformer__lsh_dropout=0.2, 
                 reformer__ff_dropout=0.2,
                 reformer__post_attn_dropout=0.2,
                 reformer__ff_mult=4):
        # tokenizer config
        src_tokenizer_config = {
            "bos_id": src__tokenizer__bos_id,
            "eos_id": src__tokenizer__eos_id,
            "pad_id": src__tokenizer__pad_id,
            "max_seq_len": src__max_seq_len,
        }
        tgt_tokenizer_config = {
            "bos_id": tgt__tokenizer__bos_id,
            "eos_id": tgt__tokenizer__eos_id,
            "pad_id": tgt__tokenizer__pad_id,
            "max_seq_len": tgt__max_seq_len,
        }
        
        # embedding config
        src_embedding_config = {
            "vocab_value": src__embedding__vocab_value,
            "vocab_coord_type": src__embedding__vocab_coord_type,
            "vocab_position": src__embedding__vocab_position,
            "pad_idx_value": src__embedding__pad_idx_value,
            "pad_idx_coord_type": src__embedding__pad_idx_coord_type,
            "pad_idx_position": src__embedding_pad_idx_position,
            "embed_dim": embed_dim,
        }
        tgt_embedding_config = {
            "vocab_value": tgt__embedding__vocab_value,
            "vocab_coord_type": tgt__embedding__vocab_coord_type,
            "vocab_position": tgt__embedding__vocab_position,
            "pad_idx_value": tgt__embedding__pad_idx_value,
            "pad_idx_coord_type": tgt__embedding__pad_idx_coord_type,
            "pad_idx_position": tgt__embedding_pad_idx_position,
            "embed_dim": embed_dim,
        }
        
        # reformer info
        reformer_config = {
            "dim": embed_dim,
            "depth": reformer__depth,
            "max_seq_len": max_seq_len,
            "heads": reformer__heads,
            "bucket_size": reformer__bucket_size,
            "n_hashes": reformer__n_hashes,
            "causal": reformer__causal,
            "lsh_dropout": reformer__lsh_dropout, 
            "ff_dropout": reformer__ff_dropout,
            "post_attn_dropout": reformer__post_attn_dropout,
            "ff_mult": reformer__ff_mult,
        }
        
        self.config = {
            "embed_dim": embed_dim,
            "max_seq_len": max_seq_len,
            "src_tokenizer": src_tokenizer_config,
            "tgt_tokenizer": tgt_tokenizer_config,
            "src_embedding": src_embedding_config,
            "tgt_embedding": tgt_embedding_config,
            "reformer": reformer_config,
        }

In [None]:
class VertexPolyGen(nn.Module):
    
    def __init__(self, model_config):
        super().__init__()
        self.tokenizer = DecodeVertexTokenizer(**model_config["tokenizer"])
        self.embedding = VertexEmbedding(**model_config["embedding"])
        self.reformer = Reformer(**model_config["reformer"])
        self.layernorm = nn.LayerNorm(model_config["embed_dim"])
        self.loss_func = nn.CrossEntropyLoss(ignore_index=model_config["tokenizer"]["pad_id"])
    
    def forward(self, tokens, device=None):
        
        hs = self.embedding(
            tokens["value_tokens"], 
            tokens["coord_type_tokens"], 
            tokens["position_tokens"]
        )
        
        hs = self.reformer(
            hs, input_mask=tokens["padding_mask"]
        )
        hs = self.layernorm(hs)
        
        return hs
        
    def __call__(self, inputs, device=None):
        tokens = self.tokenizer.tokenize(inputs)
        tokens = {k: v.to(device) for k, v in tokens.items()}
        
        hs = self.forward(tokens, device=device)
        hs = F.linear(hs, self.embedding.value_embed.weight)
        BATCH, LENGTH, EMBED = hs.shape
        hs = hs.reshape(BATCH*LENGTH, EMBED)
        targets = tokens["target_tokens"].reshape(BATCH*LENGTH,)
        
        acc = accuracy(
            hs, targets, ignore_label=self.tokenizer.pad_id, device=device
        )
        loss = self.loss_func(hs, targets)
        
        outputs = {
            "accuracy": acc,
            "loss": loss,
        }
        return outputs
    
    @torch.no_grad()
    def predict(self, max_seq_len=2400, device=None):
        tokenizer = self.tokenizer
        special_tokens = tokenizer.special_tokens
        
        tokens = tokenizer.get_pred_start()
        tokens = {k: v.to(device) for k, v in tokens.items()}
        preds = []
        pred_idx = 0
        
        while (pred_idx <= max_seq_len-1)\
        and ((len(preds) == 0) or (preds[-1] != special_tokens["eos"]-len(special_tokens))):
            
            if pred_idx >= 1:
                tokens = tokenizer.tokenize([torch.stack(preds)])
                tokens["value_tokens"][:, pred_idx+1] = special_tokens["pad"]
                tokens["padding_mask"][:, pred_idx+1] = True
            
            try:
                hs = self.forward(tokens, device=device)
            except IndexError:
                print(pred_idx)
                for k, v in tokens.items():
                    print(k)
                    print(v.shape)
                    print(v)
                import traceback
                print(traceback.format_exc())
                break

            hs = F.linear(hs[:, pred_idx], self.embedding.value_embed.weight)
            pred = hs.argmax(dim=1) - len(special_tokens)
            preds.append(pred[0])
            pred_idx += 1
            
        return torch.stack(preds) + len(special_tokens)