In [1]:
import os
import sys
import json
import glob
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from reformer_pytorch import Reformer

In [2]:
base_dir = os.path.dirname(os.getcwd())
data_dir = os.path.join(base_dir, "data", "original")
train_files = glob.glob(os.path.join(data_dir, "train", "*", "*.obj"))
valid_files = glob.glob(os.path.join(data_dir, "val", "*", "*.obj"))
print(len(train_files), len(valid_files))

src_dir = os.path.join(base_dir, "src")
sys.path.append(os.path.join(src_dir))

7003 1088


In [3]:
from utils import load_pipeline
from tokenizers import DecodeVertexTokenizer

In [4]:
v_batch, f_batch = [], []
for i in range(3):
    vs, _, fs = load_pipeline(train_files[i])
    
    vs = torch.tensor(vs)
    fs = [torch.tensor(f) for f in fs]
    
    v_batch.append(vs)
    f_batch.append(fs)
    print(vs.shape, len(fs))
    print("="*60)

torch.Size([655, 3]) 588
torch.Size([310, 3]) 220
torch.Size([396, 3]) 304


In [5]:
dec_tokenizer = DecodeVertexTokenizer()

In [6]:
input_tokens = dec_tokenizer.tokenize(v_batch)
input_tokens

{'value_tokens': tensor([[  0, 163, 139,  ..., 125,   1,   2],
         [  0, 137, 164,  ...,   2,   2,   2],
         [  0, 165, 101,  ...,   2,   2,   2]]),
 'target_tokens': tensor([[163, 139, 135,  ...,   1,   2,   2],
         [137, 164, 132,  ...,   2,   2,   2],
         [165, 101, 136,  ...,   2,   2,   2]]),
 'coord_type_tokens': tensor([[0, 1, 2,  ..., 3, 0, 0],
         [0, 1, 2,  ..., 0, 0, 0],
         [0, 1, 2,  ..., 0, 0, 0]]),
 'position_tokens': tensor([[  0,   1,   1,  ..., 655,   0,   0],
         [  0,   1,   1,  ...,   0,   0,   0],
         [  0,   1,   1,  ...,   0,   0,   0]]),
 'padding_mask': tensor([[False, False, False,  ..., False, False,  True],
         [False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True]])}

In [7]:
class VertexDecoderEmbedding(nn.Module):
    
    def __init__(self, embed_dim=256,
                 vocab_value=259, pad_idx_value=2, 
                 vocab_coord_type=4, pad_idx_coord_type=0,
                 vocab_position=1000, pad_idx_position=0):
        
        super().__init__()
        
        self.value_embed = nn.Embedding(
            vocab_value, embed_dim, padding_idx=pad_idx_value
        )
        self.coord_type_embed = nn.Embedding(
            vocab_coord_type, embed_dim, padding_idx=pad_idx_coord_type
        )
        self.position_embed = nn.Embedding(
            vocab_position, embed_dim, padding_idx=pad_idx_position
        )
        
        self.embed_scaler = math.sqrt(embed_dim)
        
    def forward(self, tokens):
        
        """get embedding for vertex model.
        
        Args
            tokens [dict]: tokenized vertex info.
                `value_tokens` [torch.tensor]:
                        padded (batch, length)-shape long tensor
                        with coord value from 0 to 2^n(bit).
                `coord_type_tokens` [torch.tensor]:
                        padded (batch, length) shape long tensor implies x or y or z.
                `position_tokens` [torch.tensor]:
                        padded (batch, length) shape long tensor
                        representing coord position (NOT sequence position).
        
        Returns
            embed [torch.tensor]: (batch, length, embed) shape tensor after embedding.
                        
        """
        
        embed = self.value_embed(tokens["value_tokens"]) * self.embed_scaler
        embed = embed + (self.coord_type_embed(tokens["coord_type_tokens"]) * self.embed_scaler)
        embed = embed + (self.position_embed(tokens["position_tokens"]) * self.embed_scaler)
        
        return embed

In [8]:
embed = VertexDecoderEmbedding(embed_dim=128)

In [9]:
input_tokens

{'value_tokens': tensor([[  0, 163, 139,  ..., 125,   1,   2],
         [  0, 137, 164,  ...,   2,   2,   2],
         [  0, 165, 101,  ...,   2,   2,   2]]),
 'target_tokens': tensor([[163, 139, 135,  ...,   1,   2,   2],
         [137, 164, 132,  ...,   2,   2,   2],
         [165, 101, 136,  ...,   2,   2,   2]]),
 'coord_type_tokens': tensor([[0, 1, 2,  ..., 3, 0, 0],
         [0, 1, 2,  ..., 0, 0, 0],
         [0, 1, 2,  ..., 0, 0, 0]]),
 'position_tokens': tensor([[  0,   1,   1,  ..., 655,   0,   0],
         [  0,   1,   1,  ...,   0,   0,   0],
         [  0,   1,   1,  ...,   0,   0,   0]]),
 'padding_mask': tensor([[False, False, False,  ..., False, False,  True],
         [False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True]])}

In [10]:
print(
    input_tokens["value_tokens"].shape,
    input_tokens["coord_type_tokens"].shape,
    input_tokens["position_tokens"].shape
)

torch.Size([3, 1968]) torch.Size([3, 1968]) torch.Size([3, 1968])


In [11]:
emb = embed(input_tokens)
emb.shape

torch.Size([3, 1968, 128])

In [12]:
reformer = Reformer(dim=128, depth=1, max_seq_len=8192, bucket_size=24)

In [13]:
output = reformer(emb)
output.shape

torch.Size([3, 1968, 128])

In [14]:
class Config(object):
    
    def write_to_json(self, out_path):
        with open(out_path, "w") as fw:
            json.dump(self.config, fw, indent=4)
            
    def load_from_json(self, file_path):
        with open(file_path) as fr:
            self.config = json.load(fr)
        
    def __getitem__(self, key):
        return self.config[key]

In [15]:
class VertexPolyGenConfig(Config):
    
    def __init__(self,
                 embed_dim=256, 
                 max_seq_len=2400, 
                 tokenizer__bos_id=0,
                 tokenizer__eos_id=1,
                 tokenizer__pad_id=2,
                 embedding__vocab_value=256 + 3, 
                 embedding__vocab_coord_type=4, 
                 embedding__vocab_position=1000,
                 embedding__pad_idx_value=2,
                 embedding__pad_idx_coord_type=0,
                 embedding__pad_idx_position=0,
                 reformer__depth=12,
                 reformer__heads=8,
                 reformer__n_hashes=8,
                 reformer__bucket_size=48,
                 reformer__causal=True,
                 reformer__lsh_dropout=0.2, 
                 reformer__ff_dropout=0.2,
                 reformer__post_attn_dropout=0.2,
                 reformer__ff_mult=4):
        
        # tokenizer config
        tokenizer_config = {
            "bos_id": tokenizer__bos_id,
            "eos_id": tokenizer__eos_id,
            "pad_id": tokenizer__pad_id,
            "max_seq_len": max_seq_len,
        }
        
        # embedding config
        embedding_config = {
            "vocab_value": embedding__vocab_value,
            "vocab_coord_type": embedding__vocab_coord_type,
            "vocab_position": embedding__vocab_position,
            "pad_idx_value": embedding__pad_idx_value,
            "pad_idx_coord_type": embedding__pad_idx_coord_type,
            "pad_idx_position": embedding__pad_idx_position,
            "embed_dim": embed_dim,
        }
        
        # reformer info
        reformer_config = {
            "dim": embed_dim,
            "depth": reformer__depth,
            "max_seq_len": max_seq_len,
            "heads": reformer__heads,
            "bucket_size": reformer__bucket_size,
            "n_hashes": reformer__n_hashes,
            "causal": reformer__causal,
            "lsh_dropout": reformer__lsh_dropout, 
            "ff_dropout": reformer__ff_dropout,
            "post_attn_dropout": reformer__post_attn_dropout,
            "ff_mult": reformer__ff_mult,
        }
        
        self.config = {
            "embed_dim": embed_dim,
            "max_seq_len": max_seq_len,
            "tokenizer": tokenizer_config,
            "embedding": embedding_config,
            "reformer": reformer_config,
        }

In [16]:
# utility functions

def accuracy(y_pred, y_true, ignore_label=None, device=None):
    y_pred = y_pred.argmax(dim=1)

    if ignore_label:
        normalizer = torch.sum(y_true!=ignore_label)
        ignore_mask = torch.where(
            y_true == ignore_label,
            torch.zeros_like(y_true, device=device),
            torch.ones_like(y_true, device=device)
        ).type(torch.float32)
    else:
        normalizer = y_true.shape[0]
        ignore_mask = torch.ones_like(y_true, device=device).type(torch.float32)

    acc = (y_pred.reshape(-1)==y_true.reshape(-1)).type(torch.float32)
    acc = torch.sum(acc*ignore_mask)
    return acc / normalizer


def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_normal_(m.weight)
    if type(m) == nn.Embedding:
        nn.init.uniform_(m.weight, -0.05, 0.05)

In [17]:
class VertexPolyGen(nn.Module):
    
    """Vertex model in PolyGen.
    this model learn/predict vertices like OpenAI-GPT.
    UNLIKE the paper, this model is only for unconditional generation.
    
    Args
        model_config [Config]:
                hyper parameters. see VertexPolyGenConfig class for details. 
    """
    
    def __init__(self, model_config):
        super().__init__()
        
        self.tokenizer = DecodeVertexTokenizer(**model_config["tokenizer"])
        self.embedding = VertexDecoderEmbedding(**model_config["embedding"])
        self.reformer = Reformer(**model_config["reformer"])
        self.layernorm = nn.LayerNorm(model_config["embed_dim"])
        self.loss_func = nn.CrossEntropyLoss(ignore_index=model_config["tokenizer"]["pad_id"])
        
        self.apply(init_weights)
    
    def forward(self, tokens, device=None):
        
        """forward function which can be used for both train/predict.
        
        Args
            tokens [dict]: tokenized vertex info.
                `value_tokens` [torch.tensor]:
                        padded (batch, length)-shape long tensor
                        with coord value from 0 to 2^n(bit).
                `coord_type_tokens` [torch.tensor]:
                        padded (batch, length) shape long tensor implies x or y or z.
                `position_tokens` [torch.tensor]:
                        padded (batch, length) shape long tensor
                        representing coord position (NOT sequence position).
                `padding_mask` [torch.tensor]:
                        (batch, length) shape mask implies <pad> tokens.
            device [torch.device]: gpu or not gpu, that's the problem.
            
        
        Returns
            hs [torch.tensor]:
                    hidden states from transformer(reformer) model.
                    this takes (batch, length, embed) shape.
        
        """
        
        hs = self.embedding(tokens)
        hs = self.reformer(
            hs, input_mask=tokens["padding_mask"]
        )
        hs = self.layernorm(hs)
        
        return hs
        
        
    def __call__(self, inputs, device=None):
        
        """Calculate loss while training.
        
        Args
            inputs [dict]: dict containing batched inputs.
                `vertices` [list(torch.tensor)]:
                        variable-length-list of 
                        (length, 3) shaped tensor of quantized-vertices.
            device [torch.device]: gpu or not gpu, that's the problem.
                
        Returns
            outputs [dict]: dict containing calculated variables.
                `loss` [torch.tensor]:
                        calculated scalar-shape loss with backprop info.
                `accuracy` [torch.tensor]:
                        calculated scalar-shape accuracy.
            
        """
        
        tokens = self.tokenizer.tokenize(inputs["vertices"])
        tokens = {k: v.to(device) for k, v in tokens.items()}
        
        hs = self.forward(tokens, device=device)
        
        hs = F.linear(hs, self.embedding.value_embed.weight)
        BATCH, LENGTH, EMBED = hs.shape
        hs = hs.reshape(BATCH*LENGTH, EMBED)
        targets = tokens["target_tokens"].reshape(BATCH*LENGTH,)
        
        acc = accuracy(
            hs, targets, ignore_label=self.tokenizer.pad_id, device=device
        )
        loss = self.loss_func(hs, targets)
        
        outputs = {
            "accuracy": acc,
            "perplexity": torch.exp(loss),
            "loss": loss,
        }
        return outputs
    
    
    @torch.no_grad()
    def predict(self, max_seq_len=2400, device=None):
        """predict function
        
        Args
            max_seq_len[int]: max sequence length to predict.
            device [torch.device]: gpu or not gpu, that's the problem.
            
        Return
            preds [torch.tensor]: predicted (length, ) shape tensor.
        
        """
        
        tokenizer = self.tokenizer
        special_tokens = tokenizer.special_tokens
        
        tokens = tokenizer.get_pred_start()
        tokens = {k: v.to(device) for k, v in tokens.items()}
        preds = []
        pred_idx = 0
        
        while (pred_idx <= max_seq_len-1)\
        and ((len(preds) == 0) or (preds[-1] != special_tokens["eos"]-len(special_tokens))):
            
            if pred_idx >= 1:
                tokens = tokenizer.tokenize([torch.stack(preds)])
                tokens["value_tokens"][:, pred_idx+1] = special_tokens["pad"]
                tokens["padding_mask"][:, pred_idx+1] = True
            
            hs = self.forward(tokens, device=device)

            hs = F.linear(hs[:, pred_idx], self.embedding.value_embed.weight)
            pred = hs.argmax(dim=1) - len(special_tokens)
            preds.append(pred[0])
            pred_idx += 1
            
        preds = torch.stack(preds) + len(special_tokens)
        preds = self.tokenizer.detokenize([preds])[0]
        return preds

In [18]:
config = VertexPolyGenConfig(
    embed_dim=128, reformer__depth=6, 
    reformer__lsh_dropout=0., reformer__ff_dropout=0.,
    reformer__post_attn_dropout=0.
)
model = VertexPolyGen(config)

In [19]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [20]:
inputs = {
    "vertices": v_batch[:2],
}
for b in inputs["vertices"]:
    print(b.shape)

torch.Size([655, 3])
torch.Size([310, 3])


In [21]:
import numpy as np
epoch_num = 500
model.train()
losses = []
accs = []
perps = []

for i in range(epoch_num):
    optimizer.zero_grad()
    outputs = model(inputs)
    
    loss = outputs["loss"]
    acc = outputs["accuracy"]
    perp = outputs["perplexity"]
    losses.append(loss.item())
    accs.append(acc.item())
    perps.append(perp.item())
    
    if i % 10 == 0:
        ave_loss = np.mean(losses[-10:])
        ave_acc = np.mean(accs[-10:])
        ave_perp = np.mean(perps[-10:])
        print("iteration: {}\tloss: {:.5f}\tperp: {:.3f}\tacc: {:.5f}".format(
            i, ave_loss, ave_perp, ave_acc))
    
    loss.backward()
    optimizer.step()

iteration: 0	loss: 5.59800	perp: 269.885	acc: 0.00518
iteration: 10	loss: 5.04489	perp: 159.285	acc: 0.04833
iteration: 20	loss: 4.50083	perp: 90.761	acc: 0.07977
iteration: 30	loss: 4.09772	perp: 60.588	acc: 0.12589
iteration: 40	loss: 3.70801	perp: 41.032	acc: 0.21049
iteration: 50	loss: 3.29563	perp: 27.205	acc: 0.36093
iteration: 60	loss: 2.84007	perp: 17.276	acc: 0.53673
iteration: 70	loss: 2.34739	perp: 10.571	acc: 0.71177
iteration: 80	loss: 1.84495	perp: 6.393	acc: 0.85854
iteration: 90	loss: 1.37129	perp: 3.973	acc: 0.94411
iteration: 100	loss: 0.97338	perp: 2.660	acc: 0.97953
iteration: 110	loss: 0.68129	perp: 1.981	acc: 0.99234
iteration: 120	loss: 0.48655	perp: 1.628	acc: 0.99475
iteration: 130	loss: 0.36314	perp: 1.438	acc: 0.99579
iteration: 140	loss: 0.28008	perp: 1.323	acc: 0.99644
iteration: 150	loss: 0.22367	perp: 1.251	acc: 0.99655
iteration: 160	loss: 0.18443	perp: 1.203	acc: 0.99655
iteration: 170	loss: 0.15588	perp: 1.169	acc: 0.99655
iteration: 180	loss: 0.13437	

In [22]:
model.eval()
pred = model.predict(max_seq_len=2400)
pred

tensor([160, 136, 132,  ...,  94, 136, 122])

In [27]:
true = inputs["vertices"][0].reshape(-1, )
true

tensor([160, 136, 132,  ...,  94, 136, 122], dtype=torch.int32)

In [29]:
true.shape, pred.shape

(torch.Size([1965]), torch.Size([1965]))

In [30]:
accuracy = (true == pred).sum() / len(true)
accuracy

tensor(0.8789)

In [31]:
true = inputs["vertices"][1].reshape(-1, )
true.shape

torch.Size([930])

In [32]:
accuracy = (true == pred[:930]).sum() / len(true)
accuracy

tensor(0.2667)

In [34]:
torch.save(model.state_dict(), "../results/models/vertex")