<a href="https://colab.research.google.com/github/shnhrtkyk/JTCcode/blob/main/05_%E6%B7%B1%E5%B1%A4%E5%AD%A6%E7%BF%92%E3%81%AB%E3%82%88%E3%82%8B%E7%82%B9%E7%BE%A4%E3%81%AE%E3%82%BB%E3%83%9E%E3%83%B3%E3%83%86%E3%82%A3%E3%83%83%E3%82%AF%E3%82%BB%E3%82%B0%E3%83%A1%E3%83%B3%E3%83%86%E3%83%BC%E3%82%B7%E3%83%A7%E3%83%B3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# セマンティックセグメンテーションの実装

セマンティックセグメンテーションとは、入力された各点にクラス情報を付与する手法です。
セマンティックセグメンテーションにより、入力された点群がどのような環境なのか、具体的には道路の上に車があるなどの環境認識を3次元空間で行うことができます。

# 必要なライブラリをインストール

In [None]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

In [None]:
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install pip install torchmetrics

## ライブラリのインポート

In [None]:
import os
import random
import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn.functional as F

from torch_scatter import scatter
from torchmetrics.classification import MulticlassJaccardIndex

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, DynamicEdgeConv

## 初期化

In [None]:
config_seed = 42
config_device = 'cuda' if torch.cuda.is_available() else 'cpu'

random.seed(config_seed)
torch.manual_seed(config_seed)
device = torch.device(config_device)

# 対象のクラスを選ぶ、今回はギターのセグメンテーションをします。
config_category = "Guitar" # ["Bag", "Cap", "Car", "Chair", "Earphone", "Guitar", "Knife", "Lamp", "Laptop", "Motorbike", "Mug", "Pistol", "Rocket", "Skateboard", "Table"] 
config_random_jitter_translation = 1e-2
config_random_rotation_interval_x = 15
config_random_rotation_interval_y = 15
config_random_rotation_interval_z = 15
config_validation_split = 0.2
config_batch_size = 16
config_num_workers = 6

config_num_nearest_neighbours = 30
config_aggregation_operator = "max"
config_dropout = 0.5
config_initial_lr = 1e-3
config_lr_scheduler_step_size = 5
config_gamma = 0.8
# エポック数は1回です
config_epochs = 1

# ShapeNetというデータセットを読み込む

ShapeNetという、物体のパーツごとにクラスが付与されたデータセットを用います。
※ダウンロードに時間がかかります。

In [None]:
transform = T.Compose([
    T.RandomJitter(config_random_jitter_translation),
    T.RandomRotate(config_random_rotation_interval_x, axis=0),
    T.RandomRotate(config_random_rotation_interval_y, axis=1),
    T.RandomRotate(config_random_rotation_interval_z, axis=2)
])
pre_transform = T.NormalizeScale()

In [None]:
dataset_path = os.path.join('ShapeNet', config_category)

train_val_dataset = ShapeNet(
    dataset_path, config_category, split='trainval',
    transform=transform, pre_transform=pre_transform
)

セグメンテーションデータの前処理

In [None]:
segmentation_class_frequency = {}
for idx in tqdm(range(len(train_val_dataset))):
    pc_viz = train_val_dataset[idx].pos.numpy().tolist()
    segmentation_label = train_val_dataset[idx].y.numpy().tolist()
    for label in set(segmentation_label):
        segmentation_class_frequency[label] = segmentation_label.count(label)
class_offset = min(list(segmentation_class_frequency.keys()))
print("Class Offset:", class_offset)

for idx in range(len(train_val_dataset)):
    train_val_dataset[idx].y -= class_offset

訓練データと検証データへの切り分け

In [None]:
num_train_examples = int((1 - config_validation_split) * len(train_val_dataset))
train_dataset = train_val_dataset[:num_train_examples]
val_dataset = train_val_dataset[num_train_examples:]

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

predicted_pc_viz = train_dataset[1].pos
label = train_dataset[1].y

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ax.scatter(predicted_pc_viz[:, 0],   # x
          predicted_pc_viz[:, 1],   # y
          predicted_pc_viz[:, 2],   # z
          c=label[:], # height data for color
          cmap='Spectral',
          marker="x")
ax.axis('scaled')  # {equal, scaled}
plt.show()

今回のタスクは、物体のパーツ単位のセグメンテーションです。
したがって、今回のセグメンテーション対象であるギターの各パーツである、ヘッド・ネック・ボディに分かれています。

データローダーの設定

In [None]:
train_loader = DataLoader(
    train_dataset, batch_size=config_batch_size,
    shuffle=True, num_workers=config_num_workers
)
val_loader = DataLoader(
    val_dataset, batch_size=config_batch_size,
    shuffle=False, num_workers=config_num_workers
)
visualization_loader = DataLoader(
    val_dataset[:10], batch_size=1,
    shuffle=False, num_workers=config_num_workers
)

# PyTorch Geometricによる実装

PointNet++を利用した点群のセマンティックセグメンテーションを行います．

PointNet++は，入力された点群をダウンサンプリングしながら特徴抽出を行うエンコーダと，ダウンサンプリングされた点群の特徴を元の点群数までアップサンプリングしながら特徴を伝えるデコーダで構成されます．

エンコーダでは，Set Abstraction Module(`SAModule`)を適用します．SAModuleでは，入力された点群に対して，最遠点サンプリング（fps）によって入力された点数よりも少ない点数の代表点を抽出します．次に，この代表点に対して特徴量を計算します．特徴計算には代表点の周辺点を半径rの球の中に入る点をグループ化して，このグループ化された点に対してMLPを適用します．その後グループ内の点群
に対してMaxPoolingを適用して，代表点に特徴量を与えます．半径これは，小領域に対してPointNetを適用していることと同義です．  
なお，点群のグローバルな特徴量を用いると，セグメンテーションの性能が上がるため，一番点数の少なくなった層において，`GlobalSAModule`を用いて，代表点全体の特徴量から代表値を求めます．


デコーダでは，サブサンプリングされた点群から元の点群数に戻すために，アップサンプリングを行います．`FPModule`と呼ばれる演算を用いてアップサンプリングを行い，その中身は，knn_interpolate関数を呼び出して，ある点の最近傍の点群に対して特徴量を伝播させます．
この際，エンコーダから高解像度の点群情報を用いて，サブサンプリングされる前の高解像度な情報を組みわせます．これは，画像処理におけるセグメンテーションでも行われており，スキップコネクションとも呼ばれます．


In [None]:
from torch_geometric.nn import MLP, PointNetConv, fps, global_max_pool, radius
class SAModule(torch.nn.Module):
    def __init__(self, ratio, r, nn):
        super().__init__()
        self.ratio = ratio
        self.r = r
        self.conv = PointNetConv(nn, add_self_loops=False)

    def forward(self, x, pos, batch):
        idx = fps(pos, batch, ratio=self.ratio)
        row, col = radius(pos, pos[idx], self.r, batch, batch[idx],
                          max_num_neighbors=64)
        edge_index = torch.stack([col, row], dim=0)
        x_dst = None if x is None else x[idx]
        x = self.conv((x, x_dst), (pos, pos[idx]), edge_index)
        pos, batch = pos[idx], batch[idx]
        return x, pos, batch


class GlobalSAModule(torch.nn.Module):
    def __init__(self, nn):
        super().__init__()
        self.nn = nn

    def forward(self, x, pos, batch):
        x = self.nn(torch.cat([x, pos], dim=1))
        x = global_max_pool(x, batch)
        pos = pos.new_zeros((x.size(0), 3))
        batch = torch.arange(x.size(0), device=batch.device)
        return x, pos, batch


In [None]:
from torch_geometric.nn import MLP, knn_interpolate
from torchmetrics.functional import jaccard_index

class FPModule(torch.nn.Module):
    def __init__(self, k, nn):
        super().__init__()
        self.k = k
        self.nn = nn

    def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):
        x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
        if x_skip is not None:
            x = torch.cat([x, x_skip], dim=1)
        x = self.nn(x)
        return x, pos_skip, batch_skip


class Net(torch.nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        # Input channels account for both `pos` and node features.
        self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128]))
        self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))
        self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))

        self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
        self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
        self.fp1_module = FPModule(3, MLP([128 + 3, 128, 128, 128]))

        self.mlp = MLP([128, 128, 128, num_classes], dropout=0.5, norm=None)

        self.lin1 = torch.nn.Linear(128, 128)
        self.lin2 = torch.nn.Linear(128, 128)
        self.lin3 = torch.nn.Linear(128, num_classes)

    def forward(self, data):
        sa0_out = (data.x, data.pos, data.batch)
        sa1_out = self.sa1_module(*sa0_out)
        sa2_out = self.sa2_module(*sa1_out)
        sa3_out = self.sa3_module(*sa2_out)

        fp3_out = self.fp3_module(*sa3_out, *sa2_out)
        fp2_out = self.fp2_module(*fp3_out, *sa1_out)
        x, _, _ = self.fp1_module(*fp2_out, *sa0_out)

        return self.mlp(x).log_softmax(dim=-1)

In [None]:
from torch_geometric.nn import MLP, knn_interpolate
from torchmetrics.functional import jaccard_index
config_num_classes = train_dataset.num_classes
print(config_num_classes)
model = Net(train_dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def train():
    model.train()

    total_loss = correct_nodes = total_nodes = 0
    for i, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        correct_nodes += out.argmax(dim=1).eq(data.y).sum().item()
        total_nodes += data.num_nodes

        if (i + 1) % 10 == 0:
            print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss / 10:.4f} '
                  f'Train Acc: {correct_nodes / total_nodes:.4f}')
            total_loss = correct_nodes = total_nodes = 0

@torch.no_grad()
def val_step(epoch):
    model.eval()

    ious, categories = [], []
    total_loss = correct_nodes = total_nodes = 0
    y_map = torch.empty(
        val_loader.dataset.num_classes, device=device
    ).long()
    num_val_examples = len(val_loader)
    
    progress_bar = tqdm(
        val_loader, desc=f"Validating Epoch {epoch}/{config_epochs}"
    )
    
    for data in progress_bar:
        data = data.to(device)
        outs = model(data)
        
        loss = F.nll_loss(outs, data.y)
        total_loss += loss.item()
        
        correct_nodes += outs.argmax(dim=1).eq(data.y).sum().item()
        total_nodes += data.num_nodes

        sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
        for out, y, category in zip(outs.split(sizes), data.y.split(sizes),
                                    data.category.tolist()):
            category = list(ShapeNet.seg_classes.keys())[category]
            part = ShapeNet.seg_classes[category]
            part = torch.tensor(part, device=device)

            y_map[part] = torch.arange(part.size(0), device=device)
            jaccard_index = MulticlassJaccardIndex(num_classes=part.size(0)).to(device)

            iou = jaccard_index(
                out[:, part].argmax(dim=-1), y_map[y]
            )
            ious.append(iou)

        categories.append(data.category)

    iou = torch.tensor(ious, device=device)
    category = torch.cat(categories, dim=0)
    mean_iou = float(scatter(iou, category, reduce='mean').mean())
    
    return {
        "Validation/Loss": total_loss / num_val_examples,
        "Validation/Accuracy": correct_nodes / total_nodes,
        "Validation/IoU": mean_iou
    }





In [None]:
@torch.no_grad()
def visualization_step(epoch):
    model.eval()
    for data in tqdm(visualization_loader):
        data = data.to(device)
        outs = model(data)

        predicted_labels = outs.argmax(dim=1)
        accuracy = predicted_labels.eq(data.y).sum().item() / data.num_nodes

        sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
        ious, categories = [], []
        y_map = torch.empty(
            visualization_loader.dataset.num_classes, device=device
        ).long()
        for out, y, category in zip(
            outs.split(sizes), data.y.split(sizes), data.category.tolist()
        ):
            category = list(ShapeNet.seg_classes.keys())[category]
            part = ShapeNet.seg_classes[category]
            part = torch.tensor(part, device=device)
            y_map[part] = torch.arange(part.size(0), device=device)
            jaccard_index = MulticlassJaccardIndex(num_classes=part.size(0)).to(device)
            iou = jaccard_index(
                out[:, part].argmax(dim=-1), y_map[y]
            )
            ious.append(iou)
        categories.append(data.category)
        iou = torch.tensor(ious, device=device)
        category = torch.cat(categories, dim=0)
        mean_iou = float(scatter(iou, category, reduce='mean').mean())

        gt_pc_viz = data.pos.cpu().numpy().tolist()
        segmentation_label = data.y.cpu().numpy().tolist()
        predicted_labels =  predicted_labels.cpu().numpy().tolist()
        frequency_dict = {key: 0 for key in segmentation_class_frequency.keys()}
        for label in set(segmentation_label):
            frequency_dict[label] = segmentation_label.count(label)
        for j in range(len(gt_pc_viz)):
            # gt_pc_viz[j] += [segmentation_label[j] + 1 - class_offset]
            gt_pc_viz[j] += [segmentation_label[j] + 1]

        predicted_pc_viz = data.pos.cpu().numpy().tolist()
        segmentation_label = data.y.cpu().numpy().tolist()
        frequency_dict = {key: 0 for key in segmentation_class_frequency.keys()}
        for label in set(segmentation_label):
            frequency_dict[label] = segmentation_label.count(label)
        for j in range(len(predicted_pc_viz)):
            # predicted_pc_viz[j] += [segmentation_label[j] + 1 - class_offset]
            predicted_pc_viz[j] += [predicted_labels[j] + 1]

    
    return predicted_pc_viz

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=config_lr_scheduler_step_size, gamma=config_gamma
)
for epoch in range(1, config_epochs + 1):
    predicted_pc_viz = visualization_step(epoch)
    predicted_pc_viz = np.array(predicted_pc_viz)




    train()
    val_metrics = val_step(epoch)
    
    metrics = {**val_metrics}
    print(metrics)
    metrics["learning_rate"] = scheduler.get_last_lr()[-1]


    
    scheduler.step()

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    ax.scatter(predicted_pc_viz[:, 0],   # x
              predicted_pc_viz[:, 1],   # y
              predicted_pc_viz[:, 2],   # z
              c=predicted_pc_viz[:, 3], # height data for color
              cmap='Spectral',
              marker="x")
    ax.axis('scaled')  # {equal, scaled}
    plt.show()

今回はDGCNNという点群の深層学習手法を実装します。
DGCNNはDynamic Graph CNNという手法で、各点に対して特徴量空間のKNNで周辺の点を検索し、そのKNNの点に対してDynamicEdgeConvを適用します。



層の構造は、各点に対して`DynamicEdgeConv`を適用します。
このDynamicEdgeConvを3回くりかえして、周辺点の情報を集約していきます。
その後、最終層で各点ごとにクラス分類を行います。
クラス分類には、単純なＭＬＰを適用し、出力次元数は分類したいクラス数と同じにします。
この実装では、返り値を計算する際に、`F.log_softmax`を適用します。

In [None]:
class DGCNN(torch.nn.Module):
    def __init__(self, out_channels, k=30, aggr='max'):
        super().__init__()

        self.conv1 = DynamicEdgeConv(MLP([2 * 6, 64, 64]), k, aggr)
        self.conv2 = DynamicEdgeConv(MLP([2 * 64, 64, 64]), k, aggr)
        self.conv3 = DynamicEdgeConv(MLP([2 * 64, 64, 64]), k, aggr)

        self.mlp = MLP(
            [3 * 64, 1024, 256, 128, out_channels],
            dropout=0.5, norm=None
        )

    def forward(self, data):
        x, pos, batch = data.x, data.pos, data.batch
        x0 = torch.cat([x, pos], dim=-1)
        
        x1 = self.conv1(x0, batch)
        x2 = self.conv2(x1, batch)
        x3 = self.conv3(x2, batch)
        
        out = self.mlp(torch.cat([x1, x2, x3], dim=1))
        return F.log_softmax(out, dim=1)

モデルの呼び出し

In [None]:
config_num_classes = train_dataset.num_classes
print(config_num_classes)
model = DGCNN(
    out_channels=train_dataset.num_classes,
    k=config_num_nearest_neighbours,
    aggr=config_aggregation_operator
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config_initial_lr)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=config_lr_scheduler_step_size, gamma=config_gamma
)

# 訓練の実装

In [None]:
def train_step(epoch):
    model.train()
    
    ious, categories = [], []
    total_loss = correct_nodes = total_nodes = 0
    y_map = torch.empty(
        train_loader.dataset.num_classes, device=device
    ).long()
    num_train_examples = len(train_loader)
    
    progress_bar = tqdm(
        train_loader, desc=f"Training Epoch {epoch}/{config_epochs}"
    )
    
    for data in progress_bar:
        data = data.to(device)
        
        optimizer.zero_grad()
        outs = model(data)
        loss = F.nll_loss(outs, data.y)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        correct_nodes += outs.argmax(dim=1).eq(data.y).sum().item()
        total_nodes += data.num_nodes
        
        sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
        for out, y, category in zip(outs.split(sizes), data.y.split(sizes),
                                    data.category.tolist()):
            category = list(ShapeNet.seg_classes.keys())[category]
            part = ShapeNet.seg_classes[category]
            part = torch.tensor(part, device=device)

            y_map[part] = torch.arange(part.size(0), device=device)

            jaccard_index = MulticlassJaccardIndex(num_classes=part.size(0)).to(device)

            iou = jaccard_index(
                out[:, part].argmax(dim=-1), y_map[y]
            )
            ious.append(iou)

        categories.append(data.category)
        
    iou = torch.tensor(ious, device=device)
    category = torch.cat(categories, dim=0)
    mean_iou = float(scatter(iou, category, reduce='mean').mean())
    
    return {
        "Train/Loss": total_loss / num_train_examples,
        "Train/Accuracy": correct_nodes / total_nodes,
        "Train/IoU": mean_iou
    }

In [None]:
@torch.no_grad()
def val_step(epoch):
    model.eval()

    ious, categories = [], []
    total_loss = correct_nodes = total_nodes = 0
    y_map = torch.empty(
        val_loader.dataset.num_classes, device=device
    ).long()
    num_val_examples = len(val_loader)
    
    progress_bar = tqdm(
        val_loader, desc=f"Validating Epoch {epoch}/{config_epochs}"
    )
    
    for data in progress_bar:
        data = data.to(device)
        outs = model(data)
        
        loss = F.nll_loss(outs, data.y)
        total_loss += loss.item()
        
        correct_nodes += outs.argmax(dim=1).eq(data.y).sum().item()
        total_nodes += data.num_nodes

        sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
        for out, y, category in zip(outs.split(sizes), data.y.split(sizes),
                                    data.category.tolist()):
            category = list(ShapeNet.seg_classes.keys())[category]
            part = ShapeNet.seg_classes[category]
            part = torch.tensor(part, device=device)

            y_map[part] = torch.arange(part.size(0), device=device)
            jaccard_index = MulticlassJaccardIndex(num_classes=part.size(0)).to(device)

            iou = jaccard_index(
                out[:, part].argmax(dim=-1), y_map[y]
            )
            ious.append(iou)

        categories.append(data.category)

    iou = torch.tensor(ious, device=device)
    category = torch.cat(categories, dim=0)
    mean_iou = float(scatter(iou, category, reduce='mean').mean())
    
    return {
        "Validation/Loss": total_loss / num_val_examples,
        "Validation/Accuracy": correct_nodes / total_nodes,
        "Validation/IoU": mean_iou
    }

結果を取得する関数

In [None]:
@torch.no_grad()
def visualization_step(epoch):
    model.eval()
    for data in tqdm(visualization_loader):
        data = data.to(device)
        outs = model(data)

        predicted_labels = outs.argmax(dim=1)
        accuracy = predicted_labels.eq(data.y).sum().item() / data.num_nodes

        sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
        ious, categories = [], []
        y_map = torch.empty(
            visualization_loader.dataset.num_classes, device=device
        ).long()
        for out, y, category in zip(
            outs.split(sizes), data.y.split(sizes), data.category.tolist()
        ):
            category = list(ShapeNet.seg_classes.keys())[category]
            part = ShapeNet.seg_classes[category]
            part = torch.tensor(part, device=device)
            y_map[part] = torch.arange(part.size(0), device=device)
            jaccard_index = MulticlassJaccardIndex(num_classes=part.size(0)).to(device)
            iou = jaccard_index(
                out[:, part].argmax(dim=-1), y_map[y]
            )
            ious.append(iou)
        categories.append(data.category)
        iou = torch.tensor(ious, device=device)
        category = torch.cat(categories, dim=0)
        mean_iou = float(scatter(iou, category, reduce='mean').mean())

        gt_pc_viz = data.pos.cpu().numpy().tolist()
        segmentation_label = data.y.cpu().numpy().tolist()
        predicted_labels =  predicted_labels.cpu().numpy().tolist()
        frequency_dict = {key: 0 for key in segmentation_class_frequency.keys()}
        for label in set(segmentation_label):
            frequency_dict[label] = segmentation_label.count(label)
        for j in range(len(gt_pc_viz)):
            # gt_pc_viz[j] += [segmentation_label[j] + 1 - class_offset]
            gt_pc_viz[j] += [segmentation_label[j] + 1]

        predicted_pc_viz = data.pos.cpu().numpy().tolist()
        segmentation_label = data.y.cpu().numpy().tolist()
        frequency_dict = {key: 0 for key in segmentation_class_frequency.keys()}
        for label in set(segmentation_label):
            frequency_dict[label] = segmentation_label.count(label)
        for j in range(len(predicted_pc_viz)):
            # predicted_pc_viz[j] += [segmentation_label[j] + 1 - class_offset]
            predicted_pc_viz[j] += [predicted_labels[j] + 1]

    
    return predicted_pc_viz

In [None]:
def save_checkpoint(epoch):
    """Save model checkpoints as Weights & Biases artifacts"""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, "checkpoint.pt")
    


実際の訓練を回してみましょう。
今回は1エポックですが、`config_epochs`の値を大きくすると学習する回数が多くなります。

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

for epoch in range(1, config_epochs + 1):
    predicted_pc_viz = visualization_step(epoch)
    predicted_pc_viz = np.array(predicted_pc_viz)




    train_metrics = train_step(epoch)
    val_metrics = val_step(epoch)
    
    metrics = {**train_metrics, **val_metrics}
    print(metrics)
    metrics["learning_rate"] = scheduler.get_last_lr()[-1]


    
    scheduler.step()
    save_checkpoint(epoch)
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    ax.scatter(predicted_pc_viz[:, 0],   # x
              predicted_pc_viz[:, 1],   # y
              predicted_pc_viz[:, 2],   # z
              c=predicted_pc_viz[:, 3], # height data for color
              cmap='Spectral',
              marker="x")
    ax.axis('scaled')  # {equal, scaled}
    plt.show()


学習回数が1エポックなので、あまり学習が進んでおわず、パーツ単位のセグメンテーションがうまくいっていないことがわかると思います。今回の講義では、時間の制約で1エポックとしていますが、エポック数を長く設定するともう少しセグメンテーションの結果が良くなります。