\documentclass{article}
\usepackage{amsmath, amssymb, amsfonts}
\usepackage{booktabs}
\usepackage{array}
\usepackage{geometry}
\geometry{margin=1in}

\begin{document}

\section*{Cross-Attention Mechanism: Detailed Explanation and Example}

\subsection*{1. General Formulation}

Suppose we have two graphs:
\begin{itemize}
    \item Graph A with node features $X_A \in \mathbb{R}^{N_A \times d}$.
    \item Graph B with node features $X_B \in \mathbb{R}^{N_B \times d}$.
\end{itemize}

For the cross-attention from Graph A to Graph B, we compute:
\begin{enumerate}
    \item \textbf{Queries, Keys, and Values:}
    \[
    Q_A = X_A W_{Q_A}, \quad K_B = X_B W_{K_A}, \quad V_B = X_B W_{V_A},
    \]
    where $W_{Q_A}$, $W_{K_A}$, and $W_{V_A}$ are learned weight matrices in $\mathbb{R}^{d \times d}$.
    
    \item \textbf{Attention Scores:} Compute the scaled dot-product
    \[
    S = \frac{Q_A K_B^T}{\sqrt{d}} \quad \in \mathbb{R}^{N_A \times N_B}.
    \]
    
    \item \textbf{Softmax:} For each row $i$, compute
    \[
    A_{ij} = \frac{\exp(S_{ij})}{\sum_{k=1}^{N_B} \exp(S_{ik})}.
    \]
    
    \item \textbf{Updated Features:}
    \[
    Z_A = A V_B, \quad \text{or element-wise, } Z_A[i] = \sum_{j=1}^{N_B} A_{ij} \, V_B[j].
    \]
\end{enumerate}

The reverse direction (Graph B $\rightarrow$ Graph A) is computed analogously.

\subsection*{2. A Simple Numeric Example}

Let:
\[
X_A = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}, \quad
X_B = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \end{bmatrix}.
\]
Assume $d = 2$ and for simplicity, let $W_{Q_A}=W_{K_A}=W_{V_A}=I$ (the identity matrix). Then,
\[
Q_A = X_A,\quad K_B = X_B,\quad V_B = X_B.
\]

\paragraph{Step 1: Compute the Scaled Dot-Product.}

The scaling factor is $\sqrt{2} \approx 1.414$. Compute $S = Q_A K_B^T / \sqrt{2}$.

For node 1 in Graph A ($[1,2]$):
\[
\begin{aligned}
S_{11} &= \frac{1\cdot 1 + 2\cdot 0}{1.414} \approx 0.7071,\\[1mm]
S_{12} &= \frac{1\cdot 0 + 2\cdot 1}{1.414} \approx 1.4142,\\[1mm]
S_{13} &= \frac{1\cdot 1 + 2\cdot 1}{1.414} \approx 2.1213.
\end{aligned}
\]
Thus, the first row of $S$ is approximately $[0.7071,\; 1.4142,\; 2.1213]$.

For node 2 in Graph A ($[3,4]$):
\[
\begin{aligned}
S_{21} &= \frac{3\cdot 1 + 4\cdot 0}{1.414} \approx 2.1213,\\[1mm]
S_{22} &= \frac{3\cdot 0 + 4\cdot 1}{1.414} \approx 2.8284,\\[1mm]
S_{23} &= \frac{3\cdot 1 + 4\cdot 1}{1.414} \approx 4.9497.
\end{aligned}
\]
For node 3 in Graph A ($[5,6]$):
\[
\begin{aligned}
S_{31} &= \frac{5\cdot 1 + 6\cdot 0}{1.414} \approx 3.5355,\\[1mm]
S_{32} &= \frac{5\cdot 0 + 6\cdot 1}{1.414} \approx 4.2426,\\[1mm]
S_{33} &= \frac{5\cdot 1 + 6\cdot 1}{1.414} \approx 7.7782.
\end{aligned}
\]

\paragraph{Step 2: Apply Softmax to Each Row.}

For the first row:
\[
\exp(0.7071)\approx 2.028,\quad \exp(1.4142)\approx 4.113,\quad \exp(2.1213)\approx 8.338.
\]
Sum $\approx 14.479$. Thus, the softmax weights are approximately:
\[
[0.14,\; 0.284,\; 0.575].
\]

Similarly, you compute the softmax for rows 2 and 3 (approximations):
\[
A_2 \approx [0.05,\; 0.102,\; 0.848],\quad
A_3 \approx [0.0137,\; 0.0278,\; 0.9585].
\]

\paragraph{Step 3: Compute the Updated Features.}

Recall $V_B = X_B = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \end{bmatrix}$.

For node 1 in Graph A:
\[
Z_A[1] \approx 0.14\,[1,0] + 0.284\,[0,1] + 0.575\,[1,1] = [0.14+0+0.575,\; 0+0.284+0.575] \approx [0.715,\; 0.859].
\]

For node 2:
\[
Z_A[2] \approx 0.05\,[1,0] + 0.102\,[0,1] + 0.848\,[1,1] \approx [0.05+0+0.848,\; 0+0.102+0.848] \approx [0.898,\; 0.950].
\]

For node 3:
\[
Z_A[3] \approx 0.0137\,[1,0] + 0.0278\,[0,1] + 0.9585\,[1,1] \approx [0.0137+0+0.9585,\; 0+0.0278+0.9585] \approx [0.9722,\; 0.9863].
\]

Thus, the updated feature matrix for Graph A is approximately:
\[
Z_A \approx \begin{bmatrix}
0.715 & 0.859 \\
0.898 & 0.950 \\
0.972 & 0.986 
\end{bmatrix}.
\]

\subsection*{3. Summary}

The cross-attention mechanism allows Graph A to update its node features by computing attention over the nodes of Graph B. The key formulas are:
\[
\begin{aligned}
Q_A &= X_A W_{Q_A},\\[1mm]
K_B &= X_B W_{K_A},\\[1mm]
V_B &= X_B W_{V_A},\\[1mm]
S &= \frac{Q_A K_B^T}{\sqrt{d}},\\[1mm]
A_{ij} &= \frac{\exp(S_{ij})}{\sum_{k=1}^{N_B} \exp(S_{ik})},\\[1mm]
Z_A &= A\, V_B.
\end{aligned}
\]

A similar process is applied for Graph B $\rightarrow$ Graph A.

\end{document}
