# Matrix Operations

This notebook demonstrates how visu-hlo visualizes matrix operations and linear algebra computations in JAX.

## Setup

First, let's import the necessary libraries:

In [None]:
import os

os.environ['JAX_PLATFORMS'] = 'cpu'

import jax
import jax.numpy as jnp
from visu_hlo import show

## Basic Matrix Multiplication

Let's start with a simple matrix multiplication operation:

In [None]:
@jax.jit
def matrix_multiply(A, B):
    """Simple matrix multiplication using jnp.dot."""
    return jnp.dot(A, B)


# Create sample matrices
A = jnp.ones((3, 4))
B = jnp.ones((4, 2))

print('Matrix shapes: A', A.shape, '× B', B.shape, '= result', (3, 2))
print('\nVisualization of matrix multiplication:')
show(matrix_multiply, A, B)

## Matrix-Vector Multiplication

A common operation in machine learning:

In [None]:
@jax.jit
def matrix_vector_multiply(W, x):
    """Matrix-vector multiplication: y = Wx"""
    return jnp.dot(W, x)


# Weight matrix and input vector
W = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
x = jnp.array([0.5, 1.0, 1.5])

print('Weight matrix shape:', W.shape)
print('Input vector shape:', x.shape)
print('\nVisualization:')
show(matrix_vector_multiply, W, x)

## Linear System Solving

Solving systems of linear equations Ax = b:

In [None]:
@jax.jit
def solve_linear_system(A, b):
    """Solve Ax = b using JAX's linear algebra solver."""
    return jnp.linalg.solve(A, b)


# Create an invertible matrix and target vector
A = jnp.array([[3.0, 1.0, 2.0], [1.0, 4.0, 1.0], [2.0, 1.0, 3.0]])
b = jnp.array([1.0, 2.0, 3.0])

print('System matrix A:')
print(A)
print('\nTarget vector b:', b)
print('\nVisualization of linear system solver:')
show(solve_linear_system, A, b)

## Eigenvalue Decomposition

Computing eigenvalues and eigenvectors of symmetric matrices:

In [None]:
@jax.jit
def compute_eigenvalues(matrix):
    """Compute eigenvalues of a symmetric matrix."""
    return jnp.linalg.eigvals(matrix)


# Symmetric matrix
sym_matrix = jnp.array([[4.0, 1.0, 2.0], [1.0, 3.0, 1.0], [2.0, 1.0, 5.0]])

print('Symmetric matrix:')
print(sym_matrix)
print('\nVisualization of eigenvalue computation:')
show(compute_eigenvalues, sym_matrix)

In [None]:
@jax.jit
def compute_eigenvectors(matrix):
    """Compute both eigenvalues and eigenvectors."""
    eigenvals, eigenvecs = jnp.linalg.eigh(matrix)
    return eigenvals, eigenvecs


print('Visualization of full eigendecomposition:')
show(compute_eigenvectors, sym_matrix)

## Matrix Decompositions

### QR Decomposition

In [None]:
@jax.jit
def qr_decomposition(matrix):
    """QR decomposition of a matrix."""
    Q, R = jnp.linalg.qr(matrix)
    return Q, R


# Rectangular matrix for QR decomposition
rect_matrix = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

print('Input matrix for QR decomposition:')
print(rect_matrix)
print('\nVisualization:')
show(qr_decomposition, rect_matrix)

### Singular Value Decomposition (SVD)

In [None]:
@jax.jit
def svd_decomposition(matrix):
    """Singular Value Decomposition."""
    U, s, Vt = jnp.linalg.svd(matrix, full_matrices=False)
    return U, s, Vt


print('Visualization of SVD:')
show(svd_decomposition, rect_matrix)

## Matrix Operations with Broadcasting

JAX's broadcasting capabilities in matrix operations:

In [None]:
@jax.jit
def broadcasted_operations(matrix, vector):
    """Matrix operations with broadcasting."""
    # Add vector to each row of matrix
    added = matrix + vector
    # Element-wise multiplication
    multiplied = matrix * vector
    return added, multiplied


matrix = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
vector = jnp.array([0.1, 0.2, 0.3])

print('Matrix shape:', matrix.shape)
print('Vector shape:', vector.shape)
print('\nVisualization of broadcasted operations:')
show(broadcasted_operations, matrix, vector)

## Advanced Linear Algebra

### Matrix Inverse

In [None]:
@jax.jit
def matrix_inverse(matrix):
    """Compute matrix inverse."""
    return jnp.linalg.inv(matrix)


# Well-conditioned matrix
well_conditioned = jnp.eye(3) + 0.1 * jnp.ones((3, 3))

print('Well-conditioned matrix:')
print(well_conditioned)
print('\nVisualization of matrix inversion:')
show(matrix_inverse, well_conditioned)

### Matrix Norms

In [None]:
@jax.jit
def matrix_norms(matrix):
    """Compute various matrix norms."""
    frobenius = jnp.linalg.norm(matrix, 'fro')
    spectral = jnp.linalg.norm(matrix, 2)
    return frobenius, spectral


test_matrix = jnp.array([[1.0, 2.0], [3.0, 4.0]])

print('Test matrix:')
print(test_matrix)
print('\nVisualization of matrix norm computation:')
show(matrix_norms, test_matrix)

## Batch Matrix Operations

Working with batches of matrices:

In [None]:
@jax.jit
def batch_matrix_multiply(batch_A, batch_B):
    """Multiply batches of matrices."""
    return jnp.matmul(batch_A, batch_B)


# Batch of 2x2 matrices
batch_A = jnp.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])
batch_B = jnp.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]])

print('Batch shape:', batch_A.shape)
print('\nVisualization of batch matrix multiplication:')
show(batch_matrix_multiply, batch_A, batch_B)

## Summary

This notebook demonstrated various matrix operations in JAX and how their computational graphs are visualized with visu-hlo:

- **Basic operations**: Matrix multiplication, matrix-vector products
- **Linear algebra**: System solving, eigendecomposition
- **Matrix decompositions**: QR, SVD
- **Advanced operations**: Matrix inverse, norms
- **Batch operations**: Working with multiple matrices simultaneously

Each visualization shows how JAX decomposes these high-level linear algebra operations into primitive HLO operations, providing insight into the computational structure and potential optimization opportunities.