import numpy as np
import time
from typing import TextIO, Optional
import pyttb as ttb

# Generate test data
dat_sparse = ttb.sptenrand(
    [200, 200, 200], 0.01
)  # Create a tesnor with 1% nonzeros using the 'density' param.
shape = dat_sparse.shape
dim = len(shape)
nz = dat_sparse.nnz
file_original = "test_original.sptensor.txt"
file_numpy = "test_numpy.sptensor.txt"

def export_sparse_array(fp: TextIO, A: ttb.sptensor, fmt_data: Optional[str] = "%.16e"):
    """Export sparse array data in coordinate format."""
    if not fmt_data:
        fmt_data = "%.16e"
    for i in range(A.nnz):
        # 0-based indexing in package, 1-based indexing in file
        subs = A.subs[i, :] + 1
        subs.tofile(fp, sep=" ", format="%d")
        print(end=" ", file=fp)
        val = A.vals[i][0]
        val.tofile(fp, sep=" ", format=fmt_data)
        print(file=fp)
        
def export_sparse_array_numpy(fp: TextIO, A, fmt_data: Optional[str] = "%.16e"):
    """Export sparse array data in coordinate format using NumPy."""
    # 0-based indexing in package, 1-based indexing in file
    subs = A.subs + 1 
    vals = A.vals[:, 0].reshape(-1, 1)
    data = np.hstack((subs, vals))
    np.savetxt(fp, data, fmt="%d " * subs.shape[1] + fmt_data)
        
# Function to write the header
def write_header(fp: TextIO, dim: int, shape: list, nnz: int):
    """Write the header lines to the file."""
    fp.write("sptensor\n")
    fp.write(f"{dim}\n")  # Dimension
    fp.write(" ".join(map(str, shape)) + "\n")  # Shape without parentheses
    fp.write(f"{nnz}\n")  # Number of nonzero values

# Compare the two approaches
fp_original = open(file_original, "w")  # File for original approach
fp_numpy = open(file_numpy, "w")  # File for NumPy approach

# Write the header to both files
write_header(fp_original, dim, shape, nz)
write_header(fp_numpy, dim, shape, nz)

# Measure time for original approach
start_time = time.time()
export_sparse_array(fp_original, dat_sparse)
original_time = time.time() - start_time
fp_original.close()

# Measure time for NumPy approach
start_time = time.time()
export_sparse_array_numpy(fp_numpy, dat_sparse)
numpy_time = time.time() - start_time
fp_numpy.close()

# Compare results and timings
dat_original = ttb.import_data(file_original)
print("Original approach:")
print(f"Time taken: {original_time:.6f} seconds")
print(f"First few subs:\n{dat_original.subs[:5]}")
print(f"First few vals:\n{dat_original.vals[:5]}")

dat_numpy = ttb.import_data(file_numpy)
print("\nNumPy approach:")
print(f"Time taken: {numpy_time:.6f} seconds")
print(f"First few subs:\n{dat_numpy.subs[:5]}")
print(f"First few vals:\n{dat_numpy.vals[:5]}")

# Verify that both approaches produce the same results
assert np.array_equal(dat_original.subs, dat_numpy.subs), "Indices do not match!"
assert np.array_equal(dat_original.vals, dat_numpy.vals), "Values do not match!"
print("\nBoth approaches produce identical results.")
