In [None]:

import numpy as np
import tensorflow as tf

"""
Adapted from https://github.com/gpeyre/SinkhornAutoDiff
and from https://github.com/dfdazac/wassdistance/blob/master/layers.py
and from https://github.com/michaelsdr/sinkformers/blob/main/nlp-tutorial/text-classification-transformer/sinkhorn.py
"""

def shape_list(x, out_type=tf.int32):
  """Deal with dynamic shape in tensorflow cleanly."""
  static = x.shape.as_list()
  dynamic = tf.shape(x, out_type=out_type)
  return [dynamic[i] if s is None else s for i, s in enumerate(static)]

def sinkhorn_distance(input_tensor, eps, max_iter, 
                  reduction='none',
                  stopThr=1e-2):
  
  C = input_tensor
  C_shape = shape_list(C)

  x_points = C_shape[-2]
  y_points = C_shape[-1]
  batch_size = C_shape[0]
    
  # both marginals are fixed with equal weights
  mu = 1.0 / x_points * tf.ones((batch_size, x_points))
  nu = 1.0 / y_points * tf.ones((batch_size, y_points))

  u = tf.zeros_like(mu)
  v = tf.zeros_like(nu)

  cpt = tf.constant(0)
  err = tf.constant(1.0)

  c = lambda cpt, u, v, err: tf.logical_and(cpt < max_iter, err > stopThr)

  def M( C, u, v):
    "Modified cost for logarithmic updates"
    "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
    return (-C + tf.expand_dims(u, -1) + tf.expand_dims(v, -2) )/eps

  def loop_func(cpt, u, v, err):
    u1 = tf.identity(u)  # useful to check the update

    cpt = cpt + 1

    u = eps * (tf.log(mu+1e-8) - tf.reduce_logsumexp(M(C, u, v), axis=-1)) + u
    v = eps * (tf.log(nu+1e-8) - tf.reduce_logsumexp(tf.transpose(M(C, u, v), [0, 2, 1]), axis=-1)) + v

    err = tf.reduce_mean(tf.reduce_sum(tf.abs(u - u1), axis=-1))

    return cpt, u, v, err

  _, u_final, v_final, _ = tf.while_loop(c, loop_func, loop_vars=[cpt, u, v, err])
  U, V = tf.identity(u_final), tf.identity(v_final)

  # Transport plan pi = diag(a)*K*diag(b)
  pi = tf.exp(M(C, U, V))

  cost = tf.reduce_sum(pi * C, axis=(-2, -1))

  return pi, C, U, V, cost



In [None]:
eps = 1.0
max_iter = 10
stopThr = 1e-10

In [None]:
def _cost_matrix(x, y, p=2):
    "Returns the matrix of $|x_i-y_j|^p$."
    x_col = tf.expand_dims(x, axis=-2)
    y_lin = tf.expand_dims(y, axis=-3)
    C = tf.reduce_sum((tf.abs(x_col - y_lin)) ** p, -1)
    return C

In [None]:
x = np.random.random((1, 10, 32)).astype(np.float32)
y = np.random.random((1, 16, 32)).astype(np.float32)

In [None]:
C = _cost_matrix(tf.constant(x), tf.constant(y), p=2)

[pi, C_, U, V, final_cost] = sinkhorn_distance(C, eps, max_iter, 
                  reduction='none',
                  stopThr=1e-10)

In [None]:
sess = tf.Session()
resp = sess.run([pi, C, U, V, final_cost])

In [None]:
resp[0].sum(axis=-2)

In [None]:
import torch
import torch.nn as nn

# Adapted from https://github.com/gpeyre/SinkhornAutoDiff
class SinkhornDistance(nn.Module):
    r"""
    Given two empirical measures each with :math:`P_1` locations
    :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
    outputs an approximation of the regularized OT cost for point clouds.
    Args:
        eps (float): regularization coefficient
        max_iter (int): maximum number of Sinkhorn iterations
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Default: 'none'
    Shape:
        - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
        - Output: :math:`(N)` or :math:`()`, depending on `reduction`
    """
    def __init__(self, eps, max_iter, reduction='none'):
        super(SinkhornDistance, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction

    def forward(self, x, y):
        # The Sinkhorn algorithm takes as input three variables :
        C = self._cost_matrix(x, y)  # Wasserstein cost function
        x_points = x.shape[-2]
        y_points = y.shape[-2]
        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]

        # both marginals are fixed with equal weights
        mu = torch.empty(batch_size, x_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / x_points).squeeze()
        nu = torch.empty(batch_size, y_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / y_points).squeeze()

        u = torch.zeros_like(mu)
        v = torch.zeros_like(nu)
        # To check if algorithm terminates because of threshold
        # or max iterations reached
        actual_nits = 0
        # Stopping criterion
        thresh = 1e-1

        # Sinkhorn iterations
        for i in range(self.max_iter):
            u1 = u  # useful to check the update
            u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
            v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
            err = (u - u1).abs().sum(-1).mean()

            actual_nits += 1
            if err.item() < thresh:
                break

        U, V = u, v
        # Transport plan pi = diag(a)*K*diag(b)
        pi = torch.exp(self.M(C, U, V))
        # Sinkhorn distance
        cost = torch.sum(pi * C, dim=(-2, -1))

        if self.reduction == 'mean':
            cost = cost.mean()
        elif self.reduction == 'sum':
            cost = cost.sum()

        return cost, pi, C

    def M(self, C, u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps

    @staticmethod
    def _cost_matrix(x, y, p=2):
        "Returns the matrix of $|x_i-y_j|^p$."
        x_col = x.unsqueeze(-2)
        y_lin = y.unsqueeze(-3)
        C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1)
        return C

    @staticmethod
    def ave(u, u1, tau):
        "Barycenter subroutine, used by kinetic acceleration through extrapolation."
        return tau * u + (1 - tau) * u1

In [None]:
sink = SinkhornDistance(eps=eps, max_iter=max_iter)

In [None]:

cost, pi, C = sink.forward(torch.tensor(x), torch.tensor(y))

In [None]:
pi.sum()

In [None]:
from tokenizers import (ByteLevelBPETokenizer,
      CharBPETokenizer,
      SentencePieceBPETokenizer,
      BertWordPieceTokenizer)

vocab = '/data/xuht/uncased_L-12_H-768_A-12_ilm_v1/vocab_uncased_en.txt'

chinese_bpe_tokenizer = BertWordPieceTokenizer(
    vocab, 
    lowercase=True)

In [None]:
help(chinese_bpe_tokenizer.decode)

In [None]:
sess.run(tf.range(10))

In [None]:
np.exp(-1000)

In [None]:
sess.run(tf.not_equal([1.,2.,3.,0.], 0))

In [None]:
import sklearn.preprocessing
a = [1,0,3]
label_binarizer = sklearn.preprocessing.LabelBinarizer()
label_binarizer.fit(range(max(a)+1))
b = label_binarizer.transform(a)
print('{0}'.format(b))

In [None]:
label = np.random.randint(1, 4, size=[2,5])

In [None]:
label_tf = tf.one_hot(label, depth=4)

In [None]:
one_hot_label = sess.run(label_tf)

In [None]:
label

In [None]:
(one_hot_label.sum(axis=0) !=0)*1

In [None]:
one_hot_label

In [None]:
24*8

In [None]:
def _generate_relative_positions_matrix_t5(length, max_relative_position,
                                        num_buckets=32,
                                        bidirectional=True):
  
  """
  https://github.com/bojone/bert4keras/blob/master/bert4keras/layers.py
  https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py
  # _relative_position_bucket
  https://gist.github.com/huchenxucs/c65524185e8e35c4bcfae4059f896c16
  """

  tf.logging.info("** apply all distance mat **")
  range_vec = tf.range(length)

  q_idxs = tf.expand_dims(range_vec, 1)
  v_idxs = tf.expand_dims(range_vec, 0)

  distance_mat = v_idxs - q_idxs  
  # range_mat = tf.reshape(tf.tile(range_vec, [length]), [length, length])
  # distance_mat = range_mat - tf.transpose(range_mat)
    
  num_buckets = num_buckets
  max_distance = max_relative_position
  ret = 0
  n = -distance_mat
  if bidirectional:
    num_buckets //= 2
    ret += tf.cast(tf.less(n, 0), 'int32') * num_buckets
    n = tf.abs(n)
  else:
    n = tf.maximum(n, 0)
  # now n is in the range [0, inf)
  max_exact = num_buckets // 2
  is_small = tf.less(n, max_exact)
  val_if_large = max_exact + tf.cast(
      tf.log(tf.cast(n, dtype=tf.float32) / max_exact) /
      tf.log(max_distance / max_exact) * (num_buckets - max_exact),
      'int32',
  )
  val_if_large = tf.minimum(val_if_large, num_buckets - 1)
  tf_switch = (tf.cast(is_small, dtype=tf.int32)) * n + (1-tf.cast(is_small, dtype=tf.int32)) * val_if_large
  ret += tf_switch #tf.switch(is_small, n, val_if_large)
  # ret += tf.where(is_small, n, val_if_large)

  return ret

length=64
max_relative_position=32
num_buckets=32
bidirectional=True

ret_bi = _generate_relative_positions_matrix_t5(length, max_relative_position,
                                        num_buckets=num_buckets,
                                        bidirectional=bidirectional)

ret_uni = _generate_relative_positions_matrix_t5(length, max_relative_position,
                                        num_buckets=num_buckets,
                                        bidirectional=False)

ret = sess.run([ret_bi, ret_uni])



In [None]:
ret[0]

In [None]:
ret[1][-1]

In [None]:
def _relative_position_bucket_(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
        Translate relative position to a bucket number for relative attention.
        The relative position is defined as memory_position - query_position, i.e.
        the distance in tokens from the attending position to the attended-to
        position.  If bidirectional=False, then positive relative positions are
        invalid.
        We use smaller buckets for small absolute relative_position and larger buckets
        for larger absolute relative_positions.  All relative positions >=max_distance
        map to the same bucket.  All relative positions <=-max_distance map to the
        same bucket.  This should allow for more graceful generalization to longer
        sequences than the model has been trained on.
        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
            max_distance: an integer
        Returns:
            a Tensor with the same shape as relative_position, containing int32
            values in the range [0, num_buckets)
        """
        ret = 0
        n = -relative_position
        if bidirectional:
            num_buckets //= 2
            ret += (n < 0).to(torch.long) * num_buckets  # mtf.to_int32(mtf.less(n, 0)) * num_buckets
            n = torch.abs(n)
        else:
            n = torch.max(n, torch.zeros_like(n))
        # now n is in the range [0, inf)

        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
        is_small = n < max_exact

        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).to(torch.long)
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

        ret += torch.where(is_small, n, val_if_large)
        return ret

In [None]:
import torch

In [None]:
context_position = torch.arange(64, dtype=torch.long)[:, None]
memory_position = torch.arange(64, dtype=torch.long)[None, :]
relative_position = memory_position - context_position  # shape (qlen, klen)
resp = _relative_position_bucket_(relative_position, bidirectional=False, num_buckets=32, max_distance=32)

In [None]:
resp[0]

In [None]:
ret[0][25]

In [None]:
(1-segment_ids) * resp[25].numpy()

In [None]:
segment_ids

In [None]:
context_position = torch.arange(64, dtype=torch.long)[:, None]
memory_position = torch.arange(64, dtype=torch.long)[None, :]
relative_position = memory_position - context_position  # shape (qlen, klen)
s1 = _relative_position_bucket_(relative_position, bidirectional=True, num_buckets=32, max_distance=32)

In [None]:
context_position = torch.arange(64, dtype=torch.long)[:, None]
memory_position = torch.arange(64, dtype=torch.long)[None, :]
relative_position = memory_position - context_position  # shape (qlen, klen)
s3 = _relative_position_bucket_(relative_position, bidirectional=False, num_buckets=32, max_distance=32)

In [None]:
context_position = torch.arange(32, dtype=torch.long)[:, None]
memory_position = torch.arange(32, dtype=torch.long)[None, :]
relative_position = memory_position - context_position  # shape (qlen, klen)
s2 = _relative_position_bucket_(relative_position, bidirectional=True, num_buckets=32, max_distance=32)

In [None]:
s1[31]

In [None]:
s2[-1]

In [None]:
s3[32]

In [None]:
a1 = s1 * (1-segment_ids[None, :]) * (1-segment_ids[:, None]) + s3 * (segment_ids[:, None])

In [None]:

segment_ids = [0]*25+[1]*39

segment_mask = tf.cast(np.array([segment_ids, segment_ids]), dtype=tf.int32)
relative_positions_matrix_bi = tf.constant(ret[0])
relative_positions_matrix_uni = tf.constant(ret[1])

# handle mixture of bi and uni-direction relative position
# [1, seq_len, seq_len]
relative_positions_matrix_bi = tf.expand_dims(relative_positions_matrix_bi, axis=0)
relative_positions_matrix_uni = tf.expand_dims(relative_positions_matrix_uni, axis=0)

# s1 * (1-segment_ids[None, :]) * (1-segment_ids[:, None]) + s3 * (segment_ids[:, None])
# [batch, seq_len, seq_len]
relative_positions_matrix = relative_positions_matrix_bi * (1-tf.expand_dims(segment_mask, axis=1)) * (1-tf.expand_dims(segment_mask, axis=-1)) + relative_positions_matrix_uni * (tf.expand_dims(segment_mask, axis=-1))
  

In [None]:
final = sess.run(relative_positions_matrix)

In [None]:
63356*768-21228*512*4

In [55]:
import numpy as np
import tensorflow as tf
init_np = np.random.random((4, 2, 3))
update_np = np.random.random((1, 2, 3))

In [62]:
graph = tf.Graph()
with graph.as_default():
    
    with tf.variable_scope("test", reuse=tf.AUTO_REUSE):
        queue = tf.get_variable('queue', 
                      [4, 2, 3], 
                      dtype=tf.float32,
                      initializer=tf.constant_initializer(0),
                      trainable=False)
    
    sess = tf.Session()
    queue_op = queue.assign(tf.concat([tf.constant(update_np.astype(np.float32)), queue[:-1, :, :]], axis=0))
    with tf.control_dependencies([queue_op]):
    #     p = queue + 1
        f = tf.identity(queue)
        queue_mask = tf.cast(tf.not_equal(queue, 0), dtype=tf.float32)
        Z = tf.reduce_logsumexp(queue-(1-queue_mask)*1e10, axis=-1)

    with tf.control_dependencies([queue_op]):
        p = queue + 1

    sess.run(tf.global_variables_initializer())

In [63]:
sess.run(Z)

array([[ 1.6429567e+00,  1.5841072e+00],
       [-1.0000000e+10, -1.0000000e+10],
       [-1.0000000e+10, -1.0000000e+10],
       [-1.0000000e+10, -1.0000000e+10]], dtype=float32)

In [64]:
np.exp(-10000)

0.0

In [58]:
queue_mask = tf.cast(tf.not_equal(queue, 0), dtype=tf.float32)
Z = tf.reduce_logsumexp(queue-(1-queue_mask)*1e10, axis=0)


In [61]:
sess.run(Z)

array([[0.15929994, 0.77011406, 0.60853255],
       [0.08643683, 0.19761226, 0.9424123 ]], dtype=float32)