In [1]:
import torch
import torch.nn.functional as F

## Maximum Mean Discrepancy

$$
MMD^{2} (X,Y) = \frac{1}{m(m-1)}\sum_i\sum_{j\neq i} k(x_i , x_j) - 2  \frac{1}{m \times m}\sum_i\sum_{j} k(x_i , y_j) +  \frac{1}{m(m-1)}\sum_i\sum_{j\neq i} k(y_i , y_j)
$$

### Using Gaussian Kernel

$$
k(x_i, x_j) = exp( \frac{-\parallel x_i - x_j \parallel ^2}{2 \sigma ^2}) = exp(\frac{1}{\sigma ^2}[x_i^T x_i - 2x_i^T x_j + x_j^T x_j])
$$

### Using Matrix to improve computation efficiency

$$
x = \begin{pmatrix}
x_{1}\\
x_{2}\\
x_{3}
\end{pmatrix}
$$

$$
x_ix_j = x \cdot x^T = \begin{pmatrix}
x_{1,1}& x_{1,2} & x_{1,3}\\
x_{2,1}& x_{2,2} & x_{2,3}\\
x_{3,1}& x_{3,2} & x_{3,3}
\end{pmatrix}
$$

Getting the diagonal components and form a 3x3 matrix. This is equivalent to
$$
x_i^T x_i = rx = 
\begin{pmatrix}
x_{1,1}& x_{1,1} & x_{1,1}\\
x_{2,2}& x_{2,2} & x_{2,2}\\
x_{3,3}& x_{3,3} & x_{3,3}
\end{pmatrix}
$$

$$
x_j^T x_j = rx^T = 
\begin{pmatrix}
x_{1,1}& x_{2,2} & x_{3,3}\\
x_{1,1}& x_{2,2} & x_{3,3}\\
x_{1,1}& x_{2,2} & x_{3,3}
\end{pmatrix}
$$

Kernel $k(x,x)$ can be calculated:

$$
k(x,x) = x_i^T x_i + x_j^T x_j - 2 x_i x_j
$$
```python
xx = torch.mm(x, x.T)
rx = xx.diag().unsqueeze(0).expand_as(xx)
dxx = rx.T + rx - 2 * xx

```
https://discuss.pytorch.org/t/maximum-mean-discrepancy-mmd-and-radial-basis-function-rbf/1875/2

https://towardsdatascience.com/write-markdown-latex-in-the-jupyter-notebook-10985edb91fd

https://medium.com/@aditya.rastogi/the-denominator-inside-the-log-in-the-nt-xent-loss-function-explained-b412eeceba2f

https://discuss.pytorch.org/t/build-your-own-loss-function-in-pytorch/235/6


## NT-Xent loss
$$
\mathcal{L}_{i,j} = - log \frac{exp(sim(z_i,z_j)/\tau)}{\sum_{i=1}^{2N} \mathbb{1}_{[i \neq k]} exp(sim(z_i,z_k))}
$$

### To get cosine similarity, we can do outer product of vector x

$$
x = \begin{pmatrix}
x_{1}\\
x_{2}\\
x_{3}
\end{pmatrix}
$$

$$
x \cdot x^T = \begin{pmatrix}
x_{1,1}& x_{1,2} & x_{1,3}\\
x_{2,1}& x_{2,2} & x_{2,3}\\
x_{3,1}& x_{3,2} & x_{3,3}
\end{pmatrix}
$$

And get the diagonal components

```python
xx = torch.mm(x, x.T)
xx = xx.diag().unsqueeze(0)
```

### For sim($z_i$, $z_k$) where $i \neq k $, we will need to get the off diagonal components

$$
xy = \begin{pmatrix}
x_{1}\\
x_{2}\\
x_{3}\\
y_{1}\\
y_{2}\\
y_{3}
\end{pmatrix}
$$
$$
xy \cdot xy^T =\begin{pmatrix}
x_{1}x_{1}& x_{1}x_{2} & \dots & x_{1}y_{6}\\
\vdots & \ddots & \ddots &  \vdots \\
y_{6}x_{1}  & \cdots &  & y_{6}y_{6}
\end{pmatrix}
$$
### Few ways to compute similarity matrix

1. 
```python
sim = nn.CosineSimilarity(dim=2)(xx.unsqueeze(1),xx.unsqueeze(0))
```
2.
```python
xy = torch.cat([x,y])
xy = F.normalize(xy,dim=1)
sim = torch.mm(xx,xx.T)
```

### Getting the off-diagonal components
```python
sim_i_j = torch.diag(sim, batch_size)
sim_j_i = torch.diag(sim, -batch_size)
```

Mask can be created to filter $i = k$ 

```python
ones = torch.ones(batch_size)
mask = torch.diagflat(ones, 4) + torch.diagflat(ones, -4)
i_not_k = sim - (sim * mask)
```

# Pairwise Distance

Given a matrix of N x d

$$
x = \begin{pmatrix}
x_{0}\\
x_{1}\\
x_{2}\\
\end{pmatrix}
y = \begin{pmatrix}
y_{0}\\
y_{1}\\
y_{2}\\
\end{pmatrix}
$$
where $x_{i}, y_{i}$ is given to be the row feature vector

Pairwise distance matrix dist[i,j] is the square norm between x[i,:] and y[j,:]. To vectorize the calculation, we can expand the matrix into 3D, where each component correspond to $i^{th}$ and $j^{th}$ component of matrix x and y
$$
\begin{bmatrix} \begin{pmatrix}
x_{0}\\
x_{0}\\
x_{0}\\
\end{pmatrix} - \begin{pmatrix}
y_{0}\\
y_{1}\\
y_{2}\\
\end{pmatrix}
\end{bmatrix}_{0}
$$

$$
\begin{bmatrix} \begin{pmatrix}
x_{1}\\
x_{1}\\
x_{1}\\
\end{pmatrix} - \begin{pmatrix}
y_{0}\\
y_{1}\\
y_{2}\\
\end{pmatrix}
\end{bmatrix}_{1}
$$

$$
\begin{bmatrix} \begin{pmatrix}
x_{2}\\
x_{2}\\
x_{2}\\
\end{pmatrix} - \begin{pmatrix}
y_{0}\\
y_{1}\\
y_{2}\\
\end{pmatrix}
\end{bmatrix}_{2}
$$

```python
    x = torch.randn(size=(3,3))
    y = torch.randn(size=(3,3))
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
```
```python
    x = x.unsqueeze(0).expand(n, m, d)
```
$$
\begin{bmatrix} \begin{pmatrix}
x_{0}\\
x_{1}\\
x_{2}\\
\end{pmatrix}
\end{bmatrix}_{0}
$$

$$
\begin{bmatrix} \begin{pmatrix}
x_{0}\\
x_{1}\\
x_{2}\\
\end{pmatrix}
\end{bmatrix}_{1}
$$
```python
    y = y.unsqueeze(1).expand(n, m, d)
```
$$
\begin{bmatrix} \begin{pmatrix}
x_{0}\\
x_{0}\\
x_{0}\\
\end{pmatrix}
\end{bmatrix}_{0}
$$

$$
\begin{bmatrix} \begin{pmatrix}
x_{1}\\
x_{1}\\
x_{1}\\
\end{pmatrix}
\end{bmatrix}_{1}
$$
```python
pairwise_distance = torch.pow(x - y, 2).sum(2)
```