In [1]:
import os
import json
import numpy as np
from numpy.linalg import svd
import xarray as xr
import time
from memory_profiler import memory_usage


In [2]:
# %%
import os
import json
import time
import numpy as np
import xarray as xr
from numpy.linalg import svd
from memory_profiler import memory_usage

# %%
def run_full_svd(path: str,
                 var: str = 'msl',
                 k: int = None,
                 output_file: str = None,
                 **mem_kwargs) -> dict:
    """
    Unified full SVD driver (full_matrices=True).

    Parameters
    ----------
    path         : Path to NetCDF file (e.g., 'slp.nc').
    var          : Variable name in dataset.
    k            : Number of retained modes. Defaults to full rank.
    output_file  : Optional JSONL file.
    **mem_kwargs : Extra args to memory_usage.

    Returns
    -------
    dict with benchmarking results.
    """
    # Load and reshape data
    ds = xr.open_dataset(path)
    data = ds[var].values             # (time, lat, lon)
    A = data.reshape(data.shape[0], -1).T  # (m, n)
    m, n = A.shape

    # Default to full rank
    if k is None:
        k = min(m, n)

    # Define SVD compute task
    def compute_task():
        return svd(A, full_matrices=True)

    # Profile time & memory
    t0 = time.time()
    peak_mem, (U, S, VT) = memory_usage(
        (compute_task, (), {}),
        retval=True,
        max_usage=True,
        **mem_kwargs)
    elapsed = time.time() - t0

    # Truncate
    U_k = U[:, :k]
    S_k = S[:k]
    VT_k = VT[:k, :]

    # Reconstruction error
    A_rec = (U_k * S_k) @ VT_k
    recon = np.linalg.norm(A - A_rec, ord='fro')

    # Energy
    energy = float(np.sum(S_k**2) / np.sum(S**2))

    # Condition numbers
    cond_full = float(S[0] / S[-1])
    cond_trunc = float(S_k[0] / S_k[-1])

    # Pack results
    results = {
        'method': 'Full SVD',
        'dataset': os.path.basename(path),
        'shape': (m, n),
        'k': k,
        'elapsed_time_s': float(elapsed),
        'peak_memory_MiB': float(peak_mem),
        'reconstruction_error': float(recon),
        'energy_captured': float(energy),
        'cond_full': cond_full,
        'cond_trunc': cond_trunc
    }

    # Print report
    print(f"=== {results['method']} on {results['dataset']} (m={m}, n={n}, k={k}) ===")
    for key, val in results.items():
        if key not in ('method', 'dataset', 'shape'):
            print(f"{key.replace('_',' ').capitalize():<22}: {val}")

    # Save
    if output_file:
        os.makedirs(os.path.dirname(output_file) or '.', exist_ok=True)
        safe = {kk: (vv.item() if hasattr(vv,'item') else vv) for kk,vv in results.items()}
        with open(output_file, 'a') as f:
            f.write(json.dumps(safe) + '\n')

    return results

In [3]:
res_slp = run_full_svd('slp.nc', var='msl', k=10, output_file='svd_results/full_svd_k10.jsonl')

=== Full SVD on slp.nc (m=16261, n=16071, k=10) ===
K                     : 10
Elapsed time s        : 1107.3296983242035
Peak memory mib       : 17780.96875
Reconstruction error  : 2068757.375
Energy captured       : 0.9999984502792358
Cond full             : 3680575488.0
Cond trunc            : 2293.06982421875
