In [1]:
import torch
from torch import Tensor
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
from PIL import Image
from torch_cluster import knn_graph, fps
from math import sqrt
import matplotlib.pyplot as plt
import h5py
from torch_cluster import fps
from torch_cluster import knn
from datetime import datetime
from torch_geometric.datasets import ShapeNet
from collections import Counter

In [2]:
def distance(index, point_cloud):
        return torch.norm(point_cloud[index]-point_cloud,dim=-1)

def fps_complete(point_cloud, unique_indexes, M):
    t = len(unique_indexes)
    if t >= M :
        return unique_indexes

    fps_indexes = torch.zeros(M)
    fps_indexes[0:t] = unique_indexes
    minimum_distance = distance(fps_indexes[0], point_cloud)
    for i in range(1,t):
        minimum_distance = torch.minimum(minimum_distance, distance(fps_indexes[i], point_cloud))

    for i in range(t,M):
        fps_indexes[i] = minimum_distance.argmax()
        minimum_distance = torch.minimum(minimum_distance, distance(fps_indexes[i], point_cloud))
    return fps_indexes

class SmapleNet(nn.Module):
    def __init__(self, input_points, output_points, k, bottleneck, device, initial_temperature=1.0, is_temperature_trainable=True, skip_projection=False):
        super().__init__()

        self.input_points = input_points
        self.output_points = output_points
        self.k = k
        self.device = device
        self.training = True
        self.skip_projection = skip_projection
        self.temperature = torch.nn.Parameter(
            torch.tensor(
                initial_temperature,
                requires_grad=is_temperature_trainable,
                dtype=torch.float32,
            ))

        self.conv1 = nn.Sequential(nn.Conv1d(3,64,1), nn.BatchNorm1d(64), nn.PReLU())
        self.conv2 = nn.Sequential(nn.Conv1d(64,64,1), nn.BatchNorm1d(64), nn.PReLU())
        self.conv3 = nn.Sequential(nn.Conv1d(64,64,1), nn.BatchNorm1d(64), nn.PReLU())
        self.conv4 = nn.Sequential(nn.Conv1d(64,128,1), nn.BatchNorm1d(128), nn.PReLU())
        self.conv5 = nn.Sequential(nn.Conv1d(128,bottleneck,1), nn.BatchNorm1d(bottleneck), nn.PReLU())

        self.fc1 = nn.Sequential(nn.Linear(bottleneck,256), nn.BatchNorm1d(256), nn.PReLU())
        self.fc2 = nn.Sequential(nn.Linear(256,256), nn.BatchNorm1d(256), nn.PReLU())
        self.fc3 = nn.Sequential(nn.Linear(256,256), nn.BatchNorm1d(256), nn.PReLU())
        self.fc4 = nn.Sequential(nn.Linear(256,3*output_points))

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        # x must has shape B x N x C
        x = torch.swapaxes(x,1,2)
        y = self.conv1(x)
        y = self.conv2(y)
        y = self.conv3(y)
        y = self.conv4(y)
        y = self.conv5(y)
        y = y.max(dim=2)[0]
        y = self.fc1(y)
        y = self.fc2(y)
        y = self.fc3(y)
        y = self.fc4(y)
        y = y.reshape((-1,3,self.output_points)).swapaxes(1,2)

        if self.training:
            if not self.skip_projection:
                return self.softProjection(x.swapaxes(1,2),y,self.k)
            else:
                return y
        else:
            return self.matching(x.swapaxes(1,2),y,self.k)

    def softProjection(self,point_cloud, query_cloud, k):
        # point cloud and query cloud has shape B x N x C
        B, N, _ = point_cloud.shape
        B, M, _ = query_cloud.shape
        point_cloud = point_cloud.reshape((B*N,-1))
        query_cloud = query_cloud.reshape((B*M,-1))
        batch_x = torch.repeat_interleave(torch.arange(B),N).to(self.device)
        batch_y = torch.repeat_interleave(torch.arange(B),M).to(self.device)
        dist = knn(point_cloud, query_cloud, k, batch_x, batch_y)

        exponent = -torch.norm(point_cloud[dist[1]]-query_cloud[dist[0]],dim=-1).reshape(B,M,k) / (self.temperature ** 2)
        weights = self.softmax(exponent)
        projected_points = torch.sum(point_cloud[dist[1]].reshape(B,M,k,-1) * weights[:,:,:,None],dim=2)
        return projected_points

    def matching(self, point_cloud, query_cloud):
        # point cloud and query cloud has shape B x N x C
        B, N, _ = point_cloud.shape
        B, M, _ = query_cloud.shape
        point_cloud = point_cloud.reshape((B*N,-1))
        query_cloud = query_cloud.reshape((B*M,-1))
        batch_x = torch.repeat_interleave(torch.arange(B),N).to(self.device)
        batch_y = torch.repeat_interleave(torch.arange(B),M).to(self.device)
        dist = knn(point_cloud, query_cloud, 1, batch_x, batch_y)
        indexes = dist[1].reshape((B,M))
        sampled_points = torch.zeros((B,M), device=self.device)
        for b in range(B):
            unique_indexes = indexes[b].unique()
            sampled_points[b,:] = fps_complete(point_cloud[b],unique_indexes,M)

        return point_cloud[sampled_points.flatten()].reshape(B,M,-1)



In [3]:
def display_pointClouds(pc1:torch.Tensor, pc2:torch.Tensor):
    # pc1, pc2 has shape B x N x 3
    B, _, _ = pc1.shape
    fig = plt.figure(figsize=(8,120))
    for i in range(0,B):
        ax = fig.add_subplot(B,2,2*i+1,projection='3d')
        ax.scatter(pc1[i,:,0],pc1[i,:,1],pc1[i,:,2])
        ax = fig.add_subplot(B,2,2*i+2,projection='3d')
        ax.scatter(pc2[i,:,0],pc2[i,:,1],pc2[i,:,2])
    plt.show()

In [4]:
class MVPDataset(Dataset):
    def __init__(self,path):
        super(MVPDataset,self).__init__()

        self.file = h5py.File(path,'r')
        self.keys = list(self.file.keys())
        self.count = 26

    def __getitem__(self, index):
        partial_pc = torch.tensor(self.file[self.keys[1]][index][()])
        complete_pc = torch.tensor(self.file[self.keys[0]][index//self.count][()])
        N, D = partial_pc.shape
        indexes = (fps(partial_pc,ratio=0.25)%N)
        sampled_partial = partial_pc[indexes]
        return sampled_partial, complete_pc

    def __len__(self):
        return len(self.file[self.keys[1]][()])

In [5]:
batch_size = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

trainset = MVPDataset('./MVP_Train_CP.h5')
testset = MVPDataset('./MVP_Test_CP.h5')

trainLoader = DataLoader(trainset,batch_size,shuffle=True)
testLoader = DataLoader(testset,batch_size,shuffle=True)

In [9]:
queries = torch.rand(32,512,3)
keys = values = queries
atten = nn.MultiheadAttention(3,3,batch_first=True)
y = atten(queries,keys,values)

In [None]:
class multiHeadAttension(nn.Module):
    def __init__(self,in_channel, out_channel):
        super(multiHeadAttension,self).__init__()
        self.linq = 

In [6]:
class SRGCN(nn.Module):
    def __init__(self, optype:str, in_channel, out_channel, device, k=9, dilation=1, stride=1):
        super().__init__()
        assert optype in {'conv','max','avg'}
        self.optype = optype
        self.dilation = dilation
        self.k = k
        self.stride = stride
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.device = device

        self.attention = nn.MultiheadAttention()
        self.mlp = nn.Sequential(nn.Linear(2*in_channel,out_channel), nn.LeakyReLU(inplace=True), nn.Linear(out_channel,out_channel))

    def forward(self, x:torch.Tensor):
        # path has shape B x M
        path = self.traverse_graph(x)
        B, N, D = x.shape
        _, M = path.shape
        batches = torch.arange(B)[:,None]

        batch_tensor = torch.repeat_interleave(torch.arange(B),N).to(self.device)
        if self.dilation == 1:
            indexes = (knn_graph(x.reshape(B*N,-1),self.k,batch_tensor)%N)[0].reshape((B,N,-1))[batches,path]
        else:
            indexes = (knn_graph(x.reshape(B*N,-1),self.k * self.dilation,batch_tensor)%N)[0].reshape((B,N,-1))[batches,path][:,:,torch.arange(0,self.k*self.dilation,self.dilation)]

        message = x[batches,indexes.reshape((B,-1))].reshape((B,M,self.k,-1))
        if self.optype == 'conv':
            
            
            #aggr = (message - x[batches,path][:,:,None,:]).max(dim=-2)[0]
            #mlp_prep = torch.cat((x[batches,path],aggr),dim=-1).reshape((B*M,-1))
            #return self.mlp(mlp_prep).reshape(B,M,-1)
        elif self.optype == 'max':
            return torch.cat((x[batches,path][:,:,None,:],message),dim=-2).max(dim=-2)[0]
        else:
            return torch.cat((x[batches,path][:,:,None,:],message),dim=-2).mean(dim=-2)

    def traverse_graph(self, x:torch.Tensor):
        # x has shape B x N x D
        x = x.clone().detach()
        B, N, D = x.shape
        if self.stride > 1:
            current_vertex = torch.randint(N,size=(B,1)).to(device)
            path_length = 1
            path = [current_vertex.squeeze()]
            batches = torch.arange(B)[:,None].to(device)
            #batch_tensor = torch.repeat_interleave(torch.arange(B),N)
            while(path_length<N//self.stride):
                _, indexes = torch.topk(torch.norm(x[batches,current_vertex]-x,dim=-1),k=self.stride+1,largest=False)
                x[batches,indexes[:,0:self.stride]] = torch.inf
                current_vertex = indexes[batches,[-1]]
                path.append(current_vertex.squeeze())
                path_length += 1
            return torch.stack(path).swapaxes(1,0)
        else:
            return torch.arange(N).to(device).repeat(B,1)

In [37]:
x = torch.rand(32,256,64)
device = torch.device('cpu')
conv = SRGCN('conv',64,64,device,dilation=8)
y = conv(x)

In [8]:
class Upsampling(nn.Module):
    def __init__(self, in_channel:int, device, factor:int=2, k:int=9, D:int=4):
        super().__init__()
        self.factor = factor
        self.k = k
        self.D = D
        self.device = device
        self.mlp = nn.Sequential(nn.Linear(in_channel*k,in_channel*(factor-1)))


    def forward(self,x):
        # x has shape B x N x D
        B, N, C = x.shape
        nieghbor_num = (2**(self.D-1)) * self.k
        batch_tensor = torch.repeat_interleave(torch.arange(B),N).to(self.device)
        neighbours = (knn_graph(x.reshape(B*N,-1),k=nieghbor_num,batch=batch_tensor)[0]%N).reshape(B,N,-1)
        indexes = neighbours[:,:,torch.outer(torch.pow(2,torch.arange(self.D)),torch.arange(self.k))]
        batches = torch.arange(B)[:,None]
        prep = x[batches,indexes.reshape(B,-1)].reshape((B,N,self.D,-1))
        prep = self.mlp(prep) # has shape B x N x D x (factor-1)*c
        new_vertices = torch.max(prep,dim=-2)[0] # has shape B x N x (factor-1)*c
        return torch.cat((x,new_vertices),dim=-1).reshape((B,-1,C))

class upScaling(nn.Module):
    def __init__(self, in_channel, out_channel, factor, device):
        super().__init__()
        self.factor = factor
        self.conv = SRGCN('conv',in_channel,out_channel,device)

    def forward(self,x:torch.Tensor):
        x = x.repeat_interleave(self.factor,dim=1)
        return self.conv(x)


class batchNorm(nn.Module):
    def __init__(self, in_channels, eps=0.00001):
        super(batchNorm,self).__init__()
        self.b = nn.BatchNorm1d(in_channels,eps)
    def forward(self,x:torch.Tensor):
        B, N, D = x.shape
        x = x.swapaxes(1,2)
        x = self.b(x)
        return x.swapaxes(1,2)

class DownBlock(nn.Module):
    def __init__(self, in_channel, out_channel, device, stride=1):
        super().__init__()
        self.conv1 = SRGCN('conv',in_channel,out_channel, device, stride=stride)
        self.batch1 = batchNorm(out_channel)
        self.conv2 = SRGCN('conv',out_channel,out_channel,device)
        self.batch2 = batchNorm(out_channel)
        self.relu = nn.PReLU()
        self.id = None
        if stride>1:
            self.id =nn.Sequential(SRGCN('conv',in_channel,out_channel,device,stride=stride),batchNorm(out_channel))
    def forward(self,x):
        identety = x
        x = self.conv1(x)
        x = self.batch1(x)
        x = self.conv2(x)
        x = self.batch2(x)
        if self.id is not None:
            identety = self.id(identety)
        return self.relu(x+identety)

class UpBlock(nn.Module):
    def __init__(self,in_channel, out_channel, device, factor=2, D=4, k=9, last_block=False):
        super().__init__()
        self.last_block = last_block
        self.conv = SRGCN('conv',in_channel,out_channel,device)
        self.batch1 = batchNorm(out_channel)

        self.up = Upsampling(out_channel,device,factor,k,D)
        self.batch2 = batchNorm(out_channel)

        self.relu = nn.PReLU()
        self.id = None
        if factor>1:
            self.id = nn.Sequential(upScaling(in_channel,out_channel,factor,device),batchNorm(out_channel))

    def forward(self,x):
        identity = x
        x = self.conv(x)
        x = self.batch1(x)
        x = self.up(x)
        x = self.batch2(x)
        if self.id is not None:
            identity = self.id(identity)
        if (self.last_block):
            return x + identity
        return self.relu(x+identity)

In [9]:
class PointCloudCompletion(nn.Module):
    def __init__(self,device):
        super().__init__()
        self.conv1 = SRGCN('conv',3,64,device,k=49)
        self.relu1 = nn.PReLU()
        self.maxpool = SRGCN('max',64,64,device,k=3,stride=2) # 256 x 64
        self.batch1 = batchNorm(64)
        self.down1 = DownBlock(64,128,device,stride=2) # 128 x 128
        self.down2 = DownBlock(128,256,device,stride=2) # 64 x 256
        self.down3 = DownBlock(256,512,device,stride=2) # 32 x 512
        self.conv2 = SRGCN('conv',512,512,device)
        self.batch2 = batchNorm(512)
        self.up1 = UpBlock(512,256,device,2,k=3) # 64 x 256
        self.up2 = UpBlock(512,128,device,2,k=5) # 128 x 128
        self.up3 = UpBlock(256,64,device,2,k=7) # 256 x 64
        self.up4 = UpBlock(64,3,device,2,k=9) # 512 x 3
        self.up5 = UpBlock(3,3,device,2,k=11) # 1024 x 3
        self.up6 = UpBlock(3,3,device,2,last_block=True,k=11) # 2048 x 3

    def forward(self,x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool(x)
        x = self.batch1(x)
        x = self.down1(x)
        down1 = x
        x = self.down2(x)
        down2 = x
        x = self.down3(x)
        x = self.conv2(x)
        x = self.batch2(x)
        x = self.up1(x)
        x = self.up2(torch.cat((x,down2),dim=-1))
        x = self.up3(torch.cat((x,down1),dim=-1))
        x = self.up4(x)
        x = self.up5(x)
        x = self.up6(x)
        return x

In [10]:
class ChafmerDistance(nn.Module):
    def __init__(self,device):
        super(ChafmerDistance,self).__init__()
        self.device = device

    def forward(self,P:torch.Tensor,G:torch.Tensor):
        B, N, C = P.shape
        batches = torch.arange(B)[:,None]
        batch_tensor = torch.repeat_interleave(torch.arange(B),N).to(device)

        indexes = (knn(G.reshape(B*N,-1),P.reshape(B*N,-1),1,batch_tensor,batch_tensor)[1]%N).reshape(B,N,-1).squeeze()
        termP_G = torch.norm(G[batches,indexes] - P,dim=-1).sum(dim=-1) /N

        indexes = (knn(P.reshape(B*N,-1),G.reshape(B*N,-1),1,batch_tensor,batch_tensor)[1]%N).reshape(B,N,-1).squeeze()
        termG_P = torch.norm(P[batches,indexes] - G,dim=-1).sum(dim=-1) /N

        return (termG_P + termP_G).mean()

In [None]:
learning_rate = 0.0001
epoch_nums = 1
model = PointCloudCompletion(device).to(device)
loss_criterion = ChafmerDistance(device).to(device)
b1 = 0.5
b2 = 0.999
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, betas=(b1,b2))
batch_num = len(trainLoader)
losses = []
for epoch in range(epoch_nums):  # loop over the dataset multiple times

    running_loss = 0.0
    tqdm_bar = tqdm(trainLoader, desc=f'Training Epoch {epoch} ', total=int(len(trainLoader)))
    model.train()
    for i, data in enumerate(tqdm_bar):
        ## FILL HERE
        ## You should train the model and also print the running loss for each batch
        partial_pc, complete_pc  = data
        partial_pc = partial_pc.to(device)
        complete_pc = complete_pc.to(device)
        predicted_pc = model(partial_pc)
        optimizer.zero_grad()
        loss = loss_criterion(predicted_pc, complete_pc)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    running_loss /= batch_num
    losses.append(running_loss)
    print('epoch : ',epoch,', loss : ',running_loss)
print('Finished Training')
