# PointNet

## モデル構築用パッケージをインストール

In [16]:
import torch
import torch.nn as nn
import torch.optim as optim

## Architecture

In [17]:
class NonLinear(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(NonLinear, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels

        self.main = nn.Sequential(
            nn.Linear(self.input_channels, self.output_channels),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(self.output_channels))

    def forward(self, input_data):
        return self.main(input_data)

In [18]:
class MaxPool(nn.Module):
    def __init__(self, num_channels, num_points):
        super(MaxPool, self).__init__()
        self.num_channels = num_channels
        self.num_points = num_points
        self.main = nn.MaxPool1d(self.num_points)

    def forward(self, input_data):
        out = input_data.view(-1, self.num_channels, self.num_points)
        out = self.main(out)
        out = out.view(-1, self.num_channels)
        return out

In [19]:
class InputTNet(nn.Module):
    def __init__(self, num_points):
        super(InputTNet, self).__init__()
        self.num_points = num_points

        self.main = nn.Sequential(
            NonLinear(3, 64),
            NonLinear(64, 128),
            NonLinear(128, 1024),
            MaxPool(1024, self.num_points),
            NonLinear(1024, 512),
            NonLinear(512, 256),
            nn.Linear(256, 9)
        )

    # shape of input_data is (batchsize x num_points, channel)
    def forward(self, input_data):
        matrix = self.main(input_data).view(-1, 3, 3)
        out = torch.matmul(input_data.view(-1, self.num_points, 3), matrix)
        out = out.view(-1, 3)
        return out

In [20]:
class FeatureTNet(nn.Module):
    def __init__(self, num_points):
        super(FeatureTNet, self).__init__()
        self.num_points = num_points

        self.main = nn.Sequential(
            NonLinear(64, 64),
            NonLinear(64, 128),
            NonLinear(128, 1024),
            MaxPool(1024, self.num_points),
            NonLinear(1024, 512),
            NonLinear(512, 256),
            nn.Linear(256, 4096)
        )

    # shape of input_data is (batchsize x num_points, channel)
    def forward(self, input_data):
        matrix = self.main(input_data).view(-1, 64, 64)
        out = torch.matmul(input_data.view(-1, self.num_points, 64), matrix)
        out = out.view(-1, 64)
        return out

In [21]:
class PointNet(nn.Module):
    def __init__(self, num_points, num_labels):
        super(PointNet, self).__init__()
        self.num_points = num_points
        self.num_labels = num_labels

        self.main = nn.Sequential(
            InputTNet(self.num_points),
            NonLinear(3, 64),
            NonLinear(64, 64),
            FeatureTNet(self.num_points),
            NonLinear(64, 64),
            NonLinear(64, 128),
            NonLinear(128, 1024),
            MaxPool(1024, self.num_points),
            NonLinear(1024, 512),
            nn.Dropout(p = 0.3),
            NonLinear(512, 256),
            nn.Dropout(p = 0.3),
            NonLinear(256, self.num_labels),
            )

    def forward(self, input_data):
        return self.main(input_data)

## Sample Data

In [22]:
def data_sampler(batch_size, num_points):
    half_batch_size = int(batch_size/2)
    normal_sampled = torch.randn(half_batch_size, num_points, 3)
    uniform_sampled = torch.rand(half_batch_size, num_points, 3)
    normal_labels = torch.ones(half_batch_size)
    uniform_labels = torch.zeros(half_batch_size)

    input_data = torch.cat((normal_sampled, uniform_sampled), dim=0)
    labels = torch.cat((normal_labels, uniform_labels), dim=0)

    data_shuffle = torch.randperm(batch_size)
  
    return input_data[data_shuffle].view(-1, 3), labels[data_shuffle].view(-1, 1)

In [23]:
batch_size = 64
num_points = 64
num_labels = 1

In [24]:
pointnet = PointNet(num_points, num_labels)
new_param = pointnet.state_dict()
print(new_param['main.0.main.6.bias'])

tensor([ 0.0225, -0.0600, -0.0314,  0.0354, -0.0047, -0.0578, -0.0469,  0.0073,
         0.0413])


In [25]:
for param in pointnet.parameters():
    print(param)

2e-03, -1.9187e-02,
         3.0759e-02,  2.2969e-02,  3.0003e-02,  2.9062e-02, -1.1159e-02,
        -1.7297e-02, -2.3087e-02, -1.2868e-02,  2.7561e-02, -1.7277e-02,
        -7.9646e-03, -1.5188e-02,  6.1635e-03, -2.4922e-02,  7.1074e-03,
        -2.5371e-02,  6.7718e-03, -1.3580e-02, -1.4029e-02,  2.7437e-02,
         1.7473e-03,  2.6656e-02, -1.6679e-03,  1.1729e-02, -2.4906e-02,
         1.6040e-02, -9.7106e-03,  2.2595e-02,  1.2265e-02,  1.0617e-02,
         1.6129e-02, -1.3345e-02,  3.6715e-04,  1.2542e-03,  2.2798e-02,
        -1.7539e-02,  2.3377e-02, -1.1164e-02, -2.0425e-02,  1.2335e-02,
         1.2354e-02,  3.0750e-03,  2.3295e-02,  2.5754e-02, -3.0740e-02,
        -1.6611e-02,  3.0707e-02,  5.3548e-04, -3.0337e-02, -1.9122e-02,
        -6.4766e-03,  6.2205e-04, -2.9238e-02, -2.9834e-02,  9.9700e-03,
         1.2641e-03, -1.9878e-02, -2.9550e-03, -8.6126e-03,  1.0351e-02,
         2.8127e-02, -7.6938e-03, -2.9751e-02,  1.9012e-02, -1.9640e-02,
        -1.3597e-02,  1.4651e-0

In [26]:
new_param['main.0.main.6.bias'] = torch.eye(3, 3).view(-1)
new_param['main.3.main.6.bias'] = torch.eye(64, 64).view(-1)
pointnet.load_state_dict(new_param)

criterion =  nn.BCELoss()
optimizer = optim.Adam(pointnet.parameters(), lr=0.001)

loss_list = []
accuracy_list = []

In [28]:
for iteration in range(10000+1):
    pointnet.zero_grad()

    input_data, labels = data_sampler(batch_size, num_points)

    output = pointnet(input_data)
    output = nn.Sigmoid()(output)

    error = criterion(output, labels)
    error.backward()

    optimizer.step()

    with torch.no_grad():
        output[output > 0.5] = 1
        output[output < 0.5] = 0
        accuracy = (output==labels).sum().item()/batch_size

    loss_list.append(error.item())
    accuracy_list.append(accuracy)

    if iteration % 10 == 0:
        print('Iteration : {}   Loss : {}'.format(iteration, error.item()))
        print('Iteration : {}   Accuracy : {}'.format(iteration, accuracy))

Iteration : 0   Loss : 0.6997767686843872
Iteration : 0   Accuracy : 0.53125
Iteration : 10   Loss : 0.40883487462997437
Iteration : 10   Accuracy : 0.96875
Iteration : 20   Loss : 0.334423303604126
Iteration : 20   Accuracy : 1.0
Iteration : 30   Loss : 0.3193380832672119
Iteration : 30   Accuracy : 1.0
Iteration : 40   Loss : 0.31470927596092224
Iteration : 40   Accuracy : 1.0
Iteration : 50   Loss : 0.3165438771247864
Iteration : 50   Accuracy : 1.0
Iteration : 60   Loss : 0.30594080686569214
Iteration : 60   Accuracy : 1.0
Iteration : 70   Loss : 0.29978641867637634
Iteration : 70   Accuracy : 1.0
Iteration : 80   Loss : 0.2950426936149597
Iteration : 80   Accuracy : 1.0
Iteration : 90   Loss : 0.29417240619659424
Iteration : 90   Accuracy : 1.0
Iteration : 100   Loss : 0.29515373706817627
Iteration : 100   Accuracy : 1.0
Iteration : 110   Loss : 0.28817838430404663
Iteration : 110   Accuracy : 1.0
Iteration : 120   Loss : 0.2861238718032837
Iteration : 120   Accuracy : 1.0
Iterati

KeyboardInterrupt: 