# PRSNET 实验复现

1. 数据处理：加载点云文件，每个点云文件包括点的坐标和颜色，从点云文件中加载出：体素表示、网格点到最近点的距离。注意，训练数据要使用.mat进行存储，将下面的代码块中的file_path修改为存储数据的文件地址。

In [16]:
import torch
import numpy
from scipy.io import loadmat
import os

def loadmat_dir(dir_path):
    '''
    读取.mat文件
    '''
    volumes = [] # 物体的体素表示 (32, 32, 32)
    surfaceSamples = [] # 物体的表面采样点 (3, 1000)
    closestPoints = [] # 标准网格点的最近点，(32, 32, 32, 3)
    
    for file in os.listdir(dir_path):
        if file.endswith('.mat'):
            file_path = os.path.join(dir_path, file)
            data = loadmat(file_path)
            volumes.append(data['Volume'])
            surfaceSamples.append(data['surfaceSamples'])
            closestPoints.append(data['closestPoints'])
    
    volumes = torch.tensor(volumes).to(torch.float32)
    surfaceSamples = torch.tensor(surfaceSamples).permute(0, 2, 1).to(torch.float32)
    closestPoints = torch.tensor(closestPoints).to(torch.float32)

    data = {
        'Volumes': volumes,
        'surfaceSamples': surfaceSamples,
        'closestPoints': closestPoints
    }

    return data

# # 替换为你的.mat文件路径
# file_path = 'datasets/shapenet/train'

# # 加载.mat文件
# data = loadmat_dir(file_path)
# print(data.keys())
# print(data['Volumes'].shape)
# print(data['surfaceSamples'].shape)
# print(data['closestPoints'].shape)

In [17]:
import random

# 生成一个小批量，小批量数据X长度为batch_size
# Volume.shape = (batch_size, 32, 32, 32)
# surfaceSamples.shape = (batch_size, 1000, 3)
# closestPoints = (batch_size, 32, 32, 32, 3)
def data_iter_random(data, batch_size, shuffle=True):
    """使用随机抽样生成一个小批量子序列"""
    # 对data进行shuffle
    num_examples = len(data['Volumes'])
    initial_indices = list(range(0, num_examples))
    if shuffle:
        random.shuffle(initial_indices)
    num_batches = num_examples // batch_size
    for i in range(0, batch_size * num_batches, batch_size):
        initial_indices_per_batch = initial_indices[i: i + batch_size]
        Volume = data['Volumes'][initial_indices_per_batch]
        surfaceSamples = data['surfaceSamples'][initial_indices_per_batch]
        closestPoints = data['closestPoints'][initial_indices_per_batch]
        yield Volume, surfaceSamples, closestPoints

class SeqDataLoader:
    """加载序列数据的迭代器"""
    def __init__(self, file_path, batch_size, shuffle=True):
        self.data_iter_fn = data_iter_random
        self.data = loadmat_dir(file_path)
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        return self.data_iter_fn(self.data, self.batch_size, self.shuffle)

# data_loader = SeqDataLoader(file_path, batch_size=1, shuffle=False)
# for Volume, surfaceSamples, closestPoints in data_loader:
#     print(Volume.shape)
#     print(surfaceSamples.shape)
#     print(surfaceSamples[0, 0])
#     print(closestPoints.shape)
#     print(closestPoints[0, 0, 0, 0])
#     break

2. 设计网络结构：

In [18]:
from torch import nn
import math
class MyLinear(nn.Module):
    def __init__(self, n, in_features, out_features):
        super(MyLinear, self).__init__()
        self.n = n
        self.weight = nn.Parameter(torch.randn(n, in_features, out_features))
        self.bias = nn.Parameter(torch.zeros(n, out_features))

    def forward(self, x):
        # x 的形状为 (b, n, in) 或者 (b, n * in)
        # 批量矩阵乘法，n 个不同的全连接层应用到对应的 n 维度的子向量上
        # torch.einsum 执行 (b, n, in) @ (n, in, out) -> (b, n, out)
        x = x.view(x.size(0), self.n, -1)
        output = torch.einsum('bni,nio->bno', x, self.weight)
        # 加上偏置，使用广播机制，偏置的形状为 (n, out)
        output = output + self.bias
        return output
    
class PRS_NET(nn.Module):
    # input_size: 输入体素的分辨率大小
    # output_size: 输出的平面和四元数的个数
    def __init__(self, intput_size=32, output_size=3):
        super(PRS_NET, self).__init__()
        # The CNN has five 3D convolution layers of 
        # kernel size 3, padding 1, and stride 1. 
        # After each 3D convolution, a max pooling of kernel size 2 
        # and leaky ReLU activation are applied.
        self.output_size = output_size
        conv_layer_depth = 6
        conv = nn.Sequential()
        conv.add_module(f'conv0', nn.Conv3d(in_channels=1, out_channels=4, kernel_size=3, padding=1))
        conv.add_module(f'pool0', nn.MaxPool3d(kernel_size=2))
        conv.add_module(f'leaky_relu0', nn.LeakyReLU())
        for i in range(2, conv_layer_depth):
            conv.add_module(f'conv{i-1}', nn.Conv3d(in_channels=2**i, out_channels=2**(i+1), kernel_size=3, padding=1))
            conv.add_module(f'pool{i-1}', nn.MaxPool3d(kernel_size=2))
            conv.add_module(f'leaky_relu{i-1}', nn.LeakyReLU())
        self.conv = conv

        # fnn的各层大小默认为32、16、4
        fnn_num = output_size * 2
        fnn = nn.Sequential()
        fnn.add_module(f'fnn{i}_layer0', nn.Linear(64, 32 * fnn_num))
        fnn.add_module(f'fnn{i}_leaky_relu0', nn.LeakyReLU())
        fnn.add_module(f'fnn{i}_layer1', MyLinear(fnn_num, 32, 16))
        fnn.add_module(f'fnn{i}_leaky_relu1', nn.LeakyReLU())
        fnn.add_module(f'fnn{i}_layer2', MyLinear(fnn_num, 16, 4))
        self.fnn = fnn
        self.fnn_num = fnn_num

    def forward(self, volume):
        # volume: (batch_size, 32, 32, 32)
        # print("volume:", volume.size)
        # output: plane, quat: (batch_size, 4)
        volume = volume.unsqueeze(1)
        # volume: (batch_size, 1, 32, 32, 32)
        conv_output = self.conv(volume)
        # print("conv_output:", conv_output.size)
        # conv_output.size = (batch_size, 32, 1, 1, 1)
        # flatten.size = (batch_size, 32)
        flatten = conv_output.view(conv_output.size(0), -1)
        # print("flatten:", flatten.size)
        # fnn_output.size = (output_size * 2, batch_size, 4)
        fnn_output = self.fnn(flatten).permute(1,0,2)
        # print("fnn_output:", fnn_output.size)

        planes = fnn_output[:3]
        quats = fnn_output[3:]

        # planes.shape: (output_size, batch_size, 4)
        # quats.shape: (output_size, batch_size, 4)
        return planes, quats
    
net = PRS_NET()
print(net)

PRS_NET(
  (conv): Sequential(
    (conv0): Conv3d(1, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (pool0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (leaky_relu0): LeakyReLU(negative_slope=0.01)
    (conv1): Conv3d(4, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (pool1): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (leaky_relu1): LeakyReLU(negative_slope=0.01)
    (conv2): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (pool2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (leaky_relu2): LeakyReLU(negative_slope=0.01)
    (conv3): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (pool3): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (leaky_relu3): LeakyReLU(negative_slope=0.01)
    (conv4): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), p

3. 设计损失函数：包括距离损失和正则化损失

In [88]:
# 体素坐标系与点云坐标系之间的转换：(0 - 32)以及(-0.5 - 0.5)
def point2voxel(px, gridSize=32, gridBound=0.5):
    gridMin = -gridBound + gridBound / gridSize
    vx = (px - gridMin) * gridSize / (2 * gridBound)
    return vx

def voxel2point(x, gridSize=32, gridBound=0.5):
    gridMin = -gridBound + gridBound / gridSize
    return x * (2 * gridBound) / 32 + gridMin

"""四元数乘法，支持批量操作"""
def quaternion_multiply(q1, q2):
    # q1.shape = (batch, 1, 4)
    # q2.shape = (batch, 1000, 4)
    w1, x1, y1, z1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3]
    w2, x2, y2, z2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3]
    w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
    x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
    y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
    z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
    return torch.stack([w, x, y, z], dim=-1)

"""四元数共轭"""
def quaternion_conjugate(q):
    w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
    return torch.stack([w, -1 * x, -1 * y, -1 * z], dim=-1)

"""使用四元数q旋转点p，支持批量操作"""
def rotate_point(p, q):
    # p.shape = (batch, 1000, 3)
    # q.shape = (batch, 1, 4)
    p_quat = torch.cat([torch.zeros(p.shape[0], p.shape[1], 1, device=p.device), p], dim=-1)
    q_conj = quaternion_conjugate(q)
    p_rotated = quaternion_multiply(quaternion_multiply(q, p_quat), q_conj)
    return p_rotated[..., 1:]  # 返回旋转后的点坐标

"""计算点关于一个平面的对称点"""
def plane_point(p, plane):
    # p.shape = (batch, 1000, 3)
    # plane.shape = (batch, 1, 4)
    length = 2 * (torch.sum(p * plane[...,:3], dim=2, keepdim=True) + plane[...,3:4]) # (batch_size, 1000, 1)
    norm2 = torch.sum(plane[..., :3] ** 2, dim=2, keepdim=True) # (batch_size, 1, 1)
    p1 = p - length * plane[...,:3].repeat(1, length.shape[1], 1) / (norm2 + 1e-8) # (batch_size, 1000, 3)
    return p1

"""计算点到最近点的距离，输入为点云坐标单位，输出也为点云坐标单位"""
def calculate_distance(points, closestPoints, volume, device, gridSize=32):
    # points.shape = (batch_size, 1000, 3)
    # closestPoints.shape = (batch_size, 32, 32, 32, 3)
    # volume.shape = (batch_size, 32, 32, 32)
    # print(points.shape, closestPoints.shape, volume.shape)
    # torch.Size([1, 1000, 3]) torch.Size([1, 32, 32, 32, 3]) torch.Size([1, 32, 32, 32])

    # 寻找这个点所在的网格编号
    inds = point2voxel(points)
    # inds.shape = (batch_size, 1000, 3)
    inds = torch.round(torch.clamp(inds, min=0, max=gridSize-1))
    # inds.shape = (batch_size, 1000)
    inds = torch.matmul(inds, torch.FloatTensor([gridSize**2, gridSize, 1]).to(device)).long()
    # v.shape = (batch_size, 32*32*32)
    v = volume.view(-1, gridSize**3)
    # 这里需要对距离计算一个mask，因为有些点对应的网格是有像素的，因此这时候再算距离就不准确了，因此需要将这些点的距离置为0
    # mask.shape = (batch_size, 1000, 1)
    mask = (1 - torch.gather(v, 1, inds)).unsqueeze(2)
    inds = inds.unsqueeze(2).repeat(1, 1, 3)
    # inds.shape = (batch_size, 1000, 3)
    cps = closestPoints.reshape(closestPoints.shape[0], -1, 3) # (batch_size, 32*32*32, 3)
    cps = torch.gather(cps, 1, inds).to(device) 
    # cps.shape = (batch_size, 1000, 3)
    # ------------
    return (points - cps) * mask, cps * mask

"""计算对称性损失"""
def sym_loss(planes, quats, closestPoints, surfaceSamples, volume, device):
    # planes.shape = (output_size, batch_size, 4)
    # quats.shape = (output_size, batch_size, 4)
    # closestPoints.shape = (batch_size, 32, 32, 32, 3)
    # surfaceSamples.shape = (batch_size, 1000, 3)
    loss_planes_sym = 0
    loss_quats_sym = 0
    for i in range(planes.size(0)):
        plane = planes[i].unsqueeze(1).float() # (batch_size, 1, 4)
        quat = quats[i].unsqueeze(1).float() # (batch_size, 1, 4)

        sym_Points_plane = plane_point(surfaceSamples, plane) # (batch_size, 1000, 3)
        distance, _ = calculate_distance(sym_Points_plane, closestPoints, volume, device) # (batch_size, 1000, 3)
        loss_planes_sym += torch.mean(torch.sum(torch.norm(distance, dim=2), dim=1))

        sym_Points_quat = rotate_point(surfaceSamples, quat)
        distance, _ = calculate_distance(sym_Points_quat, closestPoints, volume, device)
        loss_quats_sym += torch.mean(torch.sum(torch.norm(distance, dim=2), dim=1))

    return loss_planes_sym / planes.size(0), loss_quats_sym / quats.size(0)

"""计算正则化损失""" 
def reg_loss(planes, quats, device):
    # planes.shape = (output_size, batch_size, 4)
    # quats.shape = (output_size, batch_size, 4)
    eye = torch.eye(3).unsqueeze(0).to(device)
    M1 = planes[..., :3].permute(1, 0, 2) # (batch_size, 3, 3)
    # 对M的列向量做归一化
    M1 = M1.div(torch.norm(M1, dim=2, keepdim=True) + 1e-8)
    M1_T = M1.permute(0, 2, 1) # (batch_size, 3, output_size)
    loss_planes_reg = (torch.matmul(M1, M1_T) - eye).pow(2).sum(2).sum(1).mean()
    
    M2 = quats[..., 1:4].permute(1, 0, 2) # (batch_size, output_size, 3)
    M2 = M2.div(torch.norm(M2, dim=2, keepdim=True) + 1e-8)
    M2_T = M2.permute(0, 2, 1) # (batch_size, 3, output_size)
    loss_quats_reg = (torch.matmul(M2, M2_T) - eye).pow(2).sum(2).sum(1).mean()

    return loss_planes_reg, loss_quats_reg

# 数据中的网格点实际上是以0，0为中心，每两个网格点跨度为gridBound/gridSize的网格点
# 这是为了和点云中的坐标表示对齐
# 但是在就算最近点的时候，需要将坐标转换回32*32*32，用于查询closestPoints数组
# 这段代码参考了prsnet源码

# M1 = (plane1[:3], plane2.., plane3..) = (batch_size, 3, 3)
# M2 = (quat_dir1, quat_dir2, quat_dir3) = (batch_size, 3, 3)
# quat_dir1为四元数对应到的归一化的旋转向量

## 可视化结果

In [37]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial.transform import Rotation as R
import matplotlib.animation as animation
import trimesh
import cv2
'''用于可视化结果'''
def view_result(save_path, obj_path=None, volume=None, points=None, sym_points=None, planes=None, quats=None, gridSize=32):
    '''
        points(N个采样点):       (N, 3)
        sym_points(采样点的对称点):(N, 3)
        plane(对称平面):          (4)
        输出三维视图的视频到name中
    '''
    # 创建3D图像
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # 画出原始模型
    if obj_path is not None:
        obj_path = os.path.join(obj_path, "model_normalized.obj")
        scene = trimesh.load(obj_path)
        for _, mesh in scene.geometry.items():
            # 提取顶点和面
            vertices = point2voxel(mesh.vertices)
            faces = mesh.faces
            # 绘制三角面片
            ax.plot_trisurf(vertices[:, 0], vertices[:, 1], vertices[:, 2], linewidth=0.2, triangles=faces, alpha=0.7, edgecolor='gray')
    
    # 绘制体素
    if volume is not None:
        ax.voxels(volume, facecolors='cyan', edgecolors='k')
    
    # 画出一些点
    if points is not None:
        points = point2voxel(points.cpu())
        ax.scatter(points[:, 0], points[:, 1], points[:, 2], c='g')
    if sym_points is not None:
        sym_points = point2voxel(sym_points.cpu())
        ax.scatter(sym_points[:, 0], sym_points[:, 1], sym_points[:, 2], c='r')
    
    # 画出对称平面
    if planes is not None:
        for j, planej in enumerate(planes):
            planej = planej.cpu().numpy()[0]

            # # 定义平面和旋转轴
            # plane_normal = np.array([1, 0, 0])  # 平面的法向量 (x, y, z)
            # d = 5  # 平面与原点的距离
            # quaternion = [0, 0, np.sin(np.pi/4), np.cos(np.pi/4)]  # 四元数表示的旋转 (这里是90度绕z轴旋转)

            # 定义平面上的点
            # xx, yy = np.meshgrid(range(int(voxel.shape[0]/2)-5,int(voxel.shape[0]/2+5)), range(int(voxel.shape[1]/2-5),int(voxel.shape[1]/2)+5))
            # zz = (-planej[..., 0] * xx - planej[..., 1] * yy - planej[..., 3]) * 1. / planej[..., 2]
            indices = np.argsort(np.abs(planej[:3]), axis=-1)

            xx, yy = np.meshgrid(range(int(gridSize/2-10),int(gridSize/2+10)), range(int(gridSize/2)-10,int(gridSize/2)+10))
            xx1 = voxel2point(xx)
            yy1 = voxel2point(yy)
            zz = (-planej[indices[0]] * xx1 - planej[indices[1]] * yy1 - planej[3]) * 1. / planej[indices[2]]
            zz = point2voxel(zz)
            # zz = np.round(np.clip(zz, a_min=0, a_max=gridSize-1))
            xyz = np.empty(shape=(3,)+xx.shape)
            xyz[indices[0]] = xx
            xyz[indices[1]] = yy
            xyz[indices[2]] = zz

            # 绘制平面
            if(j % 3 == 0):
                ax.plot_surface(xyz[0], xyz[1], xyz[2], alpha=0.7, color='lightblue')
            if(j % 3 == 1):
                ax.plot_surface(xyz[0], xyz[1], xyz[2], alpha=0.7, color='lightgreen')
            if(j % 3 == 2):
                ax.plot_surface(xyz[0], xyz[1], xyz[2], alpha=0.7, color='lightpink')
        
    if quats is not None:
        for j, quatj in enumerate(quats):
            # 生成旋转矩阵
            quatj = quatj.cpu().numpy()
            rotation = R.from_quat(quatj)
            # 旋转向量
            start = np.array([gridSize/2, gridSize/2, gridSize/2])
            axis = rotation.as_rotvec()[0]  # 获取旋转向量
            # end = start + axis  # 旋转后的终点
            # 绘制旋转轴
            ax.quiver(start[0], start[1], start[2], axis[0], axis[1], axis[2], color='m', length=np.linalg.norm(axis)*2)


    # 设置图像显示范围
    ax.set_xlim(0, gridSize)
    ax.set_ylim(0, gridSize)
    ax.set_zlim(0, gridSize)

    # 显示图像
    def update(num):
        ax.view_init(elev=num, azim=num)

    ani = animation.FuncAnimation(fig, update, frames=np.arange(0, 360, 10), interval=100)
    ani.save(save_path, writer='ffmpeg', fps=10)
    plt.close(fig)

# points = torch.tensor([[0.2,0.2,0.2], [0.1,0.1,0.1], [0.4,0.4,0.4]])
# plane = torch.tensor([0,0,0,0.05])
# sym_points = torch.tensor([[-0.2,-0.2,-0.2], [-0.1,-0.1,-0.1], [-0.4,-0.4,-0.4]])
# generate_result(points, sym_points, plane, name='test/images/1')

In [23]:
# 测试画图功能
import matplotlib.pyplot as plt
from tqdm import tqdm
# 创建3D图像
def plot_model(obj_path, gridSize=32):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # 画出原始模型
    scene = trimesh.load(os.path.join(obj_path, "model_normalized.obj"))
    for name, mesh in scene.geometry.items():
        # 提取顶点和面
        vertices = point2voxel(mesh.vertices)
        faces = mesh.faces
        # 绘制三角面片
        ax.plot_trisurf(vertices[:, 0], vertices[:, 1], vertices[:, 2], linewidth=0.2, triangles=faces, alpha=0.7, edgecolor='gray')
    
    # 设置图像显示范围
    ax.set_xlim(0, gridSize)
    ax.set_ylim(0, gridSize)
    ax.set_zlim(0, gridSize)

    ax.view_init(elev=30, azim=30)

    plt.savefig(os.path.join(obj_path, "model_normalized.png"))
    plt.close(fig)

base_dir = "preprocess/shapenet/02691156/"
for dir_name in tqdm(sorted(os.listdir(base_dir))):
    print(dir_name)
    obj_path = os.path.join(base_dir, dir_name, "models")
    plot_model(obj_path)
    break

  0%|          | 0/492 [00:00<?, ?it/s]

1a04e3eab45ca15dd86060f189eb133


  0%|          | 0/492 [00:00<?, ?it/s]


In [24]:
test_path = 'test/test_models_2aec'
data = loadmat_dir(test_path)
# 测试损失函数是否正确，包括找最近点和进行对称操作
# print(data.keys())
# print(data['Volumes'].shape)
# print(data['surfaceSamples'].shape)
# print(data['closestPoints'].shape)
points = data['surfaceSamples']
closestPoints = data['closestPoints']
volumes = data['Volumes']
plane = torch.tensor([[[0,0,0.25,0.05]]])
sym_points = plane_point(points, plane)      
print(points.shape, closestPoints.shape, volumes.shape)
distance, cps = calculate_distance(sym_points, closestPoints, volumes, 'cpu')
print(torch.sum(torch.norm(distance, dim=2)))
# print(points[0, :5])
# print(sym_points[0, :5])
print(closestPoints[0, 0, 0, :5])
cps = closestPoints.reshape(closestPoints.shape[0], -1, 3)

save_path = 'test/images/testsym.mp4'
obj_path = 'test/test_models_2aec'
# view_result(save_path, None, volumes[0], None, cps[0])

torch.Size([1, 1000, 3]) torch.Size([1, 32, 32, 32, 3]) torch.Size([1, 32, 32, 32])
tensor(343.3512)
tensor([[-0.0696, -0.2409,  0.0015],
        [-0.0696, -0.2409,  0.0015],
        [-0.0696, -0.2409,  0.0015],
        [-0.0696, -0.2409,  0.0015],
        [-0.0696, -0.2409,  0.0015]])


In [25]:
class symLoss(nn.Module):
    def __init__(self, device):
        super(symLoss, self).__init__()
        self.device = device
    def forward(self, planes, quats, closestPoints, surfaceSamples, volume):
        # 定义损失计算逻辑
        # 例如，使用均方误差作为损失
        return sym_loss(planes, quats, closestPoints, surfaceSamples, volume, self.device)

class regLoss(nn.Module):
    def __init__(self, device):
        super(regLoss, self).__init__()
        self.device = device    
    def forward(self, planes, quats):
        # 定义损失计算逻辑
        # 例如，使用均方误差作为损失
        return reg_loss(planes, quats, self.device)

4. 设计训练函数

In [26]:
'''训练统计模块，参考自动手学深度学习库d2l'''
# from d2l import torch as d2l
import matplotlib.pyplot as plt
import time

class Accumulator:
    """For accumulating sums over `n` variables."""
    def __init__(self, n):
        """Defined in :numref:`sec_utils`"""
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

class Timer:
    """Record multiple running times."""
    def __init__(self):
        """Defined in :numref:`sec_minibatch_sgd`"""
        self.times = []
        self.start()

    def start(self):
        """Start the timer."""
        self.tik = time.time()

    def stop(self):
        """Stop the timer and record the time in a list."""
        self.times.append(time.time() - self.tik)
        return self.times[-1]

    def avg(self):
        """Return the average time."""
        return sum(self.times) / len(self.times)

    def sum(self):
        """Return the sum of time."""
        return sum(self.times)

    def cumsum(self):
        """Return the accumulated time."""
        return np.array(self.times).cumsum().tolist()

In [27]:
"""训练模型一个迭代周期（定义见第3章）。"""
def train_epoch(net, train_iter, sym_loss, reg_loss, updater, device, weight=25):
    state, timer = None, Timer()
    metric_loss_plane_sym = Accumulator(2)  # 统计训练损失之和
    metric_loss_plane_reg = Accumulator(2)  # 统计训练准确度之和
    metric_loss_quat_sym = Accumulator(2)  # 统计训练损失之和
    metric_loss_quat_reg = Accumulator(2)  # 统计训练准确度之和

    for Volume, surfaceSamples, closestPoints in train_iter:
        Volume = Volume.to(device)
        surfaceSamples = surfaceSamples.to(device)
        closestPoints = closestPoints.to(device)
        
        batch_size = Volume.shape[0]
        # 在第一次迭代或使用随机抽样时初始化state
        updater.zero_grad()
        # 对于ffn，我们只使用最后一个时间步计算损失
        # y_hat.shape = (num_steps * batch_size, vocab_size)
        planes, quats = net(Volume)
        loss_plane_sym, loss_quat_sym = sym_loss(planes, quats, closestPoints, surfaceSamples, Volume)
        loss_plane_reg, loss_quat_reg = reg_loss(planes, quats)
        loss = loss_plane_sym + loss_quat_sym + (loss_plane_reg + loss_quat_reg) * weight
        loss.backward()
        updater.step()
        
        metric_loss_plane_sym.add(loss_plane_sym * batch_size, batch_size)
        metric_loss_plane_reg.add(loss_plane_reg * batch_size, batch_size)
        metric_loss_quat_sym.add(loss_quat_sym * batch_size, batch_size)
        metric_loss_quat_reg.add(loss_quat_reg * batch_size, batch_size)

    return metric_loss_plane_sym[0] / metric_loss_plane_sym[1],  \
            metric_loss_plane_reg[0] / metric_loss_plane_reg[1], \
            metric_loss_quat_sym[0] / metric_loss_quat_sym[1],   \
            metric_loss_quat_reg[0] / metric_loss_quat_reg[1],   \
            metric_loss_plane_sym[1] / timer.stop()

In [28]:
from torch.utils import tensorboard
from tqdm import tqdm
# 将matplotlib中画出来的图像转成tensor，以便在tensorboard中进行显示
def get_tensor_from_video(video_path):
    """
    :param video_path: 视频文件地址
    :return: pytorch tensor
    """
    cap = cv2.VideoCapture(video_path)
    frames_list = []
    while(cap.isOpened()):
        ret,frame = cap.read()
        if not ret:
            break
        else:
            # 注意，opencv默认读取的为BGR通道组成模式，需要转换为RGB通道模式
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames_list.append(frame)
    cap.release()
    result_frames = torch.as_tensor(np.stack(frames_list), dtype=torch.uint8)
    # 注意：此时result_frames组成的维度为[视频帧数量，宽，高，通道数]
    result_frames = result_frames.permute(0,3,2,1).unsqueeze(0)
    return result_frames

'''训练函数'''
def train(net, train_iter, lr, num_epochs, device, weight = 25, test_file_path='test/test_models_2aec'):
    """训练模型"""
    def init_weights(layer):
        if isinstance(layer, nn.Conv3d):
            # 使用 Kaiming 正态分布初始化
            nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
            if layer.bias is not None:
                nn.init.zeros_(layer.bias)
        elif isinstance(layer, nn.Linear):
            # 使用 Xavier 均匀分布初始化
            nn.init.xavier_uniform_(layer.weight)
            if layer.bias is not None:
                nn.init.zeros_(layer.bias)

    net.to(device)  
    net.apply(init_weights)

    sym_loss = symLoss(device)
    reg_loss = regLoss(device)

    writer = tensorboard.SummaryWriter(log_dir="./logs", filename_suffix='prsnet')
    test_input = torch.rand((32, 32, 32, 32)).to(device)
    writer.add_graph(net, test_input) 

    updater = torch.optim.Adam(net.parameters(), lr)

    test_data = loadmat_dir(test_file_path)
    test_Volume = test_data['Volumes'][0]
    test_Volume = torch.tensor(test_Volume).to(torch.float32).unsqueeze(0).to(device)

    # 训练和预测
    for epoch in tqdm(range(num_epochs)):
        loss_plane_sym, loss_plane_reg, loss_quat_sym, loss_quat_reg, speed = train_epoch(
                                                        net, train_iter, sym_loss, reg_loss, updater, device, weight)
        # if (epoch + 1) % 10 == 0:
        writer.add_scalars("sym_loss", {'plane_loss': loss_plane_sym}, epoch)
        writer.add_scalars("sym_loss", {'quat_loss': loss_quat_sym}, epoch)
        writer.add_scalars("reg_loss", {'plane_loss': loss_plane_reg}, epoch)
        writer.add_scalars("reg_loss", {'quat_loss': loss_quat_reg}, epoch)
        
        if epoch % 40 == 0:
            with torch.no_grad():
                planes, quats = net(test_Volume)
            save_path = "test/images/train_"+str(epoch)+'.mp4'
            view_result(save_path, None, test_Volume[0], None, None, planes, quats)
            vedio = get_tensor_from_video(save_path)
            writer.add_video("Animation", vedio, epoch) 
            
    print(f'loss_plane_sym: {loss_plane_sym:.1f},\n\
          loss_plane_reg: {loss_plane_reg:.1f}, \n\
          loss_quat_sym: {loss_quat_sym:.1f}, \n\
          loss_quat_reg: {loss_quat_reg:.1f}, \n\
          {speed:.1f} 模型/秒 {str(device)}')

5. 进行训练

In [29]:
# 测试网络结构
file_path = 'datasets/shapenet/train'
net = PRS_NET()
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
print("Training on: ", device)
print("Loading data")
batch_size = 32
data_loader = SeqDataLoader(file_path, batch_size=batch_size, shuffle=True)

Training on:  cuda:2
Loading data


In [32]:
lr, num_epochs = 0.001, 400
weight = 25
train(net, data_loader, lr, num_epochs, device, weight)
torch.save(net.state_dict(), './logs/model/model_weights_400.pth')

In [33]:
!tensorboard --logdir ./logs --port=6006 

5. 设计推理函数，给定一张图片，推理出需要的景象结果。

6. 进行测试，输入一个mat文件，输出对应的体素3D视图和对称平面

In [94]:
# 清除掉太过于重合或者置信度太低的平面和旋转向量
def validation(planes, quats, closestPoints, surfaceSamples, volume, device, eps=4e-4):
    valid_plane = torch.ones(3, dtype=torch.bool).to(device)
    valid_quat = torch.ones(3, dtype=torch.bool).to(device)
    plane_losses, quat_losses = [], []

    # 排除损失值太大的对称平面或旋转轴
    for i in range(planes.shape[0]):
        plane_loss, quat_loss = sym_loss(planes[i].unsqueeze(0), 
                                         quats[i].unsqueeze(0), closestPoints, surfaceSamples, volume, device)
        if plane_loss > eps:
            valid_plane[i] = 0
        if quat_loss > eps:
            valid_quat[i] = 0
        
        plane_losses.append(plane_loss)
        quat_losses.append(plane_loss)
    
    print("Plane_losses:", plane_losses)
    print("quat_losses:", quat_loss)
    
    # 排除彼此之间靠的太近的对称平面或旋转轴
    def test_angle(vec1, vec2):
        angle = torch.dot(vec1, vec2) / torch.sqrt(torch.norm(vec1) * torch.norm(vec2))
        if torch.abs(angle) > (torch.sqrt(3) / 2):
            return 1
        return 0

    def remove_overlap_vec(vecs, valid_vec, vec_losses):
        for i in range(vecs.shape[0]):
            for j in range(0, i):
                if valid_vec[i] == 0 or valid_vec[j] == 0:
                    continue
                if test_angle(vecs[i], vecs[j]) == 1:
                    if vec_losses[i] > vec_losses[j]:
                        valid_vec[i] = 0
                    else:
                        valid_vec[j] = 0
    
    remove_overlap_vec(planes[..., :3], valid_plane, plane_losses)
    remove_overlap_vec(quats[..., :3], valid_quat, quat_losses)

    valid_planes = planes[valid_plane]
    valid_quats = quats[valid_quat]

    return valid_planes, valid_quats

# 输入Volume，输出预测出来的对称平面视频到save_path
def predict(net, obj_path, save_path, Volume, closestPoints, surfaceSamples, device):
    '''
        net: 网络
        obj_path: 物体的原始网格表示
        save_path: 保存视频的地址
        Volume: 物体的体素表示，(1, 32, 32, 32)
        closestPoints: 最近点，(1, 32, 32, 32, 3)
        surfaceSamples: 物体表面的采样点，(1, 1000, 3)
    '''
    Volume = Volume.to(torch.float32).to(device)
    closestPoints = closestPoints.to(device)
    surfaceSamples = surfaceSamples.to(device)
    # print("surfaceSamples.shape", surfaceSamples.shape)
    with torch.no_grad():
        planes, quats = net(Volume)
    # 由于实际计算出来的损失值太大，按照文章中的eps值，会将所有的对称平面和对称轴给删除掉
    # valid_planes, valid_quats = validation(planes, quats, closestPoints, surfaceSamples, Volume, device)
    sym_Points_plane = plane_point(surfaceSamples, planes[1:2])
    distance, cps = calculate_distance(sym_Points_plane, closestPoints, Volume, device)
    # view_result(save_path, None, Volume[0], None, cps[0], planes[1:2], None)
    view_result(save_path, None, Volume[0], None, None, planes, None)
    return planes, quats

In [74]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
net = PRS_NET()
net = net.to(device)
net.load_state_dict(torch.load('logs/model/model_weights_400.pth'))

<All keys matched successfully>

In [96]:
obj_path = 'datasets/shapenet/test'
data = loadmat_dir(obj_path)
for i in range(6, 25):
    save_path = "test/images/predict" + str(i) + ".mp4"
    predict(net, None, save_path, data['Volumes'][i:i+1], data['closestPoints'][i:i+1], data['surfaceSamples'][i:i+1], device)
print('done')

Num: tensor(645, device='cuda:1')
tensor(-2.9934, device='cuda:1')
Num: tensor(731, device='cuda:1')
tensor(-0.3752, device='cuda:1')
Num: tensor(462, device='cuda:1')
tensor(14.8257, device='cuda:1')
Num: tensor(542, device='cuda:1')
tensor(-12.1591, device='cuda:1')
Num: tensor(385, device='cuda:1')
tensor(5.2771, device='cuda:1')
Num: tensor(489, device='cuda:1')
tensor(-7.9840, device='cuda:1')
Num: tensor(679, device='cuda:1')
tensor(-10.0823, device='cuda:1')
Num: tensor(245, device='cuda:1')
tensor(0.8730, device='cuda:1')
Num: tensor(446, device='cuda:1')
tensor(-4.2351, device='cuda:1')
Num: tensor(650, device='cuda:1')
tensor(13.8890, device='cuda:1')
Num: tensor(713, device='cuda:1')
tensor(-9.9370, device='cuda:1')
Num: tensor(509, device='cuda:1')
tensor(5.0329, device='cuda:1')
Num: tensor(416, device='cuda:1')
tensor(-7.3138, device='cuda:1')
Num: tensor(621, device='cuda:1')
tensor(18.7652, device='cuda:1')
Num: tensor(372, device='cuda:1')
tensor(11.7815, device='cuda: