In [None]:
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import sklearn.metrics as metrics
import numpy as np

from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm


# from SageMix import SageMix
from data import ModelNet40, ScanObjectNN
from model import PointNet, DGCNN
from util import cal_loss, cal_loss_mix, IOStream
import gco

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

args = argparse.Namespace(batch_size=3, data='MN40', dropout=0.5, emb_dims=1024, epochs=50, eval=False, exp_name='SageMix', k=20, lr=0.0001, model='pointnet', model_path='', momentum=0.9, no_cuda=False, num_points=1024, seed=1, sigma=-1, test_batch_size=16, theta=0.2, use_sgd=True)

In [None]:
num_points = 1024
dataset = ModelNet40(partition='train', num_points=num_points)
batch_size=args.batch_size

test_batch_size = args.test_batch_size
train_loader = DataLoader(dataset, num_workers=8,
                        batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(ModelNet40(partition='test', num_points=num_points), num_workers=8,
                        batch_size=test_batch_size, shuffle=True, drop_last=False)
num_class=40

In [None]:
import torch
from emd_ import emd_module

class SageMix:
    def __init__(self, args, device, num_class=40):
        self.num_class = num_class
        self.EMD = emd_module.emdModule()
        self.sigma = args.sigma
        self.beta = torch.distributions.beta.Beta(torch.tensor([args.theta]), torch.tensor([args.theta]))
        self.device = device

    def mix(self, xyz, label, saliency=None):
        """
        Args:
            xyz (B,N,3)
            label (B)
            saliency (B,N): Defaults to None.
        """        
        B, N, _ = xyz.shape
        # print(xyz.shape)
        idxs = torch.randperm(B)

        
        #Optimal assignment in Eq.(3)
        # perm = xyz[idxs]
        dist_mat = torch.empty(B, B, 1024)
        ass_mat = torch.empty(B,B,1024)
        dist_mat = dist_mat.to(self.device)
        
        # print("Starting to compute optimal assignment (Heuristic-1)")
        for idx,point in enumerate(xyz):
            # perm = torch.tensor([point for x in range(B))
            # print(point.shape)
            perm = point.repeat(B,1)
            # print(perm.shape)

            perm  = perm.reshape(perm.shape[0]//1024,1024,3)
            
            dist, ass = self.EMD(xyz, perm, 0.005, 500) # mapping
                 # 32,1024
            dist_mat[idx] = dist
            ass_mat[idx] = ass

            # print('dist:',dist.shape)
            # if idx % 10 == 0:
            #     print("Now doing", idx)
        
        # print(dist_mat.shape)
        dist_mat = torch.norm(dist_mat,dim=2)
        avg_alignment_dist = torch.mean(dist_mat,dim=0)
        # print(avg_alignment_dist.shape)
        # print('avg_alignment:',avg_alignment_dist)
        # print('mean:',torch.mean(avg_alignment_dist))
        # print('min:',torch.min(avg_alignment_dist))
        # print('max:',torch.max(avg_alignment_dist))
        # print(torch.min(avg_alignment_dist))
        # print(torch.argmin(avg_alignment_dist).item())

        idx = torch.argmin(avg_alignment_dist).item()
        # dist_mat = dist_mat.fill_diagonal_(100000.0)
    
        
        # i,j = divmod(torch.argmin(dist_mat).item(),dist_mat.shape[1])
        ass = ass_mat[idx]
        
        ass = ass.long()

        # sz = ass.size(0)
        perm_new = torch.zeros_like(perm).to(self.device)
        # print('perm:',perm)
        # print(perm_new.shape)
        perm = xyz.clone()
        # print("idx:",idx)
        for i in range(B):
            # print('i:',i)
            perm_new[i] = perm[i][ass[i]]
            # print('perm_i',perm[i])
            # print('perm_new_i',perm_new[i])

        # print('perm_new',perm_new)

        return ass,perm_new,dist_mat

        # print("Done with compute optimal assignment (Heuristic-1)")
        # print(ass.shape)
        

In [None]:
device = torch.device("cuda")

model = PointNet(args, num_class).to(device)

# model = nn.DataParallel(model)
print("Let's use", torch.cuda.device_count(), "GPUs!")
opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4)


scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr)
    

sagemix = SageMix(args, device, num_class)
criterion = cal_loss_mix


best_test_acc = 0
for epoch in range(args.epochs):

    ####################
    # Train
    ####################
    train_loss = 0.0
    count = 0.0
    model.train()
    train_pred = []
    train_true = []
    for data, label in tqdm(train_loader):
        data, label = data.to(device), label.to(device).squeeze()
        print(data)
        batch_size = data.size()[0]
        
        ####################
        # generate augmented sample
        ####################
        model.eval()
        print(data.permute(0,2,1).shape)
        data_var = Variable(data.permute(0,2,1), requires_grad=True)
        logits = model(data_var)
        loss = cal_loss(logits, label, smoothing=False)
        loss.backward()
        opt.zero_grad()
        saliency = torch.sqrt(torch.mean(data_var.grad**2,1))
        
        assignment,perm_new,align_dist = sagemix.mix(data, label, saliency)

        # model.train()
            
        # opt.zero_grad()
        # logits = model(data.permute(0,2,1))
        # loss = criterion(logits, label)
        # loss.backward()
        # opt.step()
        # preds = logits.max(dim=1)[1]
        # count += batch_size
        # train_loss += loss.item() * batch_size
        
    scheduler.step()
    outstr = 'Train %d, loss: %.6f' % (epoch, train_loss*1.0/count)
    print(outstr)
    # io.cprint(outstr)