In [1]:
import os
import sys
import glob
import torch

In [2]:
base_dir = os.path.dirname(os.getcwd())
out_dir = os.path.join(base_dir, "results", "models")
data_dir = os.path.join(base_dir, "data", "original")
train_files = glob.glob(os.path.join(data_dir, "train", "*", "*.obj"))
valid_files = glob.glob(os.path.join(data_dir, "val", "*", "*.obj"))
print(len(train_files), len(valid_files))

src_dir = os.path.join(base_dir, "src")
sys.path.append(os.path.join(src_dir))

7003 1088


In [3]:
from utils import load_pipeline
from pytorch_trainer import Trainer, Reporter
from models import FacePolyGenConfig, FacePolyGen, VertexPolyGenConfig, VertexPolyGen

In [4]:
v_batch, f_batch = [], []
for i in range(3):
    vs, _, fs = load_pipeline(train_files[i])
    
    vs = torch.tensor(vs)
    fs = [torch.tensor(f) for f in fs]
    
    v_batch.append(vs)
    f_batch.append(fs)
    print(vs.shape, len(fs))
    print("="*60)

torch.Size([655, 3]) 588
torch.Size([310, 3]) 220
torch.Size([396, 3]) 304


In [5]:
model_conditions = {
    "face": FacePolyGen(FacePolyGenConfig(
                            embed_dim=64, 
                            src__reformer__depth=4,
                            src__reformer__lsh_dropout=0.,
                            src__reformer__ff_dropout=0., 
                            src__reformer__post_attn_dropout=0.,
                            tgt__reformer__depth=4, 
                            tgt__reformer__lsh_dropout=0.,
                            tgt__reformer__ff_dropout=0., 
                            tgt__reformer__post_attn_dropout=0.
    )),
    "vertex": VertexPolyGen(VertexPolyGenConfig(
                            embed_dim=128, reformer__depth=6, 
                            reformer__lsh_dropout=0., 
                            reformer__ff_dropout=0.,
                            reformer__post_attn_dropout=0.
    )),
}

src__max_seq_len changed, because of lsh-attention's bucket_size
before: 2400 --> after: 2592 (with bucket_size: 48)
tgt__max_seq_len changed, because of lsh-attention's bucket_size
before: 3900 --> after: 3936 (with bucket_size: 48)


In [6]:
model_type = "vertex"
# model_type = "vertex"
model = model_conditions[model_type]
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [7]:
class VertexDataset(torch.utils.data.Dataset):
    
    def __init__(self, vertices):
        self.vertices = vertices

    def __len__(self):
        return len(self.vertices)

    def __getitem__(self, idx):
        x = self.vertices[idx]
        return x
    
class FaceDataset(torch.utils.data.Dataset):
    
    def __init__(self, vertices, faces):
        self.vertices = vertices
        self.faces = faces

    def __len__(self):
        return len(self.vertices)

    def __getitem__(self, idx):
        x = self.vertices[idx]
        y = self.faces[idx]
        return x, y

In [8]:
v_batch = v_batch[:1]
f_batch = f_batch[:1]
v_dataset = VertexDataset(v_batch)
f_dataset = FaceDataset(v_batch, f_batch)
len(v_dataset), len(f_dataset)

(1, 1)

In [9]:
def collate_fn_vertex(batch):
    return [{"vertices": batch}]

def collate_fn_face(batch):
    vertices = [d[0] for d in batch]
    faces = [d[1] for d in batch]
    return [{"vertices": vertices, "faces": faces}]

In [10]:
batch_size = 1
v_loader = torch.utils.data.DataLoader(v_dataset, batch_size, shuffle=True, collate_fn=collate_fn_vertex)
f_loader = torch.utils.data.DataLoader(f_dataset, batch_size, shuffle=True, collate_fn=collate_fn_face)
loader_condition = {
    "face": f_loader,
    "vertex": v_loader,
}
len(v_loader), len(f_loader)

(1, 1)

In [11]:
epoch_num = 300
report_interval = 10
save_interval = 10
eval_interval = 1
loader = loader_condition[model_type]

reporter = Reporter(print_keys=['main/loss', 'main/perplexity', 'main/accuracy'])
trainer = Trainer(
    model, optimizer, [f_loader, f_loader], gpu="gpu",
    reporter=reporter, stop_trigger=(epoch_num, 'epoch'),
    report_trigger=(report_interval, 'iteration'), save_trigger=(save_interval, 'epoch'),
    log_trigger=(save_interval, 'epoch'), eval_trigger=(eval_interval, 'epoch'),
    out_dir=out_dir, #ckpt_path=os.path.join(model_save_dir, 'ckpt_18')
)

In [12]:
trainer.run()

epoch: 0	iteration: 0	main/loss: 5.61779	main/perplexity: 275.27975	main/accuracy: 0.00661


KeyboardInterrupt: 