# Quick Einsum Examples.

Based on the Einsum tutorial [here](https://www.youtube.com/watch?v=pkVwUVEHmfI). 

In [1]:
import tensorflow as tf
tf.__version__

'2.9.0'

In [5]:
# Just create a random matrix.
x = tf.random.uniform((2,3))
print(x)

tf.Tensor(
[[0.22854877 0.91063654 0.26433897]
 [0.8855262  0.8261514  0.1433636 ]], shape=(2, 3), dtype=float32)


In [22]:
# There are no summation indices, only loop indices.
# We can confidently shuffle those.
x_permuted = tf.einsum("ij->ji", x)
tf.assert_equal(x_permuted, tf.transpose(x, perm=[1, 0]))
print(x_permuted)

tf.Tensor(
[[0.22854877 0.8855262 ]
 [0.91063654 0.8261514 ]
 [0.26433897 0.1433636 ]], shape=(3, 2), dtype=float32)


In [24]:
# Both are summation indices, so we get the sum directly.
x_sum = tf.einsum("ij->", x)
tf.assert_equal(x_sum, tf.reduce_sum(x_sum))
print(x_sum)

tf.Tensor(3.2585654, shape=(), dtype=float32)


In [25]:
# We pick one of the indices as the summation index.
x_sum_along_axis_0 = tf.einsum("ij->j", x)
tf.assert_equal(x_sum_along_axis_0, 
                tf.reduce_sum(x, axis=0))
print(x_sum_along_axis_0)

x_sum_along_axis_1 = tf.einsum("ij->i", x)
tf.assert_equal(x_sum_along_axis_1, 
                tf.reduce_sum(x, axis=1))
print(x_sum_along_axis_1)

tf.Tensor([1.114075   1.7367879  0.40770257], shape=(3,), dtype=float32)
tf.Tensor([1.4035243 1.8550411], shape=(2,), dtype=float32)


In [16]:
# Pick another random vector matrix.
v = tf.random.uniform((1,3))
print(v)

tf.Tensor([[0.07929385 0.292992   0.34711337]], shape=(1, 3), dtype=float32)


In [39]:
# Matrix vector multiplication.
# Write out what you want as loops and summations as shown in the video.
# This helps convert it to einsum.
mv_mult = tf.einsum("ij,kj->i", x, v)
tf.assert_equal(mv_mult, tf.linalg.matvec(x, v))
print(mv_mult)

tf.Tensor([0.37668732 0.36203593], shape=(2,), dtype=float32)


In [40]:
# Another random matrix for matrix-matrix multiplication.
y = tf.random.uniform((3, 2))
print(y)

tf.Tensor(
[[0.90759623 0.37361777]
 [0.24184144 0.02720988]
 [0.63535583 0.6381259 ]], shape=(3, 2), dtype=float32)


In [41]:
# Matrix-Matrix multiplication.
mm_mult = tf.einsum("ij,jk->ik", x, y)
tf.assert_equal(mm_mult, tf.matmul(x, y))
print(mm_mult)

tf.Tensor(
[[0.59560895 0.27884972]
 [1.0945847  0.44481182]], shape=(2, 2), dtype=float32)


In [57]:
# First row dot product.
first_row_dot = tf.einsum("i,i->", x[0], x[0])
tf.debugging.assert_near(first_row_dot, tf.tensordot(x[0], x[0], axes=1))

In [64]:
# Dot products between matrices.
mm_dot = tf.einsum("ij,ij->", x, x)
# TODO: Figuring out tensordot representation for this.
print(mm_dot)

tf.Tensor(2.4386044, shape=(), dtype=float32)


In [66]:
# Hadamard product (element-wise multiplication).
hadamard = tf.einsum("ij,ij->ij", x, x)
tf.debugging.assert_near(hadamard, tf.multiply(x, x))

In [69]:
x1 = tf.random.uniform((1, 2))
print(x1)
x2 = tf.random.uniform((1, 3))
print(x2)

tf.Tensor([[0.24436545 0.2866938 ]], shape=(1, 2), dtype=float32)
tf.Tensor([[0.5276089  0.28632748 0.6086776 ]], shape=(1, 3), dtype=float32)


In [74]:
# Outer product of two vectors.
# Ignore i and k, since they are set to 1 just to keep TF happy.
outer_product = tf.einsum("ij,kl->jl", x1, x2)
tf.debugging.assert_near(
    outer_product,
    tf.squeeze(tf.tensordot(x1, x2, axes=0)))

In [75]:
# Batch matrix multiplication.
xb_1 = tf.random.uniform((4, 3, 2))
xb_2 = tf.random.uniform((4, 2, 3))
batch_mm = tf.einsum("ijk,ikl->ijl", xb_1, xb_2)
tf.debugging.assert_near(
    batch_mm,
    tf.matmul(xb_1, xb_2))

In [83]:
# Create a square matrix. 
x_d = tf.random.uniform((3, 3))
print(x_d)

tf.Tensor(
[[0.3436594  0.5534719  0.2811308 ]
 [0.36066985 0.4359125  0.6491792 ]
 [0.5480267  0.27624607 0.30795145]], shape=(3, 3), dtype=float32)


In [84]:
# Obtain a square matrix's diagonal.
diag = tf.einsum("ii->i", x_d)
print(diag)
tf.assert_equal(
    diag,
    tf.linalg.diag_part(x_d))

tf.Tensor([0.3436594  0.4359125  0.30795145], shape=(3,), dtype=float32)


In [85]:
# Matrix trace (sum of the elements on the upper left diag).
diag_sum = tf.einsum("ii->", x_d)
print(diag_sum)
tf.assert_equal(diag_sum, tf.reduce_sum(diag))

tf.Tensor(1.0875233, shape=(), dtype=float32)
