## CP_APR & CP_ALS Profiling

In [None]:
import cProfile
import glob
import os
import pstats

from pyttb import cp_als, cp_apr, import_data, ktensor

In [None]:
def profile(test_files, ranks, algorithm):
    """
    Profiles the performance of the cp_apr or cp_als algorithm with a set of tensors from test_files and ranks.

    Parameters
    ----------
    test_files:
        A list of strings representing the file paths to the test tensors.
    ranks:
        A list of integers representing the tensor testing ranks.
    algorithm:
        The algorithm to profile. Should be either 'cp_apr' or 'cp_als'.
    """

    # handle the input function.
    alg_func = cp_apr if algorithm == "cp_apr" else cp_als

    for test_file in test_files:
        try:
            print("*" * 50)
            tensor = import_data(test_file)
            # cp_als() and cp_apr() only handles tensors and sptensors, so convert ktensor to tensor
            if isinstance(tensor, ktensor):
                tensor = tensor.full()
            for rank in ranks:
                # initialize and enable a profiler
                profiler = cProfile.Profile()
                profiler.enable()

                try:
                    M, Minit, output = alg_func(tensor, rank)
                # ensure the profiler is always disabled before the next profiler starts
                finally:
                    profiler.disable()

                # sort the statistics based on cumulative time spent on funcs and sub-funcs
                stats = pstats.Stats(profiler).sort_stats("cumulative")
                print(f"Test file: {test_file}, Rank: {rank}, Algorithm: {algorithm}")
                stats.print_stats(10)
        except Exception as e:
            print(
                f"Error when testing {os.path.basename(test_file)} with Rank = {rank} and Algorithm = {algorithm}: {type(e).__name__}: {e}"
            )

In [None]:
ranks = [2, 3, 4]
test_files = glob.glob("data/*.tns")

In [None]:
# takes ~5m30s for cp_apr and ~9s for cp_als
profile(test_files, ranks, 'cp_apr')
# profile(test_files, ranks, "cp_als")