In [1]:
import os
import json
import time

import numpy as np
import xarray as xr
from scipy.linalg import qr, svd
from memory_profiler import memory_usage


In [2]:
def run_streaming_svd(nc_path: str,
                      var_name: str = 'msl',
                      k: int = None,
                      ff: float = 1.0,
                      batch_size: int = 100,
                      output_file: str = None,
                      **mem_kwargs) -> dict:
    """
    Unified Streaming SVD driver with single‑pass profiling and correct reconstruction.
    """
    # Load & reshape data
    ds = xr.open_dataset(nc_path)
    data = ds[var_name].values        # (time, lat, lon)
    nt, ny, nx = data.shape
    A = data.reshape(nt, ny*nx).T     # (m, n)
    m, n = A.shape

    # Determine k
    if k is None:
        k = min(m, n)
    else:
        k = min(k, m, n)

    # Streaming compute task: only builds left modes and singular values
    def compute_task():
        # Initial batch
        B0 = A[:, :batch_size]
        Q, R = qr(B0, mode='economic')
        U, S, _ = svd(R, full_matrices=False)
        modes = Q @ U[:, :k]
        sing  = S[:k]
        # Iterate over remaining batches
        for start in range(batch_size, n, batch_size):
            B = A[:, start:start+batch_size]
            weighted = ff * (modes @ np.diag(sing))
            concat   = np.concatenate((weighted, B), axis=1)
            Q, R    = qr(concat, mode='economic')
            U, S, _ = svd(R, full_matrices=False)
            idx      = np.argsort(S)[::-1][:k]
            sing     = S[idx]
            modes    = Q @ U[:, idx]
        return modes, sing

    # Profile and capture modes, singular values in one pass
    t0 = time.time()
    peak_mem, (modes, sing) = memory_usage(
        (compute_task, (), {}),
        retval=True,
        max_usage=True,
        **mem_kwargs
    )
    elapsed = time.time() - t0

    # Reconstruct via projection: A_rec = modes @ (modes^T @ A)
    proj = modes.T @ A                  # shape (k, n)
    A_rec = modes @ proj                # shape (m, n)
    recon_err = np.linalg.norm(A - A_rec, ord='fro')

    # Energy captured
    total_energy = np.sum((sing)**2)
    energy = float(np.linalg.norm(A_rec, ord='fro')**2 / np.linalg.norm(A, ord='fro')**2)


    # Condition numbers
    cond_full    = float(sing[0] / sing[-1])
    cond_trunc   = cond_full

    # Package results
    results = {
        'method': 'Streaming SVD',
        'dataset': os.path.basename(nc_path),
        'shape':    (m, n),
        'k':        k,
        'elapsed_time_s': float(elapsed),
        'peak_memory_MiB': float(peak_mem),
        'reconstruction_error': float(recon_err),
        'energy_captured':      float(energy),
        'cond_full':            cond_full,
        'cond_trunc':           cond_trunc
    }

    # Print unified report
    print(f"=== {results['method']} on {results['dataset']} (m={m}, n={n}, k={k}) ===")
    for key, val in results.items():
        if key not in ('method', 'dataset', 'shape'):
            print(f"{key.replace('_',' ').capitalize():<20}: {val}")

    # Append to JSONL
    if output_file:
        os.makedirs(os.path.dirname(output_file) or '.', exist_ok=True)
        safe = {kk: (vv.item() if hasattr(vv, 'item') else vv) for kk, vv in results.items()}
        with open(output_file, 'a') as f:
            f.write(json.dumps(safe) + '\n')

    return results

In [3]:
res_stream = run_streaming_svd('slp.nc', var_name='msl', k=10, ff=0.95, batch_size=200, output_file='svd_results/streaming_svd_k10_ff95.jsonl', multiprocess=True)

=== Streaming SVD on slp.nc (m=16261, n=16071, k=10) ===
K                   : 10
Elapsed time s      : 8.098215103149414
Peak memory mib     : 1240.21875
Reconstruction error: 2081826.875
Energy captured     : 1.0020482540130615
Cond full           : 2385.95849609375
Cond trunc          : 2385.95849609375
