In [9]:
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import sklearn.metrics as metrics
import numpy as np

from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm

import importlib
import data 
importlib.reload(data)
from data import ModelNet40, ScanObjectNN
from model import PointNet, DGCNN
from util import cal_loss, cal_loss_mix, IOStream
import gco
from emd_ import emd_module

args = argparse.Namespace(batch_size=30, data='MN40', dropout=0.5, emb_dims=1024, epochs=200, eval=False, exp_name='MultiSageMix', k=20, lr=0.0001, model='pointnet', model_path='', momentum=0.9, no_cuda=False, num_points=1024, seed=1, sigma=-1, test_batch_size=16, theta=0.2, use_sgd=False)

In [19]:
num_points = 1024
dataset = ModelNet40(partition='train', num_points=num_points)
batch_size=args.batch_size

test_batch_size = args.test_batch_size
train_loader = DataLoader(dataset, num_workers=8,
                        batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(ModelNet40(partition='test', num_points=num_points), num_workers=8,
                        batch_size=test_batch_size, shuffle=True, drop_last=False)
num_class=40

io = IOStream('checkpoints/' + args.exp_name + '/run.log')


if args.data == 'MN40':
    dataset = ModelNet40(partition='train', num_points=args.num_points)
    # args.batch_size = len(dataset)
    args.batch_size = 50
    #print('args.batch_size:',args.batch_size)
    train_loader = DataLoader(dataset, num_workers=8,
                            batch_size=args.batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=8,
                            batch_size=args.test_batch_size, shuffle=True, drop_last=False)
    num_class=40
elif args.data == 'SONN_easy':
    train_loader = DataLoader(ScanObjectNN(partition='train', num_points=args.num_points, ver="easy"), num_workers=8,
                            batch_size=args.batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(ScanObjectNN(partition='test', num_points=args.num_points, ver="easy"), num_workers=8,
                            batch_size=args.test_batch_size, shuffle=True, drop_last=False)
    num_class =15
elif args.data == 'SONN_hard':
    train_loader = DataLoader(ScanObjectNN(partition='train', num_points=args.num_points, ver="hard"), num_workers=8,
                            batch_size=args.batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(ScanObjectNN(partition='test', num_points=args.num_points, ver="hard"), num_workers=8,
                            batch_size=args.test_batch_size, shuffle=True, drop_last=False)
    num_class =15

In [20]:
import torch
from emd_ import emd_module

class SageMix:
    def __init__(self, args, num_class=40):
        self.num_class = num_class
        self.EMD = emd_module.emdModule()
        self.sigma = args.sigma
        self.beta = torch.distributions.beta.Beta(torch.tensor([args.theta]), torch.tensor([args.theta]))

    
    def mix(self, xyz, label, saliency=None, mixing_idx=0):
        """
        Args:
            xyz (B,N,3)
            label (B)
            saliency (B,N): Defaults to None.
        """        
        # print(xyz.shape)
        B, N, _ = xyz.shape
        # print(xyz.shape)
        idxs = torch.randperm(B)

        
        #Optimal assignment in Eq.(3)
        perm = xyz[idxs]
        
        _, ass = self.EMD(xyz, perm, 0.005, 500) # mapping
        ass = ass.long()
        perm_new = torch.zeros_like(perm).cuda()
        perm_saliency = torch.zeros_like(saliency).cuda()
        
        # print(ass,ass.shape)
        for i in range(B):
            perm_new[i] = perm[i][ass[i]]
            perm_saliency[i] = saliency[idxs][i][ass[i]]
        
        #####
        # Saliency-guided sequential sampling
        #####
        #Eq.(4) in the main paper
        saliency = saliency/saliency.sum(-1, keepdim=True)
        anc_idx = torch.multinomial(saliency, 1, replacement=True)
        anchor_ori = xyz[torch.arange(B), anc_idx[:,0]]
        
        #cal distance and reweighting saliency map for Eq.(5) in the main paper
        sub = perm_new - anchor_ori[:,None,:]
        dist = ((sub) ** 2).sum(2).sqrt()
        perm_saliency = perm_saliency * dist
        perm_saliency = perm_saliency/perm_saliency.sum(-1, keepdim=True)
        
        #Eq.(5) in the main paper
        anc_idx2 = torch.multinomial(perm_saliency, 1, replacement=True)
        anchor_perm = perm_new[torch.arange(B),anc_idx2[:,0]]
                
                
        #####
        # Shape-preserving continuous Mixup
        #####
        alpha = self.beta.sample((B,)).cuda()
        sub_ori = xyz - anchor_ori[:,None,:]
        sub_ori = ((sub_ori) ** 2).sum(2).sqrt()
        #Eq.(6) for first sample
        ker_weight_ori = torch.exp(-0.5 * (sub_ori ** 2) / (self.sigma ** 2))  #(M,N)
        
        sub_perm = perm_new - anchor_perm[:,None,:]
        sub_perm = ((sub_perm) ** 2).sum(2).sqrt()
        #Eq.(6) for second sample
        ker_weight_perm = torch.exp(-0.5 * (sub_perm ** 2) / (self.sigma ** 2))  #(M,N)
        
        #Eq.(9)
        weight_ori = ker_weight_ori * alpha 
        weight_perm = ker_weight_perm * (1-alpha)
        weight = (torch.cat([weight_ori[...,None],weight_perm[...,None]],-1)) + 1e-16
        weight = weight/weight.sum(-1)[...,None]

        #Eq.(8) for new sample
        x = weight[:,:,0:1] * xyz + weight[:,:,1:] * perm_new
        
        #Eq.(8) for new label
        target = weight.sum(1)
        target = target / target.sum(-1, keepdim=True)
        
        if mixing_idx == 0:
            label_onehot = torch.zeros(B, self.num_class).cuda().scatter(1, label.view(-1, 1), 1)
            label_perm_onehot = label_onehot[idxs]
            label = target[:, 0, None] * label_onehot + target[:, 1, None] * label_perm_onehot

        else:
            label_onehot = torch.zeros(B, self.num_class).cuda().scatter(1, label.view(-1, 1), 1)

        
        return x, label
    

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Try to load models
if args.model == 'pointnet':
    model = PointNet(args, num_class).to(device)
elif args.model == 'dgcnn':
    model = DGCNN(args, num_class).to(device)
else:
    raise Exception("Not implemented")
# print(str(model))

# model = nn.DataParallel(model)
print("Let's use", torch.cuda.device_count(), "GPUs!")

if args.use_sgd:
    print("Use SGD")
    opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4)
else:
    print("Use Adam")
    opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)

scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr)


sagemix = SageMix(args, num_class)
criterion = cal_loss_mix
import pickle

best_test_acc = 0
save_batch_count = 4
for epoch in range(args.epochs):
    
    ####################
    # Train
    ####################
    train_loss = 0.0
    count = 0.0
    model.train()
    train_pred = []
    train_true = []
    for data, label,label_name in tqdm(train_loader):
        data, label = data.to(device), label.to(device).squeeze()
        # print("data shape", data.shape)
        batch_size = data.size()[0]
        
        ####################
        # generate augmented sample
        ####################
        model.eval()
        data_var = Variable(data.permute(0,2,1), requires_grad=True)
        logits = model(data_var)
        loss = cal_loss(logits, label, smoothing=False)
        loss.backward()
        opt.zero_grad()
        saliency = torch.sqrt(torch.mean(data_var.grad**2,1))
        if(epoch == 20):
           
            save_batch_count -= 1
            if(save_batch_count > -1):
                print("Saving data for visualization")
                if not os.path.exists('checkpoints/%s/data_save_for_viz' % args.exp_name):
                    os.mkdir('checkpoints/%s/data_save_for_viz' % args.exp_name)
                direc = 'checkpoints/%s/data_save_for_viz' % args.exp_name
                data_dict = {'data':data.cpu().numpy(), 'label':label.cpu().numpy(), "label_name":label_name}
                pickle.dump(data_dict, open(direc + "/data_dict_normal_" + str(save_batch_count) + ".p", "wb"))

        data, label = sagemix.mix(data, label, saliency)
        if(epoch == 20):
            if(save_batch_count > -1):
                data_dict = {'data':data.cpu().numpy(), 'label':label.cpu().numpy()}
                direc = 'checkpoints/%s/data_save_for_viz' % args.exp_name
                pickle.dump(data_dict, open(direc + "/data_dict_mix_" + str(save_batch_count) + ".p", "wb"))
        
        # print(np.unique(label.cpu().numpy(),return_counts=True))
        # break
        
        mixed_saliency = torch.sqrt(torch.mean(data_var.grad**2,1))
        # print("data shape", data.shape)
        model.train()
        # break
        


        opt.zero_grad()
        logits = model(data.permute(0,2,1))
        loss = criterion(logits, label)
        loss.backward()
        opt.step()
        preds = logits.max(dim=1)[1]
        count += batch_size
        train_loss += loss.item() * batch_size
    # break 
    scheduler.step()
    outstr = 'Train %d, loss: %.6f' % (epoch, train_loss*1.0/count)
    io.cprint(outstr)

    ####################
    # Test
    ####################
    test_loss = 0.0
    count = 0.0
    model.eval()
    test_pred = []
    test_true = []
    for data, label,label_name in tqdm(test_loader):
        data, label = data.to(device), label.to(device).squeeze()
        data = data.permute(0, 2, 1)
        batch_size = data.size()[0]
        logits = model(data)
        loss = cal_loss(logits, label)
        preds = logits.max(dim=1)[1]
        count += batch_size
        test_loss += loss.item() * batch_size
        test_true.append(label.cpu().numpy())
        test_pred.append(preds.detach().cpu().numpy())
    test_true = np.concatenate(test_true)
    test_pred = np.concatenate(test_pred)
    test_acc = metrics.accuracy_score(test_true, test_pred)
    avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred)
    if test_acc >= best_test_acc:
        best_test_acc = test_acc
        torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % args.exp_name)
    outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f, best test acc: %.6f' % (epoch,
                                                                            test_loss*1.0/count,
                                                                            test_acc,
                                                                            avg_per_class_acc,
                                                                            best_test_acc)
    io.cprint(outstr)

Let's use 6 GPUs!
Use Adam


100%|██████████| 196/196 [00:40<00:00,  4.82it/s]


Train 0, loss: 2.691006


100%|██████████| 155/155 [00:02<00:00, 76.58it/s] 


Test 0, loss: 2.479731, test acc: 0.476499, test avg acc: 0.326547, best test acc: 0.476499


100%|██████████| 196/196 [00:40<00:00,  4.78it/s]


Train 1, loss: 2.182476


100%|██████████| 155/155 [00:01<00:00, 79.65it/s] 


Test 1, loss: 2.279619, test acc: 0.595624, test avg acc: 0.439936, best test acc: 0.595624


100%|██████████| 196/196 [00:41<00:00,  4.74it/s]


Train 2, loss: 1.949455


100%|██████████| 155/155 [00:01<00:00, 79.18it/s] 


Test 2, loss: 2.102094, test acc: 0.683955, test avg acc: 0.529541, best test acc: 0.683955


100%|██████████| 196/196 [00:41<00:00,  4.77it/s]


Train 3, loss: 1.787769


100%|██████████| 155/155 [00:01<00:00, 79.67it/s] 


Test 3, loss: 2.107139, test acc: 0.692464, test avg acc: 0.556302, best test acc: 0.692464


100%|██████████| 196/196 [00:41<00:00,  4.73it/s]


Train 4, loss: 1.707846


100%|██████████| 155/155 [00:02<00:00, 76.50it/s] 


Test 4, loss: 2.089398, test acc: 0.717180, test avg acc: 0.577965, best test acc: 0.717180


100%|██████████| 196/196 [00:41<00:00,  4.76it/s]


Train 5, loss: 1.637457


100%|██████████| 155/155 [00:01<00:00, 78.72it/s] 


Test 5, loss: 2.028063, test acc: 0.745543, test avg acc: 0.637070, best test acc: 0.745543


100%|██████████| 196/196 [00:41<00:00,  4.76it/s]


Train 6, loss: 1.620633


100%|██████████| 155/155 [00:01<00:00, 79.76it/s] 


Test 6, loss: 1.994169, test acc: 0.756888, test avg acc: 0.658506, best test acc: 0.756888


100%|██████████| 196/196 [00:41<00:00,  4.76it/s]


Train 7, loss: 1.555287


100%|██████████| 155/155 [00:01<00:00, 78.42it/s] 


Test 7, loss: 2.027618, test acc: 0.762966, test avg acc: 0.647791, best test acc: 0.762966


100%|██████████| 196/196 [00:41<00:00,  4.77it/s]


Train 8, loss: 1.513118


100%|██████████| 155/155 [00:01<00:00, 79.81it/s] 


Test 8, loss: 2.033188, test acc: 0.773906, test avg acc: 0.661866, best test acc: 0.773906


100%|██████████| 196/196 [00:41<00:00,  4.76it/s]


Train 9, loss: 1.489029


100%|██████████| 155/155 [00:02<00:00, 74.74it/s] 


Test 9, loss: 2.036738, test acc: 0.748379, test avg acc: 0.643012, best test acc: 0.773906


100%|██████████| 196/196 [00:41<00:00,  4.76it/s]


Train 10, loss: 1.472218


100%|██████████| 155/155 [00:01<00:00, 80.58it/s] 


Test 10, loss: 2.027523, test acc: 0.758914, test avg acc: 0.679110, best test acc: 0.773906


100%|██████████| 196/196 [00:41<00:00,  4.74it/s]


Train 11, loss: 1.459704


100%|██████████| 155/155 [00:01<00:00, 80.01it/s] 


Test 11, loss: 2.022025, test acc: 0.769449, test avg acc: 0.691738, best test acc: 0.773906


100%|██████████| 196/196 [00:41<00:00,  4.77it/s]


Train 12, loss: 1.422612


100%|██████████| 155/155 [00:02<00:00, 76.66it/s] 


Test 12, loss: 2.016528, test acc: 0.782415, test avg acc: 0.694070, best test acc: 0.782415


100%|██████████| 196/196 [00:41<00:00,  4.75it/s]


Train 13, loss: 1.407289


100%|██████████| 155/155 [00:01<00:00, 78.38it/s] 


Test 13, loss: 2.015793, test acc: 0.812804, test avg acc: 0.714535, best test acc: 0.812804


100%|██████████| 196/196 [00:41<00:00,  4.77it/s]


Train 14, loss: 1.366561


100%|██████████| 155/155 [00:02<00:00, 74.41it/s] 


Test 14, loss: 1.948508, test acc: 0.807131, test avg acc: 0.736331, best test acc: 0.812804


100%|██████████| 196/196 [00:41<00:00,  4.74it/s]


Train 15, loss: 1.353684


100%|██████████| 155/155 [00:02<00:00, 74.38it/s] 


Test 15, loss: 2.065248, test acc: 0.771880, test avg acc: 0.699297, best test acc: 0.812804


100%|██████████| 196/196 [00:41<00:00,  4.74it/s]


Train 16, loss: 1.346117


100%|██████████| 155/155 [00:02<00:00, 73.01it/s] 


Test 16, loss: 2.027712, test acc: 0.804295, test avg acc: 0.727640, best test acc: 0.812804


100%|██████████| 196/196 [00:41<00:00,  4.75it/s]


Train 17, loss: 1.330510


100%|██████████| 155/155 [00:01<00:00, 80.97it/s] 


Test 17, loss: 1.971830, test acc: 0.827796, test avg acc: 0.742517, best test acc: 0.827796


100%|██████████| 196/196 [00:41<00:00,  4.77it/s]


Train 18, loss: 1.301974


100%|██████████| 155/155 [00:02<00:00, 76.00it/s] 


Test 18, loss: 2.009313, test acc: 0.801864, test avg acc: 0.727029, best test acc: 0.827796


100%|██████████| 196/196 [00:41<00:00,  4.74it/s]


Train 19, loss: 1.301739


100%|██████████| 155/155 [00:02<00:00, 74.95it/s] 


Test 19, loss: 2.031891, test acc: 0.838331, test avg acc: 0.739494, best test acc: 0.838331


  0%|          | 0/196 [00:00<?, ?it/s]

Saving data for visualization


  1%|          | 2/196 [00:01<02:26,  1.33it/s]

Saving data for visualization


  2%|▏         | 3/196 [00:01<01:36,  1.99it/s]

Saving data for visualization
Saving data for visualization


100%|██████████| 196/196 [00:41<00:00,  4.74it/s]


Train 20, loss: 1.308449


100%|██████████| 155/155 [00:01<00:00, 81.44it/s] 


Test 20, loss: 1.993945, test acc: 0.828606, test avg acc: 0.745715, best test acc: 0.838331


100%|██████████| 196/196 [00:41<00:00,  4.76it/s]


Train 21, loss: 1.305911


100%|██████████| 155/155 [00:02<00:00, 75.19it/s] 


Test 21, loss: 2.074670, test acc: 0.797812, test avg acc: 0.711977, best test acc: 0.838331


100%|██████████| 196/196 [00:41<00:00,  4.76it/s]


Train 22, loss: 1.296699


100%|██████████| 155/155 [00:02<00:00, 73.90it/s] 


Test 22, loss: 2.056212, test acc: 0.814019, test avg acc: 0.734064, best test acc: 0.838331


100%|██████████| 196/196 [00:41<00:00,  4.76it/s]


Train 23, loss: 1.275911


100%|██████████| 155/155 [00:01<00:00, 78.42it/s] 


Test 23, loss: 2.030988, test acc: 0.817666, test avg acc: 0.748872, best test acc: 0.838331


100%|██████████| 196/196 [00:41<00:00,  4.75it/s]


Train 24, loss: 1.263288


100%|██████████| 155/155 [00:01<00:00, 78.21it/s] 


Test 24, loss: 2.037727, test acc: 0.818071, test avg acc: 0.739610, best test acc: 0.838331


100%|██████████| 196/196 [00:41<00:00,  4.76it/s]


Train 25, loss: 1.261541


100%|██████████| 155/155 [00:01<00:00, 78.88it/s] 


Test 25, loss: 2.058082, test acc: 0.817666, test avg acc: 0.752442, best test acc: 0.838331


100%|██████████| 196/196 [00:41<00:00,  4.77it/s]


Train 26, loss: 1.244237


100%|██████████| 155/155 [00:01<00:00, 78.43it/s] 


Test 26, loss: 2.047286, test acc: 0.813209, test avg acc: 0.750035, best test acc: 0.838331


100%|██████████| 196/196 [00:41<00:00,  4.75it/s]


Train 27, loss: 1.259546


100%|██████████| 155/155 [00:01<00:00, 79.37it/s] 


Test 27, loss: 2.006398, test acc: 0.841167, test avg acc: 0.763907, best test acc: 0.841167


100%|██████████| 196/196 [00:41<00:00,  4.70it/s]


Train 28, loss: 1.235081


100%|██████████| 155/155 [00:02<00:00, 76.90it/s] 


Test 28, loss: 2.059597, test acc: 0.815640, test avg acc: 0.760395, best test acc: 0.841167


100%|██████████| 196/196 [00:41<00:00,  4.76it/s]


Train 29, loss: 1.249608


100%|██████████| 155/155 [00:02<00:00, 76.65it/s] 


Test 29, loss: 2.036174, test acc: 0.813614, test avg acc: 0.762180, best test acc: 0.841167


 84%|████████▍ | 165/196 [00:35<00:06,  4.68it/s]


KeyboardInterrupt: 