In [1]:
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import sklearn.metrics as metrics
import numpy as np

from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm


# from SageMix import SageMix
from data import ModelNet40, ScanObjectNN
from model import PointNet, DGCNN
from util import cal_loss, cal_loss_mix, IOStream

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [3]:
args = argparse.Namespace(batch_size=8, data='MN40', dropout=0.5, emb_dims=1024, epochs=50, eval=False, exp_name='SageMix', k=20, lr=0.0001, model='pointnet', model_path='', momentum=0.9, no_cuda=False, num_points=1024, seed=1, sigma=-1, test_batch_size=16, theta=0.2, use_sgd=True)

In [4]:
num_points = 1024
dataset = ModelNet40(partition='train', num_points=num_points)
batch_size=args.batch_size
# print(dataset)
# dataset = dataset[:100]
# batch_size = len(dataset)
# print('args.batch_size:',batch_size)
test_batch_size = args.test_batch_size
train_loader = DataLoader(dataset, num_workers=8,
                        batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(ModelNet40(partition='test', num_points=num_points), num_workers=8,
                        batch_size=test_batch_size, shuffle=True, drop_last=False)
num_class=40

In [5]:
import torch
from emd_ import emd_module

class SageMix:
    def __init__(self, args, device, num_class=40):
        self.num_class = num_class
        self.EMD = emd_module.emdModule()
        self.sigma = args.sigma
        self.beta = torch.distributions.beta.Beta(torch.tensor([args.theta]), torch.tensor([args.theta]))
        self.device = device

    def mix(self, xyz, label, saliency=None):
        """
        Args:
            xyz (B,N,3)
            label (B)
            saliency (B,N): Defaults to None.
        """        
        B, N, _ = xyz.shape
        # print(xyz.shape)
        idxs = torch.randperm(B)

        
        #Optimal assignment in Eq.(3)
        # perm = xyz[idxs]
        dist_mat = torch.empty(B, B, 1024)
        ass_mat = torch.empty(B,B,1024)
        dist_mat = dist_mat.to(self.device)
        
        # print("Starting to compute optimal assignment (Heuristic-1)")
        for idx,point in enumerate(xyz):
            # perm = torch.tensor([point for x in range(B))
            # print(point.shape)
            perm = point.repeat(B,1)
            # print(perm.shape)

            perm  = perm.reshape(perm.shape[0]//1024,1024,3)
            
            dist, ass = self.EMD(xyz, perm, 0.005, 500) # mapping
                 # 32,1024
            dist_mat[idx] = dist
            ass_mat[idx] = ass

            # print('dist:',dist.shape)
            # if idx % 10 == 0:
            #     print("Now doing", idx)
        
        # print(dist_mat.shape)
        dist_mat = torch.norm(dist_mat,dim=2)
        dist_mat = dist_mat.fill_diagonal_(100000.0)
        
        # print(dist_mat)
        # print(dist_mat.shape)
        # print(ass_mat.shape)
        # print(dist_mat)
        i,j = divmod(torch.argmin(dist_mat).item(),dist_mat.shape[1])
        ass = ass_mat[i]
        # print(i,j)
        # print(dist_mat[i,j])
        # argmin = torch.argmin(dist_mat)
        # i,j = argmin
        ass = ass.long()

        # print("Done with compute optimal assignment (Heuristic-1)")
        # print(ass.shape)
        perm_new = torch.zeros_like(perm).to(self.device)
        perm_saliency = torch.zeros_like(saliency).to(self.device)
        
        # print(ass,ass.shape)
        for i in range(B):
            perm_new[i] = perm[i][ass[i]]
            perm_saliency[i] = saliency[idxs][i][ass[i]]
        
        #####
        # Saliency-guided sequential sampling
        #####
        #Eq.(4) in the main paper
        epsilon = 1e-3
        # print('sal0:',saliency.sum(-1, keepdim=True))
        sum_sal = saliency.sum(-1, keepdim=True) + epsilon
        # print('sum_sal:',sum_sal)
        saliency = saliency/sum_sal
        # print('sal1:',saliency)
        anc_idx = torch.multinomial(saliency, 1, replacement=True)
        anchor_ori = xyz[torch.arange(B), anc_idx[:,0]]
        
        #cal distance and reweighting saliency map for Eq.(5) in the main paper
        sub = perm_new - anchor_ori[:,None,:]
        dist = ((sub) ** 2).sum(2).sqrt()
        perm_saliency = perm_saliency * dist
        perm_saliency = perm_saliency/perm_saliency.sum(-1, keepdim=True)
        
        #Eq.(5) in the main paper
        anc_idx2 = torch.multinomial(perm_saliency, 1, replacement=True)
        anchor_perm = perm_new[torch.arange(B),anc_idx2[:,0]]
                
                
        #####
        # Shape-preserving continuous Mixup
        #####
        alpha = self.beta.sample((B,)).cuda()
        sub_ori = xyz - anchor_ori[:,None,:]
        sub_ori = ((sub_ori) ** 2).sum(2).sqrt()
        #Eq.(6) for first sample
        ker_weight_ori = torch.exp(-0.5 * (sub_ori ** 2) / (self.sigma ** 2))  #(M,N)
        
        sub_perm = perm_new - anchor_perm[:,None,:]
        sub_perm = ((sub_perm) ** 2).sum(2).sqrt()
        #Eq.(6) for second sample
        ker_weight_perm = torch.exp(-0.5 * (sub_perm ** 2) / (self.sigma ** 2))  #(M,N)
        
        #Eq.(9)
        weight_ori = ker_weight_ori * alpha 
        weight_perm = ker_weight_perm * (1-alpha)
        weight = (torch.cat([weight_ori[...,None],weight_perm[...,None]],-1)) + 1e-16
        weight = weight/weight.sum(-1)[...,None]

        #Eq.(8) for new sample
        x = weight[:,:,0:1] * xyz + weight[:,:,1:] * perm_new
        
        #Eq.(8) for new label
        target = weight.sum(1)
        target = target / target.sum(-1, keepdim=True)
        label_onehot = torch.zeros(B, self.num_class).to(self.device).scatter(1, label.view(-1, 1), 1)
        label_perm_onehot = label_onehot[idxs]
        label = target[:, 0, None] * label_onehot + target[:, 1, None] * label_perm_onehot 
        
        return x, label
    

In [6]:
device = torch.device("cuda:0")

model = PointNet(args, num_class).to(device)

# model = nn.DataParallel(model)
print("Let's use", torch.cuda.device_count(), "GPUs!")
opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4)


scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr)
    

sagemix = SageMix(args, device, num_class)
criterion = cal_loss_mix


best_test_acc = 0
for epoch in range(args.epochs):

    ####################
    # Train
    ####################
    train_loss = 0.0
    count = 0.0
    model.train()
    train_pred = []
    train_true = []
    for data, label in tqdm(train_loader):
        data, label = data.to(device), label.to(device).squeeze()
        batch_size = data.size()[0]
        
        ####################
        # generate augmented sample
        ####################
        model.eval()
        data_var = Variable(data.permute(0,2,1), requires_grad=True)
        logits = model(data_var)
        loss = cal_loss(logits, label, smoothing=False)
        loss.backward()
        opt.zero_grad()
        saliency = torch.sqrt(torch.mean(data_var.grad**2,1))
        data, label = sagemix.mix(data, label, saliency)
        # break
        model.train()
            
        opt.zero_grad()
        logits = model(data.permute(0,2,1))
        loss = criterion(logits, label)
        loss.backward()
        opt.step()
        preds = logits.max(dim=1)[1]
        count += batch_size
        train_loss += loss.item() * batch_size
        
    scheduler.step()
    outstr = 'Train %d, loss: %.6f' % (epoch, train_loss*1.0/count)
    print(outstr)
    # io.cprint(outstr)

Let's use 6 GPUs!


  0%|          | 0/1230 [00:00<?, ?it/s]