# SciTeX Linalg Tutorial

This notebook demonstrates how to use the `scitex.linalg` module for linear algebra operations.

## Features Covered

* Distance calculations (Euclidean, cdist)
* Geometric median computation
* Cosine similarity
* Vector operations with NaN handling
* Coordinate transformations
* Robust linear algebra utilities

## 1. Basic Setup

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from scitex import linalg as stx_linalg

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

print("SciTeX Linalg Tutorial")
print("Available functions:", dir(stx_linalg))

## 2. Distance Calculations

### Euclidean Distance

In [None]:
# Create sample vectors
u = np.array([1, 2, 3])
v = np.array([4, 5, 6])

# Compute Euclidean distance
dist = stx_linalg.euclidean_distance(u, v)
print(f"Euclidean distance between {u} and {v}: {dist:.3f}")

# Alternative using edist (alias)
dist_alias = stx_linalg.edist(u, v)
print(f"Using edist alias: {dist_alias:.3f}")

### Multi-dimensional Distance Calculations

In [None]:
# Create 2D arrays (multiple vectors)
points_a = np.random.randn(3, 5)  # 3 points in 5D space
points_b = np.random.randn(3, 5)  # 3 points in 5D space

print("Points A shape:", points_a.shape)
print("Points B shape:", points_b.shape)

# Compute distances along different axes
dist_axis0 = stx_linalg.euclidean_distance(points_a, points_b, axis=0)
print(f"\nDistances along axis 0: {dist_axis0.shape}")
print(f"Values: {dist_axis0}")

dist_axis1 = stx_linalg.euclidean_distance(points_a, points_b, axis=1)
print(f"\nDistances along axis 1: {dist_axis1.shape}")
print(f"Values: {dist_axis1}")

### Pairwise Distance Matrix

In [None]:
# Create sample data points
data_points = np.random.randn(5, 3)  # 5 points in 3D space
print("Data points shape:", data_points.shape)
print("Data points:")
print(data_points)

# Compute pairwise distance matrix using cdist
distance_matrix = stx_linalg.cdist(data_points, data_points)
print(f"\nDistance matrix shape: {distance_matrix.shape}")
print("Distance matrix:")
print(distance_matrix)

# Visualize distance matrix
plt.figure(figsize=(8, 6))
plt.imshow(distance_matrix, cmap='viridis')
plt.colorbar(label='Distance')
plt.title('Pairwise Distance Matrix')
plt.xlabel('Point Index')
plt.ylabel('Point Index')
plt.show()

## 3. Geometric Median

The geometric median is a robust estimator that minimizes the sum of distances to all points.

In [None]:
# Create sample data with outliers
points = torch.tensor([
    [1.0, 1.0],
    [2.0, 1.0],
    [1.0, 2.0],
    [2.0, 2.0],
    [10.0, 10.0]  # Outlier point
])

print("Original points:")
print(points)

# Compute geometric median
geom_median = stx_linalg.geometric_median(points, dim=0)
print(f"\nGeometric median: {geom_median}")

# Compare with arithmetic mean (less robust to outliers)
arith_mean = torch.mean(points, dim=0)
print(f"Arithmetic mean: {arith_mean}")

# Visualize
plt.figure(figsize=(8, 6))
plt.scatter(points[:, 0], points[:, 1], c='blue', s=100, alpha=0.7, label='Data points')
plt.scatter(geom_median[0], geom_median[1], c='red', s=200, marker='x', 
           linewidths=3, label='Geometric median')
plt.scatter(arith_mean[0], arith_mean[1], c='green', s=200, marker='+', 
           linewidths=3, label='Arithmetic mean')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.title('Geometric Median vs Arithmetic Mean')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print("\nNote: Geometric median is more robust to outliers than arithmetic mean")

## 4. Cosine Similarity and Vector Operations

In [None]:
# Create sample vectors
v1 = np.array([1, 2, 3])
v2 = np.array([2, 4, 6])  # Parallel to v1
v3 = np.array([1, 0, 0])  # Orthogonal component

print("Vector 1:", v1)
print("Vector 2:", v2)
print("Vector 3:", v3)

# Compute cosine similarities
cos_sim_12 = stx_linalg.cosine(v1, v2)
cos_sim_13 = stx_linalg.cosine(v1, v3)
cos_sim_23 = stx_linalg.cosine(v2, v3)

print(f"\nCosine similarity between v1 and v2: {cos_sim_12:.3f}")
print(f"Cosine similarity between v1 and v3: {cos_sim_13:.3f}")
print(f"Cosine similarity between v2 and v3: {cos_sim_23:.3f}")

# Compute norms with NaN handling
norm_v1 = stx_linalg.nannorm(v1)
norm_v2 = stx_linalg.nannorm(v2)
norm_v3 = stx_linalg.nannorm(v3)

print(f"\nNorm of v1: {norm_v1:.3f}")
print(f"Norm of v2: {norm_v2:.3f}")
print(f"Norm of v3: {norm_v3:.3f}")

### Handling NaN Values

In [None]:
# Create vectors with NaN values
v_nan = np.array([1, np.nan, 3])
v_clean = np.array([1, 2, 3])

print("Vector with NaN:", v_nan)
print("Clean vector:", v_clean)

# Test NaN handling
cos_sim_nan = stx_linalg.cosine(v_nan, v_clean)
norm_nan = stx_linalg.nannorm(v_nan)

print(f"\nCosine similarity with NaN: {cos_sim_nan}")
print(f"Norm with NaN: {norm_nan}")

print("\nNote: Functions gracefully handle NaN values by returning NaN")

## 5. Vector Rebasing

In [None]:
# Create base vector and target vector
v_base = np.array([10, 0])  # Base vector along x-axis
v_target = np.array([3, 4])  # Target vector

print("Base vector:", v_base)
print("Target vector:", v_target)

# Rebase the target vector onto the base vector
rebased = stx_linalg.rebase_a_vec(v_target, v_base)
print(f"\nRebased projection: {rebased:.3f}")

# Visualize the rebasing
plt.figure(figsize=(8, 6))

# Plot vectors
plt.arrow(0, 0, v_base[0], v_base[1], head_width=0.3, head_length=0.5, 
          fc='blue', ec='blue', label='Base vector')
plt.arrow(0, 0, v_target[0], v_target[1], head_width=0.3, head_length=0.5, 
          fc='red', ec='red', label='Target vector')

# Plot projection
proj_x = rebased * v_base[0] / np.linalg.norm(v_base)
proj_y = rebased * v_base[1] / np.linalg.norm(v_base)
plt.arrow(0, 0, proj_x, proj_y, head_width=0.3, head_length=0.5, 
          fc='green', ec='green', linestyle='--', label='Projection')

plt.xlim(-1, 11)
plt.ylim(-1, 5)
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Vector Rebasing (Projection)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.show()

print(f"\nProjection magnitude: {rebased:.3f}")
print(f"Original target magnitude: {np.linalg.norm(v_target):.3f}")

## 6. Coordinate Transformations

Convert three line lengths to 3D coordinates using triangle geometry.

In [None]:
# Define triangle side lengths
a = 2.0          # Length OA
b = np.sqrt(3)   # Length OB  
c = 1.0          # Length AB

print(f"Triangle side lengths: a={a}, b={b:.3f}, c={c}")

# Convert to coordinates
O, A, B = stx_linalg.three_line_lengths_to_coords(a, b, c)

print(f"\nCoordinates:")
print(f"O (origin): {O}")
print(f"A: {A}")
print(f"B: {B}")

# Convert to numpy arrays for easier manipulation
O_np = np.array(O)
A_np = np.array(A)
B_np = np.array(B)

# Verify the distances
dist_OA = np.linalg.norm(A_np - O_np)
dist_OB = np.linalg.norm(B_np - O_np)
dist_AB = np.linalg.norm(B_np - A_np)

print(f"\nVerification:")
print(f"Computed OA distance: {dist_OA:.3f} (expected: {a})")
print(f"Computed OB distance: {dist_OB:.3f} (expected: {b:.3f})")
print(f"Computed AB distance: {dist_AB:.3f} (expected: {c})")

# Visualize the triangle
plt.figure(figsize=(8, 6))

# Plot points
plt.scatter([O[0], A[0], B[0]], [O[1], A[1], B[1]], 
           c=['red', 'blue', 'green'], s=100, zorder=5)

# Plot triangle edges
triangle_x = [O[0], A[0], B[0], O[0]]
triangle_y = [O[1], A[1], B[1], O[1]]
plt.plot(triangle_x, triangle_y, 'k-', linewidth=2, alpha=0.7)

# Label points
plt.text(O[0]-0.1, O[1]-0.1, 'O', fontsize=12, fontweight='bold')
plt.text(A[0]+0.1, A[1]-0.1, 'A', fontsize=12, fontweight='bold')
plt.text(B[0]+0.1, B[1]+0.1, 'B', fontsize=12, fontweight='bold')

# Add distance labels
plt.text(1, -0.2, f'a={a}', fontsize=10, ha='center')
plt.text(-0.2, 0.5, f'b={b:.2f}', fontsize=10, ha='center', rotation=90)
plt.text(1.2, 0.5, f'c={c}', fontsize=10, ha='center', rotation=45)

plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.title('Triangle from Side Lengths')
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.show()

## 7. Practical Applications

### Clustering Analysis with Robust Distance Metrics

In [None]:
# Generate sample data with clusters
np.random.seed(42)
cluster1 = np.random.randn(20, 2) + [2, 2]
cluster2 = np.random.randn(20, 2) + [6, 6]
outliers = np.array([[0, 8], [8, 0], [10, 10]])  # Add some outliers

all_points = np.vstack([cluster1, cluster2, outliers])
print(f"Total data points: {all_points.shape[0]}")

# Compute pairwise distances
distance_matrix = stx_linalg.cdist(all_points, all_points)

# Find nearest neighbors for each point
nearest_neighbors = []
for i in range(len(all_points)):
    distances = distance_matrix[i]
    # Exclude self (distance = 0)
    other_distances = distances[distances > 0]
    nearest_dist = np.min(other_distances)
    nearest_neighbors.append(nearest_dist)

# Convert to PyTorch for geometric median
points_tensor = torch.tensor(all_points, dtype=torch.float32)
geom_median = stx_linalg.geometric_median(points_tensor, dim=0)
arith_mean = torch.mean(points_tensor, dim=0)

# Visualize
plt.figure(figsize=(12, 5))

# Plot 1: Data points with centroids
plt.subplot(1, 2, 1)
plt.scatter(cluster1[:, 0], cluster1[:, 1], c='blue', alpha=0.6, label='Cluster 1')
plt.scatter(cluster2[:, 0], cluster2[:, 1], c='red', alpha=0.6, label='Cluster 2')
plt.scatter(outliers[:, 0], outliers[:, 1], c='black', s=100, marker='x', 
           linewidths=3, label='Outliers')
plt.scatter(geom_median[0], geom_median[1], c='green', s=200, marker='s', 
           label='Geometric median')
plt.scatter(arith_mean[0], arith_mean[1], c='orange', s=200, marker='^', 
           label='Arithmetic mean')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.title('Data Clusters with Centroids')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 2: Distance matrix heatmap
plt.subplot(1, 2, 2)
plt.imshow(distance_matrix, cmap='viridis')
plt.colorbar(label='Distance')
plt.title('Pairwise Distance Matrix')
plt.xlabel('Point Index')
plt.ylabel('Point Index')

plt.tight_layout()
plt.show()

print(f"\nGeometric median: [{geom_median[0]:.2f}, {geom_median[1]:.2f}]")
print(f"Arithmetic mean: [{arith_mean[0]:.2f}, {arith_mean[1]:.2f}]")
print(f"Average nearest neighbor distance: {np.mean(nearest_neighbors):.2f}")

### Signal Similarity Analysis

In [None]:
# Generate sample signals
t = np.linspace(0, 2*np.pi, 100)
signal1 = np.sin(t)
signal2 = np.sin(t + np.pi/4)  # Phase shifted
signal3 = np.cos(t)  # Orthogonal
signal4 = np.sin(2*t)  # Different frequency

signals = [signal1, signal2, signal3, signal4]
signal_names = ['sin(t)', 'sin(t+π/4)', 'cos(t)', 'sin(2t)']

# Compute similarity matrix
n_signals = len(signals)
similarity_matrix = np.zeros((n_signals, n_signals))

for i in range(n_signals):
    for j in range(n_signals):
        similarity_matrix[i, j] = stx_linalg.cosine(signals[i], signals[j])

# Visualize
plt.figure(figsize=(15, 5))

# Plot 1: Original signals
plt.subplot(1, 3, 1)
for i, (signal, name) in enumerate(zip(signals, signal_names)):
    plt.plot(t, signal, label=name, linewidth=2)
plt.xlabel('Time')
plt.ylabel('Amplitude')
plt.title('Original Signals')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 2: Similarity matrix
plt.subplot(1, 3, 2)
im = plt.imshow(similarity_matrix, cmap='RdBu_r', vmin=-1, vmax=1)
plt.colorbar(im, label='Cosine Similarity')
plt.title('Signal Similarity Matrix')
plt.xticks(range(n_signals), signal_names, rotation=45)
plt.yticks(range(n_signals), signal_names)

# Add text annotations
for i in range(n_signals):
    for j in range(n_signals):
        plt.text(j, i, f'{similarity_matrix[i, j]:.2f}', 
                ha='center', va='center', fontsize=10)

# Plot 3: Distance relationships
plt.subplot(1, 3, 3)
signal_array = np.array(signals)
distance_matrix = stx_linalg.cdist(signal_array, signal_array)
plt.imshow(distance_matrix, cmap='viridis')
plt.colorbar(label='Euclidean Distance')
plt.title('Signal Distance Matrix')
plt.xticks(range(n_signals), signal_names, rotation=45)
plt.yticks(range(n_signals), signal_names)

plt.tight_layout()
plt.show()

print("\nSignal Analysis Summary:")
print(f"Most similar signals: {signal_names[0]} and {signal_names[2]} (cosine similarity: {similarity_matrix[0, 2]:.3f})")
print(f"Most different signals: {signal_names[0]} and {signal_names[3]} (cosine similarity: {similarity_matrix[0, 3]:.3f})")

## 8. Performance Comparison

Compare SciTeX implementations with standard library functions.

In [None]:
import time
from scipy.spatial.distance import cdist as scipy_cdist
from sklearn.metrics.pairwise import cosine_similarity

# Generate test data
n_points = 1000
n_dims = 50
test_data = np.random.randn(n_points, n_dims)

print(f"Performance test with {n_points} points in {n_dims}D space")
print("=" * 50)

# Test distance computation
print("\n1. Distance Matrix Computation:")

# SciTeX implementation
start_time = time.time()
stx_distances = stx_linalg.cdist(test_data, test_data)
stx_time = time.time() - start_time

# SciPy implementation
start_time = time.time()
scipy_distances = scipy_cdist(test_data, test_data)
scipy_time = time.time() - start_time

print(f"SciTeX cdist: {stx_time:.4f} seconds")
print(f"SciPy cdist: {scipy_time:.4f} seconds")
print(f"Difference in results: {np.max(np.abs(stx_distances - scipy_distances)):.2e}")

# Test cosine similarity
print("\n2. Cosine Similarity:")

# Select two random vectors for comparison
vec1 = test_data[0]
vec2 = test_data[1]

# SciTeX implementation
start_time = time.time()
for _ in range(1000):  # Multiple iterations for timing
    stx_cosine = stx_linalg.cosine(vec1, vec2)
stx_cosine_time = time.time() - start_time

# Sklearn implementation
start_time = time.time()
for _ in range(1000):
    sklearn_cosine = cosine_similarity([vec1], [vec2])[0, 0]
sklearn_cosine_time = time.time() - start_time

print(f"SciTeX cosine (1000 iterations): {stx_cosine_time:.4f} seconds")
print(f"Sklearn cosine (1000 iterations): {sklearn_cosine_time:.4f} seconds")
print(f"SciTeX result: {stx_cosine:.6f}")
print(f"Sklearn result: {sklearn_cosine:.6f}")
print(f"Difference: {abs(stx_cosine - sklearn_cosine):.2e}")

print("\n3. Robust Statistics:")
print("SciTeX provides additional robustness features:")
print("- Automatic NaN handling")
print("- Geometric median for outlier resistance")
print("- Vector rebasing utilities")
print("- Coordinate transformation tools")

## 9. Summary

The `scitex.linalg` module provides robust linear algebra utilities with the following key features:

### Core Functions
- **Distance calculations**: `euclidean_distance()`, `edist()`, `cdist()` 
- **Geometric median**: `geometric_median()` - robust center estimation
- **Similarity measures**: `cosine()` - cosine similarity with NaN handling
- **Robust norms**: `nannorm()` - norm computation with NaN handling
- **Vector operations**: `rebase_a_vec()` - vector projection and rebasing
- **Coordinate transforms**: `three_line_lengths_to_coords()` - triangle geometry

### Key Advantages
1. **Robustness**: Automatic NaN handling in all functions
2. **Flexibility**: Multi-dimensional operations with axis control
3. **Efficiency**: Optimized implementations compatible with NumPy/SciPy
4. **Outlier resistance**: Geometric median for robust statistics
5. **Specialized tools**: Coordinate transformations and vector rebasing

### Use Cases
- **Data analysis**: Clustering, similarity analysis, outlier detection
- **Signal processing**: Signal similarity, distance-based analysis
- **Geometry**: Coordinate transformations, triangle calculations
- **Machine learning**: Feature similarity, robust statistics
- **Scientific computing**: Robust numerical operations

The module is designed for scientific applications where robustness and reliability are critical.