In [1]:
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import sklearn.metrics as metrics
import numpy as np

from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm


# from SageMix import SageMix
from data import ModelNet40, ScanObjectNN
from model import PointNet, DGCNN
from util import cal_loss, cal_loss_mix, IOStream
import gco

In [2]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [3]:
args = argparse.Namespace(batch_size=3, data='MN40', dropout=0.5, emb_dims=1024, epochs=50, eval=False, exp_name='SageMix', k=20, lr=0.0001, model='pointnet', model_path='', momentum=0.9, no_cuda=False, num_points=1024, seed=1, sigma=-1, test_batch_size=16, theta=0.2, use_sgd=True)

In [4]:
num_points = 1024
dataset = ModelNet40(partition='train', num_points=num_points)
batch_size=args.batch_size
# print(dataset)
# dataset = dataset[:100]
# batch_size = len(dataset)
# print('args.batch_size:',batch_size)
test_batch_size = args.test_batch_size
train_loader = DataLoader(dataset, num_workers=8,
                        batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(ModelNet40(partition='test', num_points=num_points), num_workers=8,
                        batch_size=test_batch_size, shuffle=True, drop_last=False)
num_class=40

In [5]:
import torch
from emd_ import emd_module

class SageMix:
    def __init__(self, args, device, num_class=40):
        self.num_class = num_class
        self.EMD = emd_module.emdModule()
        self.sigma = args.sigma
        self.beta = torch.distributions.beta.Beta(torch.tensor([args.theta]), torch.tensor([args.theta]))
        self.device = device

    def mix(self, xyz, label, saliency=None):
        """
        Args:
            xyz (B,N,3)
            label (B)
            saliency (B,N): Defaults to None.
        """        
        B, N, _ = xyz.shape
        # print(xyz.shape)
        idxs = torch.randperm(B)

        
        #Optimal assignment in Eq.(3)
        # perm = xyz[idxs]
        dist_mat = torch.empty(B, B, 1024)
        ass_mat = torch.empty(B,B,1024)
        dist_mat = dist_mat.to(self.device)
        
        # print("Starting to compute optimal assignment (Heuristic-1)")
        for idx,point in enumerate(xyz):
            # perm = torch.tensor([point for x in range(B))
            # print(point.shape)
            perm = point.repeat(B,1)
            # print(perm.shape)

            perm  = perm.reshape(perm.shape[0]//1024,1024,3)
            
            dist, ass = self.EMD(xyz, perm, 0.005, 500) # mapping
                 # 32,1024
            dist_mat[idx] = dist
            ass_mat[idx] = ass

            # print('dist:',dist.shape)
            # if idx % 10 == 0:
            #     print("Now doing", idx)
        
        # print(dist_mat.shape)
        dist_mat = torch.norm(dist_mat,dim=2)
        avg_alignment_dist = torch.mean(dist_mat,dim=0)
        # print(avg_alignment_dist.shape)
        # print('avg_alignment:',avg_alignment_dist)
        # print('mean:',torch.mean(avg_alignment_dist))
        # print('min:',torch.min(avg_alignment_dist))
        # print('max:',torch.max(avg_alignment_dist))
        # print(torch.min(avg_alignment_dist))
        # print(torch.argmin(avg_alignment_dist).item())

        idx = torch.argmin(avg_alignment_dist).item()
        # dist_mat = dist_mat.fill_diagonal_(100000.0)
    
        
        # i,j = divmod(torch.argmin(dist_mat).item(),dist_mat.shape[1])
        ass = ass_mat[idx]
        
        ass = ass.long()

        # sz = ass.size(0)
        perm_new = torch.zeros_like(perm).to(self.device)
        # print('perm:',perm)
        # print(perm_new.shape)
        perm = xyz.clone()
        # print("idx:",idx)
        for i in range(B):
            # print('i:',i)
            perm_new[i] = perm[i][ass[i]]
            # print('perm_i',perm[i])
            # print('perm_new_i',perm_new[i])

        # print('perm_new',perm_new)

        return ass,perm_new,dist_mat

        # print("Done with compute optimal assignment (Heuristic-1)")
        # print(ass.shape)
        

In [6]:
def distance(z, dist_type='l2'):
    '''Return distance matrix between vectors'''
    with torch.no_grad():
        diff = z.unsqueeze(1) - z.unsqueeze(0)
        if dist_type[:2] == 'l2':
            A_dist = (diff**2).sum(-1)
            if dist_type == 'l2':
                A_dist = torch.sqrt(A_dist)
            elif dist_type == 'l22':
                pass
        elif dist_type == 'l1':
            A_dist = diff.abs().sum(-1)
        elif dist_type == 'linf':
            A_dist = diff.abs().max(-1)[0]
        else:
            return None
    return A_dist


In [7]:
z = torch.tensor([[1,0,0],[0,1,0], [0,0,1]])

# distance(z)

tensor([[0.0000, 1.4142, 1.4142],
        [1.4142, 0.0000, 1.4142],
        [1.4142, 1.4142, 0.0000]])

In [8]:
import numpy as np
import torch
import torch.nn.functional as F
import warnings
from match import get_onehot_matrix, mix_input
from math import ceil

import importlib
import match
importlib.reload(match)
from match import get_onehot_matrix, mix_input

importlib.reload(gco)
# from match import get_onehot_matrix, mix_input

warnings.filterwarnings("ignore")


def mixup_process(out, target_reweighted, args=None, sc=None, A_dist=None):
    m_block_num = args.m_block_num
    m_part = args.m_part

    # batch_size = out.shape[0]
    # width = out.shape[-1]

    if A_dist is None:
        A_dist = torch.eye(batch_size, device=out.device)

    if m_block_num == -1:
        m_block_num = 2**np.random.randint(1, 5)

    
    # block_size = width // m_block_num
    block_size = 8
    # print("block size:",block_size)
    # print("sc:",sc.shape) # 8,1024
    # sc = sc.unsqueeze(1)
    # sc = F.avg_pool1d(sc, block_size)


    out_list = []
    target_list = []


    
    with torch.no_grad():
        sc_part = sc
        A_dist_part = A_dist

        n_input = sc.shape[0]
        # print("n_input", n_input)
        # print("scpart rehspae", sc_part.reshape(n_input, -1).shape)
        # print("sc_part reshape sum", sc_part.sum(1).shape)
        # print("total shape", sc_part.reshape(n_input, -1).sum(1).reshape(n_input, 1, 1).shape)

        ## ORIGINAL CODE
        # sc_norm = sc_part / sc_part.reshape(n_input, -1).sum(1).reshape(n_input, 1, 1)

        ## NEW CODE
        # print("sc part shape", sc_part.shape)
        # print("sc part sum", sc_part.sum(1).shape)
        # sc_norm = sc_part / sc_part.sum(1).reshape(n_input, 1, 1)
        sc_norm = sc/torch.sum(sc, dim=1).view(-1,1)
        # print("sc_norm", sc_norm.shape)
        cost_matrix = -sc_norm
        # print(cost_matrix.shape)

        A_base = torch.eye(n_input, device=out.device)
        A_dist_part = A_dist_part / torch.sum(A_dist_part) * n_input
        A = (1 - args.m_omega) * A_base + args.m_omega * A_dist_part
        print("A:",A)

        # print("new A shape:",A.shape)
        # Return a batch(partitioned) of mixup labeling
        # mask_onehot = get_onehot_matrix(cost_matrix.detach(),
        #                                 A,
        #                                 n_output=n_input,
        #                                 beta=args.m_beta,
        #                                 gamma=args.m_gamma,
        #                                 eta=args.m_eta,
        #                                 mixup_alpha=args.mixup_alpha,
        #                                 thres=args.m_thres,
        #                                 thres_type=args.m_thres_type,
        #                                 set_resolve=args.set_resolve,
        #                                 niter=args.m_niter,
        #                                 device='cuda')
        mask_onehot = get_onehot_matrix(cost_matrix.detach(),
                                        A,
                                        n_output=1,
                                        beta=args.m_beta,
                                        gamma=args.m_gamma,
                                        eta=args.m_eta,
                                        mixup_alpha=args.mixup_alpha,
                                        thres=args.m_thres,
                                        thres_type=args.m_thres_type,
                                        set_resolve=args.set_resolve,
                                        niter=args.m_niter,
                                        device='cuda')
        
    print('mask onehot shape:',mask_onehot.shape)
    # print(mask_onehot)
    # Generate image and corrsponding soft target
    output_part, target_part = mix_input(mask_onehot, out,
                                             target_reweighted)

    out_list = output_part
    print(out_list)
    target_list = target_part
    print(target_list)
    # out_list.append(output_part)
    # target_list.append(target_part)




    
    # Partition a batch
    # for i in range(ceil(batch_size / m_part)):
    #     with torch.no_grad():
    #         sc_part = sc[i * m_part:(i + 1) * m_part]
    #         A_dist_part = A_dist[i * m_part:(i + 1) * m_part, i * m_part:(i + 1) * m_part]

    #         n_input = sc_part.shape[0]
    #         sc_norm = sc_part / sc_part.reshape(n_input, -1).sum(1).reshape(n_input, 1, 1)
    #         cost_matrix = -sc_norm

    #         A_base = torch.eye(n_input, device=out.device)
    #         A_dist_part = A_dist_part / torch.sum(A_dist_part) * n_input
    #         A = (1 - args.m_omega) * A_base + args.m_omega * A_dist_part

    #         # Return a batch(partitioned) of mixup labeling
    #         mask_onehot = get_onehot_matrix(cost_matrix.detach(),
    #                                         A,
    #                                         n_output=n_input,
    #                                         beta=args.m_beta,
    #                                         gamma=args.m_gamma,
    #                                         eta=args.m_eta,
    #                                         mixup_alpha=args.mixup_alpha,
    #                                         thres=args.m_thres,
    #                                         thres_type=args.m_thres_type,
    #                                         set_resolve=args.set_resolve,
    #                                         niter=args.m_niter,
    #                                         device='cuda')

    #     # Generate image and corrsponding soft target
    #     output_part, target_part = mix_input(mask_onehot, out[i * m_part:(i + 1) * m_part],
    #                                          target_reweighted[i * m_part:(i + 1) * m_part])

    #     out_list.append(output_part)
    #     target_list.append(target_part)

    # with torch.no_grad():
    #     out = torch.cat(out_list, dim=0)
    #     target_reweighted = torch.cat(target_list, dim=0)

    return output_part, target_part


In [9]:
import argparse
args2 = {'arch': 'preactresnet18', 'batch_size': 100, 
         'clean_lam': 1.0, 'comix': True, 
         'data_dir': './data/cifar100/', 'dataset': 'cifar100', 
         'decay': 0.0001, 'dropout': False, 'epochs': 300, 
         'evaluate': True, 'gammas': [0.1, 0.1], 'initial_channels': 64, 
         'labels_per_class': 500, 'learning_rate': 0.2, 
         'log_off': True, 'm_beta': 0.32, 
         'm_block_num': 4, 'm_eta': 0.05, 
         'm_gamma': 1.0, 'm_niter': 4, 'm_omega': 0.001, 
         'm_part': 20, 'm_thres': 0.83, 
         'm_thres_type': 'hard', 
         'mixup_alpha': 2.0, 
         'momentum': 0.9, 'ngpu': 1, 
         'parallel': False, 'print_freq': 100, 
         'resume': './checkpoint/cifar100_preactresnet18_eph300_comixup/checkpoint.pth.tar', 
         'root_dir': 'experiments', 'schedule': [100, 200], 'seed': 0, 
         'set_resolve': True, 'start_epoch': 0, 'tag': '', 
         'use_cuda': True, 'valid_labels_per_class': 0, 'workers': 0}

args2 = argparse.Namespace(**args2)

In [10]:
print(args2.m_block_num)

4


In [11]:
device = torch.device("cuda")

model = PointNet(args, num_class).to(device)

# model = nn.DataParallel(model)
print("Let's use", torch.cuda.device_count(), "GPUs!")
opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4)


scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr)
    

sagemix = SageMix(args, device, num_class)
criterion = cal_loss_mix


best_test_acc = 0
for epoch in range(args.epochs):

    ####################
    # Train
    ####################
    train_loss = 0.0
    count = 0.0
    model.train()
    train_pred = []
    train_true = []
    for data, label in tqdm(train_loader):
        data, label = data.to(device), label.to(device).squeeze()
        print(data)
        batch_size = data.size()[0]
        
        ####################
        # generate augmented sample
        ####################
        model.eval()
        print(data.permute(0,2,1).shape)
        data_var = Variable(data.permute(0,2,1), requires_grad=True)
        logits = model(data_var)
        loss = cal_loss(logits, label, smoothing=False)
        loss.backward()
        opt.zero_grad()
        saliency = torch.sqrt(torch.mean(data_var.grad**2,1))
        
        assignment,perm_new,align_dist = sagemix.mix(data, label, saliency)
        print(perm_new.shape)
        # data_var2 = Variable(perm_new.permute(0,2,1), requires_grad=True)
        
        # sc = torch.sqrt(torch.mean(data_var2.grad**2,1))
        # print('assignment:',assignment.shape)
        # print(assignment[0])
        # print('new_permutation',perm_new.shape)
        
        target_reweighted = F.one_hot(label, num_classes=num_class).float()
        

        with torch.no_grad():
            # print(saliency.shape)
            sc = saliency.unsqueeze(1)
            # print("sc:",sc.shape)
            z = F.avg_pool1d(sc, kernel_size=8, stride=1)
            # print("z:",z.shape)
            z_reshape = z.reshape(args.batch_size, -1)
            # print("z_reshape:",z_reshape.shape)
            z_idx_1d = torch.argmax(z_reshape, dim=1)
            z_idx_2d = torch.zeros((args.batch_size, 2), device=z.device)
            z_idx_2d[:, 0] = z_idx_1d // z.shape[-1]
            z_idx_2d[:, 1] = z_idx_1d % z.shape[-1]
            A_dist = distance(z_idx_2d, dist_type='l1')
            # print("A_dist:",A_dist.shape)
            # print(A_dist)

        # print(A_dist)
        print("perm_new", perm_new.shape)
        print("target_reweighted", target_reweighted.shape)
        print(args2)
        print("sc", sc.shape)
        print("A_dist", A_dist.shape)
        out, target_reweighted = mixup_process(perm_new,
                                                target_reweighted,
                                                args=args2,
                                                sc=saliency,
                                                A_dist=A_dist)
        break
        # model.train()
            
        # opt.zero_grad()
        # logits = model(data.permute(0,2,1))
        # loss = criterion(logits, label)
        # loss.backward()
        # opt.step()
        # preds = logits.max(dim=1)[1]
        # count += batch_size
        # train_loss += loss.item() * batch_size
        
    scheduler.step()
    outstr = 'Train %d, loss: %.6f' % (epoch, train_loss*1.0/count)
    print(outstr)
    # io.cprint(outstr)

Let's use 8 GPUs!


  0%|          | 0/3280 [00:00<?, ?it/s]

tensor([[[ 0.0873, -0.0416,  0.3306],
         [-0.2527,  0.0031, -0.0017],
         [ 0.4310, -0.0219,  0.0915],
         ...,
         [ 0.2252, -0.0219, -0.2534],
         [-0.0662,  0.0031,  0.4442],
         [-0.1810, -0.0219,  0.7870]],

        [[ 0.5484, -0.0995, -0.9262],
         [ 0.4643,  0.0499, -0.0856],
         [ 0.2219,  0.4488,  0.9309],
         ...,
         [ 0.3038,  0.0499, -0.6911],
         [-0.0686,  0.0499, -0.1031],
         [ 0.2990,  0.0499, -0.3497]],

        [[ 0.2852, -0.0534, -0.3134],
         [-0.3129,  0.0405,  0.2038],
         [ 0.4006,  0.0374, -0.2213],
         ...,
         [ 0.0184,  0.0915,  0.5804],
         [-0.5774, -0.0564, -0.2586],
         [ 0.0683, -0.0121, -0.2597]]], device='cuda:0')
torch.Size([3, 3, 1024])


  0%|          | 0/3280 [00:03<?, ?it/s]

torch.Size([3, 1024, 3])
perm_new torch.Size([3, 1024, 3])
target_reweighted torch.Size([3, 40])
Namespace(arch='preactresnet18', batch_size=100, clean_lam=1.0, comix=True, data_dir='./data/cifar100/', dataset='cifar100', decay=0.0001, dropout=False, epochs=300, evaluate=True, gammas=[0.1, 0.1], initial_channels=64, labels_per_class=500, learning_rate=0.2, log_off=True, m_beta=0.32, m_block_num=4, m_eta=0.05, m_gamma=1.0, m_niter=4, m_omega=0.001, m_part=20, m_thres=0.83, m_thres_type='hard', mixup_alpha=2.0, momentum=0.9, ngpu=1, parallel=False, print_freq=100, resume='./checkpoint/cifar100_preactresnet18_eph300_comixup/checkpoint.pth.tar', root_dir='experiments', schedule=[100, 200], seed=0, set_resolve=True, start_epoch=0, tag='', use_cuda=True, valid_labels_per_class=0, workers=0)
sc torch.Size([3, 1, 1024])
A_dist torch.Size([3, 3])
A: tensor([[9.9900e-01, 2.9242e-04, 7.5000e-04],
        [2.9242e-04, 9.9900e-01, 4.5758e-04],
        [7.5000e-04, 4.5758e-04, 9.9900e-01]], device='




ZeroDivisionError: float division by zero

In [None]:
# !pip install open3dfind_alignment_and_mapping
# import open3d as o3d
# import numpy as np

# # Assuming you have "points" as your numpy array containing the points you want to plot
# points = np.random.rand(100, 3)

# point_cloud = o3d.geometry.PointCloud()
# point_cloud.points = o3d.utility.Vector3dVector(points)

# o3d.visualization.draw_geometries([point_cloud])

The history saving thread hit an unexpected error (OperationalError('unable to open database file')).History will not be written to the database.
Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [97]:
import open3d as o3d
import numpy as np

points = out.squeeze().cpu().numpy()

point_cloud = o3d.geometry.PointCloud()
point_cloud.points = o3d.utility.Vector3dVector(points)

o3d.io.write_point_cloud("my_point_cloud.pcd", point_cloud)

True

In [98]:
# out.squeeze().cpu().numpy()

import open3d as o3d
import numpy as np

for i in range(3):
    points = data[i].squeeze().cpu().numpy()

    point_cloud = o3d.geometry.PointCloud()
    point_cloud.points = o3d.utility.Vector3dVector(points)

    o3d.io.write_point_cloud("clouds/my_point_cloud_input_{}.pcd".format(i), point_cloud)

points = out.squeeze().cpu().numpy()

point_cloud = o3d.geometry.PointCloud()
point_cloud.points = o3d.utility.Vector3dVector(points)

o3d.io.write_point_cloud("clouds/my_point_cloud_output.pcd", point_cloud)

True

In [99]:
# data[:,:,0]