In [2]:
import os
import json
import time

import numpy as np
import xarray as xr
from sklearn.utils.extmath import randomized_svd
from memory_profiler import memory_usage

In [None]:
def run_randomized_svd(nc_path: str,
                       var_name: str = 'msl',
                       k: int = 50,
                       n_iter: int = 5,
                       random_state: int = 0,
                       output_file: str = None,
                       **mem_kwargs) -> dict:
    """
    Unified Randomized SVD driver with single‑pass profiling.

    Parameters
    ----------
    nc_path      : Path to NetCDF ('slp.nc' or 't2m.nc').
    var_name     : Variable name ('msl' or 't2m').
    k            : Target number of modes.
    n_iter       : Number of power iterations.
    random_state : Seed for reproducibility.
    output_file  : Optional JSONL path.
    **mem_kwargs : Extra args for memory_usage.

    Returns
    -------
    dict with keys:
      'method','dataset','shape','k','elapsed_time_s','peak_memory_MiB',
      'reconstruction_error','energy_captured','cond_full','cond_trunc'
    """
    # Load & reshape data
    ds = xr.open_dataset(nc_path)
    data = ds[var_name].values       # (time, lat, lon)
    nt, ny, nx = data.shape
    A = data.reshape(nt, ny*nx).T    # (m, n)
    m, n = A.shape

    # Truncation rank k
    k = min(k, m, n)

    # Define compute task
    def compute_task():
        U, S, Vt = randomized_svd(
            A, n_components=k,
            n_iter=n_iter,
            random_state=random_state)
        return U, S, Vt

    # Profile and capture U,S,Vt in one pass
    t0 = time.time()
    peak_mem, (U, S, Vt) = memory_usage(
        (compute_task, (), {}),
        retval=True,
        max_usage=True,
        **mem_kwargs
    )
    elapsed = time.time() - t0

    # Truncate to K modes (already sized), but safe-slice
    U_k  = U[:, :k]
    S_k  = S[:k]
    Vt_k = Vt[:k, :]

    # Reconstruction error
    A_rec = (U_k * S_k) @ Vt_k
    recon_err = np.linalg.norm(A - A_rec, ord='fro')

    # Energy captured
    total_energy = np.sum(S**2)
    energy = float(np.sum(S_k**2) / total_energy)

    # Condition numbers
    cond_full  = float(S[0] / S[-1]) if n > 0 else np.nan
    cond_trunc = float(S[0] / S_k[-1])

    # Package results
    results = {
        'method': 'Randomized SVD',
        'dataset': os.path.basename(nc_path),
        'shape': (m, n),
        'k': k,
        'elapsed_time_s': float(elapsed),
        'peak_memory_MiB': float(peak_mem),
        'reconstruction_error': float(recon_err),
        'energy_captured': float(energy),
        'cond_full': cond_full,
        'cond_trunc': cond_trunc
    }

    # Print unified report
    print(f"=== {results['method']} on {results['dataset']} (m={m}, n={n}, k={k}) ===")
    for key, val in results.items():
        if key not in ('method', 'dataset', 'shape'):
            print(f"{key.replace('_',' ').capitalize():<20}: {val}")

    # Append to JSONL if requested
    if output_file:
        os.makedirs(os.path.dirname(output_file) or '.', exist_ok=True)
        safe = {kk: (vv.item() if hasattr(vv,'item') else vv)
                for kk, vv in results.items()}
        with open(output_file, 'a') as f:
            f.write(json.dumps(safe) + '\n')

    return results 

In [9]:
res_rand = run_randomized_svd('slp.nc', var_name='msl', k=10, n_iter=5, random_state=0, output_file='svd_results/randomized_svd_k10.jsonl', multiprocess=True)



=== Randomized SVD on slp.nc (m=16261, n=16071, k=10) ===
K                   : 10
Elapsed time s      : 2.073974847793579
Peak memory mib     : 1251.4140625
Reconstruction error: 2068759.25
Energy captured     : 1.0
Cond full           : 2293.0693359375
Cond trunc          : 2293.0693359375
