In [1]:
import os
import sys
import types

# --- MPI Dummy Patch for PyParSVD ---
mpi_mod = types.ModuleType("mpi4py.MPI")
class CommDummy:
    def Get_rank(self): return 0
    def Get_size(self): return 1
mpi_mod.COMM_WORLD = CommDummy()
sys.modules['mpi4py.MPI'] = mpi_mod

# Add local PyParSVD repo to path
# Adjust the relative path 'PyParSVD' if your folder has a different name or location
sys.path.insert(0, os.path.join(os.getcwd(), 'PyParSVD'))

import json
import time

import numpy as np
import xarray as xr
from pyparsvd.parsvd_serial import ParSVD_Serial
from pyparsvd.parsvd_parallel import ParSVD_Parallel
from memory_profiler import memory_usage  # local import from cloned repo
from memory_profiler import memory_usage

In [2]:
import os
import sys
import types

# --- MPI Dummy Patch for PyParSVD ---
mpi_mod = types.ModuleType("mpi4py.MPI")
class CommDummy:
    def Get_rank(self): return 0
    def Get_size(self): return 1
mpi_mod.COMM_WORLD = CommDummy()
sys.modules['mpi4py.MPI'] = mpi_mod

# Add local PyParSVD repo to path (adjust 'PyParSVD' if needed)
sys.path.insert(0, os.path.join(os.getcwd(), 'PyParSVD'))

import json
import time

import numpy as np
import xarray as xr
from pyparsvd.parsvd_serial import ParSVD_Serial
from pyparsvd.parsvd_parallel import ParSVD_Parallel
from memory_profiler import memory_usage

# %%

def run_pyparsvd(nc_path: str,
                 var_name: str = 'msl',
                 k: int = 200,
                 ff: float = 0.95,
                 batch_size: int = 100,
                 algorithm: str = 'serial',
                 output_file: str = None,
                 **mem_kwargs) -> dict:
    """
    Unified PyParSVD driver using local repo.

    Parameters
    ----------
    nc_path     : Path to NetCDF ('slp.nc' or 't2m.nc').
    var_name    : Variable name ('msl' or 't2m').
    k           : Number of singular modes.
    ff          : Forget factor for streaming.
    batch_size  : Number of columns per streaming batch.
    algorithm   : 'serial' or 'parallel'.
    output_file : Optional JSONL path to append results.
    **mem_kwargs: Extra args for memory_usage.

    Returns
    -------
    dict with keys:
      'method', 'dataset', 'shape', 'k', 'batch_size', 'elapsed_time_s',
      'peak_memory_MiB', 'reconstruction_error', 'energy_captured',
      'cond_full', 'cond_trunc'
    """
    # Load & reshape data: (grid_points, time)
    ds = xr.open_dataset(nc_path)
    data = ds[var_name].values  # (time, lat, lon)
    nt, ny, nx = data.shape
    A = data.reshape(nt, ny*nx).T  # (m, n)
    m, n = A.shape

    # Bound k and batch_size
    k = min(k, m, n)
    batch_size = min(batch_size, n)

    # Initialize model
    if algorithm.lower() == 'serial':
        model = ParSVD_Serial(K=k, ff=ff)
    else:
        model = ParSVD_Parallel(K=k, ff=ff)

    # Define compute task
    def compute_task():
        # Initialize with first batch
        model.initialize(A[:, :batch_size])
        # Stream remaining batches
        for start in range(batch_size, n, batch_size):
            end = min(start + batch_size, n)
            model.incorporate_data(A[:, start:end])
        # Retrieve modes and singular values
        modes = model._modes               # (m, k)
        sing_vals = model._singular_values  # (k,)
        # Approximate final Vᵀ from projection: Vᵀ ≈ diag(1/σ) · Uᵀ · A
        Vt = (modes.T @ A) / sing_vals[:, None]
        return modes, sing_vals, Vt

    # Profile runtime & memory once
    t0 = time.time()
    peak_mem, (modes, sing_vals, Vt) = memory_usage(
        (compute_task, (), {}),
        retval=True,
        max_usage=True,
        **mem_kwargs
    )
    elapsed = time.time() - t0

    # Reconstruction error
    A_rec = modes @ np.diag(sing_vals) @ Vt
    recon_err = np.linalg.norm(A - A_rec, ord='fro')

    # Energy captured
    total_energy = np.linalg.norm(A, ord='fro')**2
    energy = float(np.sum(sing_vals**2) / total_energy)

    # Condition numbers
    cond_full = None
    cond_trunc = float(sing_vals[0] / sing_vals[-1]) if k > 1 else np.nan

    # Package results
    results = {
        'method': f'PyParSVD ({algorithm})',
        'dataset': os.path.basename(nc_path),
        'shape': (m, n),
        'k': k,
        'batch_size': batch_size,
        'elapsed_time_s': float(elapsed),
        'peak_memory_MiB': float(peak_mem),
        'reconstruction_error': float(recon_err),
        'energy_captured': float(energy),
        'cond_full': cond_full,
        'cond_trunc': cond_trunc
    }

    # Print unified report
    print(f"=== {results['method']} on {results['dataset']}"
          f" (m={m}, n={n}, k={k}, batch_size={batch_size}) ===")
    for key, val in results.items():
        if key not in ('method', 'dataset', 'shape'):
            print(f"{key.replace('_',' ').capitalize():<22}: {val}")

    # Append to JSONL
    if output_file:
        os.makedirs(os.path.dirname(output_file) or '.', exist_ok=True)
        safe = {kk: (vv.item() if hasattr(vv, 'item') else vv)
                for kk, vv in results.items()}
        with open(output_file, 'a') as f:
            f.write(json.dumps(safe) + '\n')

    return results

In [3]:
res_serial = run_pyparsvd(
     'slp.nc', var_name='msl', k=10, ff=0.95,
     batch_size=100, algorithm='serial',
     output_file='svd_results/pyparsvd_k10_ff95.jsonl',
     multiprocess=True)


=== PyParSVD (serial) on slp.nc (m=16261, n=16071, k=10, batch_size=100) ===
K                     : 10
Batch size            : 100
Elapsed time s        : 13.557955980300903
Peak memory mib       : 1253.87890625
Reconstruction error  : 2098938.5
Energy captured       : 0.0675780177116394
Cond full             : None
Cond trunc            : 2414.977783203125
