# Imports

In [1]:
# -*- coding: UTF-8 -*-
# Local packages:
import argparse
import logging
import os
import pickle
import random
import shutil
import time
from typing import Dict, Union

# 3rd party packages:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from tqdm import tqdm,trange

from torch.utils.tensorboard import SummaryWriter
# TODO: Implement Neptune logger

# personal packages:
#from Data.preprocess import preprocess_data
#from model.timegan import TimeGAN
#from model.utils import timegan_trainer, timegan_generator

  from .autonotebook import tqdm as notebook_tqdm


# Model

## Dataset class

In [2]:
class TimeGAN_Dataset(torch.utils.data.dataset):
    """A time series dataset for TimeGAN.
    Args:
        data(numpy.ndarray): the padded dataset to be fitted. Has to transform to ndarray from DataFrame during initializ
        time(numpy.ndarray): the length of each data
    Parameters:
        - x (torch.FloatTensor): the real value features of the data
        - t (torch.LongTensor): the temporal feature of the data
    """
    def __init__(self,args):
        #sanity check data and time
        value = args.data['Sleeping stage'].values
        time = args.data['time'].values
        if len(value) != len(time):
            raise ValueError( f"len(value) `{len(value)}` != len(time) {len(time)}")
        if isinstance(time,type(None)):
            time = [len(x) for x in data]
        self.X = torch.FloatTensor(value)
        self.T = torch.LongTensor(time)

    def __len__(self):
        return len(self.X)
    
    def __getitem__(self,idx):
        return self.X[idx],self.T[idx]
    
    def collate_fn(self, batch):
        """Minibatch sampling
        """
        # Pad sequences to max length
        X_mb = [X for X in batch[0]]
        
        # The actual length of each data
        T_mb = [T for T in batch[1]]
        
        return X_mb, T_mb

TypeError: module() takes at most 2 arguments (3 given)

## TimeGAN

In [None]:
class EmbeddingNetwork(nn.Module):
    """
    The embedding network (encoder) that maps the input data to a latent space.
    """
    def __init__(self,args):
        super(EmbeddingNetwork, self).__init__()
        self.feature_dim = args.feature_dim
        self.hidden_dim = args.hidden_dim
        self.num_layers = args.num_layers
        self.padding_value = args.padding_value
        self. max_seq_len = args.max_seq_len

        #Embedder Architecture
        self.emb_rnn = nn.GRU(
            input_size=self.feature_dim,
            hidden_size=self.hidden_dim,
            num_layers=self.num_layers,
            batch_first=True,
        )
        self.emb_linear = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.emb_sigmoid = nn.Sigmoid()

        
        # Init weights
        # Default weights of TensorFlow is Xavier Uniform for W and 1 or 0 for b
        # Reference: 
        # - https://www.tensorflow.org/api_docs/python/tf/compat/v1/get_variable
        # - https://github.com/tensorflow/tensorflow/blob/v2.3.1/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py#L484-L61

        with torch.no_grad():
            for name, param in self.emb_rnn.named_parameters():
                if 'weight_ih' in name:
                    nn.init.xavier_uniform_(param.data)
                elif 'weight_hh' in name:
                    nn.init.orthogonal_(param.data)
                elif 'bias_ih' in name:
                    param.data.fill_(1)
                elif 'bias_hh' in name:
                    param.data.fill_(1)

            for name, param in self.emb_linear.named_parameters():
                if 'weight' in name:
                    nn.init.xavier_uniform_(param)
                elif 'bias' in name:
                    param.data.fill_(0)
    def forward(self,X,T):
        """Forward pass of the embedding features from original space to latent space.
        Args:
            X: Input time series feature (B x S x F)
            T: INput temporal information (B)
        Returns:
            H: latent space embeddings (B x S x H)
        """
        # Dynamic RNN input for ignoring paddings

        X_pack = nn.utils.rnn.pack_padded_sequence(
            input =X,
            lengths=T,
            batch_first=True,
            enforce_sorted=False,
        )

        # 128*100*71
        H_o,H_t = self.emb_rnn(X_pack)

        #pad RNN output back to sequence length

        H_o,T = nn.utils.rnn.pad_packed_sequence(
            sequence=H_o,
            batch_first=True,
            padding_value=self.padding_value,
            total_length=self.max_seq_len,
        )

        #128*100*10
        logits = self.emb_linear(H_o)
        H = self.emb_sigmoid(logits)

        return H
    
class RecoveryNetwork(nn.Module):
    """The recovery network (decoder) for TimeGAN
    """
    def __init__(self,arg):
        super(RecoveryNetwork, self).__init__()
        self.hidden_dim = arg.hidden_dim
        self.feature_dim = arg.feature_dim
        self.num_layers = arg.num_layers
        self.padding_value = arg.padding_value
        self.max_seq_len = arg.max_seq_len

        #Recovery Architecture
        self.rec_rnn = nn.GRU(
            input_size=self.hidden_dim,
            hidden_size=self.hidden_dim,
            num_layers=self.num_layers,
            batch_first=True,
        )

        self.rec_linear = nn.Linear(self.hidden_dim, self.feature_dim)

        # Init weights
        # Default weights of TensorFlow is Xavier Uniform for W and 1 or 0 for b
        # Reference: 
        # - https://www.tensorflow.org/api_docs/python/tf/compat/v1/get_variable
        # - https://github.com/tensorflow/tensorflow/blob/v2.3.1/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py#L484-L614

        with torch.no_grad():
            for name,param in self.rec_rnn.named_parameters():
                if 'weight_ih' in name:
                    nn.init.xavier_uniform_(param.data)
                elif 'weight_hh' in name:
                    nn.init.xavier_uniform_(param.data)
                elif 'bias_ih' in name:
                    param.data.fill_(1)
                elif 'bias_hh' in name:
                    param.data.fill_(0)
            for name,param in self.rec_linear.named_parameters():
                if 'weight' in name:
                    nn.init.xavier_uniform_(param)
                elif 'bias' in name:
                    param.data.fill_(0)
        
    def forward(self,H,T):
        """ Forward pass of the recovery features from latent space to original space.
        Args:
            H: latent representation (B x S x E)
            T: input temporal information (B)
        Returns:
            X_tilde: recovered features (B x S x F)
        """
        #Dynamic RNN input for ignoring paddings
        H_pack = nn.utils.rnn.pack_padded_sequence(
            input = H,
            lengths=T,
            batch_first=True,
            enforce_sorted=False,
        )
        #128 x 100 x 10
        H_o,H_t = self.rec_rnn(H_pack)
        #pad RNN output back to sequence length
        H_o,T = nn.utils.rnn.pad_packed_sequence(
            sequence=H_o,
            batch_first=True,
            padding_value=self.padding_value,
            total_length=self.max_seq_len,
        )
        #128 x 100 x 71
        X_tilde = self.rec_linear(H_o)
        return X_tilde

class SupervisorNetwork(nn.Module):
        """The supervisor network for TimeGAN
        """
        def __init__(self,args):
            super(SupervisorNetwork,self).__init__()
            self.hidden_dim = args.hidden_dim
            self.num_layers = args.num_layers
            self.padding_value = args.padding_value
            self.max_seq_len = args.max_seq_len

            #supervisor architecture
            self.sup_rnn = nn.GRU(
                input_size=self.hidden_dim,
                hidden_size=self.hidden_dim,
                num_layers=self.num_layers-1,
                batch_first=True,
            )
            self.sup_linear = nn.Linear(self.hidden_dim,self.hidden_dim)
            self.sup_sigmoid = nn.Sigmoid()
             # Init weights
            # Default weights of TensorFlow is Xavier Uniform for W and 1 or 0 for b
            # Reference: 
            # - https://www.tensorflow.org/api_docs/python/tf/compat/v1/get_variable
            # - https://github.com/tensorflow/tensorflow/blob/v2.3.1/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py#L484-L614
            with torch.no_grad():
                for name, param in self.sup_rnn.named_parameters():
                    if 'weight_ih' in name:
                        torch.nn.init.xavier_uniform_(param.data)
                    elif 'weight_hh' in name:
                        torch.nn.init.xavier_uniform_(param.data)
                    elif 'bias_ih' in name:
                        param.data.fill_(1)
                    elif 'bias_hh' in name:
                        param.data.fill_(0)
                for name, param in self.sup_linear.named_parameters():
                    if 'weight' in name:
                        torch.nn.init.xavier_uniform_(param)
                    elif 'bias' in name:
                        param.data.fill_(0)
        def forward(self,H,T):
            """Forward pass for the supervisor for predicting next step
            Args:
                H: latent representation (B x S x E)
                T: input temporal information (B)
            Returns:
                H_hat: predicted next step data (B x S x E)
            """

            #Dynamic RNN input for ignoring paddings
            H_pack = nn.utils.rnn.pack_padded_sequence(
                input = H,
                lengths=T,
                batch_first=True,
                enforce_sorted=False,
            )

            H_o,H_t = self.sup_rnn(H_pack)
            #pad RNN output back to sequence length
            H_o,T = nn.utils.rnn.pad_packed_sequence(
                sequence=H_o,
                batch_first=True,
                padding_value=self.padding_value,
                total_length=self.max_seq_len,
            )
            logits = self.sup_linear(H_o)
            H_hat = self.sup_sigmoid(logits)
            return H_hat

class GeneratorNetwork(nn.Module):
    """The generator network for TimeGAN
    """
    def __init__(self,args):
        super(GeneratorNetwork,self).__init__()
        self.Z_dim = args.Z_dim
        self.hidden_dim = args.hidden_dim
        self.num_layers = args.num_layers
        self.padding_value = args.padding_value
        self.max_seq_len = args.max_seq_len

        #Generator Architecture
        self.gen_rnn = nn.GRU(
            input_size=self.Z_dim,
            hidden_size=self.hidden_dim,
            num_layers=self.num_layers,
            batch_first=True,
        )
        self.gen_linear = nn.Linear(self.hidden_dim,self.hidden_dim)
        self.gen_sigmoid = nn.Sigmoid()
                # Init weights
        # Default weights of TensorFlow is Xavier Uniform for W and 1 or 0 for b
        # Reference: 
        # - https://www.tensorflow.org/api_docs/python/tf/compat/v1/get_variable
        # - https://github.com/tensorflow/tensorflow/blob/v2.3.1/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py#L484-L614
        with torch.no_grad():
            for name, param in self.gen_rnn.named_parameters():
                if 'weight_ih' in name:
                    torch.nn.init.xavier_uniform_(param.data)
                elif 'weight_hh' in name:
                    torch.nn.init.xavier_uniform_(param.data)
                elif 'bias_ih' in name:
                    param.data.fill_(1)
                elif 'bias_hh' in name:
                    param.data.fill_(0)
            for name, param in self.gen_linear.named_parameters():
                if 'weight' in name:
                    torch.nn.init.xavier_uniform_(param)
                elif 'bias' in name:
                    param.data.fill_(0)
        
    def forward(self,Z,T):
        """ Takes in random noise (features) and generates synthetic features within the last latent space
        Args:
            Z: input random noise (B x S x Z)
            T: input temporal information (B)
        Returns:
            H: embeddings (B x S x E)
        """
        #Dynamic RNN input for ignoring paddings
        Z_pack = nn.utils.rnn.pack_padded_sequence(
            input = Z,
            lengths=T,
            batch_first=True,
            enforce_sorted=False,
        )

        # 128*100*71
        H_o,H_t = self.gen_rnn(Z_pack)

        #pad RNN output back to sequence length

        H_o,T = nn.utils.rnn.pad_packed_sequence(
            sequence=H_o,
            batch_first=True,
            padding_value=self.padding_value,
            total_length=self.max_seq_len,
        )

        #128*100*10
        logits = self.gen_linear(H_o)
        H = self.gen_sigmoid(logits)

        return H

class DiscriminatorNetwork(nn.Module):
    """The discriminator network for TimeGAN
    """
    def __init__(self,args):
        super(DiscriminatorNetwork,self).__init__()
        self.hidden_dim = args.hidden_dim
        self.num_layers = args.num_layers
        self.padding_value = args.padding_value
        self.max_seq_len = args.max_seq_len

        #Discriminator Architecture
        self.dis_rnn = nn.GRU(
            input_size=self.hidden_dim,
            hidden_size=self.hidden_dim,
            num_layers=self.num_layers,
            batch_first=True,
        )
        self.dis_linear = nn.Linear(self.hidden_dim,1)

        # Init weights
        # Default weights of TensorFlow is Xavier Uniform for W and 1 or 0 for b
        # Reference: 
        # - https://www.tensorflow.org/api_docs/python/tf/compat/v1/get_variable
        # - https://github.com/tensorflow/tensorflow/blob/v2.3.1/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py#L484-L614
        with torch.no_grad():
            for name, param in self.dis_rnn.named_parameters():
                if 'weight_ih' in name:
                    torch.nn.init.xavier_uniform_(param.data)
                elif 'weight_hh' in name:
                    torch.nn.init.xavier_uniform_(param.data)
                elif 'bias_ih' in name:
                    param.data.fill_(1)
                elif 'bias_hh' in name:
                    param.data.fill_(0)
            for name, param in self.dis_linear.named_parameters():
                if 'weight' in name:
                    torch.nn.init.xavier_uniform_(param)
                elif 'bias' in name:
                    param.data.fill_(0)
    
    def forward(self, H, T):
        """ Forward pass for predicting if the data is real or synthetic
        
        Args:
            H: latent representation (B x S x E)
            T: input temporal information (B)
        Returns:
        logits: prediction logits(B x S x 1)
        """
        # dynamic RNN input for ignoring paddings
        H_pack = nn.utils.rnn.pack_padded_sequence(
            input = H,
            lengths=T,
            batch_first=True,
            enforce_sorted=False,
        )

        # 128*100*10
        H_o,H_t = self.dis_rnn(H_pack)

        # pad RNN output back to sequence length
        H_o,T = nn.utils.rnn.pad_packed_sequence(
            sequence=H_o,
            batch_first=True,
            padding_value=self.padding_value,
            total_length=self.max_seq_len,
        )

        logits = self.dis_linear(H_o).squeeze(-1)
        return logits

class TimeGAN(nn.Module):
    """ Implementation of TimeGan (Yoon et al., 2019) using PyTorch
    
    Reference:
        - Yoon, J., Jarret, D., van der Schaar, M. (2019). Time-series Generative Adversarial Networks. (https://papers.nips.cc/paper/2019/hash/c9efe5f26cd17ba6216bbe2a7d26d490-Abstract.html)
        - https://github.com/jsyoon0823/TimeGAN
    """
    def __init__(self,args):
        super(TimeGAN,self).__init__()
        self.device =args.device
        self.feature_dim = args.feature_dim
        self.Z_dim = args.Z_dim
        self.hidden_dim = args.hidden_dim
        self.max_seq_len = args.max_seq_len
        self.batch_size = args.batch_size

        self.embedder = EmbeddingNetwork(args)
        self.recovery = RecoveryNetwork(args)
        self.generator = GeneratorNetwork(args)
        self.discriminator = DiscriminatorNetwork(args)
        self.supervisor = SupervisorNetwork(args)

    def _recovery_forward(self, X, T):
        """ The embedding network forward pass and the embedder network loss
        Args:
            X: input features
            T: input temporal information
        Returns:
            E_loss: the reconstruction loss
            X_tilde: the reconstructed features
        """

        # FOrward pass
        H = self.embedder(X,T)
        X_tilde = self.recovery(H,T)

        #for Joint training
        H_hat_supervise = self.supervisor(H,T)
        G_loss_S = F.mse_loss(
            H_hat_supervise[:,:-1,:],
            H[:,1:,:],
        ) #Teacher forcing next output

        #Reconstruction loss
        E_loss_T0 = F.mse_loss(X_tilde,X)
        E_loss0 = 10*torch.sqrt(E_loss_T0)
        E_loss = E_loss0 + 0.1*G_loss_S
        return E_loss, E_loss0,E_loss_T0
    def _supervisor_forward(self, X, T):
        """ The supervisor training forward pass
        Args:
            X: original input features
            T: input temporal information
        Returns:
            S_loss: the supervisor's loss
        """
        #supervisor forward pass
        H = self.embedder(X,T)
        H_hat_supervise = self.supervisor(H,T)

        #supervised loss
        S_loss = F.mse_loss(
            H_hat_supervise[:,:-1,:],
            H[:,1:,:],
        ) #Teacher forcing next output
        return S_loss
    def _discriminator_forward(self, X, T, Z, gamma=1):
        """ The discriminator forward pass and adversarial loss
        Args:
            X: input features
            T: input temporal information
            Z: input noise
            gamma: the weight for the adversarial loss
        Returns:
            D_loss: adversarial loss
        """
        #Real
        H = self.embedder(X, T).detach()

        #generator
        E_hat = self.generator(Z,T).detach()
        H_hat = self.supervisor(E_hat,T).detach()
        
        #forward pass
        Y_real = self.discriminator(H,T)        #Encode original data
        Y_fake = self.discriminator(H_hat,T)    #Output of generator + supervisor
        Y_fake_e = self.discriminator(E_hat,T)  #Output of generator

        D_loss_real = F.binary_cross_entropy_with_logits(Y_real, torch.ones_like(Y_real))
        D_loss_fake = F.binary_cross_entropy_with_logits(Y_fake, torch.zeros_like(Y_fake))
        D_loss_fake_e = F.binary_cross_entropy_with_logits(Y_fake_e, torch.zeros_like(Y_fake_e))

        D_loss = D_loss_real + D_loss_fake + gamma * D_loss_fake_e

        return D_loss
    
    def _generator_forward(self, X, T, Z, gamma=1):
        """ The generator forward pass
        Args:
            X: original input features
            T: input temporal information
            Z: input noise for the generator
            gamma: the weight for the adversarial loss
        Returns:
            G_loss: the generator loss
        """
        #supervisor forward pass
        H = self.embedder(X,T)
        H_hat_supervise = self.supervisor(H,T)

        #generator forward pass
        E_hat = self.generator(Z,T)
        H_hat = self.supervisor(E_hat,T)

        #synthetic data generated
        X_hat = self.recovery(H_hat,T)

        #generator loss
        #Adversarial loss
        Y_fake = self.discriminator(H_hat,T)        #Output of supervisor
        Y_fake_e = self.discriminator(E_hat,T)      #Output of generator

        G_loss_U = F.binary_cross_entropy_with_logits(Y_fake, torch.ones_like(Y_fake))
        G_loss_U_e = F.binary_cross_entropy_with_logits(Y_fake_e, torch.ones_like(Y_fake_e))

        #Supervised loss
        G_loss_S = F.mse_loss(
            H_hat_supervise[:,:-1,:],
            H[:,1:,:],
        ) #Teacher forcing next output

        #Two moments losses
        G_loss_V1 = torch.mean(
            torch.abs(torch.sqrt(X_hat.var(dim=0,unbiased=False)+1e-6) - torch.sqrt(X.var(dim=0,unbiased=False)+1e-6))
        )
        G_loss_V2 = torch.mean(torch.abs((X_hat.mean(dim=0)) - (X.mean(dim=0))))
        G_loss_V = G_loss_V1 + G_loss_V2
        
        #sum of losses
        G_loss = G_loss_U + gamma * G_loss_U_e + 100 * torch.sqrt(G_loss_S) + 100 * G_loss_V
    
        return G_loss
    
    def _inference(self, Z,T):
        """ Inference for generating synthetic data
        Args:
            Z: input noise
            T: temporal information
        Returns:
            X_hat: the generated data
        """

        #generator forward pass
        E_hat = self.generator(Z,T)
        H_hat = self.supervisor(E_hat,T)

        #synthetic data generated
        X_hat = self.recovery(H_hat,T)
        return X_hat

    def forward(self,X,T,Z, obj, gamma=1):
        """
        Args:
            X: input features (B,H,F)
            T: The temporal information (B)
            Z: the sampled noise (B,H,Z)
            obj: the network to be trained ('autoencoder','supervisor','generator','discriminator')
            gamma: loss hyperparameter
        Returns:
            loss: loss for the forward pass
            X_hat: the generated data
        """

        #Move variables to device
        if obj !='inference':
            if X is None:
                raise ValueError('X cannot be empty')
            
            X = torch.FloatTensor(X)
            X = X.to(self.device)

        if Z is not None:
            Z = torch.FloatTensor(Z)
            Z = Z.to(self.device)
        
        if obj == 'autoencoder':
            #embedder and recovery forward
            loss = self._recovery_forward(X,T)
        elif obj == 'supervisor':
            loss = self._supervisor_forward(X,T)
        elif obj == 'generator':
            if Z is None:
                raise ValueError('Z cannot be empty')
            loss = self._generator_forward(X,T,Z,gamma)
        elif obj == 'discriminator':
            if Z is None:
                raise ValueError('Z cannot be empty')
            loss = self._discriminator_forward(X,T,Z,gamma)
            return loss
        elif obj == 'inference':
            X_hat = self._inference(Z,T)
            X_hat = X_hat.cpu.detach()

            return X_hat
        else:
            raise ValueError('obj must be autoencoder, supervisor, generator or discriminator')
        return loss
        

## Trainers

In [None]:
def embedding_trainer(
        model: torch.nn.Module,
        dataloader: torch.utils.data.DataLoader,
        e_opt: torch.optim.Optimizer,
        r_opt: torch.optim.Optimizer,
        args: Dict,
        writer: Union[torch.utils.tensorboard.SummaryWriter, type(None)]=None
):
    """
    Training loop for embedding and recovery functions.
    Args:
        model (torch.nn.Module): The model to train
        dataloader (torch.utils.data.DataLoader): The dataloader to use
        e_opt (torch.optim.Optimizer): The optimizer for the embedding function
        r_opt (torch.optim.Optimizer): The optimizer for the recovery function
        args (Dict): The model/training configuration
        writer (Union[torch.utils.tensorboard.SummaryWriter, type(None)], optional): The tensorboard writer to use. Defaults to None.
    """
    logger = trange(args.emb_epochs, desc =f"Epoch:0, Loss:0")
    for epoch in logger:
        for X_mb,T_mb in dataloader:

            #reset gradients
            model.zero_grad()

            #forward pass
            _,E_loss0,E_loss_T0 = model(X=X_mb,T=T_mb,Z=None,obj="autoencoder")
            loss = np.sqrt(E_loss_T0.item())

            #backward pass
            E_loss0.backward()

            #update weights
            e_opt.step()
            r_opt.step()

        # Log loss for final batch of each epochs
        logger.set_description(f"Epoch:{epoch}, Loss:{loss:.4f}")
        if writer:
            writer.add_scalar("Embedding/Loss:",loss,epoch)
            writer.flush()

def supervisor_trainer(
    model: torch.nn.Module, 
    dataloader: torch.utils.data.DataLoader, 
    s_opt: torch.optim.Optimizer, 
    g_opt: torch.optim.Optimizer, 
    args: Dict, 
    writer: Union[torch.utils.tensorboard.SummaryWriter, type(None)]=None
):
    """
    The training loop for the supervisor function
    Args:
        model (torch.nn.Module): The model to train
        dataloader (torch.utils.data.DataLoader): The dataloader to use
        s_opt (torch.optim.Optimizer): The optimizer for the supervisor function
        g_opt (torch.optim.Optimizer): The optimizer for the generator function
        args (Dict): The model/training configuration
        writer (Union[torch.utils.tensorboard.SummaryWriter, type(None)], optional): The tensorboard writer to use. Defaults to None.
    """
    logger = trange(args.sup_epochs, desc=f"Epoch: 0, Loss: 0")
    for epoch in logger:
        for X_mb, T_mb in dataloader:
            # Reset gradients
            model.zero_grad()

            # Forward Pass
            S_loss = model(X=X_mb, T=T_mb, Z=None, obj="supervisor")

            # Backward Pass
            S_loss.backward()
            loss = np.sqrt(S_loss.item())

            # Update model parameters
            s_opt.step()

        # Log loss for final batch of each epoch (29 iters)
        logger.set_description(f"Epoch: {epoch}, Loss: {loss:.4f}")
        if writer:
            writer.add_scalar(
                "Supervisor/Loss:",loss,epoch)
            writer.flush()

In [None]:
def joint_trainer(
    model: torch.nn.Module, 
    dataloader: torch.utils.data.DataLoader, 
    e_opt: torch.optim.Optimizer, 
    r_opt: torch.optim.Optimizer, 
    s_opt: torch.optim.Optimizer, 
    g_opt: torch.optim.Optimizer, 
    d_opt: torch.optim.Optimizer, 
    args: Dict, 
    writer: Union[torch.utils.tensorboard.SummaryWriter, type(None)]=None, 
    ):
    """
    The training loop for training the model altogether
    Args:
        model (torch.nn.Module): The model to train
        dataloader (torch.utils.data.DataLoader): The dataloader to use
        e_opt (torch.optim.Optimizer): The optimizer for the embedding function
        r_opt (torch.optim.Optimizer): The optimizer for the recovery function
        s_opt (torch.optim.Optimizer): The optimizer for the supervisor function
        g_opt (torch.optim.Optimizer): The optimizer for the generator function
        d_opt (torch.optim.Optimizer): The optimizer for the discriminator function
        args (Dict): The model/training configuration
        writer (Union[torch.utils.tensorboard.SummaryWriter, type(None)], optional): The tensorboard writer to use. Defaults to None.
    """
    logger = trange(
        args.sup_epochs, 
        desc=f"Epoch: 0, E_loss: 0, G_loss: 0, D_loss: 0"
    )
    for epoch in logger:
        for X_mb, T_mb in dataloader:
            ## Generator Training
            for _ in range(2):
                # Random Generator
                Z_mb = torch.rand((args.batch_size, args.max_seq_len, args.Z_dim))

                # Forward Pass (Generator)
                model.zero_grad()
                G_loss = model(X=X_mb, T=T_mb, Z=Z_mb, obj="generator")
                G_loss.backward()
                G_loss = np.sqrt(G_loss.item())

                # Update model parameters
                g_opt.step()
                s_opt.step()

                # Forward Pass (Embedding)
                model.zero_grad()
                E_loss, _, E_loss_T0 = model(X=X_mb, T=T_mb, Z=Z_mb, obj="autoencoder")
                E_loss.backward()
                E_loss = np.sqrt(E_loss.item())
                
                # Update model parameters
                e_opt.step()
                r_opt.step()

            # Random Generator
            Z_mb = torch.rand((args.batch_size, args.max_seq_len, args.Z_dim))

            ## Discriminator Training
            model.zero_grad()
            # Forward Pass
            D_loss = model(X=X_mb, T=T_mb, Z=Z_mb, obj="discriminator")

            # Check Discriminator loss
            if D_loss > args.dis_thresh:
                # Backward Pass
                D_loss.backward()

                # Update model parameters
                d_opt.step()
            D_loss = D_loss.item()

        logger.set_description(
            f"Epoch: {epoch}, E: {E_loss:.4f}, G: {G_loss:.4f}, D: {D_loss:.4f}"
        )
        if writer:
            writer.add_scalar(
                'Joint/Embedding_Loss:',E_loss,epoch)
            writer.add_scalar(
                'Joint/Generator_Loss:',G_loss,epoch)
            writer.add_scalar('Joint/Discriminator_Loss:',D_loss,epoch)
            writer.flush()

def timegan_trainer(model,loaded_data,args):
    """
    The trainign procedure for TimeGAN.
    Args:
        model (torch.nn.module): The model that generates synthetic data
        loaded_data(pandas.DataFrame): The data to train on, including data and time
        args (Dict): The model/training configuration
    Returns:
        generated_data (np.array): The synthetic data generated by the model
    """
    dataset = TimeGAN_Dataset(args=args,data=loaded_data["data"],time=loaded_data["time"])
    dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False)
    model.to(args.device)

    #initialize optimizers
    e_opt = torch.optim.Adam(model.embedder.parameters(), lr=args.lr)
    r_opt = torch.optim.Adam(model.recovery.parameters(), lr=args.lr)
    s_opt = torch.optim.Adam(model.supervisor.parameters(), lr=args.lr)
    g_opt = torch.optim.Adam(model.generator.parameters(), lr=args.lr)
    d_opt = torch.optim.Adam(model.discriminator.parameters(), lr=args.lr)

    #initialize tensorboard writer
    writer = SummaryWriter(os.path.join(f"tensorboard/{args.exp}"))

    print("\nStart Embedding Network Training")
    embedding_trainer(model=model, dataloader=dataloader, e_opt=e_opt, r_opt=r_opt, args=args, writer=writer)

    print("\nStart Training with Supervised Loss Only")
    supervisor_trainer(model=model, dataloader=dataloader, s_opt=s_opt,g_opt=g_opt, args=args, writer=writer)

    print("\nStart Joint Training")
    joint_trainer(model=model, dataloader=dataloader, e_opt=e_opt, r_opt=r_opt, s_opt=s_opt, g_opt=g_opt, d_opt=d_opt, args=args, writer=writer)


    #save model,args, and hyperparameters
    torch.save(args,f"{args.model_path}/args.pickle")
    torch.save(model.state_dict(),f"{args.model_path}/model.pt")
    print(f"Model saved to {args.model_path}")

def timegan_generator(model,T,args):
    """
    The interference procedure for TimeGAN.
    Args:
        model (torch.nn.module): The model that generates synthetic data
        T (List[int]): The time to generate data for
        args (Dict): The model/training configuration
    returns:
        generated_data (np.array): The synthetic data generated by the model
    """
    #load model
    if not os.path.exists(args.model_path):
        raise ValueError(f"Model not found at {args.model_path}")
    model.load_state_dict(torch.load(f"{args.model_path}/model.pt"))
    print("\nStart Generating Synthetic Data")
    #Initialize model to evaluation mode and run without gradients
    model.to(args.device)
    model.eval()
    with torch.no_grad():
        # Random Generator
        Z = torch.rand((args.batch_size, args.max_seq_len, args.Z_dim))
        # Forward Pass (Generator)
        generated_data = model(X=None, T=T, Z=Z, obj="inference")
    return generated_data.numpy()

# Data handling

## Data loading

In [None]:
def load_data(args):
    """
    data_limit=None,save_dataset=None
    Load and preprocess real life datasets.
    
    Args:
        data_limit (int): The number of data points to load. If None, all data points are loaded. Default: None. Used for testing.
        save_dataset (bool): If 'Full', the dataset is saved to a csv file. If it's 'limited', than save the limited dataset if data_limit is not None. Default: None.


    Returns:
        dataset (pandas.DataFrame): The dataset.
    
    """

    main_dataset = pd.DataFrame()

    
    Tk().withdraw() # we don't want a full GUI, so keep the root window from appearing
    filenames = askopenfilenames() # show an "Open" dialog box and return the path to the selected file
    print(filenames)
    for filename in filenames:
        
        if filename.endswith('.mat'):
            #output df format: [id,value_array]
            df = load_mat_as_df(filename)
            print(df)
        
        elif filename.endswith('.csv'):

            #use create_dataset_csv.py to create a csv file
            if filename.find('dataset') != -1:
                df = pd.read_csv(filename,sep=';',index_col=0)

            """ CSV format:
            ID|time|Sleeping stage|length|additional_info

            """
            pass

        elif filename.endswith('.xml'):
            ## TODO: add xml support
            #df =
            pass
        else:
            print("Unsupported file format, skipping file:",filename,".")
            pass
    
        main_dataset.append(df)

    #Cut df to data_limit size for testing purposes
    if args.data_limit is not None:
        if args.save_dataset == 'Full':
            #save dataset to a csv file
            main_dataset.to_csv('Full_dataset.csv',sep=';')
        
        elif args.save_dataset == 'Limited':
            main_dataset = main_dataset[:args.data_limit]
            #save dataset to a csv file
            main_dataset.to_csv('limited_dataset.csv',sep=';')
        elif args.save_dataset == 'None':
            main_dataset = main_dataset[:args.data_limit]
            pass
        else:
            raise ValueError("Invalid save_dataset value, valid values are 'Full','Limited','None'.")

    elif args.data_limit is None:
        if args.save_dataset == 'Full':
            #save dataset to a csv file
            main_dataset.to_csv('full_dataset.csv',sep=';')
        elif args.save_dataset == 'Limited':
            print("Warning: data_limit is None, dataset is not limited, saving full dataset.")
            main_dataset.to_csv('full_dataset.csv',sep=';')
        elif args.save_dataset == 'None':
            pass
        else:
            raise ValueError("Invalid save_dataset value, valid values are 'Full','Limited','None'.")

    
    return main_dataset #dataset as df

def load_mat_as_df(mat_file_path, var_name):
    mat = sio.loadmat(mat_file_path,simplify_cells=True)

    if var_name not in list(mat.keys()):
        var_name = get_variable_name(mat)   
        

    return pd.DataFrame(mat[var_name])

def get_variable_name(loaded_mat):

    
    root = tk.Tk()
    root.title('.mat variable selector')
    tk.Label(root, text="Choose a variable:").pack()
    choices = list(loaded_mat.keys())

    variable = tk.StringVar(root)
    variable.set(choices[0]) # default value
    w = tk.Combobox(root, textvariable=variable,values=choices)

    w.pack()
    def ok():
        print ("value is:" + variable.get())
        root.destroy()
    def cancel():
        root.destroy()
        raise InterruptedError('User cancelled, invalid variable name')

    button1 = tk.Button(root, text="OK", command=ok)
    button2 = tk.Button(root, text="Cancel", command=cancel)
    button1.pack()
    button2.pack()
    root.mainloop()
    
    return variable.get()

## Data preprocess

In [None]:
def preprocess_data(args):
    """
    padding_value: int = -1.0,
    data_limit: int = None"""
    # Load and preprocess data
    # 
    # 1. Load data from files (csv,mat,xml)
    # 2. Preprocess data:
    # 2.1. Remove outliers
    # 2.2. Extract sequence length and time
    # 2.3. Resample data
    # 2.4. Normalize data
    # 2.5. Padding 
    #  
    # Args:
    #     data_limit (int): The number of lines to load from the data file
    #     padding_value (int): The value used for padding
    #     
    # 
    # Returns:
    #     prep_data (pandas.DataFrame): The processed data

    #######################################
    # 1. Load data from files (csv,mat,xml)
    #######################################

    loaded_data = load_data(data_limit=args.data_limit)
    """
    loaded data =       time_data   , data  , length
    (pandas.DataFrame), (np.array)  ,(list) ,(int)
    
    ()
    """
    #######################################
    # 2. Preprocess data:
    #######################################
    # 2.1. Remove outliers
    #######################################
    """
    Remove row's with unacceptable sleep stages values
    """
    sleep_stages = np.array([1,2,3,4,5])
    loaded_data[loaded_data['data'].apply(lambda x: all(elem in sleep_stages for elem in x))]

    #######################################
    # 2.2. Extract sequence length and time
    #######################################
    """
    Extract sequence length of all lines and time of each line
    """
    loaded_data['length'] = loaded_data['data'].apply(lambda x: len(x))

    
    #######################################
    # 2.4. Normalize data
    #######################################
    """
    Normalize data to [0,1] using MinMaxScaler algorithm
    """

    

    if args.norm_enable == True:
        loaded_data['data']=MinMaxNormalizer(loaded_data['data'])
    

    #######################################
    # 2.5. Padding
    #######################################
    """
    Padding data to given length
    """
   
    # Question: Current padding value is 0, is it ok? Do we need it or just resample?
    data_info = {
        'length' : len(loaded_data),
        'max_length' : max(loaded_data['length']),
        'paddding_value' : args.padding_value,

    }
    
    loaded_data['data'] = loaded_data['data'].apply(lambda x: np.transpose(x))
    prep_data = pd.DataFrame(columns=['time','data'])
    for i in tqdm(range(data_info.length)):
        #create empty array with padding value
        tmp_array = np.empty([data_info.max_length,1])
        tmp_array.fill(args.padding_value)
        #fill array with data
        tmp_array[:loaded_data['data'][i].shape[0],:loaded_data['data'][i].shape[1]] = loaded_data['data'][i]
        #append to prep_data
        prep_data.append(tmp_array)

    return prep_data.to_numpy(),loaded_data['time_data'].to_numpy(),data_info
    

    
    
    


def MinMaxNormalizer(data,min_value=1,max_value=5):
    numerator = data-min_value
    denominator = max_value-min_value
    norm_data = numerator/denominator
    return norm_data

# Main

In [None]:
parser = argparse.ArgumentParser()
# Experiment Arguments
parser.add_argument(
    "--device",
    choices=["cuda", "cpu"],
    default="cuda",
    type=str,
    help="Device to use for training",
)
parser.add_argument(
    "--seed",
    default=0,
    type=int,
    help="Random seed for reproducibility",
)
parser.add_argument(
    "--exp",
    default="test",
    type=str,
    help="Experiment name",
)
parser.add_argument(
    "--norm_enable",
    default=False,
    type=bool,
    help="Enable normalization",
)
parser.add_argument(
    "--save_dataset",
    choices=["Full","Limited","None"],
    default="Full",
    type=str,
    help="Save loaded datasets as csv",
)
# Data Arguments
parser.add_argument(
    "--data_limit",
    default=None,
    type=int,
    help="Number of data points to use (None or int)",
)
parser.add_argument(
    "--train_rate",
    default=0.6,
    type=float,
    help="Train test split rate",
)
parser.add_argument(
    "--max_seq_len",
    default=1000,
    type=int,
    help="Maximum sequence length",
)

# Model Arguments
parser.add_argument(
    "--emb_epochs",
    default=600,
    type=int,
    help="Number of epochs to train embedding model",
)
parser.add_argument(
    "--gan_epochs",
    default=600,
    type=int,
    help="Number of epochs to train GAN model",
)
parser.add_argument(
    "--sup_epochs",
    default=600,
    type=int,
    help="Number of epochs to train supervised model",
)
parser.add_argument(
    "--batch_size",
    default=64,
    type=int,
    help="Batch size for training",
)
parser.add_argument(
    "--hidden_dim",
    default=20,
    type=int,
    help="Hidden dimension of RNN",
)
parser.add_argument(
    "--num_layers",
    default=3,
    type=int,
    help="Number of layers in RNN",
)
parser.add_argument(
    "--dis_thresh",
    default=0.15,
    type=float,
    help="Discriminator threshold",
)
parser.add_argument(
    "--padding_value",
    default=0,
    type=int,
    help="Data padding value",
)
parser.add_argument(
    '--learning_rate',
    default=1e-3,
    type=float)
args = parser.parse_args()