In [1]:
!git clone https://github.com/vincent-leguen/DILATE
%cd /content/DILATE
!pip install tslearn

import numpy as np
import torch
import random
from torch.utils.data import Dataset, DataLoader



class SDataset(torch.utils.data.Dataset):
    def __init__(self, X_input, X_target):
        super(SDataset, self).__init__()
        self.X_input = X_input
        self.X_target = X_target


    def __len__(self):
        return (self.X_input).shape[0]

    def __getitem__(self, idx):
        return (self.X_input[idx,:,np.newaxis], self.X_target[idx,:,np.newaxis] )

Cloning into 'DILATE'...
remote: Enumerating objects: 69, done.[K
remote: Counting objects: 100% (17/17), done.[K
remote: Compressing objects: 100% (16/16), done.[K
remote: Total 69 (delta 6), reused 2 (delta 0), pack-reused 52 (from 1)[K
Receiving objects: 100% (69/69), 4.71 MiB | 25.49 MiB/s, done.
Resolving deltas: 100% (23/23), done.
/content/DILATE
Collecting tslearn
  Downloading tslearn-0.7.0-py3-none-any.whl.metadata (16 kB)
Downloading tslearn-0.7.0-py3-none-any.whl (372 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m372.7/372.7 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tslearn
Successfully installed tslearn-0.7.0


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EncoderRNN(torch.nn.Module):
    def __init__(self,input_size, hidden_size, num_grulstm_layers, batch_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.num_grulstm_layers = num_grulstm_layers
        self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_grulstm_layers,batch_first=True)

    def forward(self, input, hidden): # input [batch_size, length T, dimensionality d]
        output, hidden = self.gru(input, hidden)
        return output, hidden

    def init_hidden(self,device):
        #[num_layers*num_directions,batch,hidden_size]
        return torch.zeros(self.num_grulstm_layers, self.batch_size, self.hidden_size, device=device)

class DecoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_grulstm_layers,fc_units, output_size):
        super(DecoderRNN, self).__init__()
        self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_grulstm_layers,batch_first=True)
        self.fc = nn.Linear(hidden_size, fc_units)
        self.out = nn.Linear(fc_units, output_size)

    def forward(self, input, hidden):
        output, hidden = self.gru(input, hidden)
        output = F.relu( self.fc(output) )
        output = self.out(output)
        return output, hidden

class Net_GRU(nn.Module):
    def __init__(self, encoder, decoder, target_length, device):
        super(Net_GRU, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.target_length = target_length
        self.device = device

    def forward(self, x):
        input_length  = x.shape[1]
        encoder_hidden = self.encoder.init_hidden(self.device)
        for ei in range(input_length):
            encoder_output, encoder_hidden = self.encoder(x[:,ei:ei+1,:]  , encoder_hidden)

        decoder_input = x[:,-1,:].unsqueeze(1) # first decoder input= last element of input sequence
        decoder_hidden = encoder_hidden

        outputs = torch.zeros([x.shape[0], self.target_length, x.shape[2]]  ).to(self.device)
        for di in range(self.target_length):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            decoder_input = decoder_output
            outputs[:,di:di+1,:] = decoder_output
        return outputs

In [3]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler

# Charger le fichier CSV
file_path = '/content/ETTh1.csv'
data = pd.read_csv(file_path)

# Extraire la colonne 'Close'
if 'LUFL' in data.columns:
    close_prices = data['LUFL'].dropna().values  # Suppression des valeurs NaN si elles existent
else:
    raise ValueError("La colonne 'max. wv (m/s)' n'existe pas dans le fichier.")

# Normalisation des données (Min-Max Scaling entre 0 et 1)
scaler = MinMaxScaler(feature_range=(0, 1))
close_prices_normalized = scaler.fit_transform(close_prices.reshape(-1, 1)).flatten()

# Fonction pour créer les séquences d'entrée et de cible
def create_sequences(data, input_size, target_size, stride=1):
    inputs = []
    targets = []
    for i in range(0, len(data) - input_size - target_size + 1, stride):
        inputs.append(data[i:i + input_size])
        targets.append(data[i + input_size:i + input_size + target_size])
    return np.array(inputs), np.array(targets)

# Diviser en séquences (25 pour input, 10 pour target)
input_size = 25
target_size = 10
stride = 5

X, y = create_sequences(close_prices_normalized, input_size, target_size, stride)

print(f"Dimensions des données d'entrée (X) : {X.shape}")
print(f"Dimensions des données de sortie (y) : {y.shape}")

# Exemple des premières séquences
print("Première séquence d'entrée :", X[0])
print("Première séquence de sortie :", y[0])

# Pour inverser la normalisation plus tard, vous pouvez utiliser :
# scaler.inverse_transform(normalized_data.reshape(-1, 1))
X = np.stack(X)
y = np.stack(y)

Dimensions des données d'entrée (X) : (3478, 25)
Dimensions des données de sortie (y) : (3478, 10)
Première séquence d'entrée : [0.5565765  0.55027876 0.51259548 0.51569273 0.5219905  0.54088375
 0.64154449 0.64784222 0.42773073 0.39933925 0.42773073 0.47800948
 0.41513523 0.4434235  0.38684699 0.38994424 0.39304149 0.44022299
 0.44022299 0.43392526 0.41823248 0.39933925 0.39304149 0.37734874
 0.37425149]
Première séquence de sortie : [0.36475325 0.35845549 0.3490605  0.3490605  0.37425149 0.43082798
 0.44022299 0.42453024 0.39933925 0.43712574]


In [4]:
# parameters
batch_size = 100
N = 200
N_input = 25
N_output = 10
sigma = 0.01
gamma = 0.01

# Load synthetic dataset
X_train_input=X[0:N,0:N_input]
X_train_target=y[0:N,0:N_output]
X_test_input=X[N:N+200,0:N_input]
X_test_target=y[N:N+200, 0:N_output]

dataset_train = SDataset(X_train_input,X_train_target)
dataset_test  = SDataset(X_test_input,X_test_target)
trainloader = DataLoader(dataset_train, batch_size=batch_size,shuffle=True, num_workers=1)
testloader  = DataLoader(dataset_test, batch_size=batch_size,shuffle=False, num_workers=1)


In [5]:
import numpy as np
import torch


from loss.dilate_loss import dilate_loss

from torch.utils.data import DataLoader
import random
from tslearn.metrics import dtw, dtw_path
import matplotlib.pyplot as plt
import warnings
import warnings; warnings.simplefilter('ignore')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
random.seed(0)



def train_model(net,loss_type, learning_rate, epochs=1000, gamma = 0.001,
                print_every=50,eval_every=50, verbose=1, Lambda=1, alpha=0.5):

    optimizer = torch.optim.Adam(net.parameters(),lr=learning_rate)
    criterion = torch.nn.MSELoss()

    for epoch in range(epochs):
        for i, data in enumerate(trainloader, 0):
            inputs, target= data
            inputs = torch.tensor(inputs, dtype=torch.float32).to(device)
            target = torch.tensor(target, dtype=torch.float32).to(device)
            batch_size = target.shape[0]
            N_output = 1

            # forward + backward + optimize
            outputs = net(inputs)
            loss_mse,loss_shape,loss_temporal = torch.tensor(0),torch.tensor(0),torch.tensor(0)

            if (loss_type=='mse'):
                loss_mse = criterion(target,outputs)
                loss = loss_mse

            if (loss_type=='dilate'):
                loss, loss_shape, loss_temporal = dilate_loss(target,outputs,alpha, gamma, device)

            if (loss_type=='foldt'):
                loss, loss_shape, loss_temporal = foldt_loss(target,outputs,alpha, gamma, device)



            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if(verbose):
            if (epoch % print_every == 0):
                print('epoch ', epoch, ' loss ',loss.item(),' loss shape ',loss_shape.item(),' loss temporal ',loss_temporal.item())
                eval_model(net,testloader, gamma,verbose=1)


def eval_model(net,loader, gamma,verbose=1):
    criterion = torch.nn.MSELoss()
    losses_mse = []
    losses_dtw = []
    losses_tdi = []

    for i, data in enumerate(loader, 0):
        loss_mse, loss_dtw, loss_tdi = torch.tensor(0),torch.tensor(0),torch.tensor(0)
        # get the inputs
        inputs, target= data
        inputs = torch.tensor(inputs, dtype=torch.float32).to(device)
        target = torch.tensor(target, dtype=torch.float32).to(device)
        batch_size, N_output = target.shape[0:2]
        outputs = net(inputs)

        # MSE
        loss_mse = criterion(target,outputs)
        loss_dtw, loss_tdi = 0,0
        # DTW and TDI
        for k in range(batch_size):
            target_k_cpu = target[k,:,0:1].view(-1).detach().cpu().numpy()
            output_k_cpu = outputs[k,:,0:1].view(-1).detach().cpu().numpy()

            path, sim = dtw_path(target_k_cpu, output_k_cpu)
            loss_dtw += sim

            Dist = 0
            for i,j in path:
                    Dist += (i-j)*(i-j)
            loss_tdi += Dist / (N_output*N_output)

        loss_dtw = loss_dtw /batch_size
        loss_tdi = loss_tdi / batch_size

        # print statistics
        losses_mse.append( loss_mse.item() )
        losses_dtw.append( loss_dtw )
        losses_tdi.append( loss_tdi )

    print( ' Eval mse= ', np.array(losses_mse).mean() ,' dtw= ',np.array(losses_dtw).mean() ,' tdi= ', np.array(losses_tdi).mean())

In [6]:
pip install pot

Collecting pot
  Downloading pot-0.9.6.post1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (40 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/40.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.2/40.2 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Downloading pot-0.9.6.post1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pot
Successfully installed pot-0.9.6.post1


In [7]:
import numpy as np
import torch
from torch.autograd import Function
from numba import jit
import ot

@jit(nopython=True)
def my_max(x, gamma):
    max_x = np.max(x)
    exp_x = np.exp((x - max_x) / gamma)
    Z = np.sum(exp_x)
    return gamma * np.log(Z) + max_x, exp_x / Z

@jit(nopython=True)
def my_min(x, gamma):
    min_x, argmax_x = my_max(-x, gamma)
    return -min_x, argmax_x

@jit(nopython=True)
def my_max_hessian_product(p, z, gamma):
    return (p * z - p * np.sum(p * z)) / gamma

@jit(nopython=True)
def my_min_hessian_product(p, z, gamma):
    return -my_max_hessian_product(p, z, gamma)

@jit(nopython=True)
def dtw_grad(theta, gamma):
    m = theta.shape[0]
    n = theta.shape[1]
    V = np.zeros((m + 1, n + 1))
    V[:, 0] = 1e10
    V[0, :] = 1e10
    V[0, 0] = 0

    Q = np.zeros((m + 2, n + 2, 3))

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            v, Q[i, j] = my_min(np.array([V[i, j - 1], V[i - 1, j - 1], V[i - 1, j]]), gamma)
            V[i, j] = theta[i - 1, j - 1] + v

    E = np.zeros((m + 2, n + 2))
    E[m + 1, :] = 0
    E[:, n + 1] = 0
    E[m + 1, n + 1] = 1
    Q[m + 1, n + 1] = 1

    for i in range(m, 0, -1):
        for j in range(n, 0, -1):
            E[i, j] = Q[i, j + 1, 0] * E[i, j + 1] + \
                      Q[i + 1, j + 1, 1] * E[i + 1, j + 1] + \
                      Q[i + 1, j, 2] * E[i + 1, j]

    return V[m, n], E[1:m + 1, 1:n + 1], Q, E

@jit(nopython=True)
def dtw_hessian_prod(theta, Z, Q, E, gamma):
    m = Z.shape[0]
    n = Z.shape[1]

    V_dot = np.zeros((m + 1, n + 1))
    V_dot[0, 0] = 0

    Q_dot = np.zeros((m + 2, n + 2, 3))
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            V_dot[i, j] = Z[i - 1, j - 1] + \
                          Q[i, j, 0] * V_dot[i, j - 1] + \
                          Q[i, j, 1] * V_dot[i - 1, j - 1] + \
                          Q[i, j, 2] * V_dot[i - 1, j]

            v = np.array([V_dot[i, j - 1], V_dot[i - 1, j - 1], V_dot[i - 1, j]])
            Q_dot[i, j] = my_min_hessian_product(Q[i, j], v, gamma)
    E_dot = np.zeros((m + 2, n + 2))

    for j in range(n, 0, -1):
        for i in range(m, 0, -1):
            E_dot[i, j] = Q_dot[i, j + 1, 0] * E[i, j + 1] + \
                          Q[i, j + 1, 0] * E_dot[i, j + 1] + \
                          Q_dot[i + 1, j + 1, 1] * E[i + 1, j + 1] + \
                          Q[i + 1, j + 1, 1] * E_dot[i + 1, j + 1] + \
                          Q_dot[i + 1, j, 2] * E[i + 1, j] + \
                          Q[i + 1, j, 2] * E_dot[i + 1, j]

    return V_dot[m, n], E_dot[1:m + 1, 1:n + 1]

def compute_ot(x, y, reg=1.0):
    cost_matrix = ot.dist(x, y)
    ot_matrix = ot.sinkhorn(np.ones(x.shape[0]) / x.shape[0], np.ones(y.shape[0]) / y.shape[0], cost_matrix, reg)
    ot_distance = np.sum(ot_matrix * cost_matrix)
    return ot_distance

class PathDTWBatch(Function):
    @staticmethod
    def forward(ctx, D, gamma, use_ot=False, x=None, y=None, reg=1.0): # D.shape: [batch_size, N, N]
        batch_size, N, N = D.shape
        device = D.device
        D_cpu = D.detach().cpu().numpy()
        gamma_gpu = torch.FloatTensor([gamma]).to(device)

        grad_gpu = torch.zeros((batch_size, N, N)).to(device)
        Q_gpu = torch.zeros((batch_size, N + 2, N + 2, 3)).to(device)
        E_gpu = torch.zeros((batch_size, N + 2, N + 2)).to(device)

        for k in range(batch_size):
            _, grad_cpu_k, Q_cpu_k, E_cpu_k = dtw_grad(D_cpu[k], gamma)
            grad_gpu[k] = torch.FloatTensor(grad_cpu_k).to(device)
            Q_gpu[k] = torch.FloatTensor(Q_cpu_k).to(device)
            E_gpu[k] = torch.FloatTensor(E_cpu_k).to(device)

        total_loss = torch.mean(grad_gpu, dim=0)

        if use_ot and x is not None and y is not None:
            ot_distances = [compute_ot(x[k].cpu().numpy(), y[k].cpu().numpy(), reg) for k in range(batch_size)]
            ot_loss = torch.tensor(ot_distances, device=dev).mean()
            total_loss += ot_loss

        ctx.save_for_backward(grad_gpu, D, Q_gpu, E_gpu, gamma_gpu, torch.tensor([use_ot], device=device), x, y, torch.tensor([reg], device=device))
        return total_loss

    @staticmethod
    def backward(ctx, grad_output):
        device = grad_output.device
        grad_gpu, D_gpu, Q_gpu, E_gpu, gamma, use_ot, x, y, reg = ctx.saved_tensors
        D_cpu = D_gpu.detach().cpu().numpy()
        Q_cpu = Q_gpu.detach().cpu().numpy()
        E_cpu = E_gpu.detach().cpu().numpy()
        gamma = gamma.detach().cpu().numpy()[0]
        Z = grad_output.detach().cpu().numpy()

        batch_size, N, N = D_cpu.shape
        Hessian = torch.zeros((batch_size, N, N)).to(device)
        for k in range(batch_size):
            _, hess_k = dtw_hessian_prod(D_cpu[k], Z, Q_cpu[k], E_cpu[k], gamma)
            Hessian[k] = torch.FloatTensor(hess_k).to(device)

        # The OT gradient part is not trivial and not implemented here for simplicity
        return Hessian, None, None, None, None, None, None, None

import numpy as np
import torch
from numba import jit
from torch.autograd import Function
from torch.fft import fft

# Fonction pour calculer les coefficients de Fourier
def fourier_coefficients(x, n_coefficients):
    # Calculate Fourier coefficients
    X_fft = fft(x, dim=1)
    return X_fft[:, :n_coefficients]

# Fonction pour calculer les distances combinées (DTW + Fourier)
def pairwise_distances_with_fourier(x, y=None, n_coefficients=5):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matrix
           n_coefficients is the number of Fourier coefficients to use
    Output: dist is a NxM matrix where dist[i,j] is the combined distance
    '''
    # Calculate DTW distances
    dtw_distances = pairwise_distances(x, y)

    # Calculate Fourier coefficients
    x_fourier = fourier_coefficients(x, n_coefficients).abs().float()
    if y is not None:
        y_fourier = fourier_coefficients(y, n_coefficients).abs().float()
    else:
        y_fourier = x_fourier

    # Calculate Fourier distances
    x_fourier_norm = (x_fourier**2).sum(1).view(-1, 1)
    y_fourier_t = torch.transpose(y_fourier, 0, 1)
    y_fourier_norm = (y_fourier**2).sum(1).view(1, -1)

    fourier_distances = x_fourier_norm + y_fourier_norm - 2.0 * torch.mm(x_fourier, y_fourier_t)
    fourier_distances = torch.clamp(fourier_distances, 0.0, float('inf'))

    # Combine distances
    combined_distances = dtw_distances + fourier_distances
    return combined_distances

# Fonction pour calculer les distances pairwise (DTW)
def pairwise_distances(x, y=None):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matrix
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)

    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    return torch.clamp(dist, 0.0, float('inf'))

@jit(nopython=True)
def compute_softdtw(D, gamma):
    N = D.shape[0]
    M = D.shape[1]
    R = np.zeros((N + 2, M + 2)) + 1e8
    R[0, 0] = 0
    for j in range(1, M + 1):
        for i in range(1, N + 1):
            r0 = -R[i - 1, j - 1] / gamma
            r1 = -R[i - 1, j] / gamma
            r2 = -R[i, j - 1] / gamma
            rmax = max(max(r0, r1), r2)
            rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax)
            softmin = - gamma * (np.log(rsum) + rmax)
            R[i, j] = D[i - 1, j - 1] + softmin
    return R

@jit(nopython=True)
def compute_softdtw_backward(D_, R, gamma):
    N = D_.shape[0]
    M = D_.shape[1]
    D = np.zeros((N + 2, M + 2))
    E = np.zeros((N + 2, M + 2))
    D[1:N + 1, 1:M + 1] = D_
    E[-1, -1] = 1
    R[:, -1] = -1e8
    R[-1, :] = -1e8
    R[-1, -1] = R[-2, -2]
    for j in range(M, 0, -1):
        for i in range(N, 0, -1):
            a0 = (R[i + 1, j] - R[i, j] - D[i + 1, j]) / gamma
            b0 = (R[i, j + 1] - R[i, j] - D[i, j + 1]) / gamma
            c0 = (R[i + 1, j + 1] - R[i, j] - D[i + 1, j + 1]) / gamma
            a = np.exp(a0)
            b = np.exp(b0)
            c = np.exp(c0)
            E[i, j] = E[i + 1, j] * a + E[i, j + 1] * b + E[i + 1, j + 1] * c
    return E[1:N + 1, 1:M + 1]

class SoftDTWBatch(Function):
    @staticmethod
    def forward(ctx, D, gamma=1.0): # D.shape: [batch_size, N , N]
        dev = D.device
        batch_size, N, N = D.shape
        gamma = torch.FloatTensor([gamma]).to(dev)
        D_ = D.detach().cpu().numpy()
        g_ = gamma.item()

        total_loss = 0
        R = torch.zeros((batch_size, N+2, N+2)).to(dev)
        for k in range(0, batch_size):  # loop over all D in the batch
            Rk = torch.FloatTensor(compute_softdtw(D_[k, :, :], g_)).to(dev)
            R[k:k+1, :, :] = Rk
            total_loss = total_loss + Rk[-2, -2]
        ctx.save_for_backward(D, R, gamma)
        return total_loss / batch_size

    @staticmethod
    def backward(ctx, grad_output):
        dev = grad_output.device
        D, R, gamma = ctx.saved_tensors
        batch_size, N, N = D.shape
        D_ = D.detach().cpu().numpy()
        R_ = R.detach().cpu().numpy()
        g_ = gamma.item()

        E = torch.zeros((batch_size, N, N)).to(dev)
        for k in range(batch_size):
            Ek = torch.FloatTensor(compute_softdtw_backward(D_[k, :, :], R_[k, :, :], g_)).to(dev)
            E[k:k+1, :, :] = Ek

        return grad_output * E, None

def combined_dtw_fourier_loss(x, y, gamma=1.0, n_coefficients=3):
    dist_matrix = pairwise_distances_with_fourier(x, y, n_coefficients)
    loss = SoftDTWBatch.apply(dist_matrix, gamma)
    return loss

def foldt_loss(outputs, targets, alpha, gamma, device, n_coefficients=6):
    # outputs, targets: shape (batch_size, N_output, 1)
    batch_size, N_output = outputs.shape[0:2]
    loss_shape = 0
    softdtw_batch = SoftDTWBatch.apply
    D = torch.zeros((batch_size, N_output, N_output)).to(device)

    for k in range(batch_size):
        # Calculate the combined distances (DTW + Fourier)
        Dk = pairwise_distances_with_fourier(targets[k, :, :].view(-1, 1), outputs[k, :, :].view(-1, 1), n_coefficients)
        D[k:k+1, :, :] = Dk

    # Calculate shape loss using SoftDTWBatch
    loss_shape = softdtw_batch(D, gamma)

    # Calculate temporal loss
    path_dtw = PathDTWBatch.apply
    path = path_dtw(D, gamma)
    Omega = pairwise_distances(torch.arange(1, N_output + 1).view(N_output, 1).float()).to(device)
    loss_temporal = torch.sum(path * Omega) / (N_output * N_output)

    # Combine shape and temporal loss
    loss = 0.5 * loss_shape + 0.5 * loss_temporal
    return loss, loss_shape, loss_temporal


In [9]:
def foldt_loss(outputs, targets, alpha, gamma, device, n_coefficients=6):
    # outputs, targets: shape (batch_size, N_output, 1)
    batch_size, N_output = outputs.shape[0:2]
    loss_shape = 0
    softdtw_batch = SoftDTWBatch.apply
    D = torch.zeros((batch_size, N_output, N_output)).to(device)

    for k in range(batch_size):
        # Calculate the combined distances (DTW + Fourier)
        Dk = pairwise_distances_with_fourier(targets[k, :, :].view(-1, 1), outputs[k, :, :].view(-1, 1), n_coefficients)
        D[k:k+1, :, :] = Dk

    # Calculate shape loss using SoftDTWBatch
    loss_shape = softdtw_batch(D, gamma)

    # Calculate temporal loss
    path_dtw = PathDTWBatch.apply
    path = path_dtw(D, gamma)
    Omega = pairwise_distances(torch.arange(1, N_output + 1).view(N_output, 1).float()).to(device)
    loss_temporal = torch.sum(path * Omega) / (N_output * N_output)

    # Combine shape and temporal loss
    loss = 0.9 * loss_shape + 0.1 * loss_temporal
    return loss, loss_shape, loss_temporal


encoder = EncoderRNN(input_size=1, hidden_size=128, num_grulstm_layers=1, batch_size=batch_size).to(device)
decoder = DecoderRNN(input_size=1, hidden_size=128, num_grulstm_layers=1,fc_units=16, output_size=1).to(device)
net_gru_mor = Net_GRU(encoder,decoder, N_output, device).to(device)
train_model(net_gru_mor,loss_type='foldt',learning_rate=0.001, epochs=500, gamma=gamma, print_every=50, eval_every=50,verbose=1)

epoch  0  loss  3.301307201385498  loss shape  3.6681177616119385  loss temporal  1.1307285603834316e-05
 Eval mse=  0.14091452583670616  dtw=  1.125239167978107  tdi=  0.0
epoch  50  loss  0.19098106026649475  loss shape  0.1570797562599182  loss temporal  0.49609291553497314
 Eval mse=  0.019122661091387272  dtw=  0.41189560804597336  tdi=  0.4157000000000002
epoch  100  loss  0.16224819421768188  loss shape  0.12959107756614685  loss temporal  0.4561622142791748
 Eval mse=  0.012047890573740005  dtw=  0.3209455680822627  tdi=  0.6096499999999999
epoch  150  loss  0.11904306709766388  loss shape  0.08389230072498322  loss temporal  0.4354000389575958
 Eval mse=  0.011603337246924639  dtw=  0.31611196137656755  tdi=  0.5563
epoch  200  loss  0.11948588490486145  loss shape  0.08528485894203186  loss temporal  0.42729508876800537
 Eval mse=  0.010952380020171404  dtw=  0.30356173961424304  tdi=  0.6508500000000002
epoch  250  loss  0.1286647766828537  loss shape  0.09269434958696365  l

In [10]:
encoder = EncoderRNN(input_size=1, hidden_size=128, num_grulstm_layers=1, batch_size=batch_size).to(device)
decoder = DecoderRNN(input_size=1, hidden_size=128, num_grulstm_layers=1,fc_units=16, output_size=1).to(device)
net_gru_d = Net_GRU(encoder,decoder, N_output, device).to(device)
train_model(net_gru_d,loss_type='dilate',learning_rate=0.001, epochs=500, gamma=gamma, print_every=50, eval_every=50,verbose=1)

epoch  0  loss  2.5603296756744385  loss shape  5.120659351348877  loss temporal  2.642863683975649e-13
 Eval mse=  0.4212789535522461  dtw=  2.017669797589545  tdi=  0.0
epoch  50  loss  0.22248846292495728  loss shape  0.09670469164848328  loss temporal  0.3482722342014313
 Eval mse=  0.016590879764407873  dtw=  0.32602285216269244  tdi=  0.9183499999999998
epoch  100  loss  0.19872646033763885  loss shape  0.0017260723980143666  loss temporal  0.39572685956954956
 Eval mse=  0.015206643380224705  dtw=  0.3130692204140725  tdi=  0.6131000000000002
epoch  150  loss  0.2000613957643509  loss shape  0.0005584281752817333  loss temporal  0.39956435561180115
 Eval mse=  0.014739493373781443  dtw=  0.3094957848053683  tdi=  0.6293500000000003
epoch  200  loss  0.17989219725131989  loss shape  -0.004184909630566835  loss temporal  0.3639692962169647
 Eval mse=  0.01305232010781765  dtw=  0.29352626605637117  tdi=  0.7777000000000001
epoch  250  loss  0.18893012404441833  loss shape  0.05948

In [11]:
# ==================== PS LOSS FUNCTIONS ====================
def create_patches(x, patch_len, stride):
    """Create patches from time series data"""
    # x shape: [B, L, C]
    B, L, C = x.shape
    num_patches = (L - patch_len) // stride + 1
    patches = x.unfold(1, patch_len, stride)  # [B, num_patches, C, patch_len]
    patches = patches.permute(0, 2, 1, 3)  # [B, C, num_patches, patch_len]
    return patches


def fourier_based_adaptive_patching(true, pred, patch_len_threshold=20):
    """Determine patch size based on frequency analysis"""
    # true, pred shape: [B, L, C]
    true_fft = torch.fft.rfft(true, dim=1)
    frequency_list = torch.abs(true_fft).mean(0).mean(-1)
    frequency_list[:1] = 0.0
    top_index = torch.argmax(frequency_list)
    period = max(true.shape[1] // max(top_index, 1), 4)
    patch_len = min(period // 2, patch_len_threshold)
    patch_len = max(patch_len, 4)  # Minimum patch length
    stride = max(patch_len // 2, 1)

    # Patching
    true_patch = create_patches(true, patch_len, stride=stride)
    pred_patch = create_patches(pred, patch_len, stride=stride)

    return true_patch, pred_patch


def patch_wise_structural_loss(true_patch, pred_patch):
    """Calculate structural losses at patch level"""
    # true_patch, pred_patch shape: [B, C, num_patches, patch_len]

    # Calculate mean
    true_patch_mean = torch.mean(true_patch, dim=-1, keepdim=True)
    pred_patch_mean = torch.mean(pred_patch, dim=-1, keepdim=True)

    # Calculate variance and standard deviation
    true_patch_var = torch.var(true_patch, dim=-1, keepdim=True, unbiased=False)
    pred_patch_var = torch.var(pred_patch, dim=-1, keepdim=True, unbiased=False)
    true_patch_std = torch.sqrt(true_patch_var + 1e-8)
    pred_patch_std = torch.sqrt(pred_patch_var + 1e-8)

    # Calculate Covariance
    true_pred_patch_cov = torch.mean(
        (true_patch - true_patch_mean) * (pred_patch - pred_patch_mean),
        dim=-1, keepdim=True
    )

    # 1. Calculate linear correlation loss
    patch_linear_corr = (true_pred_patch_cov + 1e-5) / (true_patch_std * pred_patch_std + 1e-5)
    linear_corr_loss = (1.0 - patch_linear_corr).mean()

    # 2. Calculate variance (KL divergence)
    true_patch_softmax = torch.softmax(true_patch, dim=-1)
    pred_patch_softmax = torch.log_softmax(pred_patch, dim=-1)
    kl_loss = torch.nn.functional.kl_div(
        pred_patch_softmax, true_patch_softmax, reduction='none'
    )
    var_loss = kl_loss.sum(dim=-1).mean()

    # 3. Mean loss
    mean_loss = torch.abs(true_patch_mean - pred_patch_mean).mean()

    return linear_corr_loss, var_loss, mean_loss


def ps_loss(true, pred, patch_len_threshold=20, use_dynamic_weighting=False):
    """
    Patch-wise Structural Loss
    Args:
        true: ground truth [B, L, C]
        pred: predictions [B, L, C]
        patch_len_threshold: maximum patch length
        use_dynamic_weighting: whether to use gradient-based dynamic weighting
    """
    # Ensure correct shape [B, L, C]
    if len(true.shape) == 2:
        true = true.unsqueeze(-1)
        pred = pred.unsqueeze(-1)

    # Fourier based adaptive patching
    true_patch, pred_patch = fourier_based_adaptive_patching(
        true, pred, patch_len_threshold
    )

    # Patch-wise structural loss
    corr_loss, var_loss, mean_loss = patch_wise_structural_loss(true_patch, pred_patch)

    # Simple weighted combination (without gradient-based weighting for compatibility)
    if use_dynamic_weighting:
        # Calculate similarities for weighting
        true_mean = torch.mean(true, dim=1, keepdim=True)
        pred_mean = torch.mean(pred, dim=1, keepdim=True)
        true_var = torch.var(true, dim=1, keepdim=True, unbiased=False)
        pred_var = torch.var(pred, dim=1, keepdim=True, unbiased=False)
        true_std = torch.sqrt(true_var + 1e-8)
        pred_std = torch.sqrt(pred_var + 1e-8)

        true_pred_cov = torch.mean((true - true_mean) * (pred - pred_mean), dim=1, keepdim=True)
        linear_sim = (true_pred_cov + 1e-5) / (true_std * pred_std + 1e-5)
        linear_sim = (1.0 + linear_sim) * 0.5
        var_sim = (2 * true_std * pred_std + 1e-5) / (true_var + pred_var + 1e-5)

        # Adaptive weights
        alpha = 1.0
        beta = 1.0
        gamma = torch.mean(linear_sim * var_sim).detach()

        total_ps_loss = alpha * corr_loss + beta * var_loss + gamma * mean_loss
    else:
        # Fixed weights
        total_ps_loss = corr_loss + var_loss + mean_loss

    return total_ps_loss, corr_loss, var_loss, mean_loss


# ==================== TRAINING FUNCTION ====================
def train_model(net, loss_type, learning_rate, epochs=1000, gamma=0.001,
                print_every=50, eval_every=50, verbose=1, Lambda=1, alpha=0.5,
                ps_lambda=1.0, patch_len_threshold=20, use_dynamic_weighting=False):

    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    criterion = torch.nn.MSELoss()

    for epoch in range(epochs):
        for i, data in enumerate(trainloader, 0):
            inputs, target = data
            inputs = torch.tensor(inputs, dtype=torch.float32).to(device)
            target = torch.tensor(target, dtype=torch.float32).to(device)

            # Ensure target has shape [B, L, C]
            if len(target.shape) == 2:
                target = target.unsqueeze(-1)

            # Forward pass
            outputs = net(inputs)

            # Ensure outputs has shape [B, L, C]
            if len(outputs.shape) == 2:
                outputs = outputs.unsqueeze(-1)

            # Initialize losses
            loss_mse = torch.tensor(0.0, device=device)
            loss_shape = torch.tensor(0.0, device=device)
            loss_temporal = torch.tensor(0.0, device=device)
            loss_ps_corr = torch.tensor(0.0, device=device)
            loss_ps_var = torch.tensor(0.0, device=device)
            loss_ps_mean = torch.tensor(0.0, device=device)

            # Calculate loss based on loss_type
            if loss_type == 'mse':
                loss_mse = criterion(target, outputs)
                loss = loss_mse

            elif loss_type == 'dilate':
                loss, loss_shape, loss_temporal = dilate_loss(
                    target, outputs, alpha, gamma, device
                )

            elif loss_type == 'ps':
                loss_ps_total, loss_ps_corr, loss_ps_var, loss_ps_mean = ps_loss(
                    target, outputs, patch_len_threshold, use_dynamic_weighting
                )
                loss = loss_ps_total

            elif loss_type == 'mse+ps':
                # Combined MSE + PS Loss
                loss_mse = criterion(target, outputs)
                loss_ps_total, loss_ps_corr, loss_ps_var, loss_ps_mean = ps_loss(
                    target, outputs, patch_len_threshold, use_dynamic_weighting
                )
                loss = loss_mse + ps_lambda * loss_ps_total

            else:
                raise ValueError(f"Unknown loss_type: {loss_type}")

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Print progress
        if verbose and (epoch % print_every == 0):
            if loss_type in ['dilate', 'dilate+ps']:
                print(f'Epoch {epoch:4d} | Loss: {loss.item():.6f} | '
                      f'Shape: {loss_shape.item():.6f} | Temporal: {loss_temporal.item():.6f}')
            elif loss_type in ['ps', 'mse+ps', 'dilate+ps']:
                print(f'Epoch {epoch:4d} | Loss: {loss.item():.6f} | ')
            else:
                print(f'Epoch {epoch:4d} | MSE Loss: {loss_mse.item():.6f}')

            if epoch % eval_every == 0:
                eval_model(net, testloader, gamma, verbose=1)


def eval_model(net, loader, gamma, verbose=1):
    """Evaluate model on test set"""
    criterion = torch.nn.MSELoss()
    losses_mse = []
    losses_dtw = []
    losses_tdi = []

    net.eval()
    with torch.no_grad():
        for i, data in enumerate(loader, 0):
            inputs, target = data
            inputs = torch.tensor(inputs, dtype=torch.float32).to(device)
            target = torch.tensor(target, dtype=torch.float32).to(device)

            if len(target.shape) == 2:
                target = target.unsqueeze(-1)

            batch_size, N_output = target.shape[0:2]
            outputs = net(inputs)

            if len(outputs.shape) == 2:
                outputs = outputs.unsqueeze(-1)

            # MSE
            loss_mse = criterion(target, outputs)

            # DTW and TDI
            loss_dtw, loss_tdi = 0, 0
            for k in range(batch_size):
                target_k_cpu = target[k, :, 0].detach().cpu().numpy()
                output_k_cpu = outputs[k, :, 0].detach().cpu().numpy()

                path, sim = dtw_path(target_k_cpu, output_k_cpu)
                loss_dtw += sim

                Dist = 0
                for i, j in path:
                    Dist += (i - j) * (i - j)
                loss_tdi += Dist / (N_output * N_output)

            loss_dtw = loss_dtw / batch_size
            loss_tdi = loss_tdi / batch_size

            losses_mse.append(loss_mse.item())
            losses_dtw.append(loss_dtw)
            losses_tdi.append(loss_tdi)

    net.train()

    if verbose:
        print(f'  Eval | MSE: {np.mean(losses_mse):.6f} | '
              f'DTW: {np.mean(losses_dtw):.6f} | TDI: {np.mean(losses_tdi):.6f}')


# Configuration du modèle comme avant
encoder = EncoderRNN(input_size=1, hidden_size=128, num_grulstm_layers=1, batch_size=batch_size).to(device)
decoder = DecoderRNN(input_size=1, hidden_size=128, num_grulstm_layers=1, fc_units=16, output_size=1).to(device)
net_gru_ps = Net_GRU(encoder, decoder, N_output, device).to(device)

train_model(net_gru_ps, loss_type='mse+ps', learning_rate=0.001,
            epochs=500, ps_lambda=1.0, patch_len_threshold=20)


Epoch    0 | Loss: 1.004614 | 
  Eval | MSE: 0.064562 | DTW: 0.709390 | TDI: 0.049800
Epoch   50 | Loss: 0.459615 | 
  Eval | MSE: 0.023211 | DTW: 0.384241 | TDI: 0.363600
Epoch  100 | Loss: 0.370922 | 
  Eval | MSE: 0.018777 | DTW: 0.393808 | TDI: 0.100900
Epoch  150 | Loss: 0.376059 | 
  Eval | MSE: 0.019721 | DTW: 0.416971 | TDI: 0.092100
Epoch  200 | Loss: 0.363192 | 
  Eval | MSE: 0.020267 | DTW: 0.425185 | TDI: 0.103500
Epoch  250 | Loss: 0.371626 | 
  Eval | MSE: 0.020566 | DTW: 0.429255 | TDI: 0.092350
Epoch  300 | Loss: 0.367440 | 
  Eval | MSE: 0.020760 | DTW: 0.431813 | TDI: 0.100100
Epoch  350 | Loss: 0.354366 | 
  Eval | MSE: 0.020693 | DTW: 0.430940 | TDI: 0.113050
Epoch  400 | Loss: 0.376264 | 
  Eval | MSE: 0.020688 | DTW: 0.430864 | TDI: 0.119650
Epoch  450 | Loss: 0.356417 | 
  Eval | MSE: 0.020779 | DTW: 0.432055 | TDI: 0.090700
