In [1]:
import torch
import torch.nn as nn
import numpy as np
from pdb import set_trace

In [2]:
use_cuda = torch.cuda.is_available()
Float = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
Long = torch.cuda.LongTensor if use_cuda else torch.LongTensor
Int = torch.cuda.IntTensor if use_cuda else torch.IntTensor
Double = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor

In [3]:
def pairwise_distance(point_cloud):
    """
    Args:
    point_cloud: tensor (batch_size, num_points, num_dims)
    Returns:
    pairwise distance: (batch_size, num_points, num_points)
    """
    batch_size = point_cloud.size()[0]
    point_cloud = torch.squeeze(point_cloud)
    #if batch_size==1:
    #    point_cloud = point_cloud.unsqueeze(0)
    point_cloud_transpose = point_cloud.permute(0, 2, 1)
    point_cloud_inner = -2*torch.bmm(point_cloud, point_cloud_transpose)
    point_cloud_square = (point_cloud**2).sum(dim=-1, keepdim=True)
    point_cloud_square_transpose = point_cloud_square.permute(0, 2, 1)
    return point_cloud_square + point_cloud_inner + point_cloud_square_transpose

def knn(dist_mat, k=20):
    """
    Args:
    pairwise distance: (batch_size, num_points, num_points)
    k: int
    Returns:
    nearest neighbors: (batch_size, num_points, k)
    """
    _, nn_idx = torch.topk(dist_mat, k=k, largest=False, sorted=False)
    return nn_idx.cuda()

def get_edge_feature(point_cloud, nn_idx, k=20):

    batch_size = point_cloud.size()[0]
    point_cloud = torch.squeeze(point_cloud)
    
    if batch_size==1:
        point_cloud = point_cloud.unsqueeze(0)
        
    _,num_points,num_dims = point_cloud.size()

    idx_ = torch.arange(batch_size) * num_points
    idx_ = torch.autograd.Variable(idx_.view(batch_size, 1, 1).long())

    idx_ = idx_.cuda()
    
    # print(nn_idx, batch_size, nn_idx+idx_)

    point_cloud_flat = point_cloud.contiguous().view(-1, num_dims).cuda()
    # point_cloud_nbrs = torch.gather(point_cloud_flat, dim=0, index=nn_idx+idx_)
    point_cloud_nbrs = torch.index_select(point_cloud_flat, dim=0, index=(nn_idx+idx_).view(-1, 1).squeeze().cuda())
    point_cloud_nbrs = point_cloud_nbrs.view(batch_size,num_points,k,-1)

    # print(point_cloud_nbrs)


    point_cloud_central = point_cloud.unsqueeze(-2)
    point_cloud_central = point_cloud_central.expand(-1,-1,k,-1)
    # import pdb
    # pdb.set_trace()

    edge_feature = torch.cat((point_cloud_central, point_cloud_nbrs-point_cloud_central), dim=-1).cuda()
    return edge_feature



In [4]:
class input_transform_net(nn.Module):
    """docstring for input_transform_net"""
    def __init__(self, K=3):
        super(input_transform_net, self).__init__()

        self.conv1 = nn.Conv2d(6, 64, 1)
        self.conv2 = nn.Conv2d(64, 128, 1)
        self.conv3 = nn.Conv2d(128, 1024, 1)

        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, K*K)

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(1024)
        
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.const = torch.Tensor(torch.from_numpy(np.eye(K).flatten()).float()).cuda()
        self.K = K



    def forward(self, edge_feat):

        batch_size, num_points = edge_feat.size()[0], edge_feat.size()[2]

        self.mp1 = nn.MaxPool2d((num_points, 1), stride=2)

        x = self.bn1(F.relu(self.conv1(edge_feat)))
        x = self.bn2(F.relu(self.conv2(x)))
        x,_ = torch.max(x, dim=-1, keepdim=True)

        x = self.bn3(F.relu(self.conv3(x)))
        x = self.mp1(x)

        x = x.view(batch_size, -1)

        x = self.bn4(F.relu(self.fc1(x)))
        x = self.bn5(F.relu(self.fc2(x)))

        x = self.fc3(x) + self.const

        x = x.view(batch_size, self.K, self.K)

        return x


In [5]:
import torch.nn.functional as F
import torch.autograd as grad
import torch.optim as optim
class part_seg_net(nn.Module):
    """docstring for part_seg_net"""
    def __init__(self, part_num, k=30, cat_num=16):
        super(part_seg_net, self).__init__()
        self.conv1 = nn.Conv2d(6, 64, kernel_size=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=1)
        self.conv3 = nn.Conv2d(128, 64, kernel_size=1)
        self.conv4 = nn.Conv2d(128, 64, kernel_size=1)
        self.conv5 = nn.Conv2d(128, 64, kernel_size=1)
        self.conv6 = nn.Conv2d(128, 64, kernel_size=1)
        self.conv7 = nn.Conv2d(128, 64, kernel_size=1)
        self.conv8 = nn.Conv2d(192, 1024, kernel_size=1)
        self.conv9 = nn.Conv2d(cat_num, 128, kernel_size=1)

        self.conv10 = nn.Conv2d(2752, 256, kernel_size=1)
        self.conv11 = nn.Conv2d(256, 256, kernel_size=1)
        self.conv12 = nn.Conv2d(256, 128, kernel_size=1)
        self.conv13 = nn.Conv2d(128, part_num, kernel_size=1)

        
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(64)
        self.bn4 = nn.BatchNorm2d(64)
        self.bn5 = nn.BatchNorm2d(64)
        self.bn6 = nn.BatchNorm2d(64)
        self.bn7 = nn.BatchNorm2d(64)
        self.bn8 = nn.BatchNorm2d(1024)
        self.bn9 = nn.BatchNorm2d(128)
        self.bn10 = nn.BatchNorm2d(256)
        self.bn11 = nn.BatchNorm2d(256)
        self.bn12 = nn.BatchNorm2d(128)

        self.dropout = nn.Dropout(p=0.6)

        self.input_transform = input_transform_net().cuda() if torch.cuda.is_available() else input_transform_net()

        self.k = k
        self.part_num = part_num
        self.cat_num = cat_num

    def forward(self, point_cloud, object_label):

        batch_size, num_point,_ = point_cloud.size()
        input_img = point_cloud.unsqueeze(-1)

        self.mp1 = nn.MaxPool2d((num_point, 1), stride=2)

        dist_mat = pairwise_distance(point_cloud)
        nn_idx = knn(dist_mat, k=self.k)
        edge_feat = get_edge_feature(input_img, nn_idx=nn_idx, k=self.k)
        edge_feat = edge_feat.permute(0,3,1,2)
        # point_cloud = point_cloud.permute(0,2,1)

        transform_mat = self.input_transform(edge_feat)

        point_cloud_transformed = torch.bmm(point_cloud, transform_mat).cuda()
        input_img = point_cloud_transformed.unsqueeze(-1)
        
        dist_mat = pairwise_distance(point_cloud_transformed)
        nn_idx = knn(dist_mat, k=self.k)
        edge_feat = get_edge_feature(input_img, nn_idx=nn_idx, k=self.k)
        edge_feat = edge_feat.permute(0,3,1,2)


        out1 = self.bn1(F.relu(self.conv1(edge_feat)))
        out1 = self.bn2(F.relu(self.conv2(out1)))
        out_max1,_ = torch.max(out1, dim=-1, keepdim=True)
        out_mean1 = torch.mean(out1, dim=-1, keepdim=True)

        out3 = self.bn3(F.relu(self.conv3(torch.cat((out_max1, out_mean1), dim=1))))

        out = out3.permute(0,2,3,1)
        dist_mat = pairwise_distance(out)
        nn_idx = knn(dist_mat, k=self.k)
        edge_feat = get_edge_feature(out, nn_idx=nn_idx, k=self.k)
        edge_feat = edge_feat.permute(0,3,1,2)

        out = self.bn4(F.relu(self.conv4(edge_feat)))
        out_max2,_ = torch.max(out, dim=-1, keepdim=True)
        out_mean2 = torch.mean(out, dim=-1, keepdim=True)

        out5 = self.bn5(F.relu(self.conv5(torch.cat((out_max2, out_mean2), dim=1))))

        out = out5.permute(0,2,3,1)
        dist_mat = pairwise_distance(torch.squeeze(out, dim=-2))
        nn_idx = knn(dist_mat, k=self.k)
        edge_feat = get_edge_feature(out, nn_idx=nn_idx, k=self.k)
        edge_feat = edge_feat.permute(0,3,1,2)

        out = self.bn6(F.relu(self.conv6(edge_feat)))
        out_max3,_ = torch.max(out, dim=-1, keepdim=True)
        out_mean3 = torch.mean(out, dim=-1, keepdim=True)
        out7 = self.bn7(F.relu(self.conv7(torch.cat((out_max3, out_mean3), dim=1))))

        out8 = self.bn8(F.relu(self.conv8(torch.cat((out3, out5, out7), dim=1))))

        out_max = self.mp1(out8)

        one_hot_label_expand = object_label.view(batch_size, self.cat_num, 1, 1)
        one_hot_label_expand = self.bn9(F.relu(self.conv9(one_hot_label_expand)))
        out_max = torch.cat((out_max, one_hot_label_expand), dim=1)
        out_max = out_max.expand(-1,-1,num_point,-1)

        concat = torch.cat((out_max, out_max1, out_mean1,
                            out3, out_max2, out_mean2,
                            out5, out_max3, out_mean3,
                            out7, out8), dim=1)

        net2 = self.bn10(F.relu(self.conv10(concat)))
        net2 = self.dropout(net2)
        net2 = self.bn11(F.relu(self.conv11(net2)))
        net2 = self.dropout(net2)
        net2 = self.bn12(F.relu(self.conv12(net2)))
        net2 = self.conv13(net2)

        net2 = net2.view(batch_size, self.part_num, num_point, 1)
        net2 = F.log_softmax(net2, dim=1)


        return net2


In [6]:
model = part_seg_net(16, cat_num=1)
loss_fn = nn.NLLLoss()
optimizer = optim.Adam(model.parameters())

In [7]:
def convert_label_to_one_hot(labels):
  label_one_hot = np.zeros((labels.shape[0], np.max(labels)+1))
  for idx in range(labels.shape[0]):
    label_one_hot[idx, labels[idx]] = 1
  return label_one_hot

In [8]:
n_batches = 300

In [9]:
labels = np.ones((n_batches, 100))

In [10]:
#labels = [convert_label_to_one_hot(i) for i in labels]

In [11]:
labels = torch.LongTensor(labels).cuda()

In [12]:
point_cloud = torch.Tensor(np.random.randn(n_batches,100,3)).cuda()

In [13]:
model = model.cuda()
model.train()

part_seg_net(
  (conv1): Conv2d(6, 64, kernel_size=(1, 1), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
  (conv3): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
  (conv4): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
  (conv5): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
  (conv6): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
  (conv7): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
  (conv8): Conv2d(192, 1024, kernel_size=(1, 1), stride=(1, 1))
  (conv9): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))
  (conv10): Conv2d(2752, 256, kernel_size=(1, 1), stride=(1, 1))
  (conv11): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
  (conv12): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
  (conv13): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tr

In [14]:
n_epochs = 200
object_labels_fake = np.zeros((n_batches, 1), dtype='int')
object_labels_fake = [convert_label_to_one_hot(i) for i in object_labels_fake]
object_labels_fake = torch.Tensor(object_labels_fake).cuda()

In [15]:
total_loss = 0.0
total_seg_acc = 0.0
for i in range(n_epochs):

    point_clouds = point_cloud
    object_labels = object_labels_fake
    true_part_labels = labels

    optimizer.zero_grad()

    part_label_probs = model(point_clouds, object_labels)
    _,part_labels = torch.max(part_label_probs, dim=1)

    part_labels.squeeze_(-1)
    #acc = torch.sum(part_labels==true_part_labels)/float(n_batches*100)

    #part_label_probs = part_label_probs.permute(0, 2, 1, 3)
    part_label_probs = part_label_probs.squeeze(-1)
    loss = loss_fn(part_label_probs, true_part_labels)

    total_loss = loss.data[0]
    #total_seg_acc = acc.data[0]

    loss.backward()
    
    optimizer.step()
    
    print(f"loss: {total_loss}")
    #print(f"acc: {total_seg_acc}")



loss: 2.8719911575317383
loss: 2.8545351028442383
loss: 2.840911865234375
loss: 2.8298330307006836
loss: 2.818873405456543
loss: 2.8078479766845703
loss: 2.797025680541992
loss: 2.787278890609741
loss: 2.775991201400757
loss: 2.7661023139953613
loss: 2.755455732345581
loss: 2.743039131164551
loss: 2.73239803314209
loss: 2.7199881076812744
loss: 2.70757794380188
loss: 2.6940221786499023
loss: 2.681077003479004
loss: 2.6671035289764404
loss: 2.6523802280426025
loss: 2.63742995262146
loss: 2.6219027042388916
loss: 2.605048418045044
loss: 2.5880091190338135
loss: 2.570537567138672
loss: 2.551877498626709
loss: 2.5326383113861084
loss: 2.5130293369293213
loss: 2.4923181533813477
loss: 2.4709537029266357
loss: 2.4488160610198975
loss: 2.425750255584717
loss: 2.402177572250366
loss: 2.3775086402893066
loss: 2.352163076400757
loss: 2.326099157333374
loss: 2.2994720935821533
loss: 2.271369457244873
loss: 2.2428998947143555
loss: 2.2135677337646484
loss: 2.183628797531128
loss: 2.152439832687378

KeyboardInterrupt: 

In [None]:
#res = model(point_cloud, torch.Tensor(convert_label_to_one_hot(np.zeros((3, 100), dtype='int'))).cuda())

In [16]:
point_clouds = point_cloud
object_labels = object_labels_fake
true_part_labels = labels

optimizer.zero_grad()

part_label_probs = model(point_clouds, object_labels)
_,part_labels = torch.max(part_label_probs, dim=1)


In [17]:
part_labels[0][:10]

tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]], device='cuda:0')

In [18]:
true_part_labels[0][:10]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')