Mesh classification for FAUST dataset based on Graph Convolutional networks. 
Triangular mesh is treated as a graph with vertex coordinates being used as features at each node.

In [35]:
import os.path as osp
import torch
import torch.nn.functional as F
from torch_geometric.datasets import FAUST, TOSCA
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
import torch.nn as nn
from torch_geometric.nn import GCNConv, ChebConv

In [36]:
class MyTransform(object):
    def __call__(self, data):
        data.x = data.pos
        return data

# Load training and testing data     
path = osp.join('/home/sumukh/Documents/DataSets/pytorch_datasets/', 'FAUST')
pre_transform = T.Compose([T.FaceToEdge(), MyTransform()])
train_dataset = FAUST(path, True, pre_transform)
test_dataset = FAUST(path, False, pre_transform)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)


test_loader = DataLoader(test_dataset, batch_size=1)
d = train_loader.dataset[1]

In [37]:
# Model consist of a simple graph convolutional layer and a dense layer.
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        num_out_features=64
        self.conv1 = GCNConv(train_dataset.num_features, num_out_features, cached=True)
        self.fc1 = nn.Linear(num_out_features*d.num_nodes, train_dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = x.reshape(1,-1)
        x = self.fc1(x)
        return F.log_softmax(x, dim=1)

In [39]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


def train(epoch):
    model.train()

    if epoch == 16:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.001

    if epoch == 26:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.0001

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        F.nll_loss(model(data), data.y).backward()
        optimizer.step()


def test():
    model.eval()
    correct = 0

    for data in test_loader:
        data = data.to(device)
        pred = model(data).max(1)[1]
        correct += pred.eq(data.y).sum().item()
    return correct / len(test_dataset)


for epoch in range(1, 31):
    train(epoch)
    test_acc = test()
    print('Epoch: {:02d}, Test: {:.4f}'.format(epoch, test_acc))    
    