In [11]:

import numpy as np
import tensorflow as tf

"""
Adapted from https://github.com/gpeyre/SinkhornAutoDiff
and from https://github.com/dfdazac/wassdistance/blob/master/layers.py
and from https://github.com/michaelsdr/sinkformers/blob/main/nlp-tutorial/text-classification-transformer/sinkhorn.py
"""

def shape_list(x, out_type=tf.int32):
  """Deal with dynamic shape in tensorflow cleanly."""
  static = x.shape.as_list()
  dynamic = tf.shape(x, out_type=out_type)
  return [dynamic[i] if s is None else s for i, s in enumerate(static)]

def sinkhorn_distance(input_tensor, eps, max_iter, 
                  reduction='none',
                  stopThr=1e-2):
  
  C = input_tensor
  C_shape = shape_list(C)

  x_points = C_shape[-2]
  y_points = C_shape[-1]
  batch_size = C_shape[0]
    
  # both marginals are fixed with equal weights
  mu = 1.0 / x_points * tf.ones((batch_size, x_points))
  nu = 1.0 / y_points * tf.ones((batch_size, y_points))

  u = tf.zeros_like(mu)
  v = tf.zeros_like(nu)

  cpt = tf.constant(0)
  err = tf.constant(1.0)

  c = lambda cpt, u, v, err: tf.logical_and(cpt < max_iter, err > stopThr)

  def M( C, u, v):
    "Modified cost for logarithmic updates"
    "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
    return (-C + tf.expand_dims(u, -1) + tf.expand_dims(v, -2) )/eps

  def loop_func(cpt, u, v, err):
    u1 = tf.identity(u)  # useful to check the update

    cpt = cpt + 1

    u = eps * (tf.log(mu+1e-8) - tf.reduce_logsumexp(M(C, u, v), axis=-1)) + u
    v = eps * (tf.log(nu+1e-8) - tf.reduce_logsumexp(tf.transpose(M(C, u, v), [0, 2, 1]), axis=-1)) + v

    err = tf.reduce_mean(tf.reduce_sum(tf.abs(u - u1), axis=-1))

    return cpt, u, v, err

  _, u_final, v_final, _ = tf.while_loop(c, loop_func, loop_vars=[cpt, u, v, err])
  U, V = tf.identity(u_final), tf.identity(v_final)

  # Transport plan pi = diag(a)*K*diag(b)
  pi = tf.exp(M(C, U, V))

  cost = tf.reduce_sum(pi * C, axis=(-2, -1))

  return pi, C, U, V, cost



In [63]:
eps = 1.0
max_iter = 10
stopThr = 1e-10

In [64]:
def _cost_matrix(x, y, p=2):
    "Returns the matrix of $|x_i-y_j|^p$."
    x_col = tf.expand_dims(x, axis=-2)
    y_lin = tf.expand_dims(y, axis=-3)
    C = tf.reduce_sum((tf.abs(x_col - y_lin)) ** p, -1)
    return C

In [39]:
x = np.random.random((1, 10, 32)).astype(np.float32)
y = np.random.random((1, 16, 32)).astype(np.float32)

In [65]:
C = _cost_matrix(tf.constant(x), tf.constant(y), p=2)

[pi, C_, U, V, final_cost] = sinkhorn_distance(C, eps, max_iter, 
                  reduction='none',
                  stopThr=1e-10)

In [66]:
sess = tf.Session()
resp = sess.run([pi, C, U, V, final_cost])

In [69]:
resp[0].sum(axis=-2)

array([[0.06250001, 0.06250001, 0.06250001, 0.0625    , 0.06250001,
        0.06250001, 0.06250002, 0.06250002, 0.06250003, 0.06250001,
        0.06250003, 0.06250001, 0.06250001, 0.06250001, 0.0625    ,
        0.06250001]], dtype=float32)

In [18]:
import torch
import torch.nn as nn

# Adapted from https://github.com/gpeyre/SinkhornAutoDiff
class SinkhornDistance(nn.Module):
    r"""
    Given two empirical measures each with :math:`P_1` locations
    :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
    outputs an approximation of the regularized OT cost for point clouds.
    Args:
        eps (float): regularization coefficient
        max_iter (int): maximum number of Sinkhorn iterations
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Default: 'none'
    Shape:
        - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
        - Output: :math:`(N)` or :math:`()`, depending on `reduction`
    """
    def __init__(self, eps, max_iter, reduction='none'):
        super(SinkhornDistance, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction

    def forward(self, x, y):
        # The Sinkhorn algorithm takes as input three variables :
        C = self._cost_matrix(x, y)  # Wasserstein cost function
        x_points = x.shape[-2]
        y_points = y.shape[-2]
        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]

        # both marginals are fixed with equal weights
        mu = torch.empty(batch_size, x_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / x_points).squeeze()
        nu = torch.empty(batch_size, y_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / y_points).squeeze()

        u = torch.zeros_like(mu)
        v = torch.zeros_like(nu)
        # To check if algorithm terminates because of threshold
        # or max iterations reached
        actual_nits = 0
        # Stopping criterion
        thresh = 1e-1

        # Sinkhorn iterations
        for i in range(self.max_iter):
            u1 = u  # useful to check the update
            u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
            v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
            err = (u - u1).abs().sum(-1).mean()

            actual_nits += 1
            if err.item() < thresh:
                break

        U, V = u, v
        # Transport plan pi = diag(a)*K*diag(b)
        pi = torch.exp(self.M(C, U, V))
        # Sinkhorn distance
        cost = torch.sum(pi * C, dim=(-2, -1))

        if self.reduction == 'mean':
            cost = cost.mean()
        elif self.reduction == 'sum':
            cost = cost.sum()

        return cost, pi, C

    def M(self, C, u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps

    @staticmethod
    def _cost_matrix(x, y, p=2):
        "Returns the matrix of $|x_i-y_j|^p$."
        x_col = x.unsqueeze(-2)
        y_lin = y.unsqueeze(-3)
        C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1)
        return C

    @staticmethod
    def ave(u, u1, tau):
        "Barycenter subroutine, used by kinetic acceleration through extrapolation."
        return tau * u + (1 - tau) * u1

In [19]:
sink = SinkhornDistance(eps=eps, max_iter=max_iter)

In [70]:

cost, pi, C = sink.forward(torch.tensor(x), torch.tensor(y))

In [74]:
pi.sum()

tensor(1.0000)

In [79]:
from tokenizers import (ByteLevelBPETokenizer,
      CharBPETokenizer,
      SentencePieceBPETokenizer,
      BertWordPieceTokenizer)

vocab = '/data/xuht/uncased_L-12_H-768_A-12_ilm_v1/vocab_uncased_en.txt'

chinese_bpe_tokenizer = BertWordPieceTokenizer(
    vocab, 
    lowercase=True)

In [80]:
help(chinese_bpe_tokenizer.decode)

Help on method decode in module tokenizers.implementations.base_tokenizer:

decode(ids:List[int], skip_special_tokens:Union[bool, NoneType]=True) -> str method of tokenizers.implementations.bert_wordpiece.BertWordPieceTokenizer instance
    Decode the given list of ids to a string sequence
    
    Args:
        ids: List[unsigned int]:
            A list of ids to be decoded
    
        skip_special_tokens: (`optional`) boolean:
            Whether to remove all the special tokens from the output string
    
    Returns:
        The decoded string



In [82]:
sess.run(tf.range(10))

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [84]:
np.exp(-1000)

0.0

In [88]:
sess.run(tf.not_equal([1.,2.,3.,0.], 0))

array([ True,  True,  True, False])

In [89]:
import sklearn.preprocessing
a = [1,0,3]
label_binarizer = sklearn.preprocessing.LabelBinarizer()
label_binarizer.fit(range(max(a)+1))
b = label_binarizer.transform(a)
print('{0}'.format(b))

[[0 1 0 0]
 [1 0 0 0]
 [0 0 0 1]]


In [130]:
label = np.random.randint(1, 4, size=[2,5])

In [131]:
label_tf = tf.one_hot(label, depth=4)

In [132]:
one_hot_label = sess.run(label_tf)

In [133]:
label

array([[2, 1, 1, 3, 3],
       [2, 3, 1, 2, 2]])

In [137]:
(one_hot_label.sum(axis=0) !=0)*1

array([[0, 0, 1, 0],
       [0, 1, 0, 1],
       [0, 1, 0, 0],
       [0, 0, 1, 1],
       [0, 0, 1, 1]])

In [138]:
one_hot_label

array([[[0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 0., 1.]],

       [[0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.]]], dtype=float32)

In [140]:
24*8

192