In [1]:
from scipy.sparse import coo_array
from scipy.sparse import kron
import scipy

import numpy as np
import torch
import math
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F

import numpy.typing as npt
import os
import pandas as pd
from typing import Optional, List, Dict, Tuple
from utils import *
from sklearn.metrics import roc_auc_score
import time
from scipy.sparse import coo_matrix

def ToScipy(M):
    N1, N2 = M.size(0), M.size(1)
    indices = M.indices().cpu().numpy()
    values = M.values().cpu().numpy()
    adj = coo_array((values, (indices[0], indices[1])), shape=(N1,N2))
    return adj

def ScipyKron(X, Y, convert_X=False, convert_Y=False):
    if convert_X == True:
        X = ToScipy(X.coalesce())
    # import ipdb; ipdb.set_trace()
    if convert_Y == True:
        Y = ToScipy(Y.coalesce())
    return sparse_mx_to_torch_sparse_tensor(kron(X, Y))
    

def FixedPointForward(W_list, X, A_list, V, X_list, phi_list, num_iter=1000, tol=1e-6, trasposed_A=False, compute_dphi=False):
        
        B_list = []
        
        for i in range(len(X_list)):
#             support_1 = torch.spmm(Omega[i], X_list[i])
#             support_1 = torch.spmm(torch.transpose(A_list[i], 0, 1), support_1.T).T
            support_2 = torch.spmm(V, X_list[i])
            #support_2 = torch.spmm(A_list[i], support_2.T).T
            B_list.append(  support_2)
#             B_list.append(X_list[i])
        
        
#         err = 0
        status = 'max itrs reached'
#         W = W_list[0]
#         import ipdb; ipdb.set_trace()
        for iters in range(num_iter):
            # WXA
            err = 0
            X_old = X
            for j in range(len(A_list)):
                W = W_list[j]
                A = A_list[j]
#                 At = A if trasposed_A else torch.transpose(A, 0, 1)
                B = B_list[j]
                phi = phi_list[j]
                X_ = W @ X
                support = torch.spmm(A, X_.T).T

                X = phi(support + B)
            err = torch.norm(X - X_old, np.inf)
            if err < tol:
                status = 'converged'
                break

        if status == 'max itrs reached':
                
                print('Forward Not Converge! Error: %3.5f, tol: %3.5f' % (err, tol))
        
        X_list = []
        dphi = []
        for j in range(len(A_list)):
            W = W_list[j]
            A = A_list[j]
#             At = A if trasposed_A else torch.transpose(A, 0, 1)
            B = B_list[j]
            phi = phi_list[j]
            X_ = W @ X
            support = torch.spmm(A, X_.T).T
            X = phi(support + B)
            X_list.append(X)
            
            if compute_dphi:
                with torch.enable_grad():
                    Z = support + B
                    Z.requires_grad_(True)
                    tmp = phi(Z)
                    dphi.append(torch.autograd.grad(torch.sum(tmp), Z, only_inputs=True)[0])
                    
    

        
#         if compute_dphi:
#             with torch.enable_grad():
#                 for j in range(len(A_list)):
# #                     W = W_list[j]
#                     A = A_list[j]
#                     At = A if trasposed_A else torch.transpose(A, 0, 1)
#                     B = B_list[j]
#                     X = X_list[j]
#                     phi = phi_list[j] 
                           
#                     X_ = W @ X
#                     support = torch.spmm(A, X_.T).T
#                     Z = support + B
#                     Z.requires_grad_(True)
#                     X_new = phi(Z)
#                     dphi.append(torch.autograd.grad(torch.sum(X_new), Z, only_inputs=True)[0])

        return X_list, err, iters, status, dphi
    
def FixedPointBackward(X, idx, M_list, H_list, phi_list, num_iter=1000, tol=1e-6):
    
#         err = 0
        status = 'max itrs reached'
        for iters in range(num_iter):
            # WXA
            X_old = X
            for j in range(len(M_list)):
                M = M_list[j]
                phi = phi_list[j]
#                 import pdb; pdb.set_trace()
                if idx==-1:
                    tmp = torch.spmm(M, X) + H_list[j]
                elif j==idx:
                    tmp = torch.spmm(M, X) + H_list[j]
                else:
                    tmp = torch.spmm(M, X)
                X = phi * tmp
                
            err = torch.norm(X - X_old, np.inf)
            X_old = X
            if err < tol:
                status = 'converged'
                break
        
        if status == 'max itrs reached':
                print('Backward Not Converge! Error: %3.5f, tol: %3.5f' % (err, tol))

        X_list = []
        for j in range(len(M_list)):
            M = M_list[j]
            phi = phi_list[j]
            if idx==-1:
                tmp = torch.spmm(M, X) + H_list[j]
            elif j==idx:
                tmp = torch.spmm(M, X) + H_list[j]
            else:
                tmp = torch.spmm(M, X)
            X = phi * tmp
            X_list.append(X)
            
        return X_list, err, iters, status            

class ImpDynGNN(nn.Module):
    def __init__(self, num_in, num_hid, num_out, num_node, time_steps, kappa=0.99, phi=F.relu, b_direct=False):
        super(ImpDynGNN, self).__init__()
        self.i = num_in
        self.h = num_hid
        self.o = num_out
        self.n = num_node
        self.t = time_steps
        self.k = kappa
        self.direct = b_direct

        self.phi = [F.relu]*self.t
        self.X_0 = Parameter(torch.zeros(self.h, num_node), requires_grad=False)
        self.W = nn.ParameterList([Parameter(torch.FloatTensor(self.h, self.h)) for i in range(self.t)])
#         self.W = nn.ParameterList([Parameter(torch.FloatTensor(self.h, self.h))])
#         self.Omega = nn.ParameterList([Parameter(torch.FloatTensor(self.h, self.i)) for i in range(self.t)])
        self.V = Parameter(torch.FloatTensor(self.h, self.i))
#         self.linear = nn.Linear(self.h, self.o)
        self.classifier = nn.Sequential(
                                    nn.Softplus(),
                                    nn.Linear(self.h, self.h),
                                    nn.Softplus(),
                                    nn.Linear(self.h, self.o),
                                    # nn.LogSoftmax(dim=1)
                                    )
        self.init()

    def init(self):
#         stdv = 0.01
        for i in range(len(self.W)):
            stdv = 1. / (math.sqrt(self.W[i].size(1)))
            self.W[i].data.uniform_(-stdv, stdv)
#             self.Omega[i].data.uniform_(-stdv, stdv)
#         stdv = 1. / self.W[0].size(1)
#         self.W[0].data.uniform_(-stdv, stdv)
        stdv = 1. /self.V.size(1)
        self.V.data.uniform_(-stdv, stdv)

    def forward(self, X_list, A_list, A_rho, fd_mitr=300):

        for i in range(len(self.W)):
            self.W[i] = projection_norm_inf(self.W[i], kappa=self.k / A_rho[i])

#         self.W[0] = projection_norm_inf(self.W[0], kappa=self.k / min(A_rho))

        with torch.no_grad():
            Z_list, err, iters, status, dphi = FixedPointForward(self.W, self.X_0, 
                                                              A_list, self.V,
                                                              X_list, self.phi,
                                                              compute_dphi=True, tol=1e-6)
            self.dphi = dphi

        
#         Z_list = Z_list[-1:] + Z_list[:-1]
        Z_list = [Z.requires_grad_(True) for Z in Z_list]
#         out_list = [torch.max(Z.T, 1, keepdim=True)[0] for Z in Z_list]
        out_list = [self.classifier(Z.T) for Z in Z_list]
        return Z_list, out_list
    
    def backProp(self, Z, X, A, num_iter=600, tol=1e-6):
        W = self.W
#         Omega = self.Omega
        dphi = self.dphi
        V = self.V
        num_z = Z[0].shape[0]*Z[0].shape[1]
        num_w = W[0].shape[0]*W[0].shape[1]
        num_v = V.shape[0]*V.shape[1]
#         num_omega = Omega[0].shape[0]*Omega[0].shape[1]
    
        device = Z[0].device
        w_grad = torch.zeros(num_z, num_w).to(device)
#         o_grad = torch.zeros(num_z, num_omega).to(device)
        v_grad = torch.zeros(num_z, num_v).to(device)
        
        dphi = [torch.flatten(i.T).reshape(-1,1) for i in dphi]
    
    
        M = []
        H_hat_v = []
        H_hat_w = []
#         H_hat_o = []
        # E_w = scipy.sparse.eye(W[0].shape[0])
        # E_v = scipy.sparse.eye(V.shape[0])
        
        E_w = torch.eye(W[0].shape[0]).to(device)
        E_v = torch.eye(V.shape[0]).to(device)
#         E_o = torch.eye(Omega[0].shape[0]).to(device)

#         w = W[0]
#         wT = torch.transpose(w, 0, 1).contiguous()
        for i in range(len(A)):
            aa = A[i]
            x = X[i]
            w = W[i]
            if i==0:
                z = Z[-1]
            else:
                z = Z[i-1]
                
            
            
            aT = torch.transpose(aa, 0, 1).contiguous()
            if x.shape[0] == 1:
                xT = x.reshape(-1,1)
            else:
                xT = torch.transpose(x, 0, 1).contiguous()
            
            zT = torch.transpose(z, 0, 1).contiguous()
            wT = torch.transpose(w, 0, 1).contiguous()
            
            # H = ToScipy(torch.spmm(aT,zT).to_sparse())
            # N = ToScipy(torch.spmm(aT,xT).to_sparse())
            H = torch.spmm(aT,zT).to(device)
            # N = torch.spmm(aT,xT).to(device)
            # import ipdb; ipdb.set_trace()
            
            # M.append(ScipyKron(wT.cpu().numpy(), aa , convert_X=False, convert_Y=True).to(device))
            # H_hat_v.append(ScipyKron(xT.cpu().numpy(), E_v).to(device))
            # H_hat_w.append(ScipyKron(H, E_w).to(device))
            
            M.append(torch.kron(wT, aa.to_dense()).to(device))
            H_hat_v.append(torch.kron(xT, E_v).to(device))
            H_hat_w.append(torch.kron(H, E_w).to(device))
#             H_hat_o.append(torch.kron(N, E_o).to(device))
    
        for i in range(len(W)):
            W_grad_list, _, _, status = FixedPointBackward(w_grad, i, M, H_hat_w, dphi, num_iter=num_iter, tol=tol)
            
#             O_grad_list, _, _, _ = FixedPointBackward(o_grad, i, M, H_hat_o, dphi, num_iter=num_iter, tol=tol)
            W_grad, O_grad = 0, 0
            for j in range(len(Z)):
                if Z[j].grad is None:
                    z = torch.zeros_like(Z[j]).reshape(1,-1).to(device)
                else:
                    z = torch.flatten(Z[j].grad.T).reshape(1,-1)
                W_grad += z @ W_grad_list[j]
#                 O_grad += z @ O_grad_list[j]
                
            W[i].grad = W_grad.reshape(W[i].shape[1], W[i].shape[0]).T
#             Omega[i].grad = O_grad.reshape(Omega[i].shape[1], Omega[i].shape[0]).T
        
        V_grad_list, _, _, _ = FixedPointBackward(v_grad, -1, M, H_hat_v, dphi, num_iter=num_iter, tol=tol)
        V_grad = 0
        for j in range(len(Z)):
                if Z[j].grad is None:
                    z = torch.zeros_like(Z[j]).reshape(1,-1).to(device)
                else:
                    z = torch.flatten(Z[j].grad.T).reshape(1,-1)
                V_grad += z @ V_grad_list[j]
        V.grad = V_grad.reshape(V.shape[1], V.shape[0]).T
        
        
R"""
"""
#




class PeMS(object):
    R"""
    PeMS dataset.
    """
    #
    DISTRICT: str

    def __init__(
        self,
        dirname: str,
        /,
        *,
        aug_minutes: bool, aug_weekdays: bool,
    ) -> None:
        R"""
        Initialize the class.
        """
        #
        self.from_raw(dirname)
        self.sanitize_edge()

        #
        self.raw_nodes: npt.NDArray[np.generic]

        # Augment global features by exact timestamps.
        # Gap between different steps are 5 minutes, and we use hour as
        # timestamp unit.
        (num_timestamps, num_nodes, _) = self.raw_nodes.shape
        self.timestamps = np.arange(num_timestamps) * 5.0 / 24.0

        # Augment node features by minutes.
        # Gap between different steps are 5 minutes.
        if aug_minutes:
            #
            (num_timestamps, num_nodes, _) = self.raw_nodes.shape
            num_day_minutes = 60 // 5 * 24
            num_days = (
                int(np.ceil(float(num_timestamps) / float(num_day_minutes)))
            )
            day_minutes = np.arange(num_day_minutes) * 5
            minutes = np.tile(day_minutes, num_days)[:num_timestamps]
            minutes = minutes.astype(self.raw_nodes.dtype)
            minutes = np.reshape(minutes, (num_timestamps, 1, 1))
            minutes = np.tile(minutes, (1, num_nodes, 1))
            self.raw_nodes = np.concatenate([self.raw_nodes, minutes], 2)

        # Augment node features by weekdays.
        # Gap between different steps are 5 minutes.
        if aug_weekdays:
            #
            (num_timestamps, num_nodes, _) = self.raw_nodes.shape
            num_day_minutes = 60 // 5 * 24
            num_week_minutes = num_day_minutes * 7
            num_weeks = (
                int(np.ceil(float(num_timestamps) / float(num_week_minutes)))
            )
            weekdays = np.repeat(np.arange(7), num_day_minutes)
            weekdays = np.tile(weekdays, num_weeks)[:num_timestamps]
            weekdays = weekdays.astype(self.raw_nodes.dtype)
            weekdays = np.reshape(weekdays, (num_timestamps, 1, 1))
            weekdays = np.tile(weekdays, (1, num_nodes, 1))
            self.raw_nodes = np.concatenate([self.raw_nodes, weekdays], 2)

    def from_raw(self, dirname: str, /) -> None:
        R"""
        Load from raw data.
        """
        #
        file_edges = "distance.csv"
        file_nodes = "pems{:s}.npz".format(self.DISTRICT)
        raw_edges = pd.read_csv(os.path.join(dirname, file_edges))
        self.raw_edge_srcs = raw_edges["from"].to_numpy()
        self.raw_edge_dsts = raw_edges["to"].to_numpy()
        self.raw_edge_feats = raw_edges["cost"].to_numpy()
        self.raw_nodes = np.load(os.path.join(dirname, file_nodes))["data"]

    def sanitize_edge(self, /) -> None:
        R"""
        Santiize edge data.
        """
        #
        collects: Dict[Tuple[int, int], List[float]]

        # Remove dirty duplications.
        # Duplications are same undirected connections regardless of edge
        # weights (edge weights should be the same for those duplications).
        collects = {}
        for (src, dst, feat) in (
            zip(self.raw_edge_srcs, self.raw_edge_dsts, self.raw_edge_feats)
        ):
            key = (src.item(), dst.item())
            key = (min(key), max(key))
            val = feat.item()
            if key in collects:
                collects[key].append(val)
            else:
                collects[key] = [val]
        edge_srcs_buf = []
        edge_dsts_buf = []
        edge_feats_buf = []
        for ((src, dst), feats) in collects.items():
            #
            edge_srcs_buf.append(src)
            edge_dsts_buf.append(dst)
            edge_feats_buf.append(sum(feats) / len(feats))
            if min(feats) != max(feats):
                # UNEXPECT:
                # Duplicate edges have different edge features.
                raise NotImplementedError(
                    "PeMS duplicate edges have different edge features.",
                )
        self.edge_srcs = np.array(edge_srcs_buf)
        self.edge_dsts = np.array(edge_dsts_buf)
        self.edge_feats = np.array(edge_feats_buf)
        self.edge_hetero = False
        self.edge_symmetric = True

        #
        if not np.all(self.edge_feats > 0):
            # UNEXPECT:
            # Edge features as weights must be positive.
            raise NotImplementedError(
                "PeMS{:s} edge weights is not all-positive."
                .format(self.DISTRICT),
            )



class PeMS04(PeMS):
    R"""
    PeMS (district 4) dataset.
    """
    #
    DISTRICT = "04"


class PeMS08(PeMS):
    R"""
    PeMS (district 8) dataset.
    """
    #
    DISTRICT = "08"
    
def evaluate(output, target):
    mse = torch.mean((output - target) ** 2, dim=1)
    rmse = torch.sqrt(mse)
    mape = (
        torch.mean(torch.abs(output - target) / (torch.abs(target) + 1), dim=1)
    )
    return [
            (len(mse), torch.sum(mse).item()),
            (len(rmse), torch.sum(rmse).item()),
            (len(mape), torch.sum(mape).item()),
        ]

from scipy.sparse import coo_array

device = 'cuda:3' if torch.cuda.is_available() else 'cpu'
data = PeMS04('ICML2022Code/src/PeMS04/', aug_minutes=True, aug_weekdays=True)
A_list = []
N = data.raw_nodes.shape[1]
X = data.raw_nodes
# X_list = [torch.tensor(x).view(-1,1).to(device) for x in X]
# for src, dst, weight in zip(data.edge_srcs, data.edge_dsts, data.edge_feats):
srcs = data.edge_srcs
dsts = data.edge_dsts
src = np.concatenate([srcs, dsts])
dst = np.concatenate([dsts, srcs])
weight = np.concatenate([data.edge_feats, data.edge_feats])
adj = aug_normalized_adjacency(coo_array((weight, (src, dst)), shape=(N,N)))
A = sparse_mx_to_torch_sparse_tensor(adj).to(device)
A_rho = get_spectral_rad(A)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class bimodel(nn.Module):
    def __init__(self, num_in, num_hid, num_out, num_node, time_steps, kappa=0.99, phi=F.relu, b_direct=False):
        super(bimodel, self).__init__()
        self.i = num_in
        self.h = num_hid
        self.o = num_out
        self.n = num_node
        self.t = time_steps
        self.k = kappa
        self.direct = b_direct

        self.phi = [F.relu]*self.t
        self.X_0 = Parameter(torch.zeros(self.h, num_node), requires_grad=False)
        self.W = nn.ParameterList([Parameter(torch.FloatTensor(self.h, self.h)) for i in range(self.t)])
#         self.W = nn.ParameterList([Parameter(torch.FloatTensor(self.h, self.h))])
#         self.Omega = nn.ParameterList([Parameter(torch.FloatTensor(self.h, self.i)) for i in range(self.t)])
        self.V = Parameter(torch.FloatTensor(self.h, self.i))
#         self.linear = nn.Linear(self.h, self.o)
        self.classifier = nn.Sequential(
#                                     nn.Dropout(p=0.3)
                                    nn.Linear(self.h, self.h),
#                                     nn.BatchNorm1d(self.h),
                                    nn.Softplus(),
                                    nn.Linear(self.h, self.o),
                                    # nn.LogSoftmax(dim=1)
                                    )
        self.init()

    def init(self):
#         stdv = 0.01
        for i in range(len(self.W)):
            stdv = 1. / (math.sqrt(self.W[i].size(1)))
            self.W[i].data.uniform_(-stdv, stdv)
#             self.Omega[i].data.uniform_(-stdv, stdv)
#         stdv = 1. / self.W[0].size(1)
#         self.W[0].data.uniform_(-stdv, stdv)
        stdv = 1. /self.V.size(1)
        self.V.data.uniform_(-stdv, stdv)
        
    def project(self,X_list, A_list, A_rho):
        
        self.X_list = X_list
        self.A_list = A_list
        for i in range(len(self.W)):
            self.W[i] = projection_norm_inf(self.W[i], kappa=self.k / A_rho[i])
            
        
    def forward(self,Z):

        X_list = self.X_list
        A_list = self.A_list
        V = self.V
        W_list = self.W
        phi = F.relu

        
        for j in range(len(A_list)):
                W = W_list[j]
                A = A_list[j]
                B = torch.spmm(V, X_list[j])
                Z_ = W @ Z
                support = torch.spmm(A, Z_.T).T

                Z = phi(support + B)
        return Z
    
    
    
    
    def predict(self, X_list, A_list, A_rho):
        # for i in range(len(self.W)):
        #     self.W[i] = projection_norm_inf(self.W[i], kappa=self.k / A_rho[i])
        
        V = self.V
        W_list = self.W
        phi = F.relu
        max_iter = 300
        tol = 1e-6
        
        device = W_list[0].device
        Z = torch.zeros(self.X_0.shape).to(device)
        
        status = 'max itrs reached'
        for i in range(max_iter):
            Z_old = Z
            for j in range(len(A_list)):
                    W = W_list[j]
                    A = A_list[j]
                    B = torch.spmm(V, X_list[j])
                    Z_ = W @ Z_old
                    support = torch.spmm(A, Z_.T).T

                    Z = phi(support + B)
            
            err = torch.norm(Z - Z_old, np.inf)
            if err < tol:
                status = 'converged'
                break

        if status == 'max itrs reached':
                
                print('Forward Not Converge! Error: %3.5f, tol: %3.5f' % (err, tol))
            
        Y_hat = self.classifier(Z.T)
        return Y_hat
        
            

In [None]:
# Inductive ##########################################
import random
from tqdm import tqdm
import itertools

idx = list(range(16980))
train_indices = [idx[i*10:(i*10+7)] for i in range(1698)]
# train_indices.append(idx[17840:])
train_indices = list(itertools.chain.from_iterable(train_indices))

val_indices = [idx[i*10+7:(i*10+8)] for i in range(1698)]
val_indices = list(itertools.chain.from_iterable(val_indices))

test_indices = [idx[i*10+8:(i+1)*10] for i in range(1698)]
test_indices = list(itertools.chain.from_iterable(test_indices))

X_train = X[train_indices]
X_min = np.min(X_train)
X_max = np.max(X_train)
X_new = (X-X_min)/(X_max-X_min)
X_list = [torch.tensor(x.T, dtype=torch.float).to(device) for x in X_new]


###################################### Train ############################################
bests = []
for j in range(5):
    model = bimodel(5, 16,3, 307, 12, kappa=0.99).to(device)
    optimizer = torch.optim.AdamW(model.parameters(),
                       lr=0.01, weight_decay=1e-5)
    
    Z_all = [torch.zeros(16*307)]*len(idx)
    V_all = [torch.zeros(16*307)]*len(idx)
    eta_1 = 0.9
    eta_2 = 0.001
    
    for i in range(10):
        losses = []
        random.shuffle(train_indices)
        model.train()
        with tqdm(train_indices) as tq:
            for idx in tq:
                A_list_tmp = [A]*12
                A_rho_tmp = [A_rho]*12
                X_list_tmp = X_list[idx:idx+12]
                y = X_list[idx+12][:3].T

                

                # import ipdb; ipdb.set_trace()

                model.project(X_list_tmp, A_list_tmp, A_rho_tmp)
                
                for k in range(1):
                    
                    Z_0 = Z_all[idx]
                    V_0 = V_all[idx]
                    
                    Z_1 = Z_0.to(device).requires_grad_(True)
                    Z = model(Z_1.view(16, -1))

                    Y_hat = model.classifier(Z.T)
                    loss = torch.square(Y_hat-y).mean()
                    loss2 = torch.norm(Z_1-Z.reshape(-1), 2)

                    z_grad = torch.autograd.grad(loss, Z_1, retain_graph=True)

                    Z_0 = (1 - eta_1)*Z_0 + eta_1*Z.reshape(-1).detach().cpu()

                    g_z_grad = torch.autograd.grad(loss2, Z_1, retain_graph=True, create_graph=True)

                    hv = torch.inner(g_z_grad[0],V_0.to(device))
                    phi_v = torch.autograd.grad(hv, Z_1, retain_graph=True)

                    V_0 = V_0 - eta_2*phi_v[0].cpu() + eta_2*z_grad[0].cpu()
                    # V_0 = V_0 - eta_2*V_0 + eta_2*z_grad[0].cpu()

                    loss3 = -torch.inner(g_z_grad[0],V_0.to(device))
                    loss3.backward(retain_graph=True)
                    loss.backward()

                    optimizer.step()
                    optimizer.zero_grad()

                    Z_all[idx] = Z_0
                    V_all[idx] = V_0
                    # import ipdb; ipdb.set_trace()
                losses.append(loss.item())
            print('epoch',i,f'loss {np.mean(losses) :.6f}')

####################################### Test ##############################################
            y_true = []
            y_pred = []
            model.eval()
            for idx in test_indices:
                A_list_tmp = [A]*12
                A_rho_tmp = [A_rho]*12
                X_list_tmp = X_list[idx:idx+12]
                y = X_list[idx+12][:3].T

                # model.train()
                with torch.no_grad():
                    out_Z = model.predict(X_list_tmp, A_list_tmp, A_rho_tmp)
                # import ipdb; ipdb.set_trace()
                y_pred.append(out_Z.detach())
                y_true.append(y)

            y_true = torch.concat(y_true, 0)
            y_pred = torch.concat(y_pred, 0)
            results = evaluate(y_pred, y_true)
            mse = results[0][1]/results[0][0]
            mape = results[2][1]/results[2][0]
            print(f'mse: {mse:.4f}, mape: {mape:.4f}')
        bests.append(mape)
    
print(np.mean(bests), np.std(bests))

100%|█████████████████████████████████████████████████████████████████████████████████████| 11886/11886 [10:10<00:00, 19.46it/s]


epoch 0 loss 0.000291
mse: 0.0010, mape: 0.0204


100%|█████████████████████████████████████████████████████████████████████████████████████| 11886/11886 [09:49<00:00, 20.17it/s]


epoch 1 loss 0.000201
mse: 0.0004, mape: 0.0083


100%|█████████████████████████████████████████████████████████████████████████████████████| 11886/11886 [09:46<00:00, 20.28it/s]


epoch 2 loss 0.000193
mse: 0.0002, mape: 0.0056


100%|█████████████████████████████████████████████████████████████████████████████████████| 11886/11886 [10:23<00:00, 19.06it/s]


epoch 3 loss 0.000192
mse: 0.0003, mape: 0.0055


100%|█████████████████████████████████████████████████████████████████████████████████████| 11886/11886 [09:39<00:00, 20.51it/s]


epoch 4 loss 0.000191
mse: 0.0002, mape: 0.0060


100%|█████████████████████████████████████████████████████████████████████████████████████| 11886/11886 [09:44<00:00, 20.33it/s]


epoch 5 loss 0.000190
mse: 0.0002, mape: 0.0056


100%|█████████████████████████████████████████████████████████████████████████████████████| 11886/11886 [09:42<00:00, 20.39it/s]


epoch 6 loss 0.000190
mse: 0.0002, mape: 0.0056


100%|█████████████████████████████████████████████████████████████████████████████████████| 11886/11886 [09:43<00:00, 20.38it/s]


epoch 7 loss 0.000190
mse: 0.0003, mape: 0.0068


100%|█████████████████████████████████████████████████████████████████████████████████████| 11886/11886 [09:44<00:00, 20.35it/s]


epoch 8 loss 0.000189
mse: 0.0003, mape: 0.0073


  8%|███████▏                                                                               | 988/11886 [00:48<08:54, 20.38it/s]