In [1]:
import torch
import torch.nn as nn
import numpy as np
from pdb import set_trace

In [2]:
use_cuda = torch.cuda.is_available()
Float = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
Long = torch.cuda.LongTensor if use_cuda else torch.LongTensor
Int = torch.cuda.IntTensor if use_cuda else torch.IntTensor
Double = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor
torch.backends.cudnn.benchmark = False


In [3]:
x_train = []
x_test = []
y_train = []
y_test = []

scale_factor = 1

with open ('norm_head_american.txt', 'r') as file:
    
    for line in file.readlines():
        line = line.split()
        x_train += [np.array([float(line[0])*scale_factor, 
                              float(line[1])*scale_factor, 
                              float(line[2])*scale_factor])]
        y_train += [np.array(int(line[3]))]
        
with open ('norm_head_korean.txt', 'r') as file:
    
    for line in file.readlines():
        line = line.split()
        x_test += [np.array([float(line[0])*scale_factor, 
                              float(line[1])*scale_factor, 
                              float(line[2])*scale_factor])]
        y_test += [np.array(int(line[3]))]


In [4]:
def pairwise_distance(point_cloud):
    """
    Args:
    point_cloud: tensor (batch_size, num_points, num_dims)
    Returns:
    pairwise distance: (batch_size, num_points, num_points)
    """
    batch_size = point_cloud.size()[0]
    point_cloud = torch.squeeze(point_cloud)
    if batch_size==1:
        point_cloud = point_cloud.unsqueeze(0)
    point_cloud_transpose = point_cloud.permute(0, 2, 1)
    point_cloud_inner = -2*torch.bmm(point_cloud, point_cloud_transpose)
    point_cloud_square = (point_cloud**2).sum(dim=-1, keepdim=True)
    point_cloud_square_transpose = point_cloud_square.permute(0, 2, 1)
    return point_cloud_square + point_cloud_inner + point_cloud_square_transpose

def knn(dist_mat, k=20):
    """
    Args:
    pairwise distance: (batch_size, num_points, num_points)
    k: int
    Returns:
    nearest neighbors: (batch_size, num_points, k)
    """
    _, nn_idx = torch.topk(dist_mat, k=k, largest=False, sorted=False)
    return nn_idx.cuda()

def get_edge_feature(point_cloud, nn_idx, k=20):

    batch_size = point_cloud.size()[0]
    point_cloud = torch.squeeze(point_cloud)
    
    if batch_size==1:
        point_cloud = point_cloud.unsqueeze(0)
        
    _,num_points,num_dims = point_cloud.size()

    idx_ = torch.arange(batch_size) * num_points
    idx_ = torch.autograd.Variable(idx_.view(batch_size, 1, 1).long())

    idx_ = idx_.cuda()
    
    # print(nn_idx, batch_size, nn_idx+idx_)

    point_cloud_flat = point_cloud.contiguous().view(-1, num_dims).cuda()
    # point_cloud_nbrs = torch.gather(point_cloud_flat, dim=0, index=nn_idx+idx_)
    point_cloud_nbrs = torch.index_select(point_cloud_flat, dim=0, index=(nn_idx+idx_).view(-1, 1).squeeze().cuda())
    point_cloud_nbrs = point_cloud_nbrs.view(batch_size,num_points,k,-1)

    # print(point_cloud_nbrs)


    point_cloud_central = point_cloud.unsqueeze(-2)
    point_cloud_central = point_cloud_central.expand(-1,-1,k,-1)
    # import pdb
    # pdb.set_trace()

    edge_feature = torch.cat((point_cloud_central, point_cloud_nbrs-point_cloud_central), dim=-1).cuda()
    return edge_feature

def convert_label_to_one_hot(labels):
  label_one_hot = np.zeros((labels.shape[0], np.max(labels)+1))
  for idx in range(labels.shape[0]):
    label_one_hot[idx, labels[idx]] = 1
  return label_one_hot


def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]

In [5]:
class input_transform_net(nn.Module):
    """docstring for input_transform_net"""
    def __init__(self, K=3):
        super(input_transform_net, self).__init__()

        self.conv1 = nn.Conv2d(6, 64, 1)
        self.conv2 = nn.Conv2d(64, 128, 1)
        self.conv3 = nn.Conv2d(128, 1024, 1)

        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, K*K)

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(1024)
        
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.const = torch.Tensor(torch.from_numpy(np.eye(K).flatten()).float()).cuda()
        self.K = K



    def forward(self, edge_feat):

        batch_size, num_points = edge_feat.size()[0], edge_feat.size()[2]

        self.mp1 = nn.MaxPool2d((num_points, 1), stride=2)

        x = self.bn1(F.relu(self.conv1(edge_feat)))
        x = self.bn2(F.relu(self.conv2(x)))
        x,_ = torch.max(x, dim=-1, keepdim=True)

        x = self.bn3(F.relu(self.conv3(x)))
        x = self.mp1(x)

        x = x.view(batch_size, -1)

        x = self.bn4(F.relu(self.fc1(x)))
        x = self.bn5(F.relu(self.fc2(x)))

        x = self.fc3(x) + self.const

        x = x.view(batch_size, self.K, self.K)

        return x


In [6]:
import torch.nn.functional as F
import torch.autograd as grad
import torch.optim as optim
class part_seg_net(nn.Module):
    """docstring for part_seg_net"""
    def __init__(self, part_num, k=30, cat_num=16):
        super(part_seg_net, self).__init__()
        self.conv1 = nn.Conv2d(6, 64, kernel_size=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=1)
        self.conv3 = nn.Conv2d(128, 64, kernel_size=1)
        self.conv4 = nn.Conv2d(128, 64, kernel_size=1)
        self.conv5 = nn.Conv2d(128, 64, kernel_size=1)
        self.conv6 = nn.Conv2d(128, 64, kernel_size=1)
        self.conv7 = nn.Conv2d(128, 64, kernel_size=1)
        self.conv8 = nn.Conv2d(192, 1024, kernel_size=1)
        self.conv9 = nn.Conv2d(cat_num, 128, kernel_size=1)

        self.conv10 = nn.Conv2d(2752, 256, kernel_size=1)
        self.conv11 = nn.Conv2d(256, 256, kernel_size=1)
        self.conv12 = nn.Conv2d(256, 128, kernel_size=1)
        self.conv13 = nn.Conv2d(128, part_num, kernel_size=1)

        
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(64)
        self.bn4 = nn.BatchNorm2d(64)
        self.bn5 = nn.BatchNorm2d(64)
        self.bn6 = nn.BatchNorm2d(64)
        self.bn7 = nn.BatchNorm2d(64)
        self.bn8 = nn.BatchNorm2d(1024)
        self.bn9 = nn.BatchNorm2d(128)
        self.bn10 = nn.BatchNorm2d(256)
        self.bn11 = nn.BatchNorm2d(256)
        self.bn12 = nn.BatchNorm2d(128)

        self.dropout = nn.Dropout(p=0.2)

        self.input_transform = input_transform_net().cuda() if torch.cuda.is_available() else input_transform_net()

        self.k = k
        self.part_num = part_num
        self.cat_num = cat_num

    def forward(self, point_cloud, object_label):

        batch_size, num_point,_ = point_cloud.size()
        input_img = point_cloud.unsqueeze(-1)

        self.mp1 = nn.MaxPool2d((num_point, 1), stride=2)

        dist_mat = pairwise_distance(point_cloud)
        nn_idx = knn(dist_mat, k=self.k)
        edge_feat = get_edge_feature(input_img, nn_idx=nn_idx, k=self.k)
        edge_feat = edge_feat.permute(0,3,1,2)
        # point_cloud = point_cloud.permute(0,2,1)

        transform_mat = self.input_transform(edge_feat)

        point_cloud_transformed = torch.bmm(point_cloud, transform_mat).cuda()
        input_img = point_cloud_transformed.unsqueeze(-1)
        
        dist_mat = pairwise_distance(point_cloud_transformed)
        nn_idx = knn(dist_mat, k=self.k)
        edge_feat = get_edge_feature(input_img, nn_idx=nn_idx, k=self.k)
        edge_feat = edge_feat.permute(0,3,1,2)


        out1 = self.bn1(F.relu(self.conv1(edge_feat)))
        out1 = self.bn2(F.relu(self.conv2(out1)))
        out_max1,_ = torch.max(out1, dim=-1, keepdim=True)
        out_mean1 = torch.mean(out1, dim=-1, keepdim=True)

        out3 = self.bn3(F.relu(self.conv3(torch.cat((out_max1, out_mean1), dim=1))))

        out = out3.permute(0,2,3,1)
        dist_mat = pairwise_distance(out)
        nn_idx = knn(dist_mat, k=self.k)
        edge_feat = get_edge_feature(out, nn_idx=nn_idx, k=self.k)
        edge_feat = edge_feat.permute(0,3,1,2)

        out = self.bn4(F.relu(self.conv4(edge_feat)))
        out_max2,_ = torch.max(out, dim=-1, keepdim=True)
        out_mean2 = torch.mean(out, dim=-1, keepdim=True)

        out5 = self.bn5(F.relu(self.conv5(torch.cat((out_max2, out_mean2), dim=1))))

        out = out5.permute(0,2,3,1)
        dist_mat = pairwise_distance(torch.squeeze(out, dim=-2))
        nn_idx = knn(dist_mat, k=self.k)
        edge_feat = get_edge_feature(out, nn_idx=nn_idx, k=self.k)
        edge_feat = edge_feat.permute(0,3,1,2)

        out = self.bn6(F.relu(self.conv6(edge_feat)))
        out_max3,_ = torch.max(out, dim=-1, keepdim=True)
        out_mean3 = torch.mean(out, dim=-1, keepdim=True)
        out7 = self.bn7(F.relu(self.conv7(torch.cat((out_max3, out_mean3), dim=1))))

        out8 = self.bn8(F.relu(self.conv8(torch.cat((out3, out5, out7), dim=1))))

        out_max = self.mp1(out8)

        one_hot_label_expand = object_label.view(batch_size, self.cat_num, 1, 1)
        one_hot_label_expand = self.bn9(F.relu(self.conv9(one_hot_label_expand)))
        out_max = torch.cat((out_max, one_hot_label_expand), dim=1)
        out_max = out_max.expand(-1,-1,num_point,-1)

        concat = torch.cat((out_max, out_max1, out_mean1,
                            out3, out_max2, out_mean2,
                            out5, out_max3, out_mean3,
                            out7, out8), dim=1)

        net2 = self.bn10(F.relu(self.conv10(concat)))
        net2 = self.dropout(net2)
        net2 = self.bn11(F.relu(self.conv11(net2)))
        net2 = self.dropout(net2)
        net2 = self.bn12(F.relu(self.conv12(net2)))
        net2 = self.conv13(net2)

        net2 = net2.view(batch_size, self.part_num, num_point, 1)
        net2 = F.softmax(net2, dim=1)


        return net2


In [7]:
#labels = np.ones((n_batches_train, 100))
"""TODO: Find a less awkward way of dealing with batches. Right now n_batch has to be a divisor of the number of points"""
labels = np.array(y_train)
#labels = np.reshape(labels, (n_batches_train, int(labels.shape[1]//n_batches_train)))

labels_test = np.array(y_test)
#labels_test = np.reshape(labels_test, (n_batches, int(labels_test.shape[1]//n_batches)))

In [8]:
point_cloud = np.array(x_train)
#point_cloud = np.reshape(point_cloud, (n_batches_train, point_cloud.shape[0]//n_batches_train, 3))

point_cloud_test = np.array(x_test)
#point_cloud_test = np.reshape(point_cloud_test, (n_batches, point_cloud_test.shape[0]//n_batches, 3))

In [9]:
#labels = [convert_label_to_one_hot(i) for i in labels]
labels = np.array(labels)

In [10]:
labels.shape

(109527,)

In [11]:
point_cloud.shape

(109527, 3)

In [12]:
#point_cloud = torch.Tensor(point_cloud).cuda()
#labels = torch.LongTensor(labels).cuda()

#point_cloud_test = torch.Tensor(point_cloud_test).cuda()
#labels_test = torch.LongTensor(labels_test).cuda()


In [13]:
#point_cloud = torch.Tensor(np.random.randn(n_batches_train,100,3)).cuda()

In [14]:
labels, point_cloud = unison_shuffled_copies(labels, point_cloud)

In [15]:
len(labels)

109527

In [16]:
model = part_seg_net(6, cat_num=1, k=8).cuda()
model.train()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), amsgrad=True)

In [17]:
mini_batch_size = 4096
n_of_points = 8

batches = []
batches_labels = []

for i in range(0, point_cloud.shape[0], mini_batch_size):
    
    curr_mini_batch = []
    curr_mini_batch_labels = []
        
    for k in (range(mini_batch_size)):
        
        curr_item = []
        curr_item_labels = []
        
        for j in range(0, n_of_points):
            
            if i*n_of_points+(k*n_of_points)+j >= point_cloud.shape[0]:
                break

            curr_item += [point_cloud[i*n_of_points+(k*n_of_points)+j]]
            curr_item_labels += [labels[i*n_of_points+(k*n_of_points)+j]]
            
        if i*n_of_points+(k*n_of_points)+j >= point_cloud.shape[0]:
            break

        curr_mini_batch += [np.array(curr_item)]
        curr_mini_batch_labels += [np.array(curr_item_labels)]
        
    if (curr_mini_batch == []):
        break
        
    batches += [np.array(curr_mini_batch)]
    batches_labels += [np.array(curr_mini_batch_labels)]
        
    
batches = [torch.Tensor(i).cuda() for i in batches]
batches_labels = [torch.LongTensor(i).cuda() for i in batches_labels]

In [18]:
#prova = np.array(prova)
#prova = [np.array(i) for i in prova]
#prova = np.array(prova)

In [19]:
objects_label_fake = torch.zeros(mini_batch_size, 1, device='cuda')
n_epochs = 300

In [None]:
from tqdm import tnrange, tqdm_notebook
from time import sleep

total_loss = 0.0
total_seg_acc = 0.0

t1 = tqdm_notebook(range(n_epochs))

for i in t1:
    
    t2 = tqdm_notebook(range(len(batches)))
    for j in t2:
        
        batch_to_feed = batches[j]
        object_labels = objects_label_fake[0:batch_to_feed.size()[0]]
        true_part_labels = batches_labels[j]
        optimizer.zero_grad()

        part_label_probs = model(batch_to_feed, object_labels)
        _,part_labels = torch.max(part_label_probs, dim=1)
        
        part_labels.squeeze_(-1)
        #acc = torch.sum(part_labels==true_part_labels)/float(n_batches*100)

        #part_label_probs = part_label_probs.permute(0, 2, 1, 3)
        part_label_probs = part_label_probs.squeeze(-1)
        loss = loss_fn(part_label_probs, true_part_labels)

        total_loss = loss.data[0]
        #total_seg_acc = acc.data[0]
        loss.backward()

        optimizer.step()
        
        t2.set_description(f"Current loss: {total_loss}\t\t")
        
    t1.set_description(f"Current epoch: {i}")



HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))



HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

In [None]:
#res = model(point_cloud, torch.Tensor(convert_label_to_one_hot(np.zeros((3, 100), dtype='int'))).cuda())

In [None]:
#torch.save(model, "checkpoint.pkl")
#model = torch.load("checkpoint.pkl")
model.eval();
#model.cuda();

In [None]:
from tqdm import tqdm_notebook

res = []

t2 = tqdm_notebook(range(len(batches)))

for j in t2:
        
    batch_to_feed = batches[j]
    object_labels = objects_label_fake[0:batch_to_feed.size()[0]]
    part_label_probs = model(batch_to_feed, object_labels)
    
    res += [part_label_probs.detach_()]
    
    torch.cuda.empty_cache()


In [None]:
n_right = 0

for batch in tqdm_notebook(range(len(batches_labels))):
    for minibatch in range(mini_batch_size):
        for point in range(n_of_points):
            
            if batch == len(batches_labels)-1:
                minibatch = res[-1].shape[0]-1
            
            if (torch.argmax(res[batch][minibatch,:,point,0])) == batches_labels[batch][minibatch,point]:
                n_right+=1
    
#torch.argmax(res[batch][minibatchitem,:,point,0])

In [None]:
n_right / len(point_cloud)