In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import sklearn.metrics as metrics
import numpy as np
import random

from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm


# from SageMix import SageMix
from data import ModelNet40, ScanObjectNN
from model import PointNet, DGCNN
from util import cal_loss, cal_loss_mix, IOStream
import wandb
import torch.nn.functional as F
# import io

import torch
from emd_ import emd_module

In [None]:
class SageMix:
    def __init__(self, args, num_class=40):
        self.num_class = num_class
        self.EMD = emd_module.emdModule()
        self.sigma = args.sigma
        self.beta = torch.distributions.beta.Beta(torch.tensor([args.theta]), torch.tensor([args.theta]))
    
    def mix(self, xyz, label, saliency=None, n_mix=4, theta=0.2):
        """
        Args:
            xyz (B,N,3)
            label (B)
            saliency (B,N): Defaults to None.
        """        
        # label_ori = label.clone()
        # print(xyz.shape)
        B, N, _ = xyz.shape
        # print("saliency based", saliency_based)
        # mapping = self.find_optimal_mapping(xyz, saliency)
        # return 0
        # print(mapping)
        # return 0
        # print(xyz.shape)
        # idxs = mapping[:,1].to(torch.int64) #torch.randperm(B)
        idxs = torch.stack([torch.randperm(B) for _ in range(n_mix)])
        # idxs = torch.argsort(torch.rand(B, n_clouds))

        xyzs = torch.zeros((n_mix, B, N, 3)).cuda()
        for i in range(n_mix):
            if i == 0: xyzs[i] = xyz
            else:
                xyzs[i] = xyz[idxs[i]]


        all_xyz = torch.zeros((n_mix, B, N, 3)).cuda()
        all_xyz[0] = xyzs[0]

        all_saliency = torch.zeros((n_mix, B, N)).cuda()
        all_saliency[0] = saliency
        for i in range(1, n_mix):
            _, ass = self.EMD(xyzs[0], xyzs[i], 0.005, 500)

            xyz_new = torch.zeros_like(xyzs[i]).cuda()
            saliency_new = torch.zeros_like(saliency).cuda()
            
            # print(ass,ass.shape)
            for j in range(B):
                # print("ass j", ass[j])
                # print("ass j shape", ass[j].shape)
                # print("xyzs shape", xyzs.shape)
                # print("xyzs i shape", xyzs[i].shape)
                all_xyz[i][j] = xyzs[i][j][ass[j]]
                all_saliency[i][j] = saliency[idxs[i]][j][ass[j]]

                # xyz_new[i] = xyzs[j][ass[j]]
                # saliency_new[j] = saliency[idxs[j]][j][ass[j]]
            
            # all_xyz[i] = xyz_new
            # all_saliency[i] = saliency_new
        # print("permuted saliency", saliency[1])

        anchors = torch.zeros(n_mix, B, 3).cuda()

        saliency = saliency/saliency.sum(-1, keepdim=True)
        anc_idx = torch.randint(0, 1024, (B,1)).cuda()
        # anc_idx = torch.multinomial(saliency, 1, replacement=True)
        anchor_ori = all_xyz[0][torch.arange(B), anc_idx[:,0]]
        anchors[0] = anchor_ori
        # # print("anchor shape", anchor_ori.shape)

        anc_idx_new = 0
        perm_saliency_new = 0
        # ker_weight_fix = 0
        for i in range(1, n_mix):
            dists = []
            for j in range(0,i):
                # print("all_xyz", all_xyz[i])
                # print("anchors", anchors)
                sub = all_xyz[i] - anchors[j][:, None, :]
                # subs.append(sub)
                dist = ((sub) ** 2).sum(2).sqrt()
                dists.append(dist)
                # print(dist.shape)
            dist = torch.stack(dists).sum(dim=0)
            
            perm_saliency_new = all_saliency[i] * dist
            perm_saliency_new = perm_saliency_new/perm_saliency_new.sum(-1, keepdim=True)


        #     ## try to fix this at 0
            # anc_idx_new = torch.multinomial(perm_saliency_new, 1, replacement=True)
            anc_idx_new = torch.randint(0, 1024, (B,1)).cuda()
            anchor_perm_new = all_xyz[i][torch.arange(B),anc_idx_new[:,0]]
            anchors[i] = anchor_perm_new
            # sub = perm_new - anchor_ori[:,None,:]
        #     # dist = ((sub) ** 2).sum(2).sqrt()
        #     # perm_saliency = perm_saliency * dist
        #     # perm_saliency = perm_saliency/perm_saliency.sum(-1, keepdim=True)
        # # alpha = self.dirichlet.sample((B,)).cuda()
        pi = torch.distributions.dirichlet.Dirichlet(torch.tensor([theta for i in range(n_mix)])).sample((B,)).cuda()
        # # print("pi shape", pi.shape)
        # # print("pi sum", pi.sum(1))
        

        kerns = torch.zeros(n_mix, B, N).cuda()
        weights = torch.zeros(n_mix, B, N).cuda()
        weights_copy = []
        for i in range(n_mix):
            sub_ori = all_xyz[i] - anchors[i][:,None,:]
            sub_ori = ((sub_ori) ** 2).sum(2).sqrt()
        #     #Eq.(6) for first sample
            ker_weight_ori = torch.exp(-0.5 * (sub_ori ** 2) / (self.sigma ** 2))  #(M,N)
            kerns[i] = ker_weight_ori
        #     # print("kern weight ori", ker_weight_ori.shape)

            weights[i] = ker_weight_ori * pi[:,i][:,None]
            weights_copy.append(weights[i][...,None])

            # ker_weight_fix = ker_weight_ori


        # # weight = (torch.cat([weight_ori[...,None],weight_perm[...,None]],-1)) + 1e-16
        weight = (torch.cat(weights_copy,-1)) + 1e-16
        weight = weight/weight.sum(-1)[...,None]

        weight_old = weight.clone()
        x = torch.zeros((B, N, 3)).cuda()

        for i in range(n_mix):
            x += weight[:, :, i:i+1] * all_xyz[i]
        target = weight.sum(1)
        target = target / target.sum(-1, keepdim=True)

        label_one_hots = torch.zeros(n_mix, B, self.num_class).cuda()
        label_onehot = torch.zeros(B, self.num_class).cuda().scatter(1, label.view(-1, 1), 1)
        label_one_hots[0] = label_onehot
        # print("label_onehot shape", label_onehot.shape)

        label = torch.zeros(B, self.num_class).cuda()
        label += label_one_hots[0] * target[:, 0, None]
        
        for i in range(1, n_mix):
            label_perm_onehot = label_onehot[idxs[i]]
            label += label_perm_onehot * target[:, i, None]
        
        return x, label

In [None]:
def distance(z, dist_type='l2'):
    '''Return distance matrix between vectors'''
    with torch.no_grad():
        diff = z.unsqueeze(1) - z.unsqueeze(0)
        if dist_type[:2] == 'l2':
            A_dist = (diff**2).sum(-1)
            if dist_type == 'l2':
                A_dist = torch.sqrt(A_dist)
            elif dist_type == 'l22':
                pass
        elif dist_type == 'l1':
            A_dist = diff.abs().sum(-1)
        elif dist_type == 'linf':
            A_dist = diff.abs().max(-1)[0]
        else:
            return None
    return A_dist


def calc_A_dist(saliency, theta=0.5):
    sc = saliency.unsqueeze(1)
    # print("sc:",sc.shape)
    # z = F.avg_pool1d(sc, kernel_size=8, stride=1)
    # print("z:",z.shape)
    z = sc
    z_reshape = z.reshape(args.batch_size, -1)
    # print("z_reshape:",z_reshape.shape)
    z_idx_1d = torch.argmax(z_reshape, dim=1)
    z_idx_2d = torch.zeros((args.batch_size, 2), device=z.device)
    z_idx_2d[:, 0] = z_idx_1d // z.shape[-1]
    z_idx_2d[:, 1] = z_idx_1d % z.shape[-1]
    # print("z_idx_2d:",z_idx_2d)
    A_dist = distance(z_idx_2d, dist_type='l1')
    # print("A_dist:", A_dist)

    n_input = saliency.shape[0]
    
    A_base = torch.eye(n_input, device=out.device)

    A_dist = A_dist / torch.sum(A_dist) * n_input
    m_omega = torch.distributions.beta.Beta(theta, theta).sample()
    A = (1 - m_omega) * A_base + m_omega * A_dist
    # print("A", A)
    return A
