In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from frameworks.data.time_series import generate_time_series
from frameworks.tda.embedding import embedding_time_series
from frameworks.utils.plots import plot_run_chart, plot_persistence_diagram
from gtda.homology import VietorisRipsPersistence
from gtda.diagrams import Amplitude

figsize = 4

signal_types = ["random", "periodic", "quasi-periodic", "oscillatory", "log-periodic-power", "geometric_random_walk", "random_walk_critical"]
length = 15
embedding_delay = 1
embedding_dimension = 2
sliding_window_size = 10
sliding_stride = 1

signals = []
fig, axes = plt.subplots(1, len(signal_types), figsize=(figsize * len(signal_types), figsize))
for i, signal_type in enumerate(signal_types):
    t, signal = generate_time_series(
        length=length,
        signal_type=signal_type,
        snr=40,
        amplitude=1,
        frequency=10,
        amplitude_ratio=0.25,
        frequency_ratio=0.3,
        alpha=0.5,
        exponential_factor=-1,
        nonlinearity=0.5,
        critical_time=0.75
    )
    signals.append(signal)
    plot_run_chart(axes[i], t, signal, label="Signal")

embedded_signals = [embedding_time_series(signal, embedding_delay, embedding_dimension, sliding_window_size, stride=sliding_stride) for signal in signals]


class TopoAutoencoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(TopoAutoencoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim)
        )

    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon, z


# class TopoAutoencoder(nn.Module):
#     def __init__(self, input_dim, latent_dim):
#         super(TopoAutoencoder, self).__init__()
#
#         # Encoder using LSTM
#         self.encoder = nn.LSTM(input_dim, latent_dim, batch_first=True)
#
#         # Decoder
#         self.decoder = nn.LSTM(latent_dim, input_dim, batch_first=True)
#
#     def forward(self, x):
#         _, (z, _) = self.encoder(x)  # Get last hidden state
#         x_recon, _ = self.decoder(z.unsqueeze(0))  # Decode latent state
#         return x_recon, z.squeeze(0)


def compute_persistence_diagram(data):
    """Compute persistence diagram using Vietoris-Rips complex."""
    diagrams = []
    for signal in data:  # Iterate over batch
        signal = signal.reshape(-1, 1)  # Reshape each individual sequence

        vr = VietorisRipsPersistence(homology_dimensions=[0, 1], n_jobs=-1)
        diagrams.append(vr.fit_transform([signal]))  # Pass as a list
    return diagrams


def topological_loss(diagrams_x, diagrams_z):
    """Compute topological loss as the difference in persistence diagrams."""
    amplitude = Amplitude(metric="wasserstein", metric_params={"p": 1})
    loss_x = amplitude.fit_transform(diagrams_x)
    loss_z = amplitude.fit_transform(diagrams_z)
    return torch.tensor(np.mean(np.abs(loss_x - loss_z)), requires_grad=True)


# Training Setup
latent_dim = 2  # Low-dimensional space
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TopoAutoencoder(input_dim=embedding_dimension, latent_dim=latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
mse_loss = nn.MSELoss()

# Training Loop
# Training Loop
num_epochs = 1
# Convert embedded_signals to a batch tensor (3D: batch_size, seq_length, embedding_dim)
x = torch.stack([torch.tensor(signal[0], dtype=torch.float32) for signal in embedded_signals]).to(device)

for epoch in range(num_epochs):
    total_loss = 0

    optimizer.zero_grad()

    # Forward pass
    x_recon, z = model(x)

    # Compute persistence diagrams for each sample in batch
    diagrams_x = compute_persistence_diagram(x.cpu().detach().numpy())[0]
    diagrams_z = compute_persistence_diagram(z.cpu().detach().numpy())[0]

    # Compute losses
    loss_recon = mse_loss(x_recon, x)
    loss_topo = topological_loss(diagrams_x, diagrams_z)
    loss = loss_recon + 0.1 * loss_topo  # Weighted combination

    # Backpropagation
    loss.backward()
    optimizer.step()
    total_loss += loss.item()

    # Logging
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss:.4f}")


import matplotlib.pyplot as plt
from persim import plot_diagrams

# Select a sample signal
signal_idx = 5
signal = embedded_signals[signal_idx]

# Compute Persistence Diagrams
diagrams_x = compute_persistence_diagram(signal)
diagrams_z = compute_persistence_diagram(
    model.encoder(torch.tensor(signal[0], dtype=torch.float32).to(device)).cpu().detach().numpy())

z = model.encoder(torch.tensor(signal[0], dtype=torch.float32).to(device)).clone().detach()
x_rec = model.decoder(z).cpu().detach()
diagrams_x_rec = compute_persistence_diagram(x_rec.numpy())

fig, axes = plt.subplots(1, 2, figsize=(12, 6))
plot_persistence_diagram(diagrams_x[0][0], ax=axes[0])
axes[0].set_title("Persistence Diagram - Input Space")

plot_persistence_diagram(diagrams_z[0][0], ax=axes[1])
axes[1].set_title("Persistence Diagram - Latent Space")
# plt.show()

# Compute evaluation metrics
# mse_losses = []
# topo_losses = []
#
# for signal in embedded_signals:
#     x = torch.tensor(signal[0], dtype=torch.float32).to(device)
#     x_recon, z = model(x)
#
#     # Compute persistence diagrams
#     diagrams_x = compute_persistence_diagram(x.cpu().detach().numpy())[0]
#     diagrams_z = compute_persistence_diagram(z.cpu().detach().numpy())[0]
#
#     # Compute losses
#     mse_losses.append(mse_loss(x_recon, x).item())
#     topo_losses.append(topological_loss(diagrams_x, diagrams_z).item())
#
# # Print Table
# import pandas as pd
# df_results = pd.DataFrame({"Signal Type": signal_types, "MSE Loss": mse_losses, "Topological Loss": topo_losses})
# print(df_results)