## Data Preprocessing

In [None]:
import tensorflow_datasets as tfds
import torch
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split

In [None]:
# @formatter:off
%load_ext autoreload
%autoreload 2
# @formatter:on

In [None]:
data_str = "synth"
data_size = 6399

batch_size = 32
latent_dim = 8
num_epochs = 32
alpha = 0.8
beta = 1.6
gamma = 0.8

In [None]:
data = tfds.load(data_str)["train"]
print(len(list(data)))

In [None]:
data = tfds.load(data_str)["train"].take(data_size)

In [None]:
ecg_data = pd.DataFrame(data)['ecg'].map(lambda ecg: ecg['I'].numpy())
ecg_data = np.array(ecg_data.tolist()).reshape(len(data), 1, 500)

In [None]:
ecg_data.shape

In [None]:
# Split the data into train and validation sets
train_data, val_data = train_test_split(ecg_data, test_size=0.2, random_state=42)

## Model Configuration

In [None]:
from pythae.models import BetaTCVAEConfig, BetaTCVAE
from pythae.trainers import BaseTrainerConfig
from model import Encoder, Decoder

config = BaseTrainerConfig(
    output_dir='../results/my_model',
    learning_rate=1e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_epochs=num_epochs,
)

model_config = BetaTCVAEConfig(
    input_dim=(1, 1, 500),
    latent_dim=latent_dim,
    beta=beta,
    alpha=alpha,
    gamma=gamma
)

model = BetaTCVAE(
    model_config=model_config,
    encoder=Encoder(model_config),
    decoder=Decoder(model_config)
)

In [None]:
model

In [None]:
from pythae.pipelines import TrainingPipeline

pipeline = TrainingPipeline(
    training_config=config,
    model=model
)

In [None]:
pipeline(
    train_data = train_data,
    eval_data= val_data
)

## Visualize Reconstructions

In [None]:
from pythae.models import AutoModel

model = AutoModel.load_from_folder("../results/my_model/BetaTCVAE_training_2024-02-26_17-12-57/final_model/") #change path to trained model
#model = AutoModel.load_from_folder("../results/my_model/BetaTCVAE_training_2024-02-26_17-07-52/final_model/")
#model = AutoModel.load_from_folder("../results/my_model/BetaTCVAE_training_2024-02-27_17-18-23/final_model/")

In [None]:

embeddings = model.encoder.forward(torch.from_numpy(ecg_data.astype('float32')))['embedding']

In [None]:
reconstructions = model.decoder.forward(embeddings)['reconstruction'].detach().numpy()

In [None]:
from matplotlib import pyplot as plt

In [None]:
sample_nr = 5
input_data = ecg_data[sample_nr][0]
reconstruction = reconstructions[sample_nr][0]

embeddings_mod_1 = embeddings[sample_nr].detach().clone()
embeddings_mod_1[0] = 0
reconstruction_mod_1 = model.decoder.forward(embeddings_mod_1)['reconstruction'].detach().numpy()[0][0]

embeddings_mod_2 = embeddings[sample_nr].detach().clone()
embeddings_mod_2[1] = 0
reconstruction_mod_2 = model.decoder.forward(embeddings_mod_2)['reconstruction'].detach().numpy()[0][0]

x_points = list(range(0, 500))
plt.plot(x_points, reconstruction_mod_1, color='violet', label='reconstruction dim0=0')
plt.plot(x_points, reconstruction_mod_2, color='turquoise', label='reconstruction dim1=0')
plt.plot(x_points, reconstruction, color='#a9dbb8', label='reconstruction')
plt.plot(x_points, input_data, color='#020887', label='original')
plt.legend()

#plt.savefig("../results/model_evaluation/synth_reconstruction" + "_5_sample" + ".pdf", bbox_inches='tight')
plt.show()

In [None]:
from collections import namedtuple

meta_data = pd.DataFrame(data)[['t_height', 'p_height']].map(lambda d: d.numpy())
#meta_data = pd.DataFrame(data).drop(['ecg', 'quality'], axis=1).map(lambda d: d.numpy())
Data = namedtuple('Data', ['meta_data', 'z', 'discrete_features'])

input_data = Data(meta_data=meta_data, z=embeddings.detach().numpy(), discrete_features=[False]*len(meta_data.columns))


In [None]:
from interpretability_component.disentanglement_metrics import *
from interpretability_component.similarity_measures import *

sim_measure = SimilarityMeasure(input_data, mutual_information_regression)
dim_labels = sim_measure.get_interpreted_features(get_top_feature_for_meaningful_dims)
print(dim_labels)

dis_metric = DisentanglementMetric(sim_measure, mig_sup)
dis_metric_score = dis_metric.compute_score()
print(dis_metric_score)

In [None]:
from interpretability_component.utils import *
#visualize_mi_matrix(sim_measure.scores, meta_data.columns, save_pdf=True, filename="../results/model_evaluation/synth")
visualize_mi_matrix(sim_measure.scores, meta_data.columns)

In [None]:
import datetime

path= "../results/model_evaluation/" + str(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))

df = pd.DataFrame()
df['config'] = [model.model_config]
df['dim_labels'] = [dim_labels]
df['dis_metrics'] = [dis_metric_score]
df['eval_loss'] = [4.328]

print(df)

#df.to_csv(path)