In [21]:
from time import sleep
from pathlib import Path
from itertools import tee
from functools import lru_cache

import trimesh
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, remove_self_loops
from torch_geometric.transforms import BaseTransform, Compose, FaceToEdge
from torch_geometric.data import Data, InMemoryDataset, extract_zip, DataLoader
from torch_geometric.datasets import FAUST


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
def load_mesh(mesh_filename: Path):
    """Extract vertices and faces from raw mesh file.

    Parameters
    ----------
    mesh_filename: PathLike
        Path to mesh `.ply` file.

    Returns
    -------
    vertices: torch.tensor
        Float tensor of size (|V|, 3), where each row
        specifies the spatial position of a vertex in 3D space.
    faces: torch.tensor
        Intger tensor of size (|M|, 3), where each row
        defines a traingular face.
    """
    mesh = trimesh.load_mesh(mesh_filename, process=False)
    vertices = torch.from_numpy(mesh.vertices).to(torch.float)
    faces = torch.from_numpy(mesh.faces)
    faces = faces.t().to(torch.long).contiguous()
    return vertices, faces

In [4]:
class SegmentationFaust(InMemoryDataset):
    map_seg_label_to_id = dict(
        head=0,
        torso=1,
        left_arm=2,
        left_hand=3,
        right_arm=4,
        right_hand=5,
        left_upper_leg=6,
        left_lower_leg=7,
        left_foot=8,
        right_upper_leg=9,
        right_lower_leg=10,
        right_foot=11,
    )

    def __init__(self, root, train: bool = True, pre_transform=None):
        """
        Parameters
        ----------
        root: PathLike
            Root directory where the dataset should be saved.
        train: bool
            Whether to load training data or test data.
        pre_transform: Optional[Callable]
            A function that takes in a torch_geometric.data.Data object
            and outputs a transformed version. Note that the transformed
            data object will be saved to disk.

        """
        super().__init__(root, pre_transform)
        path = self.processed_paths[0] if train else self.processed_paths[1]
        self.data, self.slices = torch.load(path)

    @property
    def processed_file_names(self) -> list:
        return ["training.pt", "test.pt"]

    @property
    @lru_cache(maxsize=32)
    def _segmentation_labels(self):
        """Extract segmentation labels."""
        path_to_labels = Path(self.root) / "MPI-FAUST"/ "segmentations.npz"
        seg_labels = np.load(str(path_to_labels))["segmentation_labels"]
        return torch.from_numpy(seg_labels).type(torch.int64)

    def _mesh_filenames(self):
        """Extract all mesh filenames."""
        path_to_meshes = Path(self.root)/ "MPI-FAUST" / "meshes"
        return path_to_meshes.glob("*.ply")

    def _unzip_dataset(self):
        """Extract dataset from zip."""
        path_to_zip = Path(self.root) / "MPI-FAUST.zip"
        extract_zip(str(path_to_zip), self.root, log=False)

    def process(self):
        """Process the raw meshes files and their corresponding class labels."""
        self._unzip_dataset()

        data_list = []
        for mesh_filename in sorted(self._mesh_filenames()):
            vertices, faces = load_mesh(mesh_filename)
            data = Data(x=vertices, face=faces)
            data.segmentation_labels = self._segmentation_labels
            if self.pre_transform is not None:
                data = self.pre_transform(data)
            data_list.append(data)

        torch.save(self.collate(data_list[:80]), self.processed_paths[0])
        torch.save(self.collate(data_list[80:]), self.processed_paths[1])

In [5]:
class NormalizeUnitSphere(BaseTransform):
    """Center and normalize node-level features to unit length."""

    @staticmethod
    def _re_center(x):
        """Recenter node-level features onto feature centroid."""
        centroid = torch.mean(x, dim=0)
        return x - centroid

    @staticmethod
    def _re_scale_to_unit_length(x):
        """Rescale node-level features to unit-length."""
        max_dist = torch.max(torch.norm(x, dim=1))
        return x / max_dist

    def __call__(self, data: Data):
        if data.x is not None:
            data.x = self._re_scale_to_unit_length(self._re_center(data.x))

        return data

    def __repr__(self):
        return "{}()".format(self.__class__.__name__)

In [13]:
root = '/home/sofialima'
pre_transform = Compose([FaceToEdge(remove_faces=False), NormalizeUnitSphere()])

train_data = SegmentationFaust(
    root=root,
    pre_transform=pre_transform,
)
test_data = SegmentationFaust(
    root=root,
    train=False,
    pre_transform=pre_transform,
)
# train_loader = DataLoader(train_data,  shuffle=True)
# test_loader = DataLoader(test_data, shuffle=False)

In [28]:
for data in train_data:
        print(type(data))
        print(data.num_node_features)
        print(data.x)
        print(data.x.shape)

<class 'torch_geometric.data.data.Data'>
3
tensor([[ 0.0583,  0.6440,  0.1784],
        [ 0.0535,  0.6296,  0.1831],
        [ 0.0637,  0.6292,  0.1746],
        ...,
        [-0.0492,  0.6164,  0.0742],
        [-0.0497,  0.6173,  0.0759],
        [-0.0510,  0.6145,  0.0776]])
torch.Size([6890, 3])
<class 'torch_geometric.data.data.Data'>
3
tensor([[ 0.0429,  0.5554,  0.1249],
        [ 0.0385,  0.5411,  0.1290],
        [ 0.0492,  0.5413,  0.1217],
        ...,
        [-0.0508,  0.5262,  0.0154],
        [-0.0514,  0.5273,  0.0170],
        [-0.0530,  0.5246,  0.0188]])
torch.Size([6890, 3])
<class 'torch_geometric.data.data.Data'>
3
tensor([[ 0.0358,  0.5648,  0.1484],
        [ 0.0316,  0.5516,  0.1542],
        [ 0.0414,  0.5508,  0.1462],
        ...,
        [-0.0647,  0.5296,  0.0537],
        [-0.0649,  0.5308,  0.0550],
        [-0.0662,  0.5283,  0.0569]])
torch.Size([6890, 3])
<class 'torch_geometric.data.data.Data'>
3
tensor([[ 0.0343,  0.3101,  0.0257],
        [ 0.0304,

<class 'torch_geometric.data.data.Data'>
3
tensor([[ 0.0405,  0.3622,  0.0061],
        [ 0.0351,  0.3514,  0.0129],
        [ 0.0452,  0.3485,  0.0064],
        ...,
        [-0.0601,  0.3180, -0.0742],
        [-0.0604,  0.3192, -0.0733],
        [-0.0617,  0.3181, -0.0726]])
torch.Size([6890, 3])
<class 'torch_geometric.data.data.Data'>
3
tensor([[ 0.0615,  0.6262,  0.1473],
        [ 0.0554,  0.6119,  0.1552],
        [ 0.0679,  0.6086,  0.1474],
        ...,
        [-0.0528,  0.5676,  0.0423],
        [-0.0538,  0.5686,  0.0440],
        [-0.0556,  0.5670,  0.0453]])
torch.Size([6890, 3])
<class 'torch_geometric.data.data.Data'>
3
tensor([[-0.0263,  0.6293,  0.1417],
        [-0.0275,  0.6147,  0.1510],
        [-0.0144,  0.6159,  0.1446],
        ...,
        [-0.0974,  0.5291,  0.0396],
        [-0.0991,  0.5301,  0.0411],
        [-0.1004,  0.5281,  0.0423]])
torch.Size([6890, 3])
<class 'torch_geometric.data.data.Data'>
3
tensor([[ 0.0406,  0.4103, -0.0894],
        [ 0.0353,

In [32]:
import os
import json

In [37]:
# datas = []

save_path = os.path.join(root, 'data_dicts')

if not os.path.exists(save_path):
    os.mkdir(save_path)
    
for idx, data in enumerate(train_data):
        data_dict = data.to_dict()
        data_dict2 = {k: v.numpy().tolist() for k, v in data_dict.items()}
        with open(os.path.join(save_path, f'data_{idx}.json'), 'w') as f:
            json.dump(data_dict2, f)
            

In [39]:
npzfile = np.load('MPI-FAUST/segmentations.npz')

In [40]:
type(npzfile)

numpy.lib.npyio.NpzFile

In [41]:
print(npzfile.files)

['segmentation_labels']


In [42]:
type(npzfile['segmentation_labels'])

numpy.ndarray

In [43]:
segmentations = npzfile['segmentation_labels']

In [44]:
segmentations.size

6890

In [45]:
segmentations[0]

array([0], dtype=int32)

In [46]:
segmentations[-1]

array([0], dtype=int32)

In [47]:
segmentations[2000]

array([3], dtype=int32)

In [48]:
import pandas as pd

from torch_geometric.nn import SAGEConv

In [None]:
def construct_graph(save_path,
			data_num):

	with open(os.path.join(save_path, f'data_{data_num}.json'), 'r') as f:
		graph_dict = json.load(f)

	x = torch.tensor(graph_dict['x'], dtype=torch.float)
	y = torch.tensor(graph_dict['y'], dtype=torch.long)	
	edge_index = torch.tensor(graph_dict['edge_index'], dtype=torch.long)
	
	graph = Data(x=x, edge_index=edge_index, y=y)

	print(f'Number of nuclei: {graph.num_nodes}')
	print(f'Number of edges: {graph.num_edges}')	
	print(f'Number of features: {graph.num_node_features}')
	
	return graph

def accuracy(pred_y, 
	 y):
	    """Calculate accuracy."""
	    return ((pred_y == y).sum() / len(y)).item()
	

def test_model(model, 
		 loader):
		
	    criterion = torch.nn.CrossEntropyLoss()
		
	    model.eval()
		
	    loss = 0
	    acc = 0
	    predict = []

	    with torch.no_grad():
		    for data in loader:
		        predicted = model(data.x, data.edge_index)				
		        loss += criterion(predicted, data.y) / len(loader)
		        acc += accuracy(predicted.argmax(dim=1), data.y) / len(loader)
		        predict.append(predicted.argmax(dim=1))
		
	    model.train()

	    return loss, acc, predict
		
def train_model(model, 
		loader,
		val_loader,
		epochs=100):
		
	    criterion = torch.nn.CrossEntropyLoss()
	    optimizer = model.optimizer
		
	    epochs = epochs
		
	    metrics_df = pd.DataFrame(columns={'loss':[],
						'acc':[],
						'val_loss':[],
						'val_acc':[]})
	    
		

	    model.train()
	    for epoch in range(epochs+1):
			
	        total_loss = 0
	        acc = 0
	        total_val_loss = 0
	        total_val_acc = 0

	        # Train on batches
	        for i, data in enumerate(loader):
	          optimizer.zero_grad()
	          predicted = model(data.x, data.edge_index,)
	          loss = criterion(predicted, data.y)
			  
			  #Record losses
	          total_loss += criterion(predicted, data.y) / len(loader)
	          acc += accuracy(predicted.argmax(dim=1), data.y) / len(loader)	    
			  
	          loss.backward()
	          optimizer.step()

	          # Validation
	          val_loss, val_acc, _ = test_model(model, val_loader)
	          total_val_loss += val_loss
	          total_val_acc += val_acc				  		  			  
			  
	        metrics_df.loc[i, 'loss'] = total_loss
	        metrics_df.loc[i, 'acc'] = acc
	        metrics_df.loc[i, 'val_loss'] = val_loss
	        metrics_df.loc[i, 'val_acc'] = val_acc			  

	        # Print metrics every 10 epochs
	        if(epoch % 10 == 0):
				      print(f'Epoch {epoch:>3} | Train Loss: {total_loss:.2f} '
	              f'| Train Acc: {acc*100:>5.2f}% '
	              f'| Val Loss: {val_loss:.2f} '
	              f'| Val Acc: {val_acc*100:.2f}%')
	              
	    return model
	
def save_model(model):		
	    torch.save(model.state_dict(), r'./data/model')
	    
class GraphSAGE(torch.nn.Module):
	"""GraphSAGE"""
	def __init__(self, dim_h, dim_in, dim_out):
		super().__init__()
		self.sage1 = SAGEConv(dim_in, dim_h)
		self.sage2 = SAGEConv(dim_h, dim_out)
		self.optimizer = torch.optim.Adam(self.parameters(),
                                      lr=0.01,
                                      weight_decay=5e-4)

	def forward(self, x, edge_index):
		h = self.sage1(x, edge_index).relu()
		# h = F.dropout(h, p=0.5, training=self.training)
		h = self.sage2(h, edge_index)
		return F.log_softmax(h, dim=1)