In [1]:
import os
import sys
import json
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from reformer_pytorch import Reformer

In [2]:
base_dir = os.path.dirname(os.getcwd())
data_dir = os.path.join(base_dir, "data", "original")
train_files = glob.glob(os.path.join(data_dir, "train", "*", "*.obj"))
valid_files = glob.glob(os.path.join(data_dir, "val", "*", "*.obj"))
print(len(train_files), len(valid_files))

src_dir = os.path.join(base_dir, "src")
sys.path.append(os.path.join(src_dir))

7003 1088


In [3]:
from utils import load_pipeline
from tokenizers import DecodeVertexTokenizer

In [4]:
v_batch, f_batch = [], []
for i in range(3):
    vs, _, fs = load_pipeline(train_files[i])
    
    vs = torch.tensor(vs)
    fs = [torch.tensor(f) for f in fs]
    
    v_batch.append(vs)
    f_batch.append(fs)
    print(vs.shape, len(fs))
    print("="*60)

torch.Size([655, 3]) 588
torch.Size([310, 3]) 220
torch.Size([396, 3]) 304


In [5]:
dec_tokenizer = DecodeVertexTokenizer()

In [6]:
input_tokens = dec_tokenizer.tokenize(v_batch)
input_tokens

{'value_tokens': tensor([[  0, 163, 139,  ..., 125,   1,   2],
         [  0, 137, 164,  ...,   2,   2,   2],
         [  0, 165, 101,  ...,   2,   2,   2]]),
 'target_tokens': tensor([[163, 139, 135,  ...,   1,   2,   2],
         [137, 164, 132,  ...,   2,   2,   2],
         [165, 101, 136,  ...,   2,   2,   2]]),
 'coord_type_tokens': tensor([[0, 1, 2,  ..., 3, 0, 0],
         [0, 1, 2,  ..., 0, 0, 0],
         [0, 1, 2,  ..., 0, 0, 0]]),
 'position_tokens': tensor([[  0,   1,   1,  ..., 655,   0,   0],
         [  0,   1,   1,  ...,   0,   0,   0],
         [  0,   1,   1,  ...,   0,   0,   0]]),
 'padding_mask': tensor([[ True,  True,  True,  ...,  True,  True, False],
         [ True,  True,  True,  ..., False, False, False],
         [ True,  True,  True,  ..., False, False, False]])}

In [7]:
class VertexDecoderEmbedding(nn.Module):
    
    def __init__(self, embed_dim=256,
                 vocab_value=259, pad_idx_value=2, 
                 vocab_coord_type=4, pad_idx_coord_type=0,
                 vocab_position=1000, pad_idx_position=0):
        
        super().__init__()
        
        self.value_embed = nn.Embedding(
            vocab_value, embed_dim, padding_idx=pad_idx_value
        )
        self.coord_type_embed = nn.Embedding(
            vocab_coord_type, embed_dim, padding_idx=pad_idx_coord_type
        )
        self.position_embed = nn.Embedding(
            vocab_position, embed_dim, padding_idx=pad_idx_position
        )
        
    def forward(self, value_tokens, coord_type_tokens, position_tokens):
        embed = self.value_embed(value_tokens)
        embed = embed + self.coord_type_embed(coord_type_tokens)
        embed = embed + self.position_embed(position_tokens)
        return embed

In [10]:
embed = VertexDecoderEmbedding(embed_dim=128)

In [11]:
input_tokens

{'value_tokens': tensor([[  0, 163, 139,  ..., 125,   1,   2],
         [  0, 137, 164,  ...,   2,   2,   2],
         [  0, 165, 101,  ...,   2,   2,   2]]),
 'target_tokens': tensor([[163, 139, 135,  ...,   1,   2,   2],
         [137, 164, 132,  ...,   2,   2,   2],
         [165, 101, 136,  ...,   2,   2,   2]]),
 'coord_type_tokens': tensor([[0, 1, 2,  ..., 3, 0, 0],
         [0, 1, 2,  ..., 0, 0, 0],
         [0, 1, 2,  ..., 0, 0, 0]]),
 'position_tokens': tensor([[  0,   1,   1,  ..., 655,   0,   0],
         [  0,   1,   1,  ...,   0,   0,   0],
         [  0,   1,   1,  ...,   0,   0,   0]]),
 'padding_mask': tensor([[ True,  True,  True,  ...,  True,  True, False],
         [ True,  True,  True,  ..., False, False, False],
         [ True,  True,  True,  ..., False, False, False]])}

In [12]:
print(
    input_tokens["value_tokens"].shape,
    input_tokens["coord_type_tokens"].shape,
    input_tokens["position_tokens"].shape
)

torch.Size([3, 1968]) torch.Size([3, 1968]) torch.Size([3, 1968])


In [13]:
emb = embed(
    input_tokens["value_tokens"], 
    input_tokens["coord_type_tokens"], 
    input_tokens["position_tokens"]
)
emb.shape

torch.Size([3, 1968, 128])

In [14]:
reformer = Reformer(dim=128, depth=1, max_seq_len=8192, bucket_size=24)

In [15]:
output = reformer(emb)
output.shape

torch.Size([3, 1968, 128])

In [16]:
class Config(object):
    
    def write_to_json(self, out_path):
        with open(out_path, "w") as fw:
            json.dump(self.config, fw, indent=4)
            
    def load_from_json(self, file_path):
        with open(file_path) as fr:
            self.config = json.load(fr)
        
    def __getitem__(self, key):
        return self.config[key]

In [17]:
class VertexPolyGenConfig(Config):
    
    def __init__(self,
                 embed_dim=256, 
                 max_seq_len=2400, 
                 tokenizer__bos_id=0,
                 tokenizer__eos_id=1,
                 tokenizer__pad_id=2,
                 embedding__vocab_value=256 + 3, 
                 embedding__vocab_coord_type=4, 
                 embedding__vocab_position=1000,
                 embedding__pad_idx_value=2,
                 embedding__pad_idx_coord_type=0,
                 embedding__pad_idx_position=0,
                 reformer__depth=12,
                 reformer__heads=8,
                 reformer__n_hashes=8,
                 reformer__bucket_size=48,
                 reformer__causal=True,
                 reformer__lsh_dropout=0.2, 
                 reformer__ff_dropout=0.2,
                 reformer__post_attn_dropout=0.2,
                 reformer__ff_mult=4):
        
        # tokenizer config
        tokenizer_config = {
            "bos_id": tokenizer__bos_id,
            "eos_id": tokenizer__eos_id,
            "pad_id": tokenizer__pad_id,
            "max_seq_len": max_seq_len,
        }
        
        # embedding config
        embedding_config = {
            "vocab_value": embedding__vocab_value,
            "vocab_coord_type": embedding__vocab_coord_type,
            "vocab_position": embedding__vocab_position,
            "pad_idx_value": embedding__pad_idx_value,
            "pad_idx_coord_type": embedding__pad_idx_coord_type,
            "pad_idx_position": embedding__pad_idx_position,
            "embed_dim": embed_dim,
        }
        
        # reformer info
        reformer_config = {
            "dim": embed_dim,
            "depth": reformer__depth,
            "max_seq_len": max_seq_len,
            "heads": reformer__heads,
            "bucket_size": reformer__bucket_size,
            "n_hashes": reformer__n_hashes,
            "causal": reformer__causal,
            "lsh_dropout": reformer__lsh_dropout, 
            "ff_dropout": reformer__ff_dropout,
            "post_attn_dropout": reformer__post_attn_dropout,
            "ff_mult": reformer__ff_mult,
        }
        
        self.config = {
            "embed_dim": embed_dim,
            "max_seq_len": max_seq_len,
            "tokenizer": tokenizer_config,
            "embedding": embedding_config,
            "reformer": reformer_config,
        }

In [18]:
def accuracy(y_pred, y_true, ignore_label=None, device=None):
    y_pred = y_pred.argmax(dim=1)

    if ignore_label:
        normalizer = torch.sum(y_true!=ignore_label)
        ignore_mask = torch.where(
            y_true == ignore_label,
            torch.zeros_like(y_true, device=device),
            torch.ones_like(y_true, device=device)
        ).type(torch.float32)
    else:
        normalizer = y_true.shape[0]
        ignore_mask = torch.ones_like(y_true, device=device).type(torch.float32)

    acc = (y_pred.reshape(-1)==y_true.reshape(-1)).type(torch.float32)
    acc = torch.sum(acc*ignore_mask)
    return acc / normalizer

In [19]:
class VertexPolyGen(nn.Module):
    
    def __init__(self, model_config):
        super().__init__()
        self.tokenizer = DecodeVertexTokenizer(**model_config["tokenizer"])
        self.embedding = VertexDecoderEmbedding(**model_config["embedding"])
        self.reformer = Reformer(**model_config["reformer"])
        self.layernorm = nn.LayerNorm(model_config["embed_dim"])
        self.loss_func = nn.CrossEntropyLoss(ignore_index=model_config["tokenizer"]["pad_id"])
    
    def forward(self, tokens, device=None):
        
        hs = self.embedding(
            tokens["value_tokens"], 
            tokens["coord_type_tokens"], 
            tokens["position_tokens"]
        )
        
        hs = self.reformer(
            hs, input_mask=tokens["padding_mask"]
        )
        hs = self.layernorm(hs)
        
        return hs
        
    def __call__(self, inputs, device=None):
        tokens = self.tokenizer.tokenize(inputs)
        tokens = {k: v.to(device) for k, v in tokens.items()}
        
        hs = self.forward(tokens, device=device)
        hs = F.linear(hs, self.embedding.value_embed.weight)
        BATCH, LENGTH, EMBED = hs.shape
        hs = hs.reshape(BATCH*LENGTH, EMBED)
        targets = tokens["target_tokens"].reshape(BATCH*LENGTH,)
        
        acc = accuracy(
            hs, targets, ignore_label=self.tokenizer.pad_id, device=device
        )
        loss = self.loss_func(hs, targets)
        
        outputs = {
            "accuracy": acc,
            "loss": loss,
        }
        return outputs
    
    @torch.no_grad()
    def predict(self, max_seq_len=2400, device=None):
        tokenizer = self.tokenizer
        special_tokens = tokenizer.special_tokens
        
        tokens = tokenizer.get_pred_start()
        tokens = {k: v.to(device) for k, v in tokens.items()}
        preds = []
        pred_idx = 0
        
        while (pred_idx <= max_seq_len-1)\
        and ((len(preds) == 0) or (preds[-1] != special_tokens["eos"]-len(special_tokens))):
            
            if pred_idx >= 1:
                tokens = tokenizer.tokenize([torch.stack(preds)])
                tokens["value_tokens"][:, pred_idx+1] = special_tokens["pad"]
                tokens["padding_mask"][:, pred_idx+1] = True
            
            try:
                hs = self.forward(tokens, device=device)
            except IndexError:
                print(pred_idx)
                for k, v in tokens.items():
                    print(k)
                    print(v.shape)
                    print(v)
                import traceback
                print(traceback.format_exc())
                break

            hs = F.linear(hs[:, pred_idx], self.embedding.value_embed.weight)
            pred = hs.argmax(dim=1) - len(special_tokens)
            preds.append(pred[0])
            pred_idx += 1
            
        return torch.stack(preds) + len(special_tokens)

In [29]:
config = VertexPolyGenConfig(
    embed_dim=128, reformer__depth=6, 
    reformer__lsh_dropout=0., reformer__ff_dropout=0.,
    reformer__post_attn_dropout=0.
)
model = VertexPolyGen(config)

In [30]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [31]:
data = v_batch[:1]
data

[tensor([[160, 136, 132],
         [160, 136, 130],
         [160, 136, 126],
         ...,
         [ 94, 136, 127],
         [ 94, 136, 124],
         [ 94, 136, 122]], dtype=torch.int32)]

In [None]:
import numpy as np
epoch_num = 700
model.train()
losses = []
accs = []

for i in range(epoch_num):
    optimizer.zero_grad()
    outputs = model(data)
    
    loss = outputs["loss"]
    acc = outputs["accuracy"]
    losses.append(loss.item())
    accs.append(acc.item())
    if i % 10 == 0:
        ave_loss = np.mean(losses[-10:])
        ave_acc = np.mean(accs[-10:])
        print("iteration: {}\tloss: {:.5f}\tacc: {}".format(i, ave_loss, ave_acc))
    loss.backward()
    optimizer.step()

iteration: 0	loss: 71.26718	acc: 0.01373346894979477
iteration: 10	loss: 34.52539	acc: 0.020752797555178403
iteration: 20	loss: 11.62529	acc: 0.03056968506425619
iteration: 30	loss: 7.15253	acc: 0.03723296038806438
iteration: 40	loss: 5.38933	acc: 0.04328585974872112
iteration: 50	loss: 4.46247	acc: 0.04587995931506157
iteration: 60	loss: 4.03079	acc: 0.05330620557069778
iteration: 70	loss: 3.80348	acc: 0.05879959352314472
iteration: 80	loss: 3.63885	acc: 0.064394710958004
iteration: 90	loss: 3.54084	acc: 0.07054933831095696
iteration: 100	loss: 3.46305	acc: 0.08092573806643485
iteration: 110	loss: 3.39959	acc: 0.09257375374436379


In [30]:
model.eval()
pred = model.predict(max_seq_len=2400)
pred

tensor([163, 163, 163,  ..., 139, 125,   1])

In [45]:
((pred[:-1] - 3) == data[0].reshape(-1, )).sum()

tensor(1860)

In [51]:
pred[:-1].reshape(655, 3) - 3

tensor([[160, 160, 160],
        [160, 136, 136],
        [160, 136, 132],
        ...,
        [ 94, 136, 127],
        [ 94, 136, 124],
        [ 94, 136, 122]])

In [50]:
data[0]

tensor([[160, 136, 132],
        [160, 136, 130],
        [160, 136, 126],
        ...,
        [ 94, 136, 127],
        [ 94, 136, 124],
        [ 94, 136, 122]], dtype=torch.int32)

In [42]:
dec_tokenizer.tokenize()

{'value_tokens': tensor([[  0, 166, 142, 132, 132, 142, 132, 132, 142, 132, 132, 142, 122, 122,
          142, 142, 142, 142, 142, 142, 142, 125, 125, 142, 132, 132, 142, 121,
          121,  90, 121, 121, 121, 121, 121, 121, 121, 121, 121, 121, 121, 176,
          142, 142, 142, 122, 122, 142, 132, 132, 142, 132, 132, 142, 122, 122,
          133, 129, 129, 142, 132, 132, 142, 142, 142, 142, 142, 142, 142, 132,
          132, 142, 142, 142, 142, 142, 142, 142, 142, 142, 142, 122, 122, 142,
          132, 132, 142, 142, 142, 142, 142, 142, 142, 142, 142, 142, 132, 132,
          142, 126, 126,   1,   2]]),
 'target_tokens': tensor([[166, 142, 132, 132, 142, 132, 132, 142, 132, 132, 142, 122, 122, 142,
          142, 142, 142, 142, 142, 142, 125, 125, 142, 132, 132, 142, 121, 121,
           90, 121, 121, 121, 121, 121, 121, 121, 121, 121, 121, 121, 176, 142,
          142, 142, 122, 122, 142, 132, 132, 142, 132, 132, 142, 122, 122, 133,
          129, 129, 142, 132, 132, 142, 142, 142,

In [25]:
dec_tokenizer.tokenize([torch.tensor([0])])

{'value_tokens': tensor([[0, 3, 1, 2]]),
 'target_tokens': tensor([[3, 1, 2, 2]]),
 'coord_type_tokens': tensor([[0, 1, 0, 0]]),
 'position_tokens': tensor([[0, 1, 0, 0]]),
 'padding_mask': tensor([[False, False, False,  True]])}