In [1]:
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import sklearn.metrics as metrics
import numpy as np

from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm


from data import ModelNet40, ScanObjectNN
from model import PointNet, DGCNN
from util import cal_loss, cal_loss_mix, IOStream
import gco

In [2]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

args = argparse.Namespace(batch_size=30, data='MN40', dropout=0.5, emb_dims=1024, epochs=50, eval=False, exp_name='SageMix', k=20, lr=0.0001, model='pointnet', model_path='', momentum=0.9, no_cuda=False, num_points=1024, seed=1, sigma=-1, test_batch_size=16, theta=0.2, use_sgd=True)

In [3]:
num_points = 1024
dataset = ModelNet40(partition='train', num_points=num_points)
batch_size=args.batch_size

test_batch_size = args.test_batch_size
train_loader = DataLoader(dataset, num_workers=8,
                        batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(ModelNet40(partition='test', num_points=num_points), num_workers=8,
                        batch_size=test_batch_size, shuffle=True, drop_last=False)
num_class=40

In [4]:
# import torch
# from emd_ import emd_module

# class SageMix:
#     def __init__(self, args, device, num_class=40):
#         self.num_class = num_class
#         self.EMD = emd_module.emdModule()
#         self.sigma = args.sigma
#         self.beta = torch.distributions.beta.Beta(torch.tensor([args.theta]), torch.tensor([args.theta]))
#         self.device = device

#     def mix(self, xyz, label, saliency=None):
#         """
#         Args:
#             xyz (B,N,3)
#             label (B)
#             saliency (B,N): Defaults to None.
#         """        
#         B, N, _ = xyz.shape
#         # print(xyz.shape)
#         idxs = torch.randperm(B)

        
#         #Optimal assignment in Eq.(3)
#         # perm = xyz[idxs]
#         dist_mat = torch.empty(B, B, 1024)
#         ass_mat = torch.empty(B,B,1024)
#         dist_mat = dist_mat.to(self.device)
        
#         # print("Starting to compute optimal assignment (Heuristic-1)")
#         for idx,point in enumerate(xyz):
#             # perm = torch.tensor([point for x in range(B))
#             # print(point.shape)
#             perm = point.repeat(B,1)
#             # print(perm.shape)

#             perm  = perm.reshape(perm.shape[0]//1024,1024,3)
            
#             dist, ass = self.EMD(xyz, perm, 0.005, 500) # mapping
#                  # 32,1024
#             dist_mat[idx] = dist
#             ass_mat[idx] = ass

#             # print('dist:',dist.shape)
#             # if idx % 10 == 0:
#             #     print("Now doing", idx)
        
#         # print(dist_mat.shape)
#         dist_mat = torch.norm(dist_mat,dim=2)
#         avg_alignment_dist = torch.mean(dist_mat,dim=0)
#         # print(avg_alignment_dist.shape)
#         # print('avg_alignment:',avg_alignment_dist)
#         # print('mean:',torch.mean(avg_alignment_dist))
#         # print('min:',torch.min(avg_alignment_dist))
#         # print('max:',torch.max(avg_alignment_dist))
#         # print(torch.min(avg_alignment_dist))
#         # print(torch.argmin(avg_alignment_dist).item())

#         idx = torch.argmin(avg_alignment_dist).item()
#         # dist_mat = dist_mat.fill_diagonal_(100000.0)
    
        
#         # i,j = divmod(torch.argmin(dist_mat).item(),dist_mat.shape[1])
#         ass = ass_mat[idx]
        
#         ass = ass.long()

#         # sz = ass.size(0)
#         perm_new = torch.zeros_like(perm).to(self.device)
#         # print('perm:',perm)
#         # print(perm_new.shape)
#         perm = xyz.clone()
#         # print("idx:",idx)
#         for i in range(B):
#             # print('i:',i)
#             perm_new[i] = perm[i][ass[i]]
#             # print('perm_i',perm[i])
#             # print('perm_new_i',perm_new[i])

#         # print('perm_new',perm_new)

#         return ass,perm_new,dist_mat

#         # print("Done with compute optimal assignment (Heuristic-1)")
#         # print(ass.shape)
        

In [5]:
# args = argparse.Namespace(batch_size=30, data='MN40', dropout=0.5, emb_dims=1024, epochs=50, eval=False, exp_name='SageMix', k=20, lr=0.001, model='dgcnn', model_path='', momentum=0.9, no_cuda=False, num_points=1024, seed=1, sigma=-1, test_batch_size=16, theta=0.2, use_sgd=True)

# args.cuda

In [6]:
if args.data == 'MN40':
    dataset = ModelNet40(partition='train', num_points=args.num_points)
    # args.batch_size = len(dataset)
    args.batch_size = 30
    print('args.batch_size:',args.batch_size)
    train_loader = DataLoader(dataset, num_workers=8,
                            batch_size=args.batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=8,
                            batch_size=args.test_batch_size, shuffle=True, drop_last=False)
    num_class=40
elif args.data == 'SONN_easy':
    train_loader = DataLoader(ScanObjectNN(partition='train', num_points=args.num_points, ver="easy"), num_workers=8,
                            batch_size=args.batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(ScanObjectNN(partition='test', num_points=args.num_points, ver="easy"), num_workers=8,
                            batch_size=args.test_batch_size, shuffle=True, drop_last=False)
    num_class =15
elif args.data == 'SONN_hard':
    train_loader = DataLoader(ScanObjectNN(partition='train', num_points=args.num_points, ver="hard"), num_workers=8,
                            batch_size=args.batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(ScanObjectNN(partition='test', num_points=args.num_points, ver="hard"), num_workers=8,
                            batch_size=args.test_batch_size, shuffle=True, drop_last=False)
    num_class =15


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#Try to load models
if args.model == 'pointnet':
    model = PointNet(args, num_class).to(device)
elif args.model == 'dgcnn':
    model = DGCNN(args, num_class).to(device)
else:
    raise Exception("Not implemented")

args.batch_size: 30


In [7]:
from emd_ import emd_module
class SageMix:
    def __init__(self, args, num_class=40):
        self.num_class = num_class
        self.EMD = emd_module.emdModule()
        self.sigma = args.sigma
        self.beta = torch.distributions.beta.Beta(torch.tensor([args.theta]), torch.tensor([args.theta]))

    
    def mix(self, xyz, label, saliency=None, mixing_idx=0, device='cuda:0'):
        """
        Args:
            xyz (B,N,3)
            label (B)
            saliency (B,N): Defaults to None.
        """        
        B, N, _ = xyz.shape
        # print(xyz.shape)
        idxs = torch.randperm(B)

        # print(xyz)
        
        #Optimal assignment in Eq.(3)
        perm = xyz[idxs]
        
        _, ass = self.EMD(xyz, perm, 0.005, 500) # mapping
        ass = ass.long()
        # print(ass)
        perm_new = torch.zeros_like(perm).to(device)#.cuda()
        perm_saliency = torch.zeros_like(saliency).to(device)#.cuda()
        
        # print(ass,ass.shape)
        for i in range(B):
            perm_new[i] = perm[i][ass[i]]
            print(idxs)
            print(ass)
            print(saliency)
            perm_saliency[i] = saliency[idxs][i][ass[i]]
        
        #####
        # Saliency-guided sequential sampling
        #####
        #Eq.(4) in the main paper
        saliency = saliency/saliency.sum(-1, keepdim=True)
        anc_idx = torch.multinomial(saliency, 1, replacement=True)
        anchor_ori = xyz[torch.arange(B), anc_idx[:,0]]
        
        #cal distance and reweighting saliency map for Eq.(5) in the main paper
        sub = perm_new - anchor_ori[:,None,:]
        dist = ((sub) ** 2).sum(2).sqrt()
        perm_saliency = perm_saliency * dist
        perm_saliency = perm_saliency/perm_saliency.sum(-1, keepdim=True)
        
        #Eq.(5) in the main paper
        anc_idx2 = torch.multinomial(perm_saliency, 1, replacement=True)
        anchor_perm = perm_new[torch.arange(B),anc_idx2[:,0]]
                
                
        #####
        # Shape-preserving continuous Mixup
        #####
        alpha = self.beta.sample((B,)).to(device)#.cuda()
        sub_ori = xyz - anchor_ori[:,None,:]
        sub_ori = ((sub_ori) ** 2).sum(2).sqrt()
        #Eq.(6) for first sample
        ker_weight_ori = torch.exp(-0.5 * (sub_ori ** 2) / (self.sigma ** 2))  #(M,N)
        
        sub_perm = perm_new - anchor_perm[:,None,:]
        sub_perm = ((sub_perm) ** 2).sum(2).sqrt()
        #Eq.(6) for second sample
        ker_weight_perm = torch.exp(-0.5 * (sub_perm ** 2) / (self.sigma ** 2))  #(M,N)
        
        #Eq.(9)
        weight_ori = ker_weight_ori * alpha 
        weight_perm = ker_weight_perm * (1-alpha)
        weight = (torch.cat([weight_ori[...,None],weight_perm[...,None]],-1)) + 1e-16
        weight = weight/weight.sum(-1)[...,None]

        #Eq.(8) for new sample
        x = weight[:,:,0:1] * xyz + weight[:,:,1:] * perm_new
        
        #Eq.(8) for new label
        target = weight.sum(1)
        target = target / target.sum(-1, keepdim=True)
        
        if mixing_idx == 0:
            label_onehot = torch.zeros(B, self.num_class).to(device).scatter(1, label.view(-1, 1), 1)
            label_perm_onehot = label_onehot[idxs]
            label = target[:, 0, None] * label_onehot + target[:, 1, None] * label_perm_onehot

        else:
            label_onehot = torch.zeros(B, self.num_class).to(device).scatter(1, label.view(-1, 1), 1)

        
        return x, label
    

In [8]:
# #!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import sklearn.metrics as metrics
import numpy as np

from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm


from data import ModelNet40, ScanObjectNN
from model import PointNet, DGCNN
from util import cal_loss, cal_loss_mix, IOStream

if args.use_sgd:
    print("Use SGD")
    opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4)
else:
    print("Use Adam")
    opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)


best_test_acc = 0
sagemix=SageMix(args, num_class)
scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr)
criterion = cal_loss_mix


for epoch in range(args.epochs):

    ####################
    # Train
    ####################
    train_loss = 0.0
    count = 0.0
    model.train()
    train_pred = []
    train_true = []
    for data, label in tqdm(train_loader):
        data, label = data.to(device), label.to(device).squeeze()
        # print("data shape", data)
        batch_size = data.size()[0]
        split_idx = int(batch_size * 2/3)
        data01 = data[:split_idx, :, :]
        label01 = label[:split_idx]
        data2 = data[split_idx:, :, :]
        label2 = label[split_idx:]
        
        ####################
        # generate augmented sample
        ####################
        model.eval()
        data_var = Variable(data.permute(0,2,1), requires_grad=True)
        logits = model(data_var)
        loss = cal_loss(logits, label, smoothing=False)
        loss.backward()
        opt.zero_grad()
        saliency = torch.sqrt(torch.mean(data_var.grad**2,1))
        # print("saliency shape", saliency.shape)
        data_mix, label_mix = sagemix.mix(data01, label01, saliency[:split_idx,:])

        label2_onehot = torch.zeros(label2.shape[0], num_class).to(device).scatter(1, label2.view(-1, 1), 1)
        # print("label2_onehot shape", label2_onehot.shape)
        # print("label_mix shape", label_mix.shape)
        # label2_perm_onehot = label2_onehot[idxs]
        # label = target[:, 0, None] * label_onehot + target[:, 1, None] * label_perm_onehot

        data_all = torch.cat((data_mix, data2), dim=0)
        # print("data_all shape", data_all.shape)
        label_all = torch.cat((label_mix, label2_onehot), dim=0)

        data_var = Variable(data_mix.permute(0,2,1), requires_grad=True)
        logits = model(data_var)
        loss_mix = criterion(logits, label_mix)
        loss_mix.backward()
        opt.zero_grad()
        saliency_mix = torch.sqrt(torch.mean(data_var.grad**2,1))

        saliency_all = torch.cat((saliency_mix, saliency[split_idx:,:]), dim=0)

        data_total_mix, label_total_mix = sagemix.mix(data_all, label_all, saliency[:split_idx,:])
        print(saliency_all[0,:])
        print(saliency_all[25,:])
        # print("saliency_all shape", saliency_all.shape)
        break

        # data_allmix, label_allmix = sagemix.mix(data_all, label_, saliency)
        # print("label_all shape", label_all.shape)
        
        # mixed_saliency = torch.sqrt(torch.mean(data_var.grad**2,1))
        # print("data shape", data.shape)
        # model.train()
        # break
                
            
        # data3, label3 = sagemix.mix(data2, label2, saliency2, mixing_idx=1)

        
        # mixed_saliency = torch.sqrt(torch.mean(data_var.grad**2,1))
        # print("data shape", data.shape)
        # model.train()
        # # break
            
        # opt.zero_grad()
        # logits3 = model(data3.permute(0,2,1))
        # loss3 = criterion(logits3, label3)
        # loss3.backward()
        # opt.step()
        # preds = logits3.max(dim=1)[1]
        # count += batch_size
        # train_loss += loss3.item() * batch_size
        
    scheduler.step()
    outstr = 'Train %d, loss: %.6f' % (epoch, train_loss*1.0/count)
    io.cprint(outstr)

    ####################
    # Test
    ####################
    test_loss = 0.0
    count = 0.0
    model.eval()
    test_pred = []
    test_true = []
    for data, label in tqdm(test_loader):
        data, label = data.to(device), label.to(device).squeeze()
        data = data.permute(0, 2, 1)
        batch_size = data.size()[0]
        logits = model(data)
        loss = cal_loss(logits, label)
        preds = logits.max(dim=1)[1]
        count += batch_size
        test_loss += loss.item() * batch_size
        test_true.append(label.cpu().numpy())
        test_pred.append(preds.detach().cpu().numpy())
    test_true = np.concatenate(test_true)
    test_pred = np.concatenate(test_pred)
    test_acc = metrics.accuracy_score(test_true, test_pred)
    avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred)
    if test_acc >= best_test_acc:
        best_test_acc = test_acc
        torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % args.exp_name)
    outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f, best test acc: %.6f' % (epoch,
                                                                            test_loss*1.0/count,
                                                                            test_acc,
                                                                            avg_per_class_acc,
                                                                            best_test_acc)
    io.cprint(outstr)
    






Use SGD


  0%|          | 0/328 [00:00<?, ?it/s]../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [44,0,0], thread: [0,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [44,0,0], thread: [1,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [44,0,0], thread: [2,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [44,0,0], thread: [3,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [44,0,0], thread: [4,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [44,0,0],

RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
