## クラス分類による頭部の5方向推定

Input: 1つの点群データの行列(600*3)

Label: 0度、+-45度、+-90度の5クラス

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import open3d as o3d

In [2]:
class NonLinear(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(NonLinear, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels

        self.main = nn.Sequential(
            nn.Linear(self.input_channels, self.output_channels),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(self.output_channels))

    def forward(self, input):
        return self.main(input)

In [3]:
class MaxPool(nn.Module):
    def __init__(self, num_channels, num_points):
        super(MaxPool, self).__init__()
        self.num_channels = num_channels
        self.num_points = num_points
        self.main = nn.MaxPool1d(self.num_points)
    
    def forward(self, input):
        out = input.view(-1, self.num_channels, self.num_points)
        out = self.main(out)
        out = out.view(-1, self.num_channels)

        return out

In [None]:
class InputTNet(nn.Module):
    def __init__(self, num_points):
        super().__init__()
        self.num_points = num_points

        self.main = nn.Sequential(
            NonLinear(3, 64),
            NonLinear(64, 128),
            NonLinear(128, 1024),
            MaxPool(1024, self.num_points),
            NonLinear(1024, 512),
            NonLinear(512, 256),
            nn.Linear(256, 9)
        )

    def forward(self, input):
        matrix = self.main(input).view(-1, 3, 3)
        out = torch.matmul(input.view(-1, self.num_points, 3), matrix)
        out = out.view(-1, 3)

        return out

In [None]:
class FeatureTNet(nn.Module):
    def __init__(self, num_points):
        super(FeatureTNet, self).__init__()
        self.num_points = num_points

        self.main = nn.Sequential(
            NonLinear(64, 64),
            NonLinear(64, 128),
            NonLinear(128, 1024),
            MaxPool(1024, self.num_points),
            NonLinear(1024, 512),
            NonLinear(512, 256),
            nn.Linear(256, 4096)
        )
    
    def forward(self, input):
        matrix = self.main(input).view(-1, 64, 64)
        out = torch.matmul(input.view(-1, self.num_points, 64), matrix)
        out = out.view(-1, 64)

        return out

In [None]:
class PointNet(nn.Module):
    def __init__(self, num_points, num_labels):
        super(PointNet, self).__init__()
        self.num_points = num_points
        self.num_labels = num_labels

        self.main = nn.Sequential(
            InputTNet(self.num_points),
            NonLinear(3, 64),
            NonLinear(64, 64),
            FeatureTNet(self.num_points),
            NonLinear(64, 64),
            NonLinear(64, 128),
            NonLinear(128, 1024),
            MaxPool(1024, self.num_points),
            NonLinear(1024, 512),
            nn.Dropout(p=0.3),
            NonLinear(512, 256),
            nn.Dropout(p=0.3),
            NonLinear(256, self.num_labels)
            nn.LogSoftmax(dim=1)
        )

    def forward(self, input):
        return self.main(input)

In [None]:
def data_loader(num_points=600):
    X = []
    y = []
    for label in range(5):
        for i in range(10):
            pcd = o3d.io.read_point_cloud(f"../Data/five_position_class/{label}/{i}.pcd")
            points = np.array(pcd.points)[:num_points]
            x.append(points)
            y.append(i)

    data_shuffle = torch.randperm(50)

    return X[data_shuffle], y[data_shuffle]

In [None]:
# Main function

batch_size = 50
num_points = 64
num_labels = 1
epochs = 15

pointnet = PointNet(num_points, num_labels)

new_param = pointnet.state_dict()
new_param['main.0.main.6.bias'] = torch.eye(3, 3).view(-1)
new_param['main.3.main.6.bias'] = torch.eye(64, 64).view(-1)
pointnet.load_state_dict(new_param)

criterion = nn.NLLLoss()
optimizer = optim.Adam(pointnet.parameters(), lr=0.001)

loss_list = []
accuracy_list = []

# inputs, labels = data_loader()

for iteration in range(100+1):
    pointnet.zero_grad()
    inputs, labels = data_loader()

    outputs = pointnet(inputs)

    error = criterion(outputs, labels)
    error.backward()

    optimizer.step()

    with torch.no_grad():
        _, pred = torch.max(outputs, 1)
        accuracy = (pred==labels).sum().item()/batch_size

    loss_list.append(error.item())
    accuracy_list.append(accuracy)

    if iteration % 10 == 0:
        print(f'Iteration: {iteration}    Loss: {error.item()}')
        print(f'Iteration: {iteration}    Accuracy: {accuracy}')