In [1]:
import os
import sys
import json
import glob
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from reformer_pytorch import Reformer

In [2]:
base_dir = os.path.dirname(os.getcwd())
data_dir = os.path.join(base_dir, "data", "original")
train_files = glob.glob(os.path.join(data_dir, "train", "*", "*.obj"))
valid_files = glob.glob(os.path.join(data_dir, "val", "*", "*.obj"))
print(len(train_files), len(valid_files))

src_dir = os.path.join(base_dir, "src")
sys.path.append(os.path.join(src_dir))

7003 1088


In [3]:
from utils_polygen import load_pipeline
from tokenizers import DecodeVertexTokenizer

In [4]:
v_batch, f_batch = [], []
for i in range(3):
    vs, _, fs = load_pipeline(train_files[i])
    
    vs = torch.tensor(vs)
    fs = [torch.tensor(f) for f in fs]
    
    v_batch.append(vs)
    f_batch.append(fs)
    print(vs.shape, len(fs))
    print("="*60)

torch.Size([204, 3]) 160
torch.Size([62, 3]) 45
torch.Size([64, 3]) 601


In [5]:
dec_tokenizer = DecodeVertexTokenizer(max_seq_len=2592)

In [6]:
input_tokens = dec_tokenizer.tokenize(v_batch)
input_tokens

{'value_tokens': tensor([[  0, 169, 124,  ...,   2,   2,   2],
         [  0, 167, 166,  ...,   2,   2,   2],
         [  0, 167, 167,  ...,   2,   2,   2]]),
 'target_tokens': tensor([[169, 124, 169,  ...,   2,   2,   2],
         [167, 166, 167,  ...,   2,   2,   2],
         [167, 167, 130,  ...,   2,   2,   2]]),
 'coord_type_tokens': tensor([[0, 1, 2,  ..., 0, 0, 0],
         [0, 1, 2,  ..., 0, 0, 0],
         [0, 1, 2,  ..., 0, 0, 0]]),
 'position_tokens': tensor([[0, 1, 1,  ..., 0, 0, 0],
         [0, 1, 1,  ..., 0, 0, 0],
         [0, 1, 1,  ..., 0, 0, 0]]),
 'padding_mask': tensor([[False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True]])}

In [7]:
class VertexDecoderEmbedding(nn.Module):
    
    def __init__(self, embed_dim=256,
                 vocab_value=259, pad_idx_value=2, 
                 vocab_coord_type=4, pad_idx_coord_type=0,
                 vocab_position=1000, pad_idx_position=0):
        
        super().__init__()
        
        self.value_embed = nn.Embedding(
            vocab_value, embed_dim, padding_idx=pad_idx_value
        )
        self.coord_type_embed = nn.Embedding(
            vocab_coord_type, embed_dim, padding_idx=pad_idx_coord_type
        )
        self.position_embed = nn.Embedding(
            vocab_position, embed_dim, padding_idx=pad_idx_position
        )
        
        self.embed_scaler = math.sqrt(embed_dim)
        
    def forward(self, tokens):
        
        """get embedding for vertex model.
        
        Args
            tokens [dict]: tokenized vertex info.
                `value_tokens` [torch.tensor]:
                        padded (batch, length)-shape long tensor
                        with coord value from 0 to 2^n(bit).
                `coord_type_tokens` [torch.tensor]:
                        padded (batch, length) shape long tensor implies x or y or z.
                `position_tokens` [torch.tensor]:
                        padded (batch, length) shape long tensor
                        representing coord position (NOT sequence position).
        
        Returns
            embed [torch.tensor]: (batch, length, embed) shape tensor after embedding.
                        
        """
        
        embed = self.value_embed(tokens["value_tokens"])
        embed = embed + self.coord_type_embed(tokens["coord_type_tokens"])
        embed = embed + self.position_embed(tokens["position_tokens"])
        embed = embed * self.embed_scaler
        
        return embed

In [8]:
embed = VertexDecoderEmbedding(embed_dim=128)

In [9]:
input_tokens

{'value_tokens': tensor([[  0, 169, 124,  ...,   2,   2,   2],
         [  0, 167, 166,  ...,   2,   2,   2],
         [  0, 167, 167,  ...,   2,   2,   2]]),
 'target_tokens': tensor([[169, 124, 169,  ...,   2,   2,   2],
         [167, 166, 167,  ...,   2,   2,   2],
         [167, 167, 130,  ...,   2,   2,   2]]),
 'coord_type_tokens': tensor([[0, 1, 2,  ..., 0, 0, 0],
         [0, 1, 2,  ..., 0, 0, 0],
         [0, 1, 2,  ..., 0, 0, 0]]),
 'position_tokens': tensor([[0, 1, 1,  ..., 0, 0, 0],
         [0, 1, 1,  ..., 0, 0, 0],
         [0, 1, 1,  ..., 0, 0, 0]]),
 'padding_mask': tensor([[False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True]])}

In [10]:
print(
    input_tokens["value_tokens"].shape,
    input_tokens["coord_type_tokens"].shape,
    input_tokens["position_tokens"].shape
)

torch.Size([3, 2592]) torch.Size([3, 2592]) torch.Size([3, 2592])


In [11]:
emb = embed(input_tokens)
emb.shape

torch.Size([3, 2592, 128])

In [12]:
reformer = Reformer(dim=128, depth=1, max_seq_len=8192, bucket_size=24)

In [13]:
output = reformer(emb)
output.shape

torch.Size([3, 2592, 128])

In [14]:
class Config(object):
    
    def write_to_json(self, out_path):
        with open(out_path, "w") as fw:
            json.dump(self.config, fw, indent=4)
            
    def load_from_json(self, file_path):
        with open(file_path) as fr:
            self.config = json.load(fr)
        
    def __getitem__(self, key):
        return self.config[key]

In [15]:
class VertexPolyGenConfig(Config):
    
    def __init__(self,
                 embed_dim=256, 
                 max_seq_len=2400, 
                 tokenizer__bos_id=0,
                 tokenizer__eos_id=1,
                 tokenizer__pad_id=2,
                 embedding__vocab_value=256 + 3, 
                 embedding__vocab_coord_type=4, 
                 embedding__vocab_position=1000,
                 embedding__pad_idx_value=2,
                 embedding__pad_idx_coord_type=0,
                 embedding__pad_idx_position=0,
                 reformer__depth=12,
                 reformer__heads=8,
                 reformer__n_hashes=8,
                 reformer__bucket_size=48,
                 reformer__causal=True,
                 reformer__lsh_dropout=0.2, 
                 reformer__ff_dropout=0.2,
                 reformer__post_attn_dropout=0.2,
                 reformer__ff_mult=4):
        
        # tokenizer config
        tokenizer_config = {
            "bos_id": tokenizer__bos_id,
            "eos_id": tokenizer__eos_id,
            "pad_id": tokenizer__pad_id,
            "max_seq_len": max_seq_len,
        }
        
        # embedding config
        embedding_config = {
            "vocab_value": embedding__vocab_value,
            "vocab_coord_type": embedding__vocab_coord_type,
            "vocab_position": embedding__vocab_position,
            "pad_idx_value": embedding__pad_idx_value,
            "pad_idx_coord_type": embedding__pad_idx_coord_type,
            "pad_idx_position": embedding__pad_idx_position,
            "embed_dim": embed_dim,
        }
        
        # reformer info
        reformer_config = {
            "dim": embed_dim,
            "depth": reformer__depth,
            "max_seq_len": max_seq_len,
            "heads": reformer__heads,
            "bucket_size": reformer__bucket_size,
            "n_hashes": reformer__n_hashes,
            "causal": reformer__causal,
            "lsh_dropout": reformer__lsh_dropout, 
            "ff_dropout": reformer__ff_dropout,
            "post_attn_dropout": reformer__post_attn_dropout,
            "ff_mult": reformer__ff_mult,
        }
        
        self.config = {
            "embed_dim": embed_dim,
            "max_seq_len": max_seq_len,
            "tokenizer": tokenizer_config,
            "embedding": embedding_config,
            "reformer": reformer_config,
        }

In [16]:
# utility functions

def accuracy(y_pred, y_true, ignore_label=None, device=None):
    y_pred = y_pred.argmax(dim=1)

    if ignore_label:
        normalizer = torch.sum(y_true!=ignore_label)
        ignore_mask = torch.where(
            y_true == ignore_label,
            torch.zeros_like(y_true, device=device),
            torch.ones_like(y_true, device=device)
        ).type(torch.float32)
    else:
        normalizer = y_true.shape[0]
        ignore_mask = torch.ones_like(y_true, device=device).type(torch.float32)

    acc = (y_pred.reshape(-1)==y_true.reshape(-1)).type(torch.float32)
    acc = torch.sum(acc*ignore_mask)
    return acc / normalizer


def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_normal_(m.weight)
    if type(m) == nn.Embedding:
        nn.init.uniform_(m.weight, -0.05, 0.05)

In [17]:
class VertexPolyGen(nn.Module):
    
    """Vertex model in PolyGen.
    this model learn/predict vertices like OpenAI-GPT.
    UNLIKE the paper, this model is only for unconditional generation.
    
    Args
        model_config [Config]:
                hyper parameters. see VertexPolyGenConfig class for details. 
    """
    
    def __init__(self, model_config):
        super().__init__()
        
        self.tokenizer = DecodeVertexTokenizer(**model_config["tokenizer"])
        self.embedding = VertexDecoderEmbedding(**model_config["embedding"])
        self.reformer = Reformer(**model_config["reformer"])
        self.layernorm = nn.LayerNorm(model_config["embed_dim"])
        self.loss_func = nn.CrossEntropyLoss(ignore_index=model_config["tokenizer"]["pad_id"])
        
        self.apply(init_weights)
    
    def forward(self, tokens, device=None):
        
        """forward function which can be used for both train/predict.
        
        Args
            tokens [dict]: tokenized vertex info.
                `value_tokens` [torch.tensor]:
                        padded (batch, length)-shape long tensor
                        with coord value from 0 to 2^n(bit).
                `coord_type_tokens` [torch.tensor]:
                        padded (batch, length) shape long tensor implies x or y or z.
                `position_tokens` [torch.tensor]:
                        padded (batch, length) shape long tensor
                        representing coord position (NOT sequence position).
                `padding_mask` [torch.tensor]:
                        (batch, length) shape mask implies <pad> tokens.
            device [torch.device]: gpu or not gpu, that's the problem.
            
        
        Returns
            hs [torch.tensor]:
                    hidden states from transformer(reformer) model.
                    this takes (batch, length, embed) shape.
        
        """
        
        hs = self.embedding(tokens)
        hs = self.reformer(
            hs, input_mask=tokens["padding_mask"]
        )
        hs = self.layernorm(hs)
        
        return hs
        
        
    def __call__(self, inputs, device=None):
        
        """Calculate loss while training.
        
        Args
            inputs [dict]: dict containing batched inputs.
                `vertices` [list(torch.tensor)]:
                        variable-length-list of 
                        (length, 3) shaped tensor of quantized-vertices.
            device [torch.device]: gpu or not gpu, that's the problem.
                
        Returns
            outputs [dict]: dict containing calculated variables.
                `loss` [torch.tensor]:
                        calculated scalar-shape loss with backprop info.
                `accuracy` [torch.tensor]:
                        calculated scalar-shape accuracy.
            
        """
        
        tokens = self.tokenizer.tokenize(inputs["vertices"])
        tokens = {k: v.to(device) for k, v in tokens.items()}
        
        hs = self.forward(tokens, device=device)
        
        hs = F.linear(hs, self.embedding.value_embed.weight)
        BATCH, LENGTH, EMBED = hs.shape
        hs = hs.reshape(BATCH*LENGTH, EMBED)
        targets = tokens["target_tokens"].reshape(BATCH*LENGTH,)
        
        acc = accuracy(
            hs, targets, ignore_label=self.tokenizer.pad_id, device=device
        )
        loss = self.loss_func(hs, targets)
        
        outputs = {
            "accuracy": acc,
            "perplexity": torch.exp(loss),
            "loss": loss,
        }
        return outputs
    
    
    @torch.no_grad()
    def predict(self, max_seq_len=2400, device=None):
        """predict function
        
        Args
            max_seq_len[int]: max sequence length to predict.
            device [torch.device]: gpu or not gpu, that's the problem.
            
        Return
            preds [torch.tensor]: predicted (length, ) shape tensor.
        
        """
        
        tokenizer = self.tokenizer
        special_tokens = tokenizer.special_tokens
        
        tokens = tokenizer.get_pred_start()
        tokens = {k: v.to(device) for k, v in tokens.items()}
        preds = []
        pred_idx = 0
        
        while (pred_idx <= max_seq_len-1)\
        and ((len(preds) == 0) or (preds[-1] != special_tokens["eos"]-len(special_tokens))):
            
            if pred_idx >= 1:
                tokens = tokenizer.tokenize([torch.stack(preds)])
                tokens["value_tokens"][:, pred_idx+1] = special_tokens["pad"]
                tokens["padding_mask"][:, pred_idx+1] = True
            
            hs = self.forward(tokens, device=device)

            hs = F.linear(hs[:, pred_idx], self.embedding.value_embed.weight)
            pred = hs.argmax(dim=1) - len(special_tokens)
            preds.append(pred[0])
            pred_idx += 1
            
        preds = torch.stack(preds) + len(special_tokens)
        preds = self.tokenizer.detokenize([preds])[0]
        return preds

In [18]:
config = VertexPolyGenConfig(
    embed_dim=128, reformer__depth=6, 
    reformer__lsh_dropout=0., reformer__ff_dropout=0.,
    reformer__post_attn_dropout=0.
)
model = VertexPolyGen(config)

In [19]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [20]:
inputs = {
    "vertices": v_batch[:2],
}
for b in inputs["vertices"]:
    print(b.shape)

torch.Size([204, 3])
torch.Size([62, 3])


In [21]:
import numpy as np
epoch_num = 300
model.train()
losses = []
accs = []
perps = []

for i in range(epoch_num):
    optimizer.zero_grad()
    outputs = model(inputs)
    
    loss = outputs["loss"]
    acc = outputs["accuracy"]
    perp = outputs["perplexity"]
    losses.append(loss.item())
    accs.append(acc.item())
    perps.append(perp.item())
    
    if i % 10 == 0:
        ave_loss = np.mean(losses[-10:])
        ave_acc = np.mean(accs[-10:])
        ave_perp = np.mean(perps[-10:])
        print("iteration: {}\tloss: {:.5f}\tperp: {:.3f}\tacc: {:.5f}".format(
            i, ave_loss, ave_perp, ave_acc))
    
    loss.backward()
    optimizer.step()

iteration: 0	loss: 5.57170	perp: 262.881	acc: 0.05500
iteration: 10	loss: 4.82134	perp: 129.929	acc: 0.14925
iteration: 20	loss: 4.07001	perp: 59.521	acc: 0.28187
iteration: 30	loss: 3.47319	perp: 32.687	acc: 0.42400
iteration: 40	loss: 2.89775	perp: 18.384	acc: 0.59175
iteration: 50	loss: 2.31088	perp: 10.229	acc: 0.76762
iteration: 60	loss: 1.73632	perp: 5.747	acc: 0.89712
iteration: 70	loss: 1.23784	perp: 3.476	acc: 0.96250
iteration: 80	loss: 0.85495	perp: 2.361	acc: 0.98550
iteration: 90	loss: 0.59418	perp: 1.815	acc: 0.99587
iteration: 100	loss: 0.42693	perp: 1.534	acc: 0.99625
iteration: 110	loss: 0.32102	perp: 1.379	acc: 0.99625
iteration: 120	loss: 0.25241	perp: 1.287	acc: 0.99625
iteration: 130	loss: 0.20504	perp: 1.228	acc: 0.99625
iteration: 140	loss: 0.17135	perp: 1.187	acc: 0.99625
iteration: 150	loss: 0.14645	perp: 1.158	acc: 0.99625
iteration: 160	loss: 0.12735	perp: 1.136	acc: 0.99625
iteration: 170	loss: 0.11230	perp: 1.119	acc: 0.99625
iteration: 180	loss: 0.10016	pe

In [22]:
model.eval()
pred = model.predict(max_seq_len=2400)
pred

tensor([164, 163, 164, 164, 163,  90, 164, 154, 164, 164, 154,  90, 163, 154,
        164, 163, 154, 163, 163, 154,  91, 163,  91, 163, 163,  91,  91, 162,
        163, 162, 162, 163,  92, 162,  92, 162, 162,  92,  92, 162,  91, 162,
        162,  91,  92, 144, 153,  92, 144, 153,  91, 144, 146,  92, 144, 146,
         91, 138, 153, 163, 138, 153, 162, 138, 146, 163, 138, 146, 162, 133,
        153,  92, 133, 153,  91, 133, 146,  92, 133, 146,  91, 128, 154,  92,
        128, 154,  91, 128, 146,  92, 128, 146,  91, 125, 153, 163, 125, 153,
        162, 125, 146, 163, 125, 146, 162, 121, 153, 163, 121, 153, 162, 121,
        146, 163, 121, 146, 162, 117, 154,  92, 117, 154,  91, 117, 146,  92,
        117, 146,  91, 111, 153, 163, 111, 153, 162, 111, 146, 163, 111, 146,
        162,  92, 163, 162,  92, 163,  92,  92,  92, 162,  92,  92,  92,  92,
         91, 162,  92,  91,  92,  91, 154, 163,  91, 154,  91,  91, 154,  90,
         91,  91, 163,  91,  91,  91,  90, 163, 164,  90, 163,  

In [23]:
true = inputs["vertices"][0].reshape(-1, )
true

tensor([166, 121, 166, 166, 121,  88, 166, 108, 166, 166, 108,  88, 165, 106,
        165, 165, 106,  89, 165, 104, 165, 165, 104,  89, 165, 103, 165, 165,
        103,  89, 164, 121, 164, 164, 121,  90, 164, 108, 164, 164, 108,  90,
        164, 106, 164, 164, 106,  90, 164, 105, 164, 164, 105,  90, 164, 101,
        164, 164, 101,  90, 163, 103, 163, 163, 103,  91, 163, 102, 163, 163,
        102,  91, 163,  99, 163, 163,  99,  91, 162, 100, 162, 162, 100,  92,
        162,  98, 162, 162,  98,  92, 161,  99, 161, 161,  99,  93, 160,  97,
        160, 160,  97,  94, 159,  98, 159, 159,  98,  95, 159,  96, 159, 159,
         96,  95, 158,  97, 158, 158,  97,  96, 157,  96, 157, 157,  96,  97,
        157,  95, 157, 157,  95,  97, 155,  96, 155, 155,  96,  99, 155,  94,
        155, 155,  94,  99, 153,  95, 153, 153,  95, 101, 153,  94, 153, 153,
         94, 101, 152,  95, 152, 152,  95, 102, 152,  94, 152, 152,  94, 102,
        131, 160, 161, 131, 160, 160, 131, 160, 159, 131, 160,  

In [24]:
true.shape, pred.shape

(torch.Size([612]), torch.Size([612]))

In [25]:
accuracy = (true == pred).sum() / len(true)
accuracy

tensor(0.8644)

In [26]:
true = inputs["vertices"][1].reshape(-1, )
true.shape

torch.Size([186])

In [28]:
torch.save(model.state_dict(), "../results/models/vertex")