In [7]:
%load_ext autoreload
%autoreload 2

import os
import sys
import time
import numpy as np
from tqdm import tqdm_notebook as tqdm

import matplotlib.pyplot as plt 
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.nn import (NNConv, GMMConv, GraphConv, Set2Set)
from torch_geometric.nn import (SplineConv, graclus, max_pool, max_pool_x, global_mean_pool)

import trimesh

from visualization_utils import plot_mesh_3d

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
def make_data_instance_from_mesh(mesh: trimesh.base.Trimesh) -> torch_geometric.data.Data:
        ''' Takes a raw trimesh mesh and transforms it 
            into a pytorch geometric data instance for prediction normals on each vertex.
            
            Warning! For some reasons the resulting graph is not directed.
        '''

        edge_indices = mesh.edges

        edge_attributes = [0] * len(edge_indices)
        for i, (a, b) in enumerate(edge_indices):
            edge_attributes[i] = mesh.vertices[b] - mesh.vertices[a]

        data = torch_geometric.data.Data(   x=torch.tensor(mesh.vertices, dtype=torch.float), 
                                            y=torch.tensor(mesh.vertex_normals, dtype=torch.float),
                                            edge_index=torch.tensor(edge_indices, dtype=torch.long).t().contiguous(), 
                                            edge_attr=torch.tensor(edge_attributes, dtype=torch.float),
                                            face=mesh.faces )
        return data

class NoramalsDataset(torch_geometric.data.Dataset):
    ''' Just load a mesh at each iteration and convert it into a data instance.
        
        Warning! Probably it would be much more eficient to implement 
            InMemoryDataset and/or store preprocess instances.
    '''
    
    def __init__(self, root, transform=None, pre_transform=None, 
                             apply_rotation=False, train=True, delimetr=0.7):
        super(NoramalsDataset, self).__init__(root, transform, pre_transform)
        
        self.apply_rotation = apply_rotation
        
        self.objects = list()
        for (dirpath, dirnames, filenames) in os.walk(root):
            self.objects += [os.path.join(dirpath, file) for file in filenames if file[-4:] == '.obj']
        
        delimetr = int(delimetr * len(self.objects))
        if train:
            self.objects = self.objects[:delimetr]
        else:
            self.objects = self.objects[delimetr:]

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return []
    
    def __len__(self):
        return len(self.objects)

    def get(self, idx):
        mesh = trimesh.load(self.objects[idx])
        if self.apply_rotation:
            mesh.apply_transform(trimesh.transformations.random_rotation_matrix())
        return make_data_instance_from_mesh(mesh)
    
class NoramalsInMemoryDataset(torch_geometric.data.InMemoryDataset):
    '''
        Preprocess dataset and store it
    '''
    
    def __init__(self, root, transform=None, pre_transform=None):
        super(NoramalsInMemoryDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['0.obj']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        print("Attempt to download dataset")
        pass
    
    def process(self):
        # Read data into huge `Data` list.
        data_list = []
        
        object_pathes = list()
        for (dirpath, dirnames, filenames) in os.walk(self.raw_dir):
            object_pathes += [os.path.join(dirpath, file) for file in filenames if file[-4:] == '.obj']
 
        for object_path in tqdm(object_pathes):
            mesh = trimesh.load(object_path)
            data_list.append(make_data_instance_from_mesh(mesh))

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [9]:
class SplineBlock(nn.Module):
    def __init__(self, num_in_features, num_outp_features, mid_features, kernel=3, dim=3):
        super(SplineBlock, self).__init__()
        self.conv1 = SplineConv(num_in_features, mid_features, dim, kernel, is_open_spline=False)
        self.conv2 = SplineConv(mid_features, 2 * mid_features, dim, kernel, is_open_spline=False)
        self.conv3 = SplineConv(2 * mid_features + num_in_features, num_outp_features, dim, kernel, is_open_spline=False)

    def forward(self, data):
        res = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
        res = F.elu(self.conv2(res, data.edge_index, data.edge_attr))
        res = torch.cat([res, data.x], dim=1)
        res = self.conv3(res, data.edge_index, data.edge_attr)
        return res


class SplineCNN(nn.Module):
    def __init__(self, num_features, kernel=3, dim=3):
        super(SplineCNN, self).__init__()
        self.block1 = SplineBlock(num_features, 3, 9, kernel, dim)

    def forward(self, data):
        res = self.block1(data)
        return res
    
class SplineCNN2(nn.Module):
    def __init__(self, num_features, kernel=3, dim=3):
        super(SplineCNN2, self).__init__()
        self.block1 = SplineBlock(num_features, 16, 9, kernel, dim)
        self.block2 = SplineBlock(16, 3, 32, kernel, dim)

    def forward(self, data):
        data.x = F.elu(self.block1(data))
        res = F.elu(self.block2(data))
        return res
    
class SplineCNN4(nn.Module):
    def __init__(self, num_features, kernel=3, dim=3):
        super(SplineCNN4, self).__init__()
        self.block1 = SplineBlock(num_features, 16, 9, kernel, dim)
        self.block2 = SplineBlock(16, 64, 32, kernel, dim)
        self.block3 = SplineBlock(64, 64, 128, kernel, dim)
        self.block4 = SplineBlock(64, 3, 16, kernel, dim)

    def forward(self, data):
        data.x = F.elu(self.block1(data))
        data.x = F.elu(self.block2(data))
        data.x = F.elu(self.block3(data))
        
        return self.block4(data)
    

class MoNet(nn.Module):
    def __init__(self, num_features, kernel=[3, 3, 3], dim=3):
        super(MoNet, self).__init__()
        self.conv1 = GMMConv(in_channels=num_features, out_channels=8, dim=dim, kernel_size=kernel)
        self.conv2 = GMMConv(in_channels=8, out_channels=16, dim=dim, kernel_size=kernel)
        self.conv3 = GMMConv(in_channels=16, out_channels=8, dim=dim, kernel_size=kernel)
        self.conv4 = GMMConv(in_channels=8, out_channels=3, dim=dim, kernel_size=kernel)
        
    def forward(self, data):
        data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
        data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))
        data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr))
        data.x = self.conv4(data.x, data.edge_index, data.edge_attr)
        return data.x

In [20]:
def train(epoch, model, train_loader, device, optimizer):
    model.train()

    for data in tqdm(train_loader, leave=False):
        data = data.to(device)
        optimizer.zero_grad()
        responce = model(data)
        loss = F.mse_loss(responce, data.y)
        loss.backward()
        optimizer.step()

    del data, loss, responce 

def validate(model, test_loader, device):
    model.eval()
    loss = 0

    for data in tqdm(test_loader):
        data = data.to(device)
        pred = model(data)
        loss += F.mse_loss(data.y, pred).cpu().detach().numpy()
    return loss / len(test_loader)

def process_model(network, out_file_name, train_loader, validation_loader,
                  init_lr=0.1, num_epochs=150):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = network(dataset.num_features).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=init_lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.8, patience=3, min_lr=0.00001, verbose=True)

   
    start_time = time.time()
    for epoch in tqdm(range(num_epochs)):
        train(epoch, model, train_loader, device, optimizer)
        test_acc = validate(model, validation_loader, device)
        scheduler.step(test_acc)
        with open(out_file_name, 'a') as file:
            print('Epoch: {:02d}, Time: {:.4f}, Validation Accuracy: {:.4f}'\
                  .format(epoch, time.time() - start_time, test_acc), file=file)

#         start_time = time.time()
#         test_acc = validate(model, test_loader, device)
#         print('Test, Time: {:.4f}, Accuracy: {:.4f}'\
#               .format(time.time() - start_time, test_acc))

    return model

In [11]:
# train_dataset = NoramalsDataset('/cvlabsrc1/cvlab/dataset_shapenet/manifolds/02691156/')
# train_loader = torch_geometric.data.DataLoader(train_dataset, batch_size=16, shuffle=False)

# validation_dataset = NoramalsDataset('/cvlabsrc1/cvlab/dataset_shapenet/manifolds/02691156/', train=False)
# validation_loader = torch_geometric.data.DataLoader(validation_dataset, batch_size=16, shuffle=False)

%time dataset = NoramalsInMemoryDataset('/cvlabsrc1/cvlab/dataset_shapenet/normals_dataset/')

CPU times: user 16.8 s, sys: 26.4 s, total: 43.3 s
Wall time: 43.3 s


In [None]:
train_loader = torch_geometric.data.DataLoader(dataset[:-16], batch_size=4, shuffle=False)
validation_loader = torch_geometric.data.DataLoader(dataset[-16:], batch_size=4, shuffle=False)

In [13]:
train_loader = dataset[:1500]
validation_loader = dataset[1500:]

In [None]:
model = process_model(SplineCNN4, 'SplineV4', train_loader, validation_loader, init_lr=0.01)

HBox(children=(IntProgress(value=0, max=150), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1500), HTML(value='')))

In [None]:
import gc
gc.collect()

In [1]:
!nvidia-smi

Mon Feb 24 12:26:41 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.26       Driver Version: 440.26       CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:18:00.0 Off |                    0 |
| N/A   56C    P0    72W / 300W |      0MiB / 32510MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage    

In [None]:
# Check if it works on their dataset
from torch_geometric.datasets import Planetoid

cora_dataset = Planetoid(root='/tmp/Cora', name='Cora')

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class Net(torch.nn.Module):
    def __init__(self, aa):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

In [None]:
loader = torch_geometric.data.DataLoader(cora_dataset, batch_size=1, shuffle=False)
model = process_model(Net, 'SplineV1', loader, loader, init_lr=0.1, num_epochs=250)

In [None]:
# Copy files to a new directory

from shutil import copyfile

objects = list()
for (dirpath, dirnames, filenames) in os.walk('/cvlabsrc1/cvlab/dataset_shapenet/manifolds/02933112/'):
    objects += [os.path.join(dirpath, file) for file in filenames if file[-4:] == '.obj']

for idx, obj_path in enumerate(objects):
    copyfile(obj_path, '/cvlabsrc1/cvlab/dataset_shapenet/normals_dataset/raw_dir/{}.obj'.format(idx))