In [1]:
import os
import json
import time

import numpy as np
import xarray as xr
from sklearn.decomposition import TruncatedSVD
from memory_profiler import memory_usage

In [2]:
def run_truncated_svd(nc_path: str,
                      var_name: str = 'msl',
                      k: int = 50,
                      algorithm: str = 'randomized',
                      n_iter: int = 5,
                      random_state: int = 0,
                      output_file: str = None,
                      **mem_kwargs) -> dict:
    """
    Unified Truncated SVD driver with single-pass profiling.

    Parameters
    ----------
    nc_path      : Path to NetCDF ('slp.nc' or 't2m.nc').
    var_name     : Variable name ('msl' or 't2m').
    k            : Number of components.
    algorithm    : 'randomized' or 'arpack'.
    n_iter       : Number of iterations (if randomized).
    random_state : Seed for reproducibility.
    output_file  : Optional JSONL path.
    **mem_kwargs : Extra args for memory_usage.

    Returns
    -------
    dict with keys:
      'method','dataset','shape','k','elapsed_time_s','peak_memory_MiB',
      'reconstruction_error','energy_captured','cond_full','cond_trunc'
    """
    # Load & reshape
    ds = xr.open_dataset(nc_path)
    data = ds[var_name].values       # (time, lat, lon)
    nt, ny, nx = data.shape
    A = data.reshape(nt, ny*nx).T    # (m, n)
    m, n = A.shape

    # Bound k
    k = min(k, m, n)

    # Define compute task
    def compute_task():
        model = TruncatedSVD(
            n_components=k,
            algorithm=algorithm,
            n_iter=n_iter,
            random_state=random_state)
        # Fit and get factors
        US = model.fit_transform(A)   # (m, k)
        S  = model.singular_values_   # (k,)
        VT = model.components_        # (k, n)
        return US, S, VT

    # Profile and capture US, S, VT in one pass
    t0 = time.time()
    peak_mem, (US, S, VT) = memory_usage(
        (compute_task, (), {}),
        retval=True,
        max_usage=True,
        **mem_kwargs
    )
    elapsed = time.time() - t0

    # Truncate to K (US already truncated)
    US_k = US
    S_k  = S
    VT_k = VT

    # Reconstruction error
    A_rec = US_k @ VT_k
    recon_err = np.linalg.norm(A - A_rec, ord='fro')

    # Energy captured
    total_energy = np.linalg.norm(A, ord='fro')**2
    energy = float(np.sum(S_k**2) / total_energy)

    # Condition numbers
    # full SVD would have S_full[-1] possibly < S_k[-1], but we compare within truncation
    cond_full  = float(S_k[0] / S_k[-1])
    cond_trunc = cond_full

    # Package results
    results = {
        'method': 'Truncated SVD',
        'dataset': os.path.basename(nc_path),
        'shape': (m, n),
        'k': k,
        'elapsed_time_s': float(elapsed),
        'peak_memory_MiB': float(peak_mem),
        'reconstruction_error': float(recon_err),
        'energy_captured': float(energy),
        'cond_full': 'none',
        'cond_trunc': cond_trunc
    }

    # Print unified report
    print(f"=== {results['method']} on {results['dataset']} (m={m}, n={n}, k={k}) ===")
    for key, val in results.items():
        if key not in ('method', 'dataset', 'shape'):
            print(f"{key.replace('_',' ').capitalize():<20}: {val}")

    # Append to JSONL
    if output_file:
        os.makedirs(os.path.dirname(output_file) or '.', exist_ok=True)
        safe = {kk: (vv.item() if hasattr(vv,'item') else vv)
                for kk, vv in results.items()}
        with open(output_file, 'a') as f:
            f.write(json.dumps(safe) + '\n')

    return results

In [3]:
res_trunc = run_truncated_svd('slp.nc', var_name='msl', k=10, algorithm='randomized', output_file='svd_results/truncated_svd_k10.jsonl', multiprocess=True)

=== Truncated SVD on slp.nc (m=16261, n=16071, k=10) ===
K                   : 10
Elapsed time s      : 3.3546559810638428
Peak memory mib     : 2241.7421875
Reconstruction error: 2068758.25
Energy captured     : 1.089472770690918
Cond full           : none
Cond trunc          : 2293.0693359375
