In [1]:
import os

import math
import numpy as np

import tensorboard

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torch.utils.tensorboard import SummaryWriter

from auxiliary.laserscan import LaserScan
from auxiliary.laserscan import SemLaserScan

import random

import yaml

%load_ext autoreload
%autoreload 2

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
DATASET_PATH = "/mnt/raid/xyiheng/kittiDataset/sequences"
TESTLABEL_PATH = "/mnt/raid/xyiheng/method_predictions/sequences"
NUM_OF_CLASSES = 20
BATCH_SIZE = 3
SAMPLE_SIZE = 50000

In [None]:
## Define PointNet new version
class PointNet_new(nn.Module):
    def __init__(self, input_dimension, output_dimension, feature_dimension, isSegmentation):
        super(PointNet2, self).__init__()
        
        self.input_dimension = input_dimension
        self.output_dimension = output_dimension
        self.feature_dimension = feature_dimension
        
        self.conv1 = nn.Conv1d(input_dimension, 64, 1)
        #self.batch1 = nn.BatchNorm1d(64)
        self.relu1 = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv1d(64, 128, 1)
        #self.batch2 = nn.BatchNorm1d(128)
        self.relu2 = nn.ReLU(inplace=True)
        
        self.conv3 = nn.Conv1d(128, 128, 1)
        #self.batch3 = nn.BatchNorm1d(128)
        self.relu3 = nn.ReLU(inplace=True)
        
        self.conv4 = nn.Conv1d(128, 512, 1)
        #self.batch4 = nn.BatchNorm1d(512)
        self.relu4 = nn.ReLU(inplace=True)
        
        self.conv5 = nn.Conv1d(512, self.feature_dimension, 1)
        #self.batch5 = nn.BatchNorm1d(self.feature_dimension)
        self.relu5 = nn.ReLU()

        self.conv6 = nn.Conv1d(835 + self.feature_dimension, 256, 1)
        #self.batch6 = nn.BatchNorm1d(256)
        self.relu6 = nn.ReLU(inplace=True)
        
        self.conv7 = nn.Conv1d(256, 256, 1)
        #self.batch7 = nn.BatchNorm1d(256)
        self.relu7 = nn.ReLU(inplace=True)
        
        self.conv8 = nn.Conv1d(256, 128, 1)
        #self.batch8 = nn.BatchNorm1d(128)
        self.relu8 = nn.ReLU(inplace=True)
        self.dropout1 = nn.Dropout(0.5)
        self.conv9 = nn.Conv1d(128, self.output_dimension, 1)
        
        self.isSegmentation = isSegmentation

    def forward(self, cp):
        B, N = cp.shape[0], cp.shape[1]
        cp = cp.permute(0, 2, 1)
        cp_copy_0 = cp

        cp = self.conv1(cp)
        #cp = self.batch1(cp)
        cp = self.relu1(cp)
        cp_copy_1 = cp
        
        cp = self.conv2(cp)
        #cp = self.batch2(cp)
        cp = self.relu2(cp)
        cp_copy_2 = cp
        
        cp = self.conv3(cp)
        #cp = self.batch3(cp)
        cp = self.relu3(cp)
        cp_copy_3 = cp
        
        cp = self.conv4(cp)
        #cp = self.batch4(cp)
        cp = self.relu4(cp)
        cp_copy_4 = cp
        
        cp = self.conv5(cp)
        #cp = self.batch5(cp)
        cp = self.relu5(cp)
        
        global_feature = torch.max(cp, 2)[0]
        cp = torch.max(cp, 2)[0].view(-1, self.feature_dimension, 1).repeat(1, 1, N)
        tensor_1 = torch.cat([cp_copy_0, cp_copy_1, cp_copy_2, cp_copy_3, cp_copy_4, cp], 1)

        if self.isSegmentation:
            output = self.conv6(tensor_1)
            #output = self.batch6(output)
            output = self.relu6(output)
            
            output = self.conv7(output)
            #output = self.batch7(output)
            output = self.relu7(output)
            
            output = self.conv8(output)
            #output = self.batch8(output)
            output = self.relu8(output)
            output = self.dropout1(output)
            output = self.conv9(output)
            return output
        else:
            return global_feature

In [None]:
## Define PointNet
class PointNet(nn.Module):
    def __init__(self, input_dimension, output_dimension, feature_dimension, isSegmentation):
        super(PointNet, self).__init__()
        
        self.input_dimension = input_dimension
        self.output_dimension = output_dimension
        self.feature_dimension = feature_dimension

        self.conv1 = nn.Conv1d(input_dimension, 64, 1)
        self.batch1 = nn.BatchNorm1d(64)
        self.relu1 = nn.ReLU(inplace=True)
        
        self.conv1_1 = nn.Conv1d(64, 64, 1)
        self.batch1_1 = nn.BatchNorm1d(64)
        self.relu1_1 = nn.ReLU(inplace=True)
        
        self.conv1_2 = nn.Conv1d(64, 64, 1)
        self.batch1_2 = nn.BatchNorm1d(64)
        self.relu1_2 = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.batch2 = nn.BatchNorm1d(128)
        self.relu2 = nn.ReLU(inplace=True)
        
        self.conv3 = nn.Conv1d(128, self.feature_dimension, 1)
        self.batch3 = nn.BatchNorm1d(self.feature_dimension)
        self.relu3 = nn.ReLU()
        
        #self.fc1 = nn.Linear(self.feature_dimension, 256)
        #self.batch_1 = nn.BatchNorm1d(256)
        
        #self.fc2 = nn.Linear(256, 128)
        #self.batch_2 = nn.BatchNorm1d(128)
        
        self.conv4 = nn.Conv1d(128 + self.feature_dimension, 512, 1)
        self.batch4 = nn.BatchNorm1d(512)
        self.relu4 = nn.ReLU(inplace=True)
        
        self.conv5 = nn.Conv1d(512, 256, 1)
        self.batch5 = nn.BatchNorm1d(256)
        self.relu5 = nn.ReLU(inplace=True)
        #self.dropout5 = nn.Dropout(0.7)
        
        self.conv6 = nn.Conv1d(256, 128, 1)
        self.batch6 = nn.BatchNorm1d(128)
        self.relu6 = nn.ReLU(inplace=True)
        
        self.conv6_1 = nn.Conv1d(128, 128, 1)
        self.batch6_1 = nn.BatchNorm(128)
        self.relu6_1 = nn.ReLU(inplace=True)
        
        self.conv7 = nn.Conv1d(128, self.output_dimension, 1)
        
        self.isSegmentation = isSegmentation

    def forward(self, cp):
        B, N = cp.shape[0], cp.shape[1]
        cp = cp.permute(0, 2, 1)

        cp = self.conv1(cp)
        cp = self.batch1(cp)
        cp = self.relu1(cp)
        
        cp = self.conv1_1(cp)
        cp = self.batch1_1(cp)
        cp = self.relu1_1(cp)
        
        cp = self.conv1_2(cp)
        cp = self.batch1_2(cp)
        cp = self.relu1_2(cp)
        
        cp = self.conv2(cp)
        cp = self.batch2(cp)
        cp = self.relu2(cp)
        
        cp_copy = cp
        
        cp = self.conv3(cp)
        cp = self.batch3(cp)
        cp = self.relu3(cp)
        
        cp = torch.max(cp, 2, keepdim=True)[0]
        cp = cp.view(-1, self.feature_dimension)
        
        #cp = F.relu(self.batch_1(self.fc1(cp)), inplace=True)
        #cp = F.relu(self.batch_2(self.fc2(cp)), inplace=True)
        
        if self.isSegmentation:
            cp = cp.view(B, 128, 1).repeat(1, 1, N)
            tensor_1 = torch.cat([cp_copy, cp], 1)
            output = self.conv4(tensor_1)
            output = self.batch4(output)
            output = self.relu4(output)
            
            output = self.conv5(output)
            output = self.batch5(output)
            output = self.relu5(output)
            #output = self.dropout5(output)
            
            output = self.conv6(output)
            output = self.batch6(output)
            output = self.relu6(output)
            
            output = self.conv6_1(output)
            output = self.batch6_1(output)
            output = self.relu6_1(output)
            output = self.conv7(output) # [B, num_of_classes, N]
            return output
        else:
            return cp

In [4]:
## Define kittiDataset
class KittiDataset(Dataset):
    def __init__(self, train, num_classes, dataset_path, sample_size, augment):
        super(KittiDataset, self).__init__()
        self.train = train
        self.augment = augment
        self.sample_size = sample_size

        scan = SemLaserScan(num_classes)
        frames = []
        labels = []

        father_files = os.listdir(dataset_path)
        for father_file in father_files:
            points_path = dataset_path + "/" + father_file + "/velodyne"
            labels_path = dataset_path + "/"+ father_file + "/labels"
            points_files = os.listdir(points_path)
            if (int(father_file) <= 10 and int(father_file) != 8 and train == True) or (int(father_file) == 8 and train == False):
                for points_file in points_files:
                    scan.open_scan(points_path + "/" + points_file)
                    points = scan.points
                    frames.append(points)
                    index = points_file.split(".")[0]
                    scan.open_label(labels_path + "/" + index + ".label")
                    labels.append(scan.sem_label)
            
        self.frames = frames
        self.labels = labels
    
    def __len__(self):
        return len(self.frames)
    
    def __getitem__(self, idx):
        frame = np.array(self.frames[idx])
        label = np.array(self.labels[idx], dtype=np.long)
        if self.augment and self.train:
            sample_indexes = np.array(random.sample(range(0, frame.shape[0]), self.sample_size), dtype=np.long)
            if np.random.random() > 0.5:
                # Flipping along the YZ plane
                frame[:,0] = -1 * frame[:,0]              
                
            if np.random.random() > 0.5:
                # Flipping along the XZ plane
                frame[:,1] = -1 * frame[:,1]                              
            
            # Rotation along up-axis/Z-axis
            theta = (np.random.random()*np.pi/18) - np.pi/36 # -5 ~ +5 degree
            matrix = np.zeros((3,3))
            a = np.cos(theta)
            b = np.sin(theta)
            matrix[0, 0] = a
            matrix[0, 1] = -b
            matrix[1, 0] = b
            matrix[1, 1] = a
            matrix[2, 2] = 1
            frame[:,0:3] = np.dot(frame[:,0:3], np.transpose(matrix))
                        
            # Rescale randomly by 0.9 - 1.1
            proportion = np.random.uniform(0.9, 1.1, 1)
            frame = frame * proportion
            
            return torch.FloatTensor(frame[sample_indexes]), torch.tensor(label[sample_indexes])
        else:
            random_indexes = np.array(random.sample(range(0, frame.shape[0]), self.sample_size), dtype=np.long)
            return torch.FloatTensor(frame[random_indexes]), torch.tensor(label[random_indexes])

In [5]:
training_set = KittiDataset(True, NUM_OF_CLASSES, DATASET_PATH, SAMPLE_SIZE, True)
val_set = KittiDataset(False, NUM_OF_CLASSES, DATASET_PATH, SAMPLE_SIZE, False)

In [6]:
training_dataloader = DataLoader(dataset = training_set, batch_size = BATCH_SIZE, 
                                 shuffle = True, drop_last = True, num_workers = 2)
val_dataloader = DataLoader(dataset = val_set, batch_size = BATCH_SIZE, 
                            shuffle = False, drop_last = True, num_workers = 2)

In [7]:
print(len(training_dataloader))
print(len(val_dataloader))

6376
1357


In [8]:
## Get weight for cross entropy
proportion = np.zeros(NUM_OF_CLASSES)
with open("config/semantic-kitti.yaml", 'r') as f:
    content = yaml.load(f.read())["content"]
    
with open("config/semantic-kitti.yaml", 'r') as f2:
    learning_map = yaml.load(f2.read())["learning_map"]

for key in learning_map:
    proportion[learning_map[key]] += content[key]

  after removing the cwd from sys.path.
  import sys


In [9]:
# normalize weight
sum = np.sum(proportion[1:])
new_proportion = proportion / sum
weight = np.sqrt(np.sqrt(1 / proportion))
weight[0] = 0
weight = torch.FloatTensor(weight).to(device)

In [None]:
print(weight)

In [None]:
point_net = PointNet_new(3, 20, 2048, True)
criterion = nn.CrossEntropyLoss(ignore_index=0, weight=weight)
optimizer = torch.optim.Adam(point_net.parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=0.001)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.5)

In [None]:
sum(p.numel() for p in point_net.parameters() if p.requires_grad)

In [None]:
from datetime import datetime
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now())
writer = SummaryWriter("runs/loss" + TIMESTAMP)

point_net.to(device)
for epoch in range(5000):
    if (epoch+1) % 4 == 0: # every 3 epochs update
        lr_scheduler.step()
    train_running_loss = 0.0
    val_running_loss = 0.0
    correct = 0.0
    total = 0
    break_signal= False
    for i, data in enumerate(training_dataloader, 0):
        #if epoch == 1 and i == 3000:
            #break_signal = True
        point_net.train()
        X, y = data
        X = X.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        
        y_pred = point_net(X)
               
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        
        train_running_loss += loss.item()
        
        if i % 200 == 199:
            with torch.no_grad():
                point_net.eval()
                for j, data2 in enumerate(val_dataloader, 0):
                    X_val, y_val = data2
                    X_val = X_val.to(device)
                    y_val = y_val.to(device)
                    y_pred = point_net(X_val)
                    val_running_loss += criterion(y_pred, y_val).item()
                    
                    _, preds = torch.max(y_pred, 1)
                    correct += preds.eq(y_val).sum().item()
                    total += y_val.size(0) * y_val.size(1)
        
            train_running_loss /= 200
            val_running_loss /= j
            correct /= total
        
            with open('loss.txt','a') as f:
                f.write("[Epoch %d, Iteration %5d] train_loss: %.3f acc: %.2f %% val_loss: %.3f\n" % 
                    (epoch+1, i+1, train_running_loss, 100*correct, val_running_loss))
        
            writer.add_scalars('loss', {'training_loss':train_running_loss,
                                        'val_loss':val_running_loss}, epoch * len(training_dataloader) + i)
        
            train_running_loss = 0.0
            val_running_loss = 0.0
            correct = 0.0
            total = 0
    if break_signal:
        break

    writer.flush()

In [None]:
torch.save(point_net_2.state_dict(), "point_net_2_1")

In [None]:
point_net = PointNet2(20)
point_net.load_state_dict(torch.load("point_net_2_1"))
point_net.to(device)

In [None]:
point_net_2.eval()

In [None]:
## build test/val results & confusion matrix
#c_matrix = np.zeros((20, 20))
scan = SemLaserScan(NUM_OF_CLASSES)
father_files = os.listdir(DATASET_PATH)
for father_file in father_files:
    points_path = DATASET_PATH + "/" + father_file + "/velodyne"
    points_files = os.listdir(points_path)
    labels_path = DATASET_PATH + "/"+ father_file + "/labels"
    if int(father_file) == 8:
        for points_file in points_files:
            correct = 0
            total = 0
            scan.open_scan(points_path + "/" + points_file)
            X_test = scan.points
            X_test = torch.FloatTensor(X_test)
            
            index = points_file.split(".")[0]
            scan.open_label(labels_path + "/" + index + ".label")
            y = scan.sem_label
            
            # to save memory of cuda, split it into 2 parts
            X1 = X_test[0: 50000, :]
            X2 = X_test[50000:, :]
            X1 = X1.to(device)
            X1 = X1.view(1, -1, 3)
            X1 = X1.permute(0, 2, 1)
            y1 = point_net_2(X1)
            _, preds1 = torch.max(y1, 1)
            
            X2 = X2.to(device)
            X2 = X2.view(1, -1, 3)
            X2 = X2.permute(0, 2, 1)
            y2 = point_net_2(X2)
            _, preds2 = torch.max(y2, 1)
           
            preds = torch.cat([preds1, preds2], dim=1)
            preds = preds.cpu()
            y = np.array(y, dtype=np.uint8)
            y = torch.LongTensor(y).view(1, -1)
            correct = preds.eq(y).sum().item()
            total = y.size(0) * y.size(1)
            
            preds = preds.cpu()
            # build confusion matrix -> rows are ground truth labels, columns are predicted labels
            for i in range(y.shape[0]):
                if y[i] == preds[i]:
                    c_matrix[y[i], y[i]] += 1
                else:
                    c_matrix[y[i], preds[i]] += 1
            np.array(preds, dtype=np.uint32).tofile(TESTLABEL_PATH + "/" + father_file + "/" + "predictions/" + index + ".label", sep="", format="%s")

In [None]:
## calculate accuracy
accuracy = []
for i in range(c_matrix.shape[0]):
    accuracy.append(c_matrix[i, i] / np.sum(c_matrix[i]))
accuracy = np.array(accuracy)
average_accuracy = np.sum(accuracy[1:]) / 19

In [None]:
print(accuracy)
print(average_accuracy)

In [None]:
########################## START POINT_NET++ #################################

In [10]:
def index_points(points, idx):
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points

In [11]:
def farthest_point_sample(xyz, num_centroids):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        num_centroids: number of samples(centroids)
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, num_centroids, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(num_centroids):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids

In [12]:
def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist

In [13]:
def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    sqrdists = square_distance(new_xyz, xyz)
    group_idx[sqrdists > radius ** 2] = N
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    mask = group_idx == N
    group_idx[mask] = group_first[mask]
    return group_idx

In [14]:
def sample_and_group(xyz, feature, num_centroids, num_neighbors, radius):
    """
    Input:
        xyz: input points position data, [B, N, 3]
        feature: input feature data, [B, N, D]
        num_centroids:
        num_neighbors:
        radius:
    Return:
        centroids: sampled points position data, [B, num_centroids, num_neighbors, 3]
        new_points: sampled position+feature data, [B, num_centroids, num_neighbors, 3 + D]
    """
    device = xyz.device
    B, N, C = xyz.shape
    S = num_centroids
    # get centroids
    centroids_index = farthest_point_sample(xyz, num_centroids) # [B, num_centroids]
    centroids = index_points(xyz, centroids_index)
    neighbors_index = query_ball_point(radius, num_neighbors, xyz, centroids)
    neighbors = index_points(xyz, neighbors_index) # [B, npoint, nsample, C]
    neighbors_norm = neighbors - centroids.view(B, S, 1, C)
    
    if feature is not None:
        feature_neighbors = index_points(feature, neighbors_index)
        new_points = torch.cat([neighbors_norm, feature_neighbors], dim = -1)
    else:
        new_points = neighbors_norm
    
    return centroids, new_points

In [15]:
class PointNetSetAbstraction(nn.Module):
    def __init__(self, num_centroids, radius, num_neighbors, in_channel, mlp):
        super(PointNetSetAbstraction, self).__init__()
        self.num_centroids = num_centroids
        self.radius = radius
        self.num_neighbors = num_neighbors
        
        self.conv_list = nn.ModuleList()
        #self.bn_list = nn.ModuleList()
        
        last_channel = in_channel
        for out_channel in mlp:
            self.conv_list.append(nn.Conv2d(last_channel, out_channel, 1, 1))
            #self.bn_list.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
        
    def forward(self, xyz, feature):
        # xyz: [B, C, N] feature: [B, D, N]
        xyz = xyz.permute(0, 2, 1)
        if feature is not None:
            feature = feature.permute(0, 2, 1)
        
        centroids, new_points = sample_and_group(xyz, feature, self.num_centroids, self.num_neighbors, self.radius)
        # new_points [B, num_centroids, num_neighbors, 3 + D]
        new_points = new_points.permute(0, 3, 2, 1) # [B, 3 + D, num_neighbors, num_centroids]
        for i, conv in enumerate(self.conv_list):
            #bn = self.bn_list[i]
            new_points =  F.relu((conv(new_points)), inplace=True)
        new_points = torch.max(new_points, 2)[0] # dim = 2 -> reduce the third dimension = num_neighbors
        new_xyz = centroids.permute(0, 2, 1) # from [B, N, C] to [B, C, N]
        return new_xyz, new_points

In [16]:
class PointNetFeaturePropagation(nn.Module):
    def __init__(self, in_channel, mlp):
        super(PointNetFeaturePropagation, self).__init__()
        self.conv_list = nn.ModuleList()
        #self.bn_list = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.conv_list.append(nn.Conv1d(last_channel, out_channel, 1))
            #self.bn_list.append(nn.BatchNorm1d(out_channel))
            last_channel = out_channel
    
    def forward(self, xyz1, xyz2, feature1, feature2):
        """
        Input:
            xyz1: input points position data, [B, C, N]
            xyz2: sampled input points position data, [B, C, S]
            feature1: input feature data, [B, D, N]
            feature2: input feature data, [B, D, S]
        Return:
            new_feature: upsampled feature data, [B, D', N]
        """
        xyz1 = xyz1.permute(0, 2, 1) # [B, N, C]
        xyz2 = xyz2.permute(0, 2, 1) # [B, S, C]
        
        B = xyz1.shape[0]
        N = xyz1.shape[1]
        S = xyz2.shape[1]
        D = feature2.shape[1]
        
        dists = square_distance(xyz1, xyz2)
        dists, idx = dists.sort(dim = -1) # [B, N, S]
        dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]
        
        dist_rev = 1.0 / (dists + 1e-8) # [B, N, 3]
        norm = torch.sum(dist_rev, dim = 2, keepdim = True) # [B, N, 3]
        weights = dist_rev / norm # [B, N, 3]
        
        feature2 = feature2.permute(0, 2, 1) # [B, S, D]
        interpolated_feature = torch.sum(index_points(feature2, idx) * weights.view(B, N, 3, 1), dim=2)
        
        if feature1 is None:
            new_feature = interpolated_feature
        else:
            feature1 = feature1.permute(0, 2, 1) # [B, N, D]
            new_feature = torch.cat([feature1, interpolated_feature], dim = -1) # [B, N, D + ?]
        
        new_feature = new_feature.permute(0, 2, 1) #[B, D + ?, N]
        
        for i, conv in enumerate(self.conv_list):
            #bn = self.bn_list[i]
            new_feature = F.relu((conv(new_feature)), inplace=True)
        return new_feature

In [17]:
class PointNet2(nn.Module):
    def __init__(self, num_classes):
        super(PointNet2, self).__init__()
        self.sa1 = PointNetSetAbstraction(1024, 2, 64, 3, [32, 32, 64])
        self.sa2 = PointNetSetAbstraction(512, 4, 64, 64 + 3, [64, 64, 128])
        self.sa3 = PointNetSetAbstraction(256, 6, 64, 128 + 3, [128, 128, 256])
        self.sa4 = PointNetSetAbstraction(128, 8, 64, 256 + 3, [256, 256, 512])
        self.fp4 = PointNetFeaturePropagation(768, [256, 256])
        self.fp3 = PointNetFeaturePropagation(384, [256, 256])
        self.fp2 = PointNetFeaturePropagation(320, [256, 128])
        self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])
        self.conv1 = nn.Conv1d(128, 128, 1)
        #self.bn1 = nn.BatchNorm1d(128)
        self.drop1 = nn.Dropout(0.5)
        self.conv2 = nn.Conv1d(128, num_classes, 1)
        self.num_classes = num_classes
        
    def forward(self, xyz):
        # xyz: [B, 3, N]
        xyz_0 = xyz[:,0:3,:]
        feature_0 = None
        B, N = xyz.shape[0], xyz.shape[2]
        
        xyz_1, feature_1 = self.sa1(xyz_0, feature_0)
        xyz_2, feature_2 = self.sa2(xyz_1, feature_1)
        xyz_3, feature_3 = self.sa3(xyz_2, feature_2)
        xyz_4, feature_4 = self.sa4(xyz_3, feature_3)
        
        feature_3 = self.fp4(xyz_3, xyz_4, feature_3, feature_4)
        feature_2 = self.fp3(xyz_2, xyz_3, feature_2, feature_3)
        feature_1 = self.fp2(xyz_1, xyz_2, feature_1, feature_2)
        feature_0 = self.fp1(xyz_0, xyz_1, feature_0, feature_1)
        
        output = self.conv1(feature_0)
        #output = self.bn1(output)
        output = self.drop1(F.relu(output))
        output = self.conv2(output)
        
        return output

In [18]:
point_net_2 = PointNet2(20)
criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=0)
optimizer = torch.optim.Adam(point_net_2.parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=0.001)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.5)

In [19]:
for m in point_net_2.modules():
    if isinstance(m, (nn.Conv2d, nn.Conv1d)):
        nn.init.kaiming_normal_(m.weight, mode='fan_in')

In [None]:
sum(p.numel() for p in point_net_2.parameters() if p.requires_grad)

In [None]:
from datetime import datetime
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now())
writer = SummaryWriter("runs/loss" + TIMESTAMP)

point_net_2.to(device)
for epoch in range(10):
    if (epoch+1) % 4 == 0:# every 3 epoch update
        lr_scheduler.step()
    train_running_loss = 0.0
    val_running_loss = 0.0
    correct = 0.0
    total = 0
    break_signal= False
    for i, data in enumerate(training_dataloader, 0):
        point_net_2.train()
        X, y = data
        X = X.permute(0, 2, 1)
        X = X.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        
        y_pred = point_net_2(X)
               
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        
        train_running_loss += loss.item()
        
        if i % 300 == 299:
            with torch.no_grad():
                point_net_2.eval()
                for j, data2 in enumerate(val_dataloader, 0):
                    if (((i+1)/300) % 2 == 0 and (j+1) % 2 == 0) or (((i+1)/300) % 2 > 0 and (j+1) % 2 > 0):
                        X_val, y_val = data2
                        X_val = X_val.permute(0, 2, 1)
                        X_val = X_val.to(device)
                        y_val = y_val.to(device)
                        y_pred = point_net_2(X_val)
                        val_running_loss += criterion(y_pred, y_val).item()
                    
                        _, preds = torch.max(y_pred, 1)
                        correct += preds.eq(y_val).sum().item()
                        total += y_val.size(0) * y_val.size(1)
        
            train_running_loss /= 300
            val_running_loss /= (j/2)
            correct /= total
        
            with open('loss.txt','a') as f:
                f.write("[Epoch %d, Iteration %5d] train_loss: %.3f acc: %.2f %% val_loss: %.3f\n" % 
                    (epoch+1, i+1, train_running_loss, 100*correct, val_running_loss))
        
            writer.add_scalars('loss', {'training_loss':train_running_loss,
                                        'val_loss':val_running_loss}, epoch * len(training_dataloader) + i)
        
            train_running_loss = 0.0
            val_running_loss = 0.0
            correct = 0.0
            total = 0
    if break_signal:
        break

    writer.flush()

In [None]:
point_net_2.eval()
#torch.save(point_net_2.state_dict(), "point_net_2_3")

In [None]:
point_net_2 = PointNet2(20)
point_net_2.to(device)
point_net_2.load_state_dict(torch.load("point_net_2_1"))
point_net_2.eval()
c_matrix = np.zeros((20, 20))
for j, data2 in enumerate(val_dataloader, 0):
    total = 0
    correct = 0.0
    X_val, y_val = data2
    X_val = X_val.permute(0, 2, 1)
    X_val = X_val.to(device)
    y_val = y_val.to(device)
    y_pred = point_net_2(X_val)
                    
    _, preds = torch.max(y_pred, 1)
    
    # build confusion matrix -> rows are ground truth labels, columns are predicted labels
    for i in range(y_val.shape[0]):
        for k in range(y_val.shape[1]):
            if y_val[i, k] == preds[i, k]:
                c_matrix[y_val[i, k], y_val[i, k]] += 1
            else:
                c_matrix[y_val[i, k], preds[i, k]] += 1

In [None]:
c_new = c_matrix[1:, 1:]
print(c_new.shape)

In [None]:
a = c_new[0,0]
sum = np.sum(c_new[0]) + np.sum(c_new[1:, 0])
print(a / sum)

In [None]:
a = c_new[1,1]
sum = np.sum(c_new[1]) + np.sum(c_new[2:, 1]) + c_new[0,1]
print(a / sum)

In [None]:
a = c_new[2,2]
sum = np.sum(c_new[2]) + np.sum(c_new[3:, 2]) + np.sum(c_new[0:1,2])
print(a / sum)

In [None]:
a = c_new[3,3]
sum = np.sum(c_new[3]) + np.sum(c_new[4:, 3]) + np.sum(c_new[0:2,3])
print(a / sum)

In [None]:
a = c_new[4,4]
sum = np.sum(c_new[4]) + np.sum(c_new[5:, 4]) + np.sum(c_new[0:3,4])
print(a / sum)

In [None]:
a = c_new[5,5]
sum = np.sum(c_new[5]) + np.sum(c_new[6:, 5]) + np.sum(c_new[0:4,5])
print(a / sum)

In [None]:
a = c_new[6,6]
sum = np.sum(c_new[6]) + np.sum(c_new[7:, 6]) + np.sum(c_new[0:5,6])
print(a / sum)

In [None]:
a = c_new[7,7]
sum = np.sum(c_new[7]) + np.sum(c_new[8:, 7]) + np.sum(c_new[0:6,7])
print(a / sum)

In [None]:
a = c_new[8,8]
sum = np.sum(c_new[8]) + np.sum(c_new[9:, 8]) + np.sum(c_new[0:7,8])
print(a / sum)

In [None]:
a = c_new[9,9]
sum = np.sum(c_new[9]) + np.sum(c_new[10:, 9]) + np.sum(c_new[0:8,9])
print(a / sum)

In [None]:
a = c_new[10,10]
sum = np.sum(c_new[10]) + np.sum(c_new[11:, 10]) + np.sum(c_new[0:9,10])
print(a / sum)

In [None]:
a = c_new[11,11]
sum = np.sum(c_new[11]) + np.sum(c_new[12:, 11]) + np.sum(c_new[0:10,11])
print(a / sum)

In [None]:
a = c_new[12,12]
sum = np.sum(c_new[12]) + np.sum(c_new[13:, 12]) + np.sum(c_new[0:11,12])
print(a / sum)

In [None]:
a = c_new[13,13]
sum = np.sum(c_new[13]) + np.sum(c_new[14:, 13]) + np.sum(c_new[0:12,13])
print(a / sum)

In [None]:
a = c_new[14,14]
sum = np.sum(c_new[14]) + np.sum(c_new[15:, 14]) + np.sum(c_new[0:13,14])
print(a / sum)

In [None]:
a = c_new[15,15]
sum = np.sum(c_new[15]) + np.sum(c_new[16:, 15]) + np.sum(c_new[0:14,15])
print(a / sum)

In [None]:
a = c_new[16,16]
sum = np.sum(c_new[16]) + np.sum(c_new[17:, 16]) + np.sum(c_new[0:15,16])
print(a / sum)

In [None]:
a = c_new[17,17]
sum = np.sum(c_new[17]) + np.sum(c_new[18:, 17]) + np.sum(c_new[0:16,17])
print(a / sum)

In [None]:
a = c_new[18,18]
sum = np.sum(c_new[18]) + np.sum(c_new[0:17,18])
print(a / sum)