https://qiita.com/opeco17/items/707a5c57bca41a145122

In [1]:
import numpy as np
import math
import random
import os
import scipy.spatial.distance
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F


In [2]:
class NonLinear(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(NonLinear, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels

        self.main = nn.Sequential(
            nn.Linear(self.input_channels, self.output_channels),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(self.output_channels))

    def forward(self, input):
        return self.main(input)



In [3]:
class MaxPool(nn.Module):
    def __init__(self, num_channels, num_points):
        super(MaxPool, self).__init__()
        self.num_channels = num_channels
        self.num_points = num_points
        self.main = nn.MaxPool1d(self.num_points)
    
    def forward(self, input):
        out = input.view(-1, self.num_channels, self.num_points)
        out = self.main(out)
        out = out.view(-1, self.num_channels)

        return out

In [4]:
class InputTNet(nn.Module):
    def __init__(self, num_points):
        super().__init__()
        self.num_points = num_points

        self.main = nn.Sequential(
            NonLinear(3, 64),
            NonLinear(64, 128),
            NonLinear(128, 1024),
            MaxPool(1024, self.num_points),
            NonLinear(1024, 512),
            NonLinear(512, 256),
            nn.Linear(256, 9)
        )

    def forward(self, input):
        matrix = self.main(input).view(-1, 3, 3)
        out = torch.matmul(input.view(-1, self.num_points, 3), matrix)
        out = out.view(-1, 3)

        return out

In [5]:
class FeatureTNet(nn.Module):
    def __init__(self, num_points):
        super(FeatureTNet, self).__init__()
        self.num_points = num_points

        self.main = nn.Sequential(
            NonLinear(64, 64),
            NonLinear(64, 128),
            NonLinear(128, 1024),
            MaxPool(1024, self.num_points),
            NonLinear(1024, 512),
            NonLinear(512, 256),
            nn.Linear(256, 4096)
        )
    
    def forward(self, input):
        matrix = self.main(input).view(-1, 64, 64)
        out = torch.matmul(input.view(-1, self.num_points, 64), matrix)
        out = out.view(-1, 64)

        return out

In [6]:
class PointNet(nn.Module):
    def __init__(self, num_points, num_labels):
        super(PointNet, self).__init__()
        self.num_points = num_points
        self.num_labels = num_labels

        self.main = nn.Sequential(
            InputTNet(self.num_points),
            NonLinear(3, 64),
            NonLinear(64, 64),
            FeatureTNet(self.num_points),
            NonLinear(64, 64),
            NonLinear(64, 128),
            NonLinear(128, 1024),
            MaxPool(1024, self.num_points),
            NonLinear(1024, 512),
            nn.Dropout(p=0.3),
            NonLinear(512, 256),
            nn.Dropout(p=0.3),
            NonLinear(256, self.num_labels)
        )

    def forward(self, input):
        return self.main(input)

In [7]:
def data_sampler(batch_size, num_points):
    half_batch_size = int(batch_size/2)
    normal_sampled = torch.randn(half_batch_size, num_points, 3)
    uniform_sampled = torch.rand(half_batch_size, num_points, 3)
    normal_labels = torch.ones(half_batch_size)
    uniform_labels = torch.zeros(half_batch_size)

    input_data = torch.cat((normal_sampled, uniform_sampled), dim=0)
    labels = torch.cat((normal_labels, uniform_labels), dim=0)

    data_shuffle = torch.randperm(batch_size)

    return input_data[data_shuffle].view(-1, 3), labels[data_shuffle].view(-1, 1)

In [8]:
# Main function

batch_size = 64
num_points = 64
num_labels = 1

pointnet = PointNet(num_points, num_labels)

new_param = pointnet.state_dict()
new_param['main.0.main.6.bias'] = torch.eye(3, 3).view(-1)
new_param['main.3.main.6.bias'] = torch.eye(64, 64).view(-1)
pointnet.load_state_dict(new_param)

criterion = nn.BCELoss()
optimizer = optim.Adam(pointnet.parameters(), lr=0.001)

loss_list = []
accuracy_list = []

for iteration in range(10000+1):

    pointnet.zero_grad()
    input, labels = data_sampler(batch_size, num_points)

    output = pointnet(input)
    output = nn.Sigmoid()(output)

    error = criterion(output, labels)
    error.backward()

    optimizer.step()

    with torch.no_grad():
        output[output > 0.5] = 1
        output[output < 0.5] = 0
        accuracy = (output==labels).sum().item()/batch_size

    loss_list.append(error.item())
    accuracy_list.append(accuracy)

    if iteration % 10 == 0:
        print(f'Iteration: {iteration}    Loss: {error.item()}')
        print(f'Iteration: {iteration}    Accuracy: {accuracy}')

ration: 7490    Loss: 0.0060462504625320435
Iteration: 7490    Accuracy: 1.0
Iteration: 7500    Loss: 0.006015968043357134
Iteration: 7500    Accuracy: 1.0
Iteration: 7510    Loss: 0.005978365894407034
Iteration: 7510    Accuracy: 1.0
Iteration: 7520    Loss: 0.005948164034634829
Iteration: 7520    Accuracy: 1.0
Iteration: 7530    Loss: 0.00598184997215867
Iteration: 7530    Accuracy: 1.0
Iteration: 7540    Loss: 0.005906805396080017
Iteration: 7540    Accuracy: 1.0
Iteration: 7550    Loss: 0.00591204734519124
Iteration: 7550    Accuracy: 1.0
Iteration: 7560    Loss: 0.00590001605451107
Iteration: 7560    Accuracy: 1.0
Iteration: 7570    Loss: 0.005846172571182251
Iteration: 7570    Accuracy: 1.0
Iteration: 7580    Loss: 0.0058439914137125015
Iteration: 7580    Accuracy: 1.0
Iteration: 7590    Loss: 0.005805515684187412
Iteration: 7590    Accuracy: 1.0
Iteration: 7600    Loss: 0.0057623437605798244
Iteration: 7600    Accuracy: 1.0
Iteration: 7610    Loss: 0.005750629585236311
Iteration