import numpy as np
import pyedflib
import plotly.graph_objects as go
import pandas as pd
from gtda.time_series import SingleTakensEmbedding
from gtda.homology import VietorisRipsPersistence
from gtda.diagrams import PersistenceEntropy, Amplitude, NumberOfPoints, ComplexPolynomial, PersistenceLandscape, HeatKernel
from gtda.plotting import plot_point_cloud, plot_heatmap
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from gtda.pipeline import Pipeline # Could later be useful, not importable with current sklearn version



def read_edf_file(file_path):
    """
    Reads an .edf file and returns the EEG and EMG streams as pandas DataFrames.
    """
    f = pyedflib.EdfReader(file_path)

    # Assuming the EEG channel is the first channel and EMG is the second channel
    eeg_signal = f.readSignal(0)
    emg_signal = f.readSignal(1)

    # Extract the channel names for the DataFrame
    eeg_channel_name = f.getSignalLabels()[0]
    emg_channel_name = f.getSignalLabels()[1]

    # Get the sample frequency
    sample_frequency = f.getSampleFrequency(0)  # Assuming both streams have the same frequency

    # Calculate the timestamps for the samples
    n_samples = min(len(eeg_signal), len(emg_signal))
    time = [i / sample_frequency for i in range(n_samples)]

    # Create pandas DataFrame
    df = pd.DataFrame({
        'Time': time,
        eeg_channel_name: eeg_signal[:n_samples],
        emg_channel_name: emg_signal[:n_samples],
    })

    # Close the EdfReader
    f.close()

    return df

file = 'edf_293.edf'

data = read_edf_file(file)


x = data.Time
y = data.EEG

!pip freeze

# Labels
label_df = pd.read_csv("Data_293.csv")
labels = label_df["NAPS_Numeric"].iloc[1:]
labels = [int(label) for label in labels]

# Examining EEG data

# Extract the indices of the first 10 of each of the single labels occuring in the list labels 

indices_dict = {}  # A dictionary to store indices for each label

for label in list(set(labels)):
    indices = [index for index, value in enumerate(labels) if value == label][:30]
    indices_dict[label] = indices


def segment_data(df, segment_size, step_size = 2):
    """
    Segments the DataFrame into non-overlapping intervals.
    """
    n_segments = int(df["Time"].iloc[-1]) // segment_size
    eeg_segments = []
    emg_segments = []

    for i in range(n_segments):
        start_idx = int(i* segment_size*1000/step_size)
        end_idx = start_idx + int(segment_size*1000/step_size)
        segment = df.iloc[start_idx:end_idx]
        eeg_segments.append(list(segment["EEG"]))
        emg_segments.append(list(segment["EMG"]))

    return eeg_segments, emg_segments

# Segment the data
segment_length = 4  # seconds
eeg_segments, emg_segments = segment_data(data, segment_length, step_size = 2)

embedding_dimension= 5
embedding_time_delay = 25
stride = 10

embedder_periodic = SingleTakensEmbedding(
    parameters_type="fixed",
    n_jobs=2,
    time_delay=embedding_time_delay,
    dimension=embedding_dimension,
    stride=stride,
)

## Examining segments with label 1

# Initial definition of embedding dictionary for all labels
y_embedded = {}

# Loop through the first positions with label '1'
for label_idx in indices_dict[1]:
    y_embedded[label_idx] = embedder_periodic.fit_transform(eeg_segments[label_idx])

# Plot the first point cloud
plot_point_cloud(y_embedded[indices_dict[1][0]])

# Reshape (?)
for label_idx in indices_dict[1]:
    y_embedded[label_idx] = y_embedded[label_idx][None, :, :]

# 0 - connected components, 1 - loops, 2 - voids
homology_dimensions = [0, 1, 2]

persistence = VietorisRipsPersistence(
    homology_dimensions=homology_dimensions, n_jobs=10
)

print("Persistence diagram for periodic signal")

diagrams = {}

for label_idx in indices_dict[1]:
    diagrams[label_idx] = persistence.fit_transform_plot(y_embedded[label_idx])

### Tuning the embedding dimension and time delay

There are two techniques that can be used to determine these parameters automatically:
- Mutual information to determine the time delay
- False nearest neighbours to determine the embedding dimension

# Initialise the embedding

max_embedding_dimension = 30
max_time_delay = 30
stride = 5

embedder = SingleTakensEmbedding(
    parameters_type="search",
    time_delay=max_time_delay,
    dimension=max_embedding_dimension,
    stride=stride,
)

def fit_embedder(embedder: SingleTakensEmbedding, y: np.ndarray, verbose: bool=True) -> np.ndarray:
    """Fits a Takens embedder and displays optimal search parameters."""
    y_embedded = embedder.fit_transform(y)

    if verbose:
        print(f"Shape of embedded time series: {y_embedded.shape}")
        print(
            f"Optimal embedding dimension is {embedder.dimension_} and time delay is {embedder.time_delay_}"
        )

    return y_embedded

# Look at some random segments
y_embedded = fit_embedder(embedder, eeg_segments[0])
y_embedded = fit_embedder(embedder, eeg_segments[100])
y_embedded = fit_embedder(embedder, eeg_segments[177])
y_embedded = fit_embedder(embedder, eeg_segments[1000])


# The optimal values are all similar (Use embedding dimension 5 and time delay 25)

# Apply dimensionality PCA to project down to 3-dimensions for visualisation:
from sklearn.decomposition import PCA
pca = PCA(n_components=3)
y_embedded_pca = pca.fit_transform(y_embedded)
plot_point_cloud(y_embedded_pca)

### Extracting features

PE = PersistenceEntropy()

persistence_entropy1 = {}

for label_idx in indices_dict[1]:
    persistence_entropy1[label_idx] = PE.fit_transform(diagrams[label_idx])

AM = Amplitude()

amplitude1 = {}

for label_idx in indices_dict[1]:
    amplitude1[label_idx] = AM.fit_transform(diagrams[label_idx])

NP = NumberOfPoints()

no_points1 = {}

for label_idx in indices_dict[1]:
    no_points1[label_idx] = NP.fit_transform(diagrams[label_idx])

CP = ComplexPolynomial()

complex_polynomial1 = {}

for label_idx in indices_dict[1]:
    complex_polynomial1[label_idx] = CP.fit_transform(diagrams[label_idx])

HK = HeatKernel()

heatkernel1 = {}

for label_idx in indices_dict[1]:
    heatkernel1[label_idx] = PL.fit_transform(diagrams[label_idx])

## Examining segments with label 3

# Loop through the first positions with label '3'

y_embedded = {} # TODO one dictionary

for label_idx in indices_dict[3]:
    y_embedded[label_idx] = embedder_periodic.fit_transform(eeg_segments[label_idx])

# Plot the first point cloud
plot_point_cloud(y_embedded[indices_dict[3][0]])

# Reshape (?)
for label_idx in indices_dict[3]:
    y_embedded[label_idx] = y_embedded[label_idx][None, :, :]

# 0 - connected components, 1 - loops, 2 - voids
homology_dimensions = [0, 1, 2]

persistence = VietorisRipsPersistence(
    homology_dimensions=homology_dimensions, n_jobs=10
)

print("Persistence diagram for periodic signal")

diagrams = {}

for label_idx in indices_dict[3]:
    diagrams[label_idx] = persistence.fit_transform_plot(y_embedded[label_idx])

### Extracting features

PE = PersistenceEntropy()

persistence_entropy3 = {}

for label_idx in indices_dict[3]:
    persistence_entropy3[label_idx] = PE.fit_transform(diagrams[label_idx])

AM = Amplitude()

amplitude3 = {}

for label_idx in indices_dict[3]:
    amplitude3[label_idx] = AM.fit_transform(diagrams[label_idx])

NP = NumberOfPoints()

no_points3 = {}

for label_idx in indices_dict[3]:
    no_points3[label_idx] = NP.fit_transform(diagrams[label_idx])

CP = ComplexPolynomial()

complex_polynomial3 = {}

for label_idx in indices_dict[3]:
    complex_polynomial3[label_idx] = CP.fit_transform(diagrams[label_idx])

## Examining segments with label 4

# Loop through the first positions with label '4'

y_embedded = {} # TODO one dictionary

for label_idx in indices_dict[4]:
    y_embedded[label_idx] = embedder_periodic.fit_transform(eeg_segments[label_idx])

# Plot the first point cloud
plot_point_cloud(y_embedded[indices_dict[4][0]])

# Reshape (?)
for label_idx in indices_dict[4]:
    y_embedded[label_idx] = y_embedded[label_idx][None, :, :]

# 0 - connected components, 1 - loops, 2 - voids
homology_dimensions = [0, 1, 2]

persistence = VietorisRipsPersistence(
    homology_dimensions=homology_dimensions, n_jobs=10
)

print("Persistence diagram for periodic signal")

diagrams = {}

for label_idx in indices_dict[4]:
    diagrams[label_idx] = persistence.fit_transform_plot(y_embedded[label_idx])

### Extracting features

PE = PersistenceEntropy()

persistence_entropy4 = {}

for label_idx in indices_dict[4]:
    persistence_entropy4[label_idx] = PE.fit_transform(diagrams[label_idx])


AM = Amplitude()

amplitude4 = {}

for label_idx in indices_dict[4]:
    amplitude4[label_idx] = AM.fit_transform(diagrams[label_idx])

NP = NumberOfPoints()

no_points4 = {}

for label_idx in indices_dict[4]:
    no_points4[label_idx] = NP.fit_transform(diagrams[label_idx])

CP = ComplexPolynomial()

complex_polynomial4 = {}

for label_idx in indices_dict[4]:
    complex_polynomial4[label_idx] = CP.fit_transform(diagrams[label_idx])

## Plotting the distribution of features for different labels

# Extract coordinates for each dictionary's keys
def extract_coordinates(dictionary):
    coordinates = [tuple(arr[0]) for arr in dictionary.values()]
    return zip(*coordinates)

### Persistence Entropy


x1, y1, z1 = extract_coordinates(persistence_entropy1)
x3, y3, z3 = extract_coordinates(persistence_entropy3)
x4, y4, z4 = extract_coordinates(persistence_entropy4)

# Create a 3D plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Plot keys from each dictionary with different colors
ax.scatter(x1, y1, z1, c='r', marker='o', label='Label 1')
ax.scatter(x3, y3, z3, c='g', marker='s', label='Label 3')
ax.scatter(x4, y4, z4, c='b', marker='^', label='Label 4')

ax.set_xlabel('dimension 0')
ax.set_ylabel('dimension 1')
ax.set_zlabel('dimension 2')

ax.legend()

plt.show()

### Amplitude (Persistence)

x1, y1, z1 = extract_coordinates(amplitude1)
x3, y3, z3 = extract_coordinates(amplitude3)
x4, y4, z4 = extract_coordinates(amplitude4)

# Create a 3D plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Plot keys from each dictionary with different colors
ax.scatter(x1, y1, z1, c='r', marker='o', label='Label 1')
ax.scatter(x3, y3, z3, c='g', marker='s', label='Label 3')
ax.scatter(x4, y4, z4, c='b', marker='^', label='Label 4')

ax.set_xlabel('dimension 0')
ax.set_ylabel('dimension 1')
ax.set_zlabel('dimension 2')

ax.legend()

plt.show()

### Number of Points ("Betti numbers")

x1, y1, z1 = extract_coordinates(no_points1)
x3, y3, z3 = extract_coordinates(no_points3)
x4, y4, z4 = extract_coordinates(no_points4)

# Create a 3D plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Plot keys from each dictionary with different colors
ax.scatter(x1, y1, z1, c='r', marker='o', label='Label 1')
ax.scatter(x3, y3, z3, c='g', marker='s', label='Label 3')
ax.scatter(x4, y4, z4, c='b', marker='^', label='Label 4')

ax.set_xlabel('dimension 0')
ax.set_ylabel('dimension 1')
ax.set_zlabel('dimension 2')

ax.legend()

plt.show()

### Complex Polynomials

x1, y1, z1 = extract_coordinates(complex_polynomials1)
x3, y3, z3 = extract_coordinates(complex_polynomials3)
x4, y4, z4 = extract_coordinates(complex_polynomials4)

# Create a 3D plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Plot keys from each dictionary with different colors
ax.scatter(x1, y1, z1, c='r', marker='o', label='Label 1')
ax.scatter(x3, y3, z3, c='g', marker='s', label='Label 3')
ax.scatter(x4, y4, z4, c='b', marker='^', label='Label 4')

ax.set_xlabel('dimension 0')
ax.set_ylabel('dimension 1')
ax.set_zlabel('dimension 2')

ax.legend()

plt.show()

Furhter pre- and postprocessing:
- Remove outliers before? (Artifacts Removal)
- Remove noise from persistance diagrams (holes with very low persistence)
- Other preprocessing (PCA? Wavelet Decomposition? "Downsampling")?
- Later use filtering on persistence diagrams (gtda.diagrams.Filtering() )

Further representations of the diagrams:
- PersistenceLandscape
- BettiCurve
- HeatKernel
- PersistenceImage
- Silhouette

## Plotting representations of the Persistence Diagrams of all data for each label

### Persistence Landscape

PL = PersistenceLandscape()

persistence_landscape = PL.fit_transform([diagrams[0], diagrams[1], diagrams[2]])

### Heat Kernel 

diagrams[0]

for idx in diagrams.keys():
    print(diagrams[idx])

HK = HeatKernel()

heatkernel1 = HK.fit_transform(np.asarray([diagrams[0], diagrams[1], diagrams[2]]))

data = {
    "x": [1, 2, 3, 4, 5, 6],
    "y": [10, 20, 30, 40, 50, 60],
    "z": [100, 200, 300, 400, 500, 600],
    "target": ["A", "B", "A", "C", "B", "C"]
}
df = pd.DataFrame(data)

# Calculate point_clouds
point_clouds = np.asarray(
    [
        df.query("target == @shape")[["x", "y", "z"]].values
        for shape in df["target"].unique()
    ]
)

persistence = VietorisRipsPersistence(
    metric="euclidean",
    homology_dimensions=homology_dimensions,
    n_jobs=6,
    collapse_edges=True,
)
persistence_diagrams = persistence.fit_transform(point_clouds)


from sklearn.pipeline import make_union

# Select a variety of metrics to calculate amplitudes
metrics = [
    {"metric": metric}
    for metric in ["bottleneck", "wasserstein", "landscape", "persistence_image"]
]

# Concatenate to generate 3 + 3 + (4 x 3) = 18 topological features
feature_union = make_union(
    PersistenceEntropy(normalize=True),
    NumberOfPoints(n_jobs=-1),
    *[Amplitude(**metric, n_jobs=-1) for metric in metrics]
)


pipe = Pipeline(
    [
        ("features", feature_union),
        ("rf", RandomForestClassifier(oob_score=True, random_state=42)),
    ]
)
pipe.fit(diag, labels)
pipe["rf"].oob_score_


## Experimental

# If I add the labels to the DataFrame data as a column target, this is a fancy way of
# Retrieving different datasets for each label

data = {
    "x": [1, 2, 3, 4, 5, 6],
    "y": [10, 20, 30, 40, 50, 60],
    "z": [100, 200, 300, 400, 500, 600],
    "target": ["A", "B", "A", "C", "B", "C"]
}
df = pd.DataFrame(data)

# Calculate point_clouds
point_clouds = np.asarray(
    [
        df.query("target == @shape")[["x", "y", "z"]].values
        for shape in df["target"].unique()
    ]
)

y_embedded[label_idx][0]