In [None]:
!pip install cuda-python
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install matplotlib scipy

In [None]:
# setup:
import torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim
# Verify CUDA's embrace
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Charging on {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Marching on the CPU")

Charging on NVIDIA GeForce RTX 3060 Ti


In [None]:
import os, random, numpy as np, open3d as o3d
import pickle
from glob import glob
from typing import Tuple, Dict, Any
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from torch.utils.data import DataLoader, Dataset
import torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim
import copy
from scipy.spatial import KDTree
from itertools import product

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [None]:
from m02_code import CustomDataset,infer_point_clouds

In [None]:
def train_and_save_model(device,train_loader,val_loader, num_classes,model, save_path="model.pth",  epochs = 100):
    criterion = nn.CrossEntropyLoss(ignore_index=-1)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            inputs, labels,index = batch
            p =inputs
            p, labels = p.float(), labels.long()
            p, labels = p.to(device), labels.to(device)
            seg_pred = model(p)
            seg_pred = seg_pred.contiguous().view(-1, num_classes)
            labels = labels.view(-1, 1)[:, 0]
            loss = criterion(seg_pred, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        model.eval()
        total_correct = 0
        total_points = 0
        with torch.no_grad():
            for batch in val_loader:
                inputs, labels,index = batch
                p =inputs
                p, labels = p.float(), labels.long()
                p, labels = p.to(device), labels.to(device)
                seg_pred = model(p)
                seg_pred = seg_pred.contiguous().view(-1, num_classes)
                labels = labels.view(-1, 1)[:, 0]
                _, predicted = seg_pred.max(1)
                total_correct += (predicted == labels).sum().item()
                total_points += labels.size(0)
        accuracy = 100 * total_correct / total_points
        print(f"Epoch [{epoch+1}/{epochs}], Loss {loss:.4f}, Accuracy: {accuracy:.2f}%")
    torch.save(model.state_dict(), save_path)

In [None]:
def square_distance(src,dst):
    B,N,_=src.shape
    _,M,_=dst.shape
    dist=-2*torch.matmul(src,dst.permute(0,2,1))
    dist+=torch.sum(src**2,-1).view(B,N,1)
    dist+=torch.sum(dst**2,-1).view(B,1,M)
    return dist

In [None]:
def index_points(points,idx):
    device=points.device
    B=points.shape[0]
    view_shape=list(idx.shape)
    view_shape[1:]=[1]*(len(view_shape)-1)
    repeat_shape=list(idx.shape)
    repeat_shape[0]=1
    batch_indices=torch.arange(B,dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points=points[batch_indices,idx,:]
    return new_points

In [None]:
def farthest_point_sample(xyz,npoint):
    device=xyz.device
    B,N,C=xyz.shape
    centroids=torch.zeros(B,npoint,dtype=torch.long).to(device)
    distance=torch.ones(B,N).to(device)*1e10
    farthest=torch.randint(0,N,(B,),dtype=torch.long).to(device)
    batch_indices=torch.arange(B,dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:,i]=farthest
        centroid=xyz[batch_indices,farthest,:].view(B,1,3)
        dist=torch.sum((xyz-centroid)**2,-1)
        mask=dist<distance
        distance[mask]=dist[mask]
        farthest=torch.max(distance,-1)[1]
    return centroids

In [None]:
def query_ball_point(radius,nsample,xyz,new_xyz):
    device=xyz.device
    B,N,C=xyz.shape
    _,S,_=new_xyz.shape
    group_idx=torch.arange(N,dtype=torch.long).to(device).view(1,1,N).repeat([B,S,1])
    sqrdists=square_distance(new_xyz,xyz)
    group_idx[sqrdists>radius**2]=N
    group_idx=group_idx.sort(dim=-1)[0][:,:,:nsample]
    group_first=group_idx[:,:,0].view(B,S,1).repeat([1,1,nsample])
    mask=group_idx==N
    group_idx[mask]=group_first[mask]
    return group_idx

In [None]:
def sample_and_group(npoint,radius,nsample,xyz,points,returnfps=False):
    B,N,C=xyz.shape
    S=npoint
    fps_idx=farthest_point_sample(xyz,npoint)
    new_xyz=index_points(xyz,fps_idx)
    idx=query_ball_point(radius,nsample,xyz,new_xyz)
    groupe_xyz=index_points(xyz,idx)
    groupe_xyz_norm=groupe_xyz-new_xyz.view(B,S,1,C)

    if points is not None:
        grouped_points=index_points(points,idx)
        new_points=torch.cat([groupe_xyz_norm,grouped_points],dim=-1)
    else:
        new_points=groupe_xyz_norm
    if returnfps:
        return new_xyz,new_points,groupe_xyz,fps_idx
    else:
        return new_xyz,new_points

In [None]:
def sample_and_group_all(xyz,points):
    device=xyz.device
    B,N,C=xyz.shape
    new_xyz=torch.zeros(B,1,C).to(device)
    grouped_xyz=xyz.view(B,1,N,C)
    if points is not None:
        new_points=torch.cat([grouped_xyz,points.view(B,1,N,-1)],dim=-1)
    else:
        new_points=grouped_xyz
    return new_xyz,new_points

In [None]:
class PointNetSetAbstraction(nn.Module):
    def __init__(self,npoint,radius,nsample,in_channel,mlp,group_all) -> None:
        super(PointNetSetAbstraction,self).__init__()
        self.npoint=npoint
        self.radius=radius
        self.nsample=nsample
        self.mlp_convs=nn.ModuleList()
        self.mlp_bns=nn.ModuleList()
        last_channel=in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel,out_channel,1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel=out_channel
        self.group_all=group_all

    def forward(self,xyz,points):
        xyz=xyz.permute(0,2,1)
        if points is not None:
            points=points.permute(0,2,1)
        if self.group_all:
            new_xyz,new_points=sample_and_group(xyz,points)
        else:
            new_xyz,new_points=sample_and_group(self.npoint,self.radius,self.nsample,xyz,points)
        new_points=new_points.permute(0,3,2,1)
        for i,conv in enumerate(self.mlp_convs):
            bn=self.mlp_bns[i]
            new_points=F.relu(bn(conv(new_points)))
        new_points=torch.max(new_points,2)[0]
        new_xyz=new_xyz.permute(0,2,1)
        return new_xyz,new_points

In [None]:
class PointNetFeaturePropagation(nn.Module):
    def __init__(self, in_channel,mlp) -> None:
        super(PointNetFeaturePropagation,self).__init__()
        self.mlp_convs=nn.ModuleList()
        self.mlp_bns=nn.ModuleList()
        last_channel=in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv1d(last_channel,out_channel,1))
            self.mlp_bns.append(nn.BatchNorm1d(out_channel))
            last_channel=out_channel

    def forward(self,xyz1,xyz2,points1,points2):
        xyz1=xyz1.permute(0,2,1)
        xyz2=xyz2.permute(0,2,1)

        points2=points2.permute(0,2,1)
        B,N,C=xyz1.shape
        _,S,_=xyz2.shape

        if S==1:
            interpolated_points=points2.repeat(1,N,1)
        else:
            dists=square_distance(xyz1,xyz2)
            dists,idx=dists.sort(dim=-1)
            dists,idx=dists[:,:,:3],idx[:,:,:3]
            dist_recip=1.0/(dists+1e-8)
            norm=torch.sum(dist_recip,dim=2,keepdim=True)
            weight=dist_recip/norm
            interpolated_points=torch.sum(index_points(points2,idx)*weight.view(B,N,3,1),dim=2)
        if points1 is not None:
            points1=points1.permute(0,2,1)
            new_points=torch.cat([points1,interpolated_points],dim=-1)
        else:
            new_points=interpolated_points
        new_points=new_points.permute(0,2,1)
        for i, conv in enumerate(self.mlp_convs):
            bn=self.mlp_bns[i]
            new_points=F.relu(bn(conv(new_points)))
        return new_points

In [None]:
class PointNet2(nn.Module):
    def __init__(self, num_classes) -> None:
        super(PointNet2,self).__init__()

        self.sa1=PointNetSetAbstraction(1024,0.1,32,3+3,[32,32,64],False)
        self.sa2=PointNetSetAbstraction(256,0.2,32,64+3,[64,64,128],False)
        self.sa3=PointNetSetAbstraction(64,0.4,32,128+3,[128,128,256],False)
        self.sa4=PointNetSetAbstraction(16,0.8,32,256+3,[256,256,512],False)

        self.fp4=PointNetFeaturePropagation(768,[256,256])
        self.fp3=PointNetFeaturePropagation(384,[256,256])
        self.fp2=PointNetFeaturePropagation(320,[256,128])
        self.fp1=PointNetFeaturePropagation(128,[256,128,128])

        self.conv1=nn.Conv1d(128,128,1)
        self.bn1=nn.BatchNorm1d(128)
        self.drop1=nn.Dropout(0.5)
        self.conv2=nn.Conv1d(128,num_classes,1)
    def forward(self,xyz):
        l1_xyz,l1_points=self.sa1(xyz,xyz)
        l2_xyz,l2_points=self.sa2(l1_xyz,l1_points)
        l3_xyz,l3_points=self.sa3(l2_xyz,l2_points)
        l4_xyz,l4_points=self.sa4(l3_xyz,l3_points)

        l3_points=self.fp4(l3_xyz,l4_xyz,l3_points,l4_points)
        l2_points=self.fp3(l2_xyz,l3_xyz,l2_points,l3_points)
        l1_points=self.fp2(l1_xyz,l2_xyz,l1_points,l2_points)
        l0_points=self.fp1(xyz,l1_xyz,None,l1_points)

        x=self.drop1(F.relu(self.bn1(self.conv1(l0_points))))
        x=self.conv2(x)
        x=F.log_softmax(x,dim=1)
        x=x.permute(0,2,1)
        return x

In [None]:
project_dir = '../../../data/AHN4_34EN2_18'
pointcloud_train_files=glob(os.path.join(project_dir,"strain/*.txt"))
pointcloud_test_files=glob(os.path.join(project_dir,"stest/*.txt"))
valid_index=np.random.choice(len(pointcloud_train_files),int(len(pointcloud_train_files)/5),replace=False)
valid_list=[pointcloud_train_files[i]for i in valid_index]
train_list=[pointcloud_train_files[i]for i in np.setdiff1d(list(range(len(pointcloud_train_files))), valid_index)]
test_list=pointcloud_test_files
num_point=4096
train_dataset=CustomDataset(train_list,"xyz",num_point=num_point)
train_loader=DataLoader(train_dataset,batch_size=32,shuffle=True)
val_dataset=CustomDataset(valid_list,"xyz",is_training=False,num_point=num_point)
val_loader=DataLoader(val_dataset,batch_size=32,shuffle=True)
test_dataset=CustomDataset(test_list,"xyz",is_training=False,has_ground_trust=False,num_point=num_point)
test_loader=DataLoader(test_dataset,batch_size=32,shuffle=True)
print(train_dataset.num_classes)

In [None]:
model=PointNet2(num_classes=train_dataset.num_classes)
model.to(device)
train_and_save_model(device,train_loader,val_loader,num_classes=train_dataset.num_classes,model=model,save_path="pnet2.pht",epochs=100)

---

## Testing

In [None]:
num_classes=train_dataset.num_classes
model=PointNet2(num_classes)
model.load_state_dict(torch.load("pnet2.pht"))
model.to(device)
test_prediction=infer_point_clouds(device,model,test_loader,num_classes)
print(test_prediction.shape)
resulting_point_cloud=np.vstack(test_prediction)
print(resulting_point_cloud.shape)
np.savetxt("results.txt",resulting_point_cloud,fmt="%.6f %.6f %.6f %d")

In [None]:
label_mapping=test_dataset.label_mapping
num_unique_labels=len(label_mapping.keys())
colormap=np.random.random((num_unique_labels,3))

points=resulting_point_cloud[:,:3]
labels=resulting_point_cloud[:,-1]
colors=colormap[labels.astype(int)-1]
pcd=o3d.geometry.PointCloud()
pcd.points=o3d.utility.Vector3dVector(points)
pcd.colors=o3d.utility.Vector3dVector(colors)
pcd.estimate_normals()

o3d.visualization.draw_geometries([pcd])