In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pyvista as pv
import numpy as np
import torch.optim as optim


from torch.autograd import Variable
from cls.data import ModelNet40
from util import cal_loss
import cls._init_path
from cls.model_mixup import PointNet, DGCNN, Pointnet2_MSG


parser = argparse.ArgumentParser(description='Point Cloud Recognition')
parser.add_argument('--model', type=str, default='dgcnn', metavar='N',
                    choices=['pointnet', 'dgcnn', 'pointnet2_MSG'],
                    help='Model to use, [pointnet, dgcnn]')
parser.add_argument('--data', type=str, default='MN40', metavar='N',
                    choices=['MN40', 'SONN_EASY', 'SONN_HARD'])
parser.add_argument('--batch_size', type=int, default=16, metavar='batch_size',
                    help='Size of batch)')
parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size',
                    help='Size of batch)')
parser.add_argument('--epochs', type=int, default=250, metavar='N',
                    help='number of episode to train ')
parser.add_argument('--optim', type=str, default="sgd",
                    choices=['sgd', 'adam'],
                    help='Optimizer, [sgd, adam]')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                    help='learning rate (default: 0.001, 0.1 if using sgd)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
parser.add_argument('--scheduler', type=str, default='cos', metavar='N',
                    choices=['cos', 'step'],
                    help='Scheduler to use, [cos, step]')
parser.add_argument('--no_cuda', type=bool, default=False,
                    help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--eval', type=bool,  default=False,
                    help='evaluate the model')
parser.add_argument('--num_points', type=int, default=1024,
                    help='num of points to use')
parser.add_argument('--dropout', type=float, default=0.5,
                    help='initial dropout rate')
parser.add_argument('--emb_dims', type=int, default=1024, metavar='N',
                    help='Dimension of embeddings')
parser.add_argument('--k', type=int, default=20, metavar='N',
                    help='Num of nearest neighbors to use')
parser.add_argument('--model_path', type=str, default='', metavar='N',
                    help='Pretrained model path')

parser.add_argument('--aug', type=str, default='default', metavar='N',
                    choices=['default', 'MN40'])
parser.add_argument("--kermix", type=bool, default= True)    
parser.add_argument("--manimix", type=bool, default=False) 
parser.add_argument('--sigma', type=float, default=0.3) 
parser.add_argument('--beta', type=float, default=5.)  
parser.add_argument('--no_saliency', action='store_true')
parser.add_argument('--smoothing_k', type=int, default=20) 
parser.add_argument('--temperature', type=float, default=2)
parser.add_argument('--temperature2', type=float, default=1)  
parser.add_argument('--sample_ver', type=int, default=3) 
args = parser.parse_args([])


num_class = 40

train_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=8,
                                batch_size=args.batch_size, shuffle=True, drop_last=False)
device = torch.device("cuda")

model = DGCNN(args, num_class).to(device)
# model = PointNet(args, num_class).to(device)
model =  nn.DataParallel(model)
model.load_state_dict(torch.load("cls/outputs/base_dgcnn/models/model.t7"))
# model.load_state_dict(torch.load("cls/outputs/pointnet_base_MN40/models/model.t7"))
model = model.module
model.train()


opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4)

criterion2 = cal_loss
for data, label in train_loader:
    data, label = data.to(device), label.to(device).squeeze()
    batch_size = data.size()[0]

    model.eval()
    data_var = Variable(data.permute(0, 2, 1), requires_grad=True)
    logits, _ = model(data_var, mixup=False)
    loss = criterion2(logits, label, smoothing=False)
    loss.backward()
    opt.zero_grad()
    model.train()
    saliency = torch.sqrt(torch.mean(data_var.grad**2,1))

#     logits, _, temp = model(data.permute(0,2,1), label, mixup=True, saliency=saliency, get_mix=True)
    break
   
    

In [2]:
from PointMixup.emd_ import emd_module

def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
 
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)
    return idx

class namemix:
    def __init__(self, args, num_class=40):
        self.num_class = num_class
        self.smoothing_k = args.smoothing_k
        self.temperature = args.temperature
        self.temperature2 = args.temperature2
        self.sample_ver = args.sample_ver
        
        # if self.kermix:
        self.EMD = emd_module.emdModule()
        self.sigma = args.sigma
        if args.beta != -1:
            self.beta = torch.distributions.beta.Beta(torch.tensor([args.beta]), torch.tensor([args.beta]))
        else:
            self.beta = None
    
    def mixup(self, xyz, label, x=None, saliency=None):
        """
        Args:
            xyz (B,N,3): 
            label (B): 
            x (B,D,N): Defaults to None.
            saliency (B,N): Defaults to None.
        """        
        
        B, N, _ = xyz.shape
        idxs = torch.randperm(B)
        perm = xyz[idxs]
        
        _, ass = self.EMD(xyz, perm, 0.005, 500) # mapping
        ass = ass.long()
        perm_new = torch.zeros_like(perm).cuda()
        if saliency is not None:
            saliency = self.saliency_smoothing(xyz, saliency)
            saliency_perm = torch.zeros_like(saliency).cuda()
        
        for i in range(B):
            perm_new[i] = perm[i][ass[i]]
            if saliency is not None:
                saliency_perm[i] = saliency[idxs][i][ass[i]]

        #random_sampling
        if saliency is None:
            anc_idx = torch.randperm(self.args.num_points)[:2]
            anchor_ori, anchor_perm = xyz[:,anc_idx[0],:], perm_new[:,anc_idx[1], :]
        #weighted sampling
        else:
            #max_to_max
            if self.sample_ver == 0:
                saliency = saliency/saliency.sum(-1, keepdim=True)
                saliency_perm = saliency_perm/saliency_perm.sum(-1, keepdim=True)
                anc_idx = torch.multinomial(saliency, 1, replacement=True)
                anc_idx2 = torch.multinomial(saliency_perm, 1, replacement=True)
                anchor_ori = xyz[torch.arange(B),anc_idx[:,0]]
                anchor_perm = perm_new[torch.arange(B),anc_idx2[:,0]]
            
            #max_to_min
            elif self.sample_ver == 1:
                saliency_perm =1/saliency_perm
                
                saliency = saliency/saliency.sum(-1, keepdim=True)
                saliency_perm = saliency_perm/saliency_perm.sum(-1, keepdim=True)
                
                anc_idx = torch.multinomial(saliency, 1, replacement=True)
                anc_idx2 = torch.multinomial(saliency_perm, 1, replacement=True)
                anchor_ori = xyz[torch.arange(B),anc_idx[:,0]]
                anchor_perm = perm_new[torch.arange(B),anc_idx2[:,0]]
            
            #min_to_min
            elif self.sample_ver == 2:
                saliency = 1/saliency
                saliency_perm = 1/saliency_perm
                
                saliency = saliency/saliency.sum(-1, keepdim=True)
                saliency_perm = saliency_perm/saliency_perm.sum(-1, keepdim=True)
                
                anc_idx = torch.multinomial(saliency, 1, replacement=True)
                anc_idx2 = torch.multinomial(saliency_perm, 1, replacement=True)
                anchor_ori = xyz[torch.arange(B),anc_idx[:,0]]
                anchor_perm = perm_new[torch.arange(B),anc_idx2[:,0]]
                
            #max_to_max, distance based re-weighting
            elif self.sample_ver == 3:
                saliency = saliency/saliency.sum(-1, keepdim=True)
                
                
                anc_idx = torch.multinomial(saliency, 1, replacement=True)
                anchor_ori = xyz[torch.arange(B),anc_idx[:,0]]
                
                sub = perm_new - anchor_ori[:,None,:]
                dist = ((sub) ** 2).sum(2).sqrt()
                
                saliency_perm = saliency_perm * dist
                saliency_perm = saliency_perm/saliency_perm.sum(-1, keepdim=True)
                
                anc_idx2 = torch.multinomial(saliency_perm, 1, replacement=True)
                anchor_perm = perm_new[torch.arange(B),anc_idx2[:,0]]
            
           

        if self.beta is not None:
            alpha = self.beta.sample((B,)).cuda()
        else:
            alpha = 0.5
        
        sub_ori = xyz - anchor_ori[:,None,:]
        sub_ori = ((sub_ori) ** 2).sum(2).sqrt()
        ker_weight_ori = torch.exp(-0.5 * (sub_ori ** 2) / (self.sigma ** 2))  #(M,N)
        
        sub_perm = perm_new - anchor_perm[:,None,:]
        sub_perm = ((sub_perm) ** 2).sum(2).sqrt()   
        ker_weight_perm = torch.exp(-0.5 * (sub_perm ** 2) / (self.sigma ** 2))  #(M,N)
        
        weight_ori = ker_weight_ori * alpha 
        weight_perm = ker_weight_perm * (1-alpha)
          
        weight = (torch.cat([weight_ori[...,None],weight_perm[...,None]],-1)) + 1e-16
        weight = weight/weight.sum(-1)[...,None]

        x = weight[:,:,0:1] * xyz + weight[:,:,1:] * perm_new
        x = x.permute(0,2,1)
        

        
        #label generation
        target = weight.sum(1)
        target = target / target.sum(-1, keepdim=True)
        label_onehot = torch.zeros(B, self.num_class).cuda().scatter(1, label.view(-1, 1), 1)
        label_perm_onehot = label_onehot[idxs]
        label = target[:, 0, None] * label_onehot + target[:, 1, None] * label_perm_onehot 
        
        return x, label, {"mix" : x, "perm_idxs" : idxs,
                  "perm" : perm_new, "ker_weight_perm" : ker_weight_perm, "weight_perm":weight_perm,
                  "ker_weight_ori" : ker_weight_ori,"weight_ori":weight_ori,\
                  "saliency" : saliency, "saliency_perm" : saliency_perm, "ratio":weight}
    
    
    def saliency_smoothing(self, xyz, saliency):
        B, N, _ = xyz.shape
        idx = knn(xyz.permute(0,2,1), self.smoothing_k)
        idx_base = torch.arange(0, B, device=xyz.device).view(-1, 1, 1)*1024
        idx = idx + idx_base
        idx = idx.view(-1)
        saliency = saliency.view(-1)[idx].reshape(B, N, self.smoothing_k) #(b, 1024, k)
        saliency = saliency.mean(-1)**self.temperature #(b, 1024)
        
        return saliency
        

    
idxs = torch.randperm(args.batch_size)
num_class = 40
mix = namemix(args, num_class)
_, new_label, temp = mix.mixup(data, label, saliency =saliency)
ker_weight_perm = temp["ker_weight_perm"] /  temp["ker_weight_perm"].max(-1, keepdim=True)[0]
weight_perm = temp["weight_perm"]/  temp["weight_perm"].max(-1, keepdim=True)[0]
saliency_perm = temp["saliency_perm"]/  temp["saliency_perm"].max(-1, keepdim=True)[0]

ker_weight_ori = temp["ker_weight_ori"]/  temp["ker_weight_ori"].max(-1, keepdim=True)[0]
weight_ori = temp["weight_ori"]/  temp["weight_ori"].max(-1, keepdim=True)[0]
saliency_ori = temp["saliency"]/  temp["saliency"].max(-1, keepdim=True)[0]



In [3]:
def get_shadow(plotter, pos, point_size, scale=3, origin=None, opacitiy=0.1, color= "#dddddd"):
    #pos (n,3)
    pos  =pos.repeat(scale, 0) 
    pos +=  np.random.rand(pos.shape[0],pos.shape[1])/20

    point_cloud = pv.PolyData(pos)

    if origin is None:
        origin = [0,0,pos[:,2].min() - 0.05]
    point_cloud = point_cloud.project_points_to_plane(origin=origin, normal=[0,0,-1])

    plotter.add_points(point_cloud, point_size=point_size, render_points_as_spheres=True,\
                        opacity=opacitiy, color = color, lighting=False) 

import seaborn as sns
# pv.start_xvfb()
camera_pos = [1,-1,1]#,[0,0,1]]

subplot_point_size = 10
g_subplot_point_size= 15

cmap = sns.color_palette("light:salmon", 5).as_hex()
scale = 1.1
cmap = ["#{:02x}{:02x}{:02x}".format(min(255, int(int(x[1:3],16)*scale)),
                                        min(255, int(int(x[3:5],16)*scale)),
                                        min(255, int(int(x[5:],16)*scale))) for x in cmap.as_hex()]



plotter = pv.Plotter(notebook=True,shape=(4,4), lighting="none")

for i in range(16):
    plotter.subplot(i//4, i%4)
    plotter.background_color = "W"
    point_cloud = pv.PolyData(data[i][:,[0,2,1]].cpu().numpy())
    point_cloud['y'] = temp["saliency"][i].cpu().numpy()
    plotter.add_points(point_cloud, point_size=subplot_point_size, render_points_as_spheres=True, cmap=cmap)
    plotter.camera_position = camera_pos
    
    get_shadow(plotter, data[i].cpu().numpy()[:,[0,2,1]], subplot_point_size)
    
    if i<8:
        light = pv.Light(intensity=0.12)
        light.set_direction_angle(30,90)
        plotter.add_light(light)
        light = pv.Light(intensity=0.12)
        light.set_direction_angle(30,-90)
        plotter.add_light(light)

#     plotter.enable_shadows()
    plotter.remove_scalar_bar()
    
plotter.show()

plotter = pv.Plotter(notebook=True,shape=(4,4), lighting="none")

for i in range(16):
    plotter.subplot(i//4, i%4)
    plotter.background_color = "W"
    point_cloud = pv.PolyData(temp["perm"][i][:,[0,2,1]].cpu().numpy())
    point_cloud['y'] = temp["saliency_perm"][i].cpu().numpy()
    plotter.add_points(point_cloud, point_size=subplot_point_size, render_points_as_spheres=True, cmap=cmap)
    plotter.camera_position = camera_pos
    
    get_shadow(plotter,temp["perm"][i].cpu().numpy()[:,[0,2,1]], subplot_point_size)
    
    if i<8:
        light = pv.Light(intensity=0.12)
        light.set_direction_angle(30,90)
        plotter.add_light(light)
        light = pv.Light(intensity=0.12)
        light.set_direction_angle(30,-90)
        plotter.add_light(light)

#     plotter.enable_shadows()
    plotter.remove_scalar_bar()
    
plotter.show()


# plotter = pv.Plotter(notebook=True, window_size=(900, 4800), shape=(16,3), lighting="none")
# for i in range(16):
#     plotter.subplot(i, 0)
#     plotter.background_color = "W"
#     point_cloud = pv.PolyData(data[i][:,[0,2,1]].cpu().numpy())
#     point_cloud['y'] = temp["ratio"][i,:,0].cpu().numpy()
#     plotter.add_points(point_cloud, point_size=subplot_point_size, render_points_as_spheres=True, cmap=cmap)
#     plotter.camera_position = camera_pos
    
#     get_shadow(plotter, data[i].cpu().numpy()[:,[0,2,1]], subplot_point_size)
    

#     plotter.remove_scalar_bar()


# perm = temp["perm"]
# for i in range(16):
#     plotter.subplot(i, 1)
#     plotter.background_color = "W"
#     point_cloud = pv.PolyData(perm[i][:,[0,2,1]].cpu().numpy())
#     point_cloud['y'] =temp["ratio"][i,:,1].cpu().numpy()
#     plotter.add_points(point_cloud, point_size=subplot_point_size, render_points_as_spheres=True, cmap=cmap)
#     plotter.camera_position = camera_pos
    
#     get_shadow(plotter, perm[i].cpu().numpy()[:,[0,2,1]], subplot_point_size)
    

#     plotter.remove_scalar_bar()

# ori = temp["mix"].permute(0,2,1)
# for i in range(16):
#     plotter.subplot(i, 2)
#     plotter.background_color = "W"
#     point_cloud = pv.PolyData(ori[i][:,[0,2,1]].cpu().numpy())
#     point_cloud['y'] = torch.zeros(args.num_points).numpy()
#     plotter.add_points(point_cloud, point_size=subplot_point_size, render_points_as_spheres=True, cmap=cmap)
#     plotter.camera_position = camera_pos
    
#     get_shadow(plotter, ori[i].cpu().numpy()[:,[0,2,1]], subplot_point_size)
    
#     if i<8:
#         light = pv.Light(intensity=0.12)
#         light.set_direction_angle(30,90)
#         plotter.add_light(light)
#         light = pv.Light(intensity=0.12)
#         light.set_direction_angle(30,-90)
#         plotter.add_light(light)

#     plotter.remove_scalar_bar()

# plotter.show()


ViewInteractiveWidget(height=768, layout=Layout(height='auto', width='100%'), width=1024)

ViewInteractiveWidget(height=768, layout=Layout(height='auto', width='100%'), width=1024)

In [4]:
for idx in range(16):
    point_size = 9
    ori  = data[idx][:,[0,2,1]].cpu().numpy()
    perm = temp["perm"][idx][:,[0,2,1]].cpu().numpy()

    plotter = pv.Plotter(notebook=True, window_size=(1500, 600), shape=(2,5))
    plotter.background_color = "W"

    plotter.subplot(0, 0)

    pos = ori
    point_cloud = pv.PolyData(pos)
    point_cloud['y'] =  saliency_ori.cpu().numpy()[idx]#**(3/4)
    plotter.add_points(point_cloud, point_size=point_size, render_points_as_spheres=True, cmap=cmap)
    plotter.camera_position = camera_pos

    get_shadow(plotter, pos, point_size,5)

    plotter.subplot(1, 0)

    pos = perm
    point_cloud = pv.PolyData(pos)
    point_cloud['y'] =  saliency_perm.cpu().numpy()[idx]#**(3/4)
    plotter.add_points(point_cloud, point_size=point_size, render_points_as_spheres=True, cmap=cmap)
    plotter.camera_position = camera_pos

    get_shadow(plotter, pos, point_size,5)

    plotter.subplot(0, 1)

    pos = ori
    point_cloud = pv.PolyData(pos)
    point_cloud['y'] =  ker_weight_ori.cpu().numpy()[idx]**(1/4)
    plotter.add_points(point_cloud, point_size=point_size, render_points_as_spheres=True, cmap=cmap)
    plotter.camera_position = camera_pos

    get_shadow(plotter, pos, point_size,5)

    plotter.subplot(1, 1)

    pos = perm
    point_cloud = pv.PolyData(pos)
    point_cloud['y'] =  ker_weight_perm.cpu().numpy()[idx]**(1/4) 
    plotter.add_points(point_cloud, point_size=point_size, render_points_as_spheres=True, cmap=cmap)
    plotter.camera_position = camera_pos

    get_shadow(plotter, pos, point_size,5)


    plotter.subplot(0, 2)

    pos = ori
    point_cloud = pv.PolyData(pos)
    
    t =  ker_weight_ori.cpu().numpy()[idx]**(1/4) * saliency_ori.cpu().numpy()[idx]#**(3/4)
    point_cloud['y'] = t/t.max()
#     point_cloud['y'] =  weight_ori.cpu().numpy()[idx]**(1/4)
    plotter.add_points(point_cloud, point_size=point_size, render_points_as_spheres=True, cmap=cmap)
    plotter.camera_position = camera_pos

    get_shadow(plotter, pos, point_size,5)

    plotter.subplot(1, 2)

    pos = perm
    point_cloud = pv.PolyData(pos)
    t=  ker_weight_perm.cpu().numpy()[idx]**(1/4) * saliency_perm.cpu().numpy()[idx]#**(3/4)
    point_cloud['y'] = t/t.max()
#     point_cloud['y'] =  weight_perm.cpu().numpy()[idx]**(1/4)
    plotter.add_points(point_cloud, point_size=point_size, render_points_as_spheres=True, cmap=cmap)
    plotter.camera_position = camera_pos

    get_shadow(plotter, pos, point_size,5)


    plotter.subplot(0, 3)

    pos = ori
    point_cloud = pv.PolyData(pos)
    point_cloud['y'] =  temp["ratio"][idx,:,0].cpu().numpy()#**(3/4)
    plotter.add_points(point_cloud, point_size=point_size, render_points_as_spheres=True, cmap=cmap)
    plotter.camera_position = camera_pos

    get_shadow(plotter, pos, point_size,5)

    plotter.subplot(1, 3)

    pos = perm
    point_cloud = pv.PolyData(pos)
    point_cloud['y'] =  temp["ratio"][idx,:,1].cpu().numpy()#**(3/4)
    plotter.add_points(point_cloud, point_size=point_size, render_points_as_spheres=True, cmap=cmap)
    plotter.camera_position = camera_pos

    get_shadow(plotter, pos, point_size,5)
    
    plotter.subplot(0, 4)

    pos = ori
    point_cloud = pv.PolyData(pos)
    point_cloud['y'] =  temp["ratio2"][idx,:,0].cpu().numpy()#**(3/4)
    plotter.add_points(point_cloud, point_size=point_size, render_points_as_spheres=True, cmap=cmap)
    plotter.camera_position = camera_pos

    get_shadow(plotter, pos, point_size,5)

    plotter.subplot(1, 4)

    pos = perm
    point_cloud = pv.PolyData(pos)
    point_cloud['y'] =  temp["ratio2"][idx,:,1].cpu().numpy()#**(3/4)
    plotter.add_points(point_cloud, point_size=point_size, render_points_as_spheres=True, cmap=cmap)
    plotter.camera_position = camera_pos

    get_shadow(plotter, pos, point_size,5)

    light = pv.Light(intensity=0.15, shadow_attenuation = 0)
    light.set_direction_angle(30,0)
    plotter.add_light(light)
    light = pv.Light(intensity=0.12, shadow_attenuation = 0)
    light.set_direction_angle(30,180)
    plotter.add_light(light)
    plotter.remove_scalar_bar()
    plotter.show()





    # plotter = pv.Plotter(notebook=True,shape=(1,2))
    # plotter.background_color = "W"

    # pos = data[idx][:,[0,2,1]].cpu().numpy()
    # point_cloud = pv.PolyData(pos)
    # point_cloud['y'] =  temp["ratio"][idx,:,0].cpu().numpy()
    # plotter.add_points(point_cloud, point_size=point_size, render_points_as_spheres=True, cmap=cmap)
    # plotter.camera_position = camera_pos

    # get_shadow(plotter, pos, point_size,5)

    # plotter.subplot(0, 1)

    # pos = temp["perm"][idx][:,[0,2,1]].cpu().numpy()
    # point_cloud = pv.PolyData(pos)
    # point_cloud['y'] =  temp["ratio"][idx,:,1].cpu().numpy()
    # plotter.add_points(point_cloud, point_size=point_size, render_points_as_spheres=True, cmap=cmap)

    # get_shadow(plotter, pos, point_size,5)

    # plotter.camera_position = camera_pos
    # plotter.remove_scalar_bar()
    # light = pv.Light(intensity=0.15, shadow_attenuation = 0)
    # light.set_direction_angle(30,0)
    # plotter.add_light(light)
    # light = pv.Light(intensity=0.12, shadow_attenuation = 0)
    # light.set_direction_angle(30,180)
    # plotter.add_light(light)
    # #         plotter.screenshot('figs/{}/{}_compare.png'.format(class_choices[label[idx]], cnt[label[idx]]))

    # plotter.show()

    # point_size = 25

    plotter = pv.Plotter(notebook=True, window_size=(1200, 600), shape=(1,2))
    plotter.background_color = "W"
#     plotter.subplot(0,0)
    
    
    pos = temp["mix"].permute(0,2,1)[idx][:,[0,2,1]].cpu().numpy()
    point_cloud = pv.PolyData(pos)
    point_cloud['y'] =  temp["ratio"][idx,:,0].cpu().numpy()#**(3/4)
    plotter.add_points(point_cloud, point_size=18, render_points_as_spheres=True, cmap="bwr")
    plotter.camera_position = camera_pos

    get_shadow(plotter, pos, 18,5)
    
    plotter.camera_position = camera_pos
    plotter.remove_scalar_bar()
    light = pv.Light(intensity=0.15, shadow_attenuation = 0)
    light.set_direction_angle(30,0)
    plotter.add_light(light)
    light = pv.Light(intensity=0.12, shadow_attenuation = 0)
    light.set_direction_angle(30,180)
    plotter.add_light(light)
    #         plotter.screenshot('figs/{}/{}_compare.png'.format(class_choices[label[idx]], cnt[label[idx]]))

    plotter.show()

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1500)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1200)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1500)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1200)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1500)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1200)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1500)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1200)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1500)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1200)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1500)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1200)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1500)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1200)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1500)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1200)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1500)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1200)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1500)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1200)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1500)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1200)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1500)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1200)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1500)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1200)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1500)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1200)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1500)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1200)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1500)

ViewInteractiveWidget(height=600, layout=Layout(height='auto', width='100%'), width=1200)

In [5]:
temp["weight_ori2"]

tensor([[0.5089, 0.5766, 0.5027,  ..., 0.5057, 0.5028, 0.5058],
        [0.3202, 0.3202, 0.3202,  ..., 0.3202, 0.3202, 0.3202],
        [0.7237, 0.9213, 0.7635,  ..., 0.6374, 0.6946, 0.6257],
        ...,
        [0.4139, 0.4055, 0.4045,  ..., 0.4040, 0.4974, 0.4039],
        [0.6722, 0.6695, 0.7768,  ..., 0.6699, 0.6981, 0.6667],
        [0.2767, 0.2736, 0.2790,  ..., 0.2741, 0.2948, 0.2736]],
       device='cuda:0')

In [6]:
saliency = torch.tensor([[0.1,0.2,0.7], [10, 1, 1]])

torch.multinomial(saliency, 1, replacement=True)

tensor([[2],
        [0]])