In [2]:
import torch
import numpy as np

In [20]:
def onehot(indexes, N=None, ignore_index=None):
    """
    Creates a one-representation of indexes with N possible entries
    if N is not specified, it will suit the maximum index appearing.
    indexes is a long-tensor of indexes
    ignore_index will be zero in onehot representation
    """
    if N is None:
        N = indexes.max() + 1
    sz = list(indexes.size())
    output = indexes.new().byte().resize_(*sz, N).zero_()
    output.scatter_(-1, indexes.unsqueeze(-1), 1)
    if ignore_index is not None and ignore_index >= 0:
        output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0)
    return output

def _is_long(x):
    if hasattr(x, 'data'):
        x = x.data
    return isinstance(x, torch.LongTensor) or isinstance(x, torch.cuda.LongTensor)


def cross_entropy(inputs, target, weight=None, ignore_index=-100, reduction='mean',
                  smooth_eps=None, smooth_dist=None, from_logits=True):
    """cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567"""
    smooth_eps = smooth_eps or 0

    # ordinary log-liklihood - use cross_entropy from nn
    if _is_long(target) and smooth_eps == 0:
        if from_logits:
            return F.cross_entropy(inputs, target, weight, ignore_index=ignore_index, reduction=reduction)
        else:
            return F.nll_loss(inputs, target, weight, ignore_index=ignore_index, reduction=reduction)

    if from_logits:
        # log-softmax of inputs
        lsm = F.log_softmax(inputs, dim=-1)
    else:
        lsm = inputs

    masked_indices = None
    num_classes = inputs.size(-1)

    if _is_long(target) and ignore_index >= 0:
        masked_indices = target.eq(ignore_index)

    if smooth_eps > 0 and smooth_dist is not None:
        if _is_long(target):
            target = onehot(target, num_classes).type_as(inputs)
        if smooth_dist.dim() < target.dim():
            smooth_dist = smooth_dist.unsqueeze(0)
        target.lerp_(smooth_dist, smooth_eps)

    if weight is not None:
        lsm = lsm * weight.unsqueeze(0)

    if _is_long(target):
        eps_sum = smooth_eps / num_classes
        eps_nll = 1. - eps_sum - smooth_eps
        likelihood = lsm.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1)
        loss = -(eps_nll * likelihood + eps_sum * lsm.sum(-1))
    else:
        loss = -(target * lsm).sum(-1)

    if masked_indices is not None:
        loss.masked_fill_(masked_indices, 0)

    if reduction == 'sum':
        loss = loss.sum()
    elif reduction == 'mean':
        if masked_indices is None:
            loss = loss.mean()
        else:
            loss = loss.sum() / float(loss.size(0) - masked_indices.sum())

    return loss

In [5]:
lam = np.random.beta(1,1, size=batch_size)

In [7]:
lam.shape

(256,)

In [None]:
def mixup_data(x, y, alpha=1.0):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

In [10]:
index = torch.randperm(256)

In [12]:
x_a = torch.randn(256, 3, 32, 32)
x_b = torch.randn(256, 3, 32, 32)

rep_a = torch.randn(256, 128) # B * dim

In [157]:
def data_mixer(x_aug1, x_aug2, alpha=1.0, eps=0.0):
    device = x_aug1.cuda()
    
    # batch size
    b = x_aug1.shape[0]
    
    idx1 = torch.Tensor(range(256)).cuda()
    idx2 = torch.randperm(b).cuda()
    
    # mixup process
    lam = torch.Tensor(np.random.beta(alpha, alpha, size=b)).to(device)
    lam = lam.reshape(b,1,1,1)
    x_aug2 = x_aug2[index]    # shuffle samples
    x_mix = lam*x_aug1 + (1-lam)*x_aug2
    
    lam = lam.reshape(b,1)
    
    target1 = onehot(idx1.long(), N=b)
    target1 = eps*(torch.ones(b)/b) + (1-eps) * target1
    target2 = onehot(idx2.long(), N=b)
    target2 = eps*(torch.ones(b)/b) + (1-eps) * target2
    
    ins_label = lam*target1 + (1-lam)*target2
    
    return x_mix, lam, ins_label

In [158]:
a = torch.randn(256,128)
b = torch.randn(256,128)

target = torch.randn(256,256)

In [159]:
x_mix, lam, ins_label = data_mixer(x_a, x_b)

In [161]:
ins_label.shape

torch.Size([256, 256])

In [156]:
torch.max(a, dim=1)[1].dtype

torch.int64

In [154]:
loss

tensor(-11.0752)