# Problem 3: Random Matrix Multiplication

## (A)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import math

In [None]:
class RandomMatrixMultiplication(): 
    """
    Implements Algorithm 2 dicussed in lecture to get a matrix M (mxp) from matrices A (mxn) and B (nxp).  
    """
    def __init__(self, A: np.ndarray, B: np.ndarray, r: int,seed: int = None):
        """
        Inputs:
        - A: (m x n) matrix
        - B: (n x p) matrix
        - r: number of samples (positive integer)
        - p: probability vector of length n, summing to 1
        """

        if seed is not None:
            np.random.seed(seed)

        m, n = A.shape
        n_B, p = B.shape

        # Make sure that the shape of the matrices match the required inputs and r postive integer 
        if n != n_B:
            raise ValueError(f"Wrong matrix dimension: A.shape = {A.shape}, B.shape = {B.shape}")
        if r <= 0 or not isinstance(r, int):
            raise ValueError("r must be a positive integer")
 
        self.A = A 
        self.B = B
        self.r = r 
        self.m = m 
        self.n = n 
        self.p = p
        self.probabilities = self.compute_optimal_probabilities()

    def compute_optimal_probabilities(self) -> np.ndarray:
        """
        We use non-uniform sampling explained on slide 19 of lectures
        """
        probs = np.array([
            np.linalg.norm(self.A[:,k]) * np.linalg.norm(self.B[k,:])
            for k in range(self.n) 
        ]) 

        probs_sum = probs.sum()        

        if np.isclose(probs_sum, 0):
            raise ValueError("probs must sum to 1")

        return probs / probs_sum


    def sample_indices(self) -> np.ndarray:
        """
        This method will sample r indices i1,...,ir which are elements of {0,...,n-1} iid with probablity p{il =k } = pk because python uses indexing. This is a minor deviation from the slides. 
        """
        return np.random.choice(self.n, size = self.r, replace = True, p = self.probabilities)

    def compute(self) -> np.ndarray:
        """
        actually do the approximation to get the matrix M
        """
        M = np.zeros((self.m, self.p)) # Intialize blank matrix
        indices = self.sample_indices()

        for i in indices: 
            A_col = self.A[:, i].reshape(self.m,1) # (m,1)
            B_row = self.B[i, :].reshape(1, self.p) # (1,p)
            M += (1 / self.probabilities[i]) * ( A_col @ B_row)

        M /= self.r 
        return M    

    

## (b)

In [None]:
A = pd.read_csv('/Users/shawheennaderi/Downloads/STA243_homework_1_matrix_A.csv', header= None)
B = pd.read_csv('/Users/shawheennaderi/Downloads/STA243_homework_1_matrix_B.csv', header= None) 


In [None]:
A.head() 

In [None]:
M_true = A @ B

In [None]:
M_true.head()

In [None]:
r_values = [20, 50, 100, 200]
approximations = {}

# Convert pandas to numpy 
A = np.array(A)
B = np.array(B)

for r in r_values:
    print(f"\nRunning r = {r}...")
    rmm = RandomMatrixMultiplication(A, B, r=r, seed=42)
    M_approx = rmm.compute()
    approximations[r] = M_approx


## (c)

In [None]:
norm_A = np.linalg.norm(A, ord='fro')
norm_B = np.linalg.norm(B, ord='fro')
denominator = norm_A * norm_B

results = []

for r in r_values:
    M_approx = approximations[r]
    numerator = np.linalg.norm(M_approx - M_true, ord='fro')
    relative_error = numerator / denominator
    results.append({"r": r, "Relative Error": relative_error})

In [None]:
df_results = pd.DataFrame(results)
print(df_results.to_string(index=False))

## (d)

In [None]:
plt.figure(figsize=(16, 18))

# Plot M_true first
plt.subplot(3, 2, 1)
sns.heatmap(M_true, cmap='viridis', cbar=False)
plt.title("True Matrix Product AB")
plt.xlabel("Columns")
plt.ylabel("Rows")

# Plot approximations for each r
for i, r in enumerate(r_values):
    plt.subplot(3, 2, i+2)  # start from slot 2
    sns.heatmap(approximations[r], cmap='viridis', cbar=False)
    plt.title(f"Approximate M (r = {r})")
    plt.xlabel("Columns")
    plt.ylabel("Rows")

plt.tight_layout()
plt.suptitle("Matrix Multiplication Approximation (with True AB)", fontsize=18, y=1.03)
plt.show()
