import numpy as np
import time
from typing import TextIO, Tuple
import pyttb as ttb

# Generate test data
dat_sparse = ttb.sptenrand(
    [300, 300, 300], 0.01
)  # Create a tesnor with 1% nonzeros using the 'density' param.

# export data
file_sparse = "test.sptensor.txt"
ttb.export_data(dat_sparse, file_sparse)
del dat_sparse


# Define the original current line-by-line reading approach
def import_sparse_array(
    fp: TextIO, n: int, nz: int, index_base: int = 1
) -> Tuple[np.ndarray, np.ndarray]:
    """Extract sparse data subs and vals from coordinate format data."""
    subs = np.zeros((nz, n), dtype="int64")
    vals = np.zeros((nz, 1))
    for k in range(nz):
        line = fp.readline().strip().split(" ")
        subs[k, :] = [np.int64(i) - index_base for i in line[:-1]]
        vals[k, 0] = line[-1]
    return subs, vals

# define the new numpy-based vectorized approach
def import_sparse_array_numpy(
    fp: TextIO, index_base: int = 1
) -> Tuple[np.ndarray, np.ndarray]:
    """Extract sparse data subs and vals from coordinate format data."""
    data = np.loadtxt(fp)
    subs = data[:, :-1].astype("int64") - index_base
    vals = data[:, -1].reshape(-1, 1)
    return subs, vals

# Read the file and compare the two approaches
with open(file_sparse, "r") as fp:
    # Skip the header lines
    fp.readline()  # "sptensor"
    dim = int(fp.readline().strip())  # Dimension
    shape = list(map(int, fp.readline().strip().split(" ")))  # Shape
    nz = int(fp.readline().strip())  # Number of nonzero values
    index_base = 1  # index base

    # Measure time for line-by-line approach
    fp.seek(0)  # Reset file pointer
    for _ in range(4): fp.readline()  # Skip header lines
    start_time = time.time()
    subs1, vals1 = import_sparse_array(fp, dim, nz, index_base)
    line_by_line_time = time.time() - start_time
    
    # Measure time for numpy read appraoch
    fp.seek(0)  # Reset file pointer
    for _ in range(4): fp.readline()  # Skip header lines
    start_time = time.time()
    subs2, vals2 = import_sparse_array_numpy(fp, index_base)
    numpy_time = time.time() - start_time

# Compare results and timings
print("Line-by-line approach:")
print(f"Time taken: {line_by_line_time:.6f} seconds")
print(f"First few subs:\n{subs1[:5]}")
print(f"First few vals:\n{vals1[:5]}")

print("\nFaster vectorized approach:")
print(f"Time taken: {numpy_time:.6f} seconds")
print(f"First few subs:\n{subs2[:5]}")
print(f"First few vals:\n{vals2[:5]}")

# Verify that all approaches produce the same results
assert np.array_equal(subs1, subs2), "Subs arrays do not match!"
assert np.array_equal(vals1, vals2), "Vals arrays do not match!"
print("\nAll approaches produce identical results.")
