# SciTeX Linear Algebra Utilities

This notebook demonstrates the linear algebra utilities provided by the `scitex.linalg` module, which offers efficient and NaN-aware implementations of common linear algebra operations for scientific computing.

## 1. Setup and Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scitex as stx

# Set up reproducible environment
stx.repro.fix_seeds(42)

# Configure visualization
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12

print(f"SciTeX version: {stx.__version__}")
print(f"NumPy version: {np.__version__}")

## 2. Distance Computations

In [None]:
# Generate sample data points
n_points = 100
n_dims = 3

# Create clusters in 3D space
cluster1 = np.random.randn(n_points // 2, n_dims) + [2, 2, 2]
cluster2 = np.random.randn(n_points // 2, n_dims) + [-2, -2, -2]
points = np.vstack([cluster1, cluster2])

print(f"Data shape: {points.shape}")
print(f"Data range: [{points.min():.2f}, {points.max():.2f}]")

In [None]:
# Euclidean distance between two points
point_a = points[0]
point_b = points[50]

# Using stx.linalg.euclidean_distance
dist = stx.linalg.euclidean_distance(point_a, point_b)
print(f"Distance between point A and B: {dist:.4f}")

# Alternative using edist (shorthand)
dist_edist = stx.linalg.edist(point_a, point_b)
print(f"Distance using edist: {dist_edist:.4f}")

# Visualize the points
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

ax.scatter(cluster1[:, 0], cluster1[:, 1], cluster1[:, 2], c='blue', label='Cluster 1', alpha=0.6)
ax.scatter(cluster2[:, 0], cluster2[:, 1], cluster2[:, 2], c='red', label='Cluster 2', alpha=0.6)
ax.scatter(*point_a, c='cyan', s=200, marker='*', label='Point A')
ax.scatter(*point_b, c='magenta', s=200, marker='*', label='Point B')

# Draw line between points
ax.plot([point_a[0], point_b[0]], 
        [point_a[1], point_b[1]], 
        [point_a[2], point_b[2]], 'k--', linewidth=2)

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.legend()
plt.title(f'Euclidean Distance: {dist:.4f}')
plt.show()

In [None]:
# Compute pairwise distances using cdist
# Select subset for visualization
subset_indices = np.random.choice(len(points), 20, replace=False)
subset_points = points[subset_indices]

# Compute distance matrix
dist_matrix = stx.linalg.cdist(subset_points, subset_points)

# Visualize distance matrix
plt.figure(figsize=(10, 8))
im = plt.imshow(dist_matrix, cmap='viridis', aspect='auto')
plt.colorbar(im, label='Distance')
plt.xlabel('Point Index')
plt.ylabel('Point Index')
plt.title('Pairwise Distance Matrix')

# Add text annotations for small matrix
if len(subset_points) <= 10:
    for i in range(len(subset_points)):
        for j in range(len(subset_points)):
            plt.text(j, i, f'{dist_matrix[i, j]:.2f}', 
                    ha='center', va='center', color='white')

plt.tight_layout()
plt.show()

print(f"Distance matrix shape: {dist_matrix.shape}")
print(f"Min distance (non-zero): {dist_matrix[dist_matrix > 0].min():.4f}")
print(f"Max distance: {dist_matrix.max():.4f}")

## 3. Geometric Median

In [None]:
# Geometric median is robust to outliers
# Create data with outliers
normal_points = np.random.randn(50, 2)
outliers = np.array([[10, 10], [-10, -10], [10, -10]])
all_points = np.vstack([normal_points, outliers])

# Compute mean and geometric median
mean_point = np.mean(all_points, axis=0)
geometric_median_point = stx.linalg.geometric_median(all_points)

# Visualize
plt.figure(figsize=(10, 8))
plt.scatter(normal_points[:, 0], normal_points[:, 1], 
           alpha=0.6, label='Normal points', s=50)
plt.scatter(outliers[:, 0], outliers[:, 1], 
           color='red', s=100, label='Outliers', marker='^')
plt.scatter(*mean_point, color='orange', s=200, 
           marker='s', label='Mean', edgecolor='black')
plt.scatter(*geometric_median_point, color='green', s=200, 
           marker='*', label='Geometric Median', edgecolor='black')

plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.title('Geometric Median vs Mean (Robustness to Outliers)')
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.tight_layout()
plt.show()

print(f"Mean position: {mean_point}")
print(f"Geometric median: {geometric_median_point}")
print(f"Distance between mean and geometric median: {np.linalg.norm(mean_point - geometric_median_point):.4f}")

## 4. Cosine Similarity

In [None]:
# Cosine similarity for measuring angular similarity
# Create vectors with different magnitudes but similar directions
vec1 = np.array([1, 2, 3])
vec2 = np.array([2, 4, 6])  # Same direction, different magnitude
vec3 = np.array([-1, -2, -3])  # Opposite direction
vec4 = np.array([3, -1, 2])  # Different direction

# Compute cosine similarities
cos_sim_12 = stx.linalg.cosine(vec1, vec2)
cos_sim_13 = stx.linalg.cosine(vec1, vec3)
cos_sim_14 = stx.linalg.cosine(vec1, vec4)

print("Cosine Similarities:")
print(f"vec1 vs vec2 (same direction): {cos_sim_12:.4f}")
print(f"vec1 vs vec3 (opposite direction): {cos_sim_13:.4f}")
print(f"vec1 vs vec4 (different direction): {cos_sim_14:.4f}")

# Visualize vectors in 3D
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Origin
origin = np.zeros(3)

# Plot vectors
vectors = [vec1, vec2, vec3, vec4]
colors = ['blue', 'green', 'red', 'orange']
labels = ['vec1', 'vec2 (2×vec1)', 'vec3 (-vec1)', 'vec4']

for vec, color, label in zip(vectors, colors, labels):
    ax.quiver(0, 0, 0, vec[0], vec[1], vec[2], 
             color=color, arrow_length_ratio=0.1, linewidth=3, label=label)

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.legend()
ax.set_title('Vector Directions and Cosine Similarity')
plt.show()

## 5. NaN-Aware Norm Computation

In [None]:
# Create data with NaN values
data_with_nan = np.array([
    [1.0, 2.0, 3.0],
    [4.0, np.nan, 6.0],
    [7.0, 8.0, 9.0],
    [np.nan, np.nan, 12.0]
])

print("Data with NaN values:")
print(data_with_nan)
print()

# Compare standard norm vs NaN-aware norm
for i, row in enumerate(data_with_nan):
    # Standard norm (will return NaN if any element is NaN)
    standard_norm = np.linalg.norm(row)
    
    # NaN-aware norm (ignores NaN values)
    nan_aware_norm = stx.linalg.nannorm(row)
    
    print(f"Row {i}: {row}")
    print(f"  Standard norm: {standard_norm:.4f}")
    print(f"  NaN-aware norm: {nan_aware_norm:.4f}")
    print()

## 6. Vector Rebasing

In [None]:
# Rebase vector to new minimum value
# Useful for shifting data ranges
original_vec = np.array([5, 10, 15, 20, 25])
new_base = 100

rebased_vec = stx.linalg.rebase_a_vec(original_vec, new_base)

print(f"Original vector: {original_vec}")
print(f"Original min: {original_vec.min()}")
print(f"Rebased vector: {rebased_vec}")
print(f"New min: {rebased_vec.min()}")
print(f"Shift applied: {rebased_vec[0] - original_vec[0]}")

# Visualize rebasing
plt.figure(figsize=(10, 6))
x = np.arange(len(original_vec))
plt.plot(x, original_vec, 'bo-', label='Original', markersize=10, linewidth=2)
plt.plot(x, rebased_vec, 'ro-', label=f'Rebased (base={new_base})', markersize=10, linewidth=2)
plt.axhline(y=original_vec.min(), color='blue', linestyle='--', alpha=0.5, label=f'Original min: {original_vec.min()}')
plt.axhline(y=new_base, color='red', linestyle='--', alpha=0.5, label=f'New base: {new_base}')
plt.xlabel('Index')
plt.ylabel('Value')
plt.title('Vector Rebasing')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 7. Triangle Coordinate Computation

In [None]:
# Convert three line lengths to triangle coordinates
# Useful for geometric computations
# Example: Triangle with sides 3, 4, 5 (right triangle)
side_lengths = [3, 4, 5]

coords = stx.linalg.three_line_lengths_to_coords(side_lengths)

print(f"Side lengths: {side_lengths}")
print(f"\nComputed coordinates:")
for i, coord in enumerate(coords):
    print(f"  Point {i}: {coord}")

# Verify the distances
print(f"\nVerification:")
dist_01 = np.linalg.norm(coords[0] - coords[1])
dist_12 = np.linalg.norm(coords[1] - coords[2])
dist_02 = np.linalg.norm(coords[0] - coords[2])
print(f"  Distance 0-1: {dist_01:.4f} (expected: {side_lengths[0]})")
print(f"  Distance 1-2: {dist_12:.4f} (expected: {side_lengths[1]})")
print(f"  Distance 0-2: {dist_02:.4f} (expected: {side_lengths[2]})")

# Visualize triangle
plt.figure(figsize=(8, 8))
triangle = plt.Polygon(coords, fill=False, edgecolor='blue', linewidth=2)
plt.gca().add_patch(triangle)

# Plot points
for i, coord in enumerate(coords):
    plt.plot(coord[0], coord[1], 'ro', markersize=10)
    plt.text(coord[0] + 0.1, coord[1] + 0.1, f'P{i}', fontsize=12)

# Add side length labels
mid_01 = (coords[0] + coords[1]) / 2
mid_12 = (coords[1] + coords[2]) / 2
mid_02 = (coords[0] + coords[2]) / 2

plt.text(mid_01[0], mid_01[1] - 0.2, f'{side_lengths[0]}', fontsize=10, ha='center')
plt.text(mid_12[0] + 0.2, mid_12[1], f'{side_lengths[1]}', fontsize=10)
plt.text(mid_02[0], mid_02[1] + 0.2, f'{side_lengths[2]}', fontsize=10, ha='center')

plt.axis('equal')
plt.grid(True, alpha=0.3)
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Triangle from Side Lengths (3-4-5 Right Triangle)')
plt.tight_layout()
plt.show()

## 8. Advanced Distance Computations with Missing Data

In [None]:
# Handling missing data in distance computations
# Create dataset with missing values
n_samples = 50
n_features = 5

# Generate complete data
complete_data = np.random.randn(n_samples, n_features)

# Introduce missing values randomly
missing_mask = np.random.random((n_samples, n_features)) < 0.2
data_with_missing = complete_data.copy()
data_with_missing[missing_mask] = np.nan

print(f"Data shape: {data_with_missing.shape}")
print(f"Missing values: {np.isnan(data_with_missing).sum()} ({np.isnan(data_with_missing).mean()*100:.1f}%)")

# Select two samples with missing values
sample1 = data_with_missing[0]
sample2 = data_with_missing[1]

print(f"\nSample 1: {sample1}")
print(f"Sample 2: {sample2}")

# Compute distance ignoring NaN values
# Only use dimensions where both samples have values
valid_dims = ~(np.isnan(sample1) | np.isnan(sample2))
if valid_dims.any():
    distance = stx.linalg.euclidean_distance(sample1[valid_dims], sample2[valid_dims])
    print(f"\nDistance using {valid_dims.sum()} valid dimensions: {distance:.4f}")
else:
    print("\nNo valid dimensions for comparison")

## 9. Performance Comparison

In [None]:
# Compare performance of different distance computation methods
import time

# Generate larger dataset
n_points = 1000
n_dims = 100
data = np.random.randn(n_points, n_dims)

# Method 1: SciTeX cdist
start_time = time.time()
dist_matrix_stx = stx.linalg.cdist(data[:100], data[:100])
stx_time = time.time() - start_time

# Method 2: NumPy manual computation
start_time = time.time()
dist_matrix_np = np.zeros((100, 100))
for i in range(100):
    for j in range(100):
        dist_matrix_np[i, j] = np.linalg.norm(data[i] - data[j])
np_time = time.time() - start_time

print(f"SciTeX cdist time: {stx_time:.4f} seconds")
print(f"NumPy manual time: {np_time:.4f} seconds")
print(f"Speedup: {np_time/stx_time:.2f}x")
print(f"\nResults match: {np.allclose(dist_matrix_stx, dist_matrix_np)}")

## 10. Integration Example: Clustering with Custom Distance

In [None]:
# Use linear algebra utilities for clustering analysis
# Generate synthetic dataset with 3 clusters
np.random.seed(42)
n_points_per_cluster = 50

# Create clusters with different characteristics
cluster1 = np.random.randn(n_points_per_cluster, 2) * 0.5 + [0, 0]
cluster2 = np.random.randn(n_points_per_cluster, 2) * 0.8 + [5, 5]
cluster3 = np.random.randn(n_points_per_cluster, 2) * 0.6 + [5, -5]

all_data = np.vstack([cluster1, cluster2, cluster3])
labels = np.array([0]*n_points_per_cluster + [1]*n_points_per_cluster + [2]*n_points_per_cluster)

# Compute cluster centers using geometric median (robust to outliers)
centers_median = []
centers_mean = []

for label in range(3):
    cluster_points = all_data[labels == label]
    centers_median.append(stx.linalg.geometric_median(cluster_points))
    centers_mean.append(np.mean(cluster_points, axis=0))

centers_median = np.array(centers_median)
centers_mean = np.array(centers_mean)

# Add some outliers
outlier_indices = np.random.choice(len(all_data), 5, replace=False)
all_data[outlier_indices] += np.random.randn(5, 2) * 5

# Visualize clusters and centers
plt.figure(figsize=(12, 8))

# Plot points
colors = ['blue', 'red', 'green']
for i in range(3):
    mask = labels == i
    plt.scatter(all_data[mask, 0], all_data[mask, 1], 
               c=colors[i], alpha=0.6, label=f'Cluster {i}', s=50)

# Mark outliers
plt.scatter(all_data[outlier_indices, 0], all_data[outlier_indices, 1],
           c='black', marker='x', s=200, linewidths=3, label='Outliers')

# Plot centers
plt.scatter(centers_mean[:, 0], centers_mean[:, 1], 
           c='orange', marker='s', s=200, edgecolor='black', linewidth=2, label='Mean Centers')
plt.scatter(centers_median[:, 0], centers_median[:, 1], 
           c='yellow', marker='*', s=300, edgecolor='black', linewidth=2, label='Geometric Median Centers')

# Compute and show inter-cluster distances
inter_cluster_dist = stx.linalg.cdist(centers_median, centers_median)

# Draw lines between cluster centers
for i in range(3):
    for j in range(i+1, 3):
        plt.plot([centers_median[i, 0], centers_median[j, 0]],
                [centers_median[i, 1], centers_median[j, 1]],
                'k--', alpha=0.3, linewidth=1)
        mid_x = (centers_median[i, 0] + centers_median[j, 0]) / 2
        mid_y = (centers_median[i, 1] + centers_median[j, 1]) / 2
        plt.text(mid_x, mid_y, f'{inter_cluster_dist[i, j]:.2f}', 
                fontsize=10, ha='center', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.xlabel('X')
plt.ylabel('Y')
plt.title('Clustering with Robust Center Estimation')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.tight_layout()
plt.show()

# Print statistics
print("Inter-cluster distances (using geometric median centers):")
print(inter_cluster_dist)
print(f"\nMean inter-cluster distance: {inter_cluster_dist[inter_cluster_dist > 0].mean():.4f}")

# Compare robustness
print("\nCenter differences (Mean vs Geometric Median):")
for i in range(3):
    diff = np.linalg.norm(centers_mean[i] - centers_median[i])
    print(f"  Cluster {i}: {diff:.4f}")

## Summary

The `scitex.linalg` module provides essential linear algebra utilities for scientific computing:

1. **Distance Computations**: Efficient Euclidean distance and pairwise distance matrix calculations
2. **Geometric Median**: Robust center estimation that is less sensitive to outliers than the mean
3. **Cosine Similarity**: Measure angular similarity between vectors regardless of magnitude
4. **NaN-Aware Operations**: Handle missing data gracefully in norm computations
5. **Vector Utilities**: Rebase vectors and convert geometric descriptions to coordinates

These utilities are particularly useful for:
- Clustering and classification tasks
- Robust statistical analysis
- Geometric computations in scientific applications
- Handling real-world data with missing values