<a href="https://colab.research.google.com/github/shnhrtkyk/JTCcode/blob/main/05_%E6%B7%B1%E5%B1%A4%E5%AD%A6%E7%BF%92%E3%81%AB%E3%82%88%E3%82%8B%E7%82%B9%E7%BE%A4%E3%81%AE%E5%88%86%E9%A1%9E.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 点群の分類を行う深層学習手法

ここでは、入力された点群がどのようなクラスであるかを予測する点群の分類タスクを行います。例えば、ロボットの目の前にある物体を計測した点群がなんのクラスなのかがわかれば、ロボットの行動を最適に行えます。

## 点群に対する深層学習手法のライブラリインストール

このノートブックではPytorchと、Pytorch向けの3次元点群処理用ライブラリであるPytorch Geometricを用います。

In [None]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install tensorboardX

# Helper functions for visualization.
%matplotlib inline
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# 可視化に関する関数を作成
def visualize_mesh(pos, face):
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    ax.axes.xaxis.set_ticklabels([])
    ax.axes.yaxis.set_ticklabels([])
    ax.axes.zaxis.set_ticklabels([])
    ax.plot_trisurf(pos[:, 0], pos[:, 1], pos[:, 2], triangles=data.face.t(), antialiased=False)
    plt.show()


def visualize_points(pos, edge_index=None, index=None):
    fig = plt.figure(figsize=(4, 4))
    if edge_index is not None:
        for (src, dst) in edge_index.t().tolist():
             src = pos[src].tolist()
             dst = pos[dst].tolist()
             plt.plot([src[0], dst[0]], [src[1], dst[1]], linewidth=1, color='black')
    if index is None:
        plt.scatter(pos[:, 0], pos[:, 1], s=50, zorder=1000)
    else:
       mask = torch.zeros(pos.size(0), dtype=torch.bool)
       mask[index] = True
       plt.scatter(pos[~mask, 0], pos[~mask, 1], s=50, color='lightgray', zorder=1000)
       plt.scatter(pos[mask, 0], pos[mask, 1], s=50, zorder=1000)
    plt.axis('off')
    plt.show()

## データをダウンロード

点群の分類タスク向けの公開データセットであるModelNetを用います。ModelNetは家具や家電などの3Dモデルから仮想的に作成した点群のデータで、点群とそれに対応するクラス情報が紐づいています。ModelNetには分類するクラス数が違うデータセットがあり、ここでは簡単な10クラス分類のデータセットを用います。
※ダウンロードに時間がかかります。

In [None]:
from pathlib import Path

from torch_geometric.datasets import ModelNet
import torch_geometric.transforms as T

current_path = Path.cwd()
dataset_dir = current_path / "modelnet10"

pre_transform = T.Compose([
    T.SamplePoints(1024, remove_faces=True, include_normals=True),
    T.NormalizeScale(),
])
# name で使用するModelNetの種類を選択する。
train_dataset = ModelNet(dataset_dir, name="10", train=True, transform=None, pre_transform=pre_transform, pre_filter=None)
test_dataset = ModelNet(dataset_dir, name="10", train=False, transform=None, pre_transform=pre_transform, pre_filter=None)

分類するクラスは以下のようにIDが降られています。

In [None]:
path = Path("/content/modelnet10/raw")
folders = [dir for dir in sorted(os.listdir(path)) if os.path.isdir(path/dir)]
classes = {folder: i for i, folder in enumerate(folders)}
classes

### データの確認

データセットとして読み込んだModelNetから点群を取得して可視化します。
今回は訓練データの1000番目の点群を可視化してみます。

In [None]:
tmp =  train_dataset[1000]
print(tmp.y)
point = tmp.pos
x = point[:, 0]
y = point[:, 1]
z = point[:, 2]

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import proj3d

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')

ax.scatter(x, y, z)
plt.show()

訓練データの数を確認します。

In [None]:
print("train_dataset len:", len(train_dataset))
print(train_dataset[0])

Pytorch Geometricで読み込んだ点群データの扱い方について説明します。
`.pos`がXYZ座標を保持しています。 N点×XYZ座標の行列です。


In [None]:
print(train_dataset[0].pos.shape)
print(train_dataset[0].pos)

次に、PytorchGeometricのデータ読み込みについて説明します。
深層学習のモデルを学習する際には、バッチと呼ばれる小さなデータ単位で処理を行います。今回はバッチサイズを32として、一回の学習に32個の点群を呼び出します。



In [None]:
from torch_geometric.data import DataLoader as DataLoader
dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)
batch = next(iter(dataloader))
print(batch)
print(batch.pos)# 座標
print(batch.y) # 真値ラベル

## PointNetの層設計

## 層設計

### シンメトリック関数

まず、入力された点群の順番に対する依存を抜くためのシンメトリック関数を実装します。ここでは、シンメトリック関数としてMaxPoolingを使用します。

In [None]:
from torch_geometric.nn import global_max_pool
import torch.nn as nn

class SymmFunction(nn.Module):
    def __init__(self):
        super(SymmFunction, self).__init__()
        self.shared_mlp = nn.Sequential(
            nn.Linear(3, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 512),
        )
        
    def forward(self, batch):
        x = self.shared_mlp(batch.pos)
        x = global_max_pool(x, batch.batch)
        return x

f = SymmFunction()
print(batch)
y = f(batch)
print(y.shape)

### TNet

TNetは入力された点群の剛体変形を吸収するために開発された関数です。


In [None]:
class InputTNet(nn.Module):
    def __init__(self):
        super(InputTNet, self).__init__()
        self.input_mlp = nn.Sequential(
            nn.Linear(3, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 1024), nn.BatchNorm1d(1024), nn.ReLU(),
        )
        self.output_mlp = nn.Sequential(
            nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(),
            nn.Linear(256, 9)
        )
        
    def forward(self, x, batch):
        x = self.input_mlp(x)
        x = global_max_pool(x, batch)
        x = self.output_mlp(x)
        x = x.view(-1, 3, 3)
        id_matrix = torch.eye(3).to(x.device).view(1, 3, 3).repeat(x.shape[0], 1, 1)
        x = id_matrix + x
        return x

### TNet(特徴量)

TNetは点群の幾何的な変形だけでなく、特徴量空間でも適用することで剛体変形の影響を削減します。


In [None]:
class FeatureTNet(nn.Module):
    def __init__(self):
        super(FeatureTNet, self).__init__()
        self.input_mlp = nn.Sequential(
            nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 1024), nn.BatchNorm1d(1024), nn.ReLU(),
        )
        self.output_mlp = nn.Sequential(
            nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(),
            nn.Linear(256, 64*64)
        )
        
    def forward(self, x, batch):
        x = self.input_mlp(x)
        x = global_max_pool(x, batch)
        x = self.output_mlp(x)
        x = x.view(-1, 64, 64)
        id_matrix = torch.eye(64).to(x.device).view(1, 64, 64).repeat(x.shape[0], 1, 1)
        x = id_matrix + x
        return x

### PointNetの本体

PointNetは非常に単純な構造を持っています。
入力された点群の各点に対して、MLPを適用し、その後にシンメトリック関数であるマックスプーリングによって入力された点群の順番に関する依存を抜きます。
また、TNetを入力された点群に適用し、さらに特徴量空間でも適用します。します。

In [None]:
class PointNetClassification(nn.Module):
    def __init__(self):
        super(PointNetClassification, self).__init__()
        # TNet
        self.input_tnet = InputTNet()
        # 各点に適用するMLP
        self.mlp1 = nn.Sequential(
            nn.Linear(3, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU(),
        )
        self.feature_tnet = FeatureTNet()
        # 各点に適用するMLP
        self.mlp2 = nn.Sequential(
            nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 1024), nn.BatchNorm1d(1024), nn.ReLU(),
        )
        # 各点に適用するMLP
        self.mlp3 = nn.Sequential(
            nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(p=0.3),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(p=0.3),
            nn.Linear(256, 10)
        )
        
    def forward(self, batch_data):
        x = batch_data.pos
        # まずTNetで変換する
        input_transform = self.input_tnet(x, batch_data.batch)
        transform = input_transform[batch_data.batch, :, :]
        x = torch.bmm(transform, x.view(-1, 3, 1)).view(-1, 3)

        # 各点に適用するMLP
        x = self.mlp1(x)
        
        feature_transform = self.feature_tnet(x, batch_data.batch)
        transform = feature_transform[batch_data.batch, :, :]
        x = torch.bmm(transform, x.view(-1, 64, 1)).view(-1, 64)

        x = self.mlp2(x)        

        # シンメトリック関数であるマックスプーリングで順番の依存を抜く
        x = global_max_pool(x, batch_data.batch)
        x = self.mlp3(x)
        
        return x, input_transform, feature_transform

## PointNetの訓練

### モデルの呼び出しなど

In [None]:
import torch
from torch.utils.tensorboard import SummaryWriter

num_epoch = 3
batch_size = 32

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = PointNetClassification()
model = model.to(device)

optimizer = torch.optim.Adam(lr=1e-4, params=model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=num_epoch // 4, gamma=0.5)

log_dir = current_path / "log_modelnet10_classification_pointnet"
log_dir.mkdir(exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

criteria = torch.nn.CrossEntropyLoss()

### 訓練コード

In [None]:
from tqdm import tqdm

for epoch in tqdm(range(num_epoch)):
    model = model.train()
    
    losses = []
    for batch_data in tqdm(train_dataloader, total=len(train_dataloader)):
        batch_data = batch_data.to(device)
        this_batch_size = batch_data.batch.detach().max() + 1
        
        pred_y, _, feature_transform = model(batch_data)
        true_y = batch_data.y.detach()

        class_loss = criteria(pred_y, true_y)
        accuracy = float((pred_y.argmax(dim=1) == true_y).sum()) / float(this_batch_size)

        id_matrix = torch.eye(feature_transform.shape[1]).to(feature_transform.device).view(1, 64, 64).repeat(feature_transform.shape[0], 1, 1)
        transform_norm = torch.norm(torch.bmm(feature_transform, feature_transform.transpose(1, 2)) - id_matrix, dim=(1, 2))
        reg_loss = transform_norm.mean()

        loss = class_loss + reg_loss * 0.001
        
        losses.append({
            "loss": loss.item(),
            "class_loss": class_loss.item(),
            "reg_loss": reg_loss.item(),
            "accuracy": accuracy,
            "seen": float(this_batch_size)})
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # scheduler.step()
        
    if (epoch % 1 == 0):
        model_path = log_dir / f"model_{epoch:06}.pth"
        torch.save(model.state_dict(), model_path)
    
    loss = 0
    class_loss = 0
    reg_loss = 0
    accuracy = 0
    seen = 0
    for d in losses:
        seen = seen + d["seen"]
        loss = loss + d["loss"] * d["seen"]
        class_loss = class_loss + d["class_loss"] * d["seen"]
        reg_loss = reg_loss + d["reg_loss"] * d["seen"]
        accuracy = accuracy + d["accuracy"] * d["seen"]
    loss = loss / seen
    class_loss = class_loss / seen
    reg_loss = reg_loss / seen
    accuracy = accuracy / seen
    writer.add_scalar("train_epoch/loss", loss, epoch)
    writer.add_scalar("train_epoch/class_loss", class_loss, epoch)
    writer.add_scalar("train_epoch/reg_loss", reg_loss, epoch)
    writer.add_scalar("train_epoch/accuracy", accuracy, epoch)

    with torch.no_grad():
        model = model.eval()

        losses = []
        for batch_data in tqdm(test_dataloader, total=len(test_dataloader)):
            batch_data = batch_data.to(device)
            this_batch_size = batch_data.batch.detach().max() + 1

            pred_y, _, feature_transform = model(batch_data)
            true_y = batch_data.y.detach()

            class_loss = criteria(pred_y, true_y)
            accuracy =float((pred_y.argmax(dim=1) == true_y).sum()) / float(this_batch_size)

            id_matrix = torch.eye(feature_transform.shape[1]).to(feature_transform.device).view(1, 64, 64).repeat(feature_transform.shape[0], 1, 1)
            transform_norm = torch.norm(torch.bmm(feature_transform, feature_transform.transpose(1, 2)) - id_matrix, dim=(1, 2))
            reg_loss = transform_norm.mean()

            loss = class_loss + reg_loss * 0.001

            losses.append({
                "loss": loss.item(),
                "class_loss": class_loss.item(),
                "reg_loss": reg_loss.item(),
                "accuracy": accuracy,
                "seen": float(this_batch_size)})
            
        loss = 0
        class_loss = 0
        reg_loss = 0
        accuracy = 0
        seen = 0
        for d in losses:
            seen = seen + d["seen"]
            loss = loss + d["loss"] * d["seen"]
            class_loss = class_loss + d["class_loss"] * d["seen"]
            reg_loss = reg_loss + d["reg_loss"] * d["seen"]
            accuracy = accuracy + d["accuracy"] * d["seen"]
        loss = loss / seen
        class_loss = class_loss / seen
        reg_loss = reg_loss / seen
        accuracy = accuracy / seen
        writer.add_scalar("test_epoch/loss", loss, epoch)
        writer.add_scalar("test_epoch/class_loss", class_loss, epoch)
        writer.add_scalar("test_epoch/reg_loss", reg_loss, epoch)
        writer.add_scalar("test_epoch/accuracy", accuracy, epoch)

## PointNet++の層設計

ここからは、深層学習モデルの構築に移ります。PytorchGeometricを用いて実際にPointNet++を実装してみます。

### サブサンプリング関数

PointNet++では、サブサンプリングを階層的に行いながら特徴量の算出を行います。
Pytorch Geometricでは、最遠点サンプリング（FPSsampling）が関数として定義されています。
PointNet++では、FPSによサブサンプリングする点の選択、選択された点の周辺にある点を探す処理、周辺点の特徴量を選択された点へ集約する、の流れで処理を行います。

In [None]:
from torch_geometric.nn import MLP, PointNetConv, fps, global_max_pool, radius
class SAModule(torch.nn.Module):
    def __init__(self, ratio, r, nn):
        super().__init__()
        self.ratio = ratio # サブサンプリングする割合
        self.r = r # グルーピングする際に使用する半径情報
        self.conv = PointNetConv(nn, add_self_loops=False) # PointNetの関数を使用する

    def forward(self, x, pos, batch):
        idx = fps(pos, batch, ratio=self.ratio) # FPSによるサブサンプリングを行い代表点の決定
        row, col = radius(pos, pos[idx], self.r, batch, batch[idx],
                          max_num_neighbors=64) # 各代表点に入る半径rに入る点の探索
        edge_index = torch.stack([col, row], dim=0)
        x_dst = None if x is None else x[idx]
        x = self.conv((x, x_dst), (pos, pos[idx]), edge_index) # PointNetによる特徴抽出
        pos, batch = pos[idx], batch[idx]
        return x, pos, batch


### グローバルサブサンプリング

グローバルサンプリングでは、各点で算出した特徴量を、ひとつのベクトルへ集約することでクラス分類が実行できます。

In [None]:
# クラスとして定義
class GlobalSAModule(torch.nn.Module):
    def __init__(self, nn):
        super().__init__()
        self.nn = nn # 引数で渡された特徴抽出を行う層を適用する

    def forward(self, x, pos, batch):
        x = self.nn(torch.cat([x, pos], dim=1))
        x = global_max_pool(x, batch) # 入力された点群の特徴に対してグローバル特徴を抽出
        pos = pos.new_zeros((x.size(0), 3))
        batch = torch.arange(x.size(0), device=batch.device)
        return x, pos, batch


### PointNet++の本体

上記で定義したサブサンプリングを繰り返し適用します。  
SAModuleの引数には、サブサンプリングする際の割合、グルーピングの半径、特徴抽出を行う際のチャンネル数を渡します。
最初の層では、入力点群が持つXYZ座標なので、MLPの入力チャンネル数は３から始まります。


In [None]:
class PointNet2Classification(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # 特徴抽出を行う
        self.sa1_module = SAModule(0.5, 0.2, MLP([3, 64, 64, 128])) # サブサンプリング1回目
        self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))# サブサンプリング2回目
        self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024])) #点群全体の特徴を拾うためにグローバル特徴を得る

        self.mlp = MLP([1024, 512, 256, 10], dropout=0.5, norm=None) # 単純なMLPで分類を行う

    def forward(self, data):
        sa0_out = (data.x, data.pos, data.batch)
        sa1_out = self.sa1_module(*sa0_out)
        sa2_out = self.sa2_module(*sa1_out)
        sa3_out = self.sa3_module(*sa2_out)
        x, pos, batch = sa3_out

        return self.mlp(x).log_softmax(dim=-1)


## PointNet++の訓練

### モデルの呼び出しなど

In [None]:
import torch
from torch.utils.tensorboard import SummaryWriter

num_epoch = 3 # 今回は3エポック回します。
batch_size = 32

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# モデルを呼びだす
model = PointNet2Classification()
model = model.to(device)
# 学習に関する設定を行う
optimizer = torch.optim.Adam(lr=1e-4, params=model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=num_epoch // 4, gamma=0.5)
# ログを出す
log_dir = current_path / "log_modelnet10_classification_pointnet2"
log_dir.mkdir(exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
# データの設定
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 損失関数はクロスエントロピー
criteria = torch.nn.CrossEntropyLoss()

### 訓練コード

In [None]:
from tqdm import tqdm

for epoch in tqdm(range(num_epoch)):
    model = model.train()
    
    losses = []
    for batch_data in tqdm(train_dataloader, total=len(train_dataloader)):
        batch_data = batch_data.to(device)
        this_batch_size = batch_data.batch.detach().max() + 1
        
        pred_y = model(batch_data)
        true_y = batch_data.y.detach()

        class_loss = criteria(pred_y, true_y)
        accuracy = float((pred_y.argmax(dim=1) == true_y).sum()) / float(this_batch_size)

        loss = class_loss 
        
        losses.append({
            "loss": loss.item(),
            "class_loss": class_loss.item(),
            "accuracy": accuracy,
            "seen": float(this_batch_size)})
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # scheduler.step()
        
    if (epoch % 1 == 0):
        model_path = log_dir / f"model_{epoch:06}.pth"
        torch.save(model.state_dict(), model_path)
    
    loss = 0
    class_loss = 0
    accuracy = 0
    seen = 0
    for d in losses:
        seen = seen + d["seen"]
        loss = loss + d["loss"] * d["seen"]
        class_loss = class_loss + d["class_loss"] * d["seen"]
        accuracy = accuracy + d["accuracy"] * d["seen"]
    loss = loss / seen
    class_loss = class_loss / seen
    accuracy = accuracy / seen
    writer.add_scalar("train_epoch/loss", loss, epoch)
    writer.add_scalar("train_epoch/class_loss", class_loss, epoch)
    writer.add_scalar("train_epoch/accuracy", accuracy, epoch)

    with torch.no_grad():
        model = model.eval()

        losses = []
        for batch_data in tqdm(test_dataloader, total=len(test_dataloader)):
            batch_data = batch_data.to(device)
            this_batch_size = batch_data.batch.detach().max() + 1

            pred_y = model(batch_data)
            true_y = batch_data.y.detach()

            class_loss = criteria(pred_y, true_y)
            accuracy =float((pred_y.argmax(dim=1) == true_y).sum()) / float(this_batch_size)



            loss = class_loss 

            losses.append({
                "loss": loss.item(),
                "class_loss": class_loss.item(),
                "accuracy": accuracy,
                "seen": float(this_batch_size)})
            
        loss = 0
        class_loss = 0
        accuracy = 0
        seen = 0
        for d in losses:
            seen = seen + d["seen"]
            loss = loss + d["loss"] * d["seen"]
            class_loss = class_loss + d["class_loss"] * d["seen"]
            accuracy = accuracy + d["accuracy"] * d["seen"]
        loss = loss / seen
        class_loss = class_loss / seen
        accuracy = accuracy / seen
        writer.add_scalar("test_epoch/loss", loss, epoch)
        writer.add_scalar("test_epoch/class_loss", class_loss, epoch)
        writer.add_scalar("test_epoch/accuracy", accuracy, epoch)

## 結果の比較

PointNetの結果

In [None]:
%load_ext tensorboard
 
 
#TensorBoard起動（表示したいログディレクトリを指定）
%tensorboard --logdir="/content/log_modelnet10_classification_pointnet" --port 6006

PointNet++の結果

In [None]:
%load_ext tensorboard
 
 
#TensorBoard起動（表示したいログディレクトリを指定）
%tensorboard --logdir="/content/log_modelnet10_classification_pointnet2"  --port 6007

In [None]:
!lsof -i -P | grep 6006

In [None]:
!kill -9 

### PointNet++をライブラリを使わずに書く

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np

def timeit(tag, t):
    print("{}: {}s".format(tag, time() - t))
    return time()

def pc_normalize(pc):
    l = pc.shape[0]
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
    pc = pc / m
    return pc
# 距離を計算
def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.

    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst

    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist


def index_points(points, idx):
    """

    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points

# サブサンプリングする関数
def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids

# グルーピングする関数
def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    sqrdists = square_distance(new_xyz, xyz)
    group_idx[sqrdists > radius ** 2] = N
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    mask = group_idx == N
    group_idx[mask] = group_first[mask]
    return group_idx

# FPSとグルーピングをまとめた関数
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
    """
    Input:
        npoint:
        radius:
        nsample:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, npoint, nsample, 3]
        new_points: sampled points data, [B, npoint, nsample, 3+D]
    """
    B, N, C = xyz.shape
    S = npoint
    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
    new_xyz = index_points(xyz, fps_idx)
    idx = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)

    if points is not None:
        grouped_points = index_points(points, idx)
        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
    else:
        new_points = grouped_xyz_norm
    if returnfps:
        return new_xyz, new_points, grouped_xyz, fps_idx
    else:
        return new_xyz, new_points

# FPSとグルーピングをまとめた関数
def sample_and_group_all(xyz, points):
    """
    Input:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, 1, 3]
        new_points: sampled points data, [B, 1, N, 3+D]
    """
    device = xyz.device
    B, N, C = xyz.shape
    new_xyz = torch.zeros(B, 1, C).to(device)
    grouped_xyz = xyz.view(B, 1, N, C)
    if points is not None:
        new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
    else:
        new_points = grouped_xyz
    return new_xyz, new_points

# 特徴抽出を行うクラス
class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
        super(PointNetSetAbstraction, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
        self.group_all = group_all

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        if self.group_all:
            new_xyz, new_points = sample_and_group_all(xyz, points) # サブサンプリングとグルーピング
        else:
            new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
        # new_xyz: sampled points position data, [B, npoint, C]
        # new_points: sampled points data, [B, npoint, nsample, C+D]
        new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points =  F.relu(bn(conv(new_points))) # MLPを適用

        new_points = torch.max(new_points, 2)[0] #MaxPoolingで代表値を決める
        new_xyz = new_xyz.permute(0, 2, 1)
        return new_xyz, new_points


class PointNetSetAbstractionMsg(nn.Module):
    def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
        super(PointNetSetAbstractionMsg, self).__init__()
        self.npoint = npoint
        self.radius_list = radius_list
        self.nsample_list = nsample_list
        self.conv_blocks = nn.ModuleList()
        self.bn_blocks = nn.ModuleList()
        for i in range(len(mlp_list)):
            convs = nn.ModuleList()
            bns = nn.ModuleList()
            last_channel = in_channel + 3
            for out_channel in mlp_list[i]:
                convs.append(nn.Conv2d(last_channel, out_channel, 1)) 
                bns.append(nn.BatchNorm2d(out_channel))
                last_channel = out_channel
            self.conv_blocks.append(convs)
            self.bn_blocks.append(bns)

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        B, N, C = xyz.shape
        S = self.npoint
        new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
        new_points_list = []
        for i, radius in enumerate(self.radius_list):
            K = self.nsample_list[i]
            group_idx = query_ball_point(radius, K, xyz, new_xyz)
            grouped_xyz = index_points(xyz, group_idx)
            grouped_xyz -= new_xyz.view(B, S, 1, C)
            if points is not None:
                grouped_points = index_points(points, group_idx)
                grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
            else:
                grouped_points = grouped_xyz

            grouped_points = grouped_points.permute(0, 3, 2, 1)  # [B, D, K, S]
            for j in range(len(self.conv_blocks[i])):
                conv = self.conv_blocks[i][j]
                bn = self.bn_blocks[i][j]
                grouped_points =  F.relu(bn(conv(grouped_points)))
            new_points = torch.max(grouped_points, 2)[0]  # [B, D', S]
            new_points_list.append(new_points)

        new_xyz = new_xyz.permute(0, 2, 1)
        new_points_concat = torch.cat(new_points_list, dim=1)
        return new_xyz, new_points_concat




In [None]:
import torch.nn as nn
import torch.nn.functional as F



class get_model(nn.Module):
    def __init__(self,num_class,normal_channel=True):
        super(get_model, self).__init__()
        in_channel = 6 if normal_channel else 3
        self.normal_channel = normal_channel
        self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 64, 128], group_all=False)
        self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False)
        self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True)
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(0.4)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(0.4)
        self.fc3 = nn.Linear(256, num_class)

    def forward(self, xyz):
        B, _, _ = xyz.shape
        if self.normal_channel:
            norm = xyz[:, 3:, :]
            xyz = xyz[:, :3, :]
        else:
            norm = None
        l1_xyz, l1_points = self.sa1(xyz, norm)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        x = l3_points.view(B, 1024)
        x = self.drop1(F.relu(self.bn1(self.fc1(x))))
        x = self.drop2(F.relu(self.bn2(self.fc2(x))))
        x = self.fc3(x)
        x = F.log_softmax(x, -1)


        return x, l3_points



class get_loss(nn.Module):
    def __init__(self):
        super(get_loss, self).__init__()

    def forward(self, pred, target, trans_feat):
        total_loss = F.nll_loss(pred, target)

        return total_loss