In [1]:
import os
import argparse
import json
import numpy as np
import torch
import torch.nn as nn

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import umap                        

from utils.util import find_max_epoch, print_size, training_loss, calc_diffusion_hyperparams
from utils.util import get_mask_mnr, get_mask_bm, get_mask_rm

from imputers.DiffWaveImputer import DiffWaveImputer
from imputers.SSSDSAImputer import SSSDSAImputer
from imputers.SSSDS4Imputer import SSSDS4Imputer


CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%
Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency.


In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='./config/config_SSSDS4-sp500.json')  
# parser.add_argument('-c', '--config', type=str, default='./config/config_DiffWave-sp500.json') 
                                                                                                                                                                                                                               
args = parser.parse_args(args=[])

with open(args.config) as f:
    data = f.read()

config = json.loads(data)

In [3]:
train_config = config["train_config"]  # training parameters

global trainset_config
trainset_config = config["trainset_config"]  # to load trainset

global diffusion_config
diffusion_config = config["diffusion_config"]  # basic hyperparameters

global diffusion_hyperparams
diffusion_hyperparams = calc_diffusion_hyperparams(
                      **diffusion_config)  # dictionary of all diffusion hyperparameters

global model_config

if train_config['use_model'] == 0:
    model_config = config['wavenet_config']
elif train_config['use_model'] == 1:
    model_config = config['sashimi_config']
elif train_config['use_model'] == 2:
    model_config = config['wavenet_config']
       
model_config['num_res_layers']=18           

In [4]:
output_directory = './results/mujoco/90'
ckpt_iter = 'max'
n_iters = 10000
iters_per_ckpt = 100
iters_per_logging = 100
# batch_size_per_gpu              
learning_rate = 0.0002
use_model = 2
only_generate_missing = 1       
masking = 'rm'               
missing_k = 20

In [5]:
local_path = "T{}_beta0{}_betaT{}".format(diffusion_config["T"],
                                              diffusion_config["beta_0"],
                                              diffusion_config["beta_T"])

output_directory = os.path.join(output_directory, local_path)
if not os.path.isdir(output_directory):
    os.makedirs(output_directory)
    os.chmod(output_directory, 0o775)               
print("output directory", output_directory, flush=True)

# map diffusion hyperparameters to gpu
for key in diffusion_hyperparams:
    if key != "T":
        diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda()
        
        
# predefine model
if use_model == 0:
    net = DiffWaveImputer(**model_config).cuda()
elif use_model == 1:
    net = SSSDSAImputer(**model_config).cuda()
elif use_model == 2:
    net = SSSDS4Imputer(**model_config).cuda()    
else:
    print('Model chosen not available.')
    
print_size(net)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

output directory ./results/mujoco/90\T200_beta00.0001_betaT0.02
SSSDS4Imputer Parameters: 7.547142M


In [6]:
training_data = np.load(trainset_config['train_data_path'])
np.random.shuffle(training_data)
print(training_data.shape)

(5775, 30, 6)


In [7]:
training_data = np.split(training_data, 55, 0)      
print(len(training_data), training_data[0].shape)

training_data = np.array(training_data)
training_data = torch.from_numpy(training_data).float().cuda()
print(training_data.shape)


55 (105, 30, 6)
torch.Size([55, 105, 30, 6])


In [8]:
missing_k = 20                    
iters = 0
loss_list=[]

# while iters < n_iters + 1:

#     for batch in training_data:             
#         transposed_mask = get_mask_rm(batch[0], missing_k)     
#         mask = transposed_mask.permute(1, 0)    
#         mask = mask.repeat(batch.size()[0], 1, 1).float().cuda()  
#         loss_mask = ~mask.bool()        
#         batch = batch.permute(0, 2, 1)   
#         optimizer.zero_grad()
#         X = batch, batch, mask, loss_mask    
#         loss = training_loss(net, nn.MSELoss(), X, diffusion_hyperparams,
#                              only_generate_missing=only_generate_missing)

#         loss.backward()
#         optimizer.step()
           
#         if iters % iters_per_logging == 0:
#                 print("iteration: {} \tloss: {}".format(iters, loss.item()))
                
#         loss_list.append(loss.item())
        
#         iters += 1
  

In [9]:
# torch.save( net.state_dict(),"./sp500_S4_iter_10000.pth" )

In [10]:
net = SSSDS4Imputer(**model_config).cuda()
net.load_state_dict(torch.load( "./sp500_S4_iter_10000.pth"))

# net.eval()

<All keys matched successfully>

In [11]:
def std_normal(size):
    """
    Generate the standard Gaussian variable of a certain size
    """
    return torch.normal(0, 1, size=size).cuda()


size=(5775,6,30)

_dh = diffusion_hyperparams
T, Alpha, Alpha_bar, Sigma = _dh["T"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"]

assert len(Alpha) == T
assert len(Alpha_bar) == T
assert len(Sigma) == T
assert len(size) == 3

# print(T, Alpha)

x = std_normal(size)      # xT  
cond = 1
mask = 1


In [12]:
import matplotlib.pyplot as plt

# training_data = np.load(trainset_config['train_data_path'])
# training_data = np.array(training_data)
# x = torch.from_numpy(training_data).float().cuda()
# x = x[30:31].permute(0,2,1)
# print(x.shape)


In [13]:
# inverted_x = x

# with torch.no_grad():        
#     for t in range(0, T-1, 1):                      
#         diffusion_steps = (t * torch.ones((1, 1))).cuda()          
#         epsilon_theta = net((inverted_x, cond, mask, diffusion_steps)) 
#         inverted_x = ( torch.sqrt(Alpha_bar[t+1]) * (inverted_x - torch.sqrt(1-Alpha_bar[t]) * epsilon_theta) ) / torch.sqrt(Alpha_bar[t]) + \
#                                                                                                      torch.sqrt(1-Alpha_bar[t+1])* epsilon_theta

In [14]:
# criterion = nn.MSELoss(reduction="mean")
# RMSE = torch.sqrt(criterion(x, inverted_x))

In [15]:
# recove = inverted_x
# with torch.no_grad():        
#     for t in range(T-1, 0, -1):                       
#         diffusion_steps = (t * torch.ones((1, 1))).cuda()          
#         epsilon_theta = net((recove, cond, mask, diffusion_steps))  
#         recove = ( torch.sqrt(Alpha_bar[t-1]) * (recove - torch.sqrt(1-Alpha_bar[t]) * epsilon_theta) ) / torch.sqrt(Alpha_bar[t]) + \
#                                                                                               torch.sqrt(1-Alpha_bar[t-1])* epsilon_theta

In [16]:
# criterion = nn.MSELoss(reduction="mean")
# RMSE = torch.sqrt(criterion(x, recove))
# print(RMSE)

In [17]:
# x_plot = x.permute(0, 2, 1).cpu().numpy()
# recove_plot = recove.permute(0, 2, 1).cpu().numpy()
# inverted_plot = inverted_x.permute(0, 2, 1).cpu().numpy()


# fig = plt.figure(figsize=(7,7))

# plt.subplot(3,2,1)
# plt.plot(x_plot[0][:,3])       
# plt.subplot(3,2,2)
# plt.plot(recove_plot[0][:,3])      
# plt.subplot(3,2,3)
# plt.plot(recove_plot[0][:,:])  
# plt.subplot(3,2,4)
# plt.plot(recove_plot[0][:,:]) 
# plt.subplot(3,2,5)
# plt.plot(inverted_plot[0][:,:]) 

In [18]:
from torch.autograd import Variable
from pypots.data import load_specific_dataset, mcar, masked_fill
from pypots.imputation import MRNN,BRITS,SAITS
from pypots.utils.metrics import cal_mse,cal_rmse,cal_mae
from sklearn.preprocessing import MinMaxScaler,StandardScaler

In [19]:
training_data = np.load(trainset_config['train_data_path'])
training_data = np.array(training_data)
x = torch.from_numpy(training_data).float().cuda()
x = x[30:31]

np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)

X_ori, X_missed, missing_mask, indicating_mask = mcar(x, 0.3, np.nan)   

In [20]:
oneshot_data = np.loadtxt('AAL.csv',delimiter = ",")
norm_data = MinMaxScaler().fit_transform(oneshot_data)
x_oneshot = np.array(norm_data[10:40])
x_oneshot = torch.from_numpy(x_oneshot).float().cuda()
x_oneshot = x_oneshot.unsqueeze(dim=0)
print(x_oneshot.shape)

X_ori, X_missed, missing_mask, indicating_mask = mcar(x_oneshot, 0.3, np.nan)
print(X_ori[0][0])

torch.Size([1, 30, 6])
tensor([0.3361, 0.3359, 0.0015, 0.3310, 0.3304, 0.3304], device='cuda:0')


In [21]:
import gc
gc.collect()
torch.cuda.empty_cache()

rmse_list =[]
imputation_list = []

for i in range(1):            
    
    saits = SAITS(n_steps=30, n_features=6, n_layers=2, d_model=128, d_inner=64, n_heads=4, d_k=64, d_v=64, dropout=0.1, batch_size=10, epochs=1000)
    
#     saits = BRITS(n_steps=30, n_features=6, rnn_hidden_size=64, batch_size=1, epochs=50)
#     saits = MRNN(n_steps=30, n_features=6, rnn_hidden_size=64, batch_size=1, epochs=100)
    
    dataset = {"X": X_missed}
    
    saits.fit(dataset)               # train the model. Here I use the whole dataset as the training set, because ground truth is not visible to the model.
    imputation = saits.impute(dataset)  # impute the originally-missing values and artificially-missing values
    imputation_list.append(imputation)
    
    imputation = torch.from_numpy(imputation).cuda()

    mae = cal_mae(imputation, X_ori, indicating_mask)    
    mse = cal_mse(imputation, X_ori, indicating_mask)
    rmse = cal_rmse(imputation, X_ori, indicating_mask)
    print(mae,mse, rmse)
    
    rmse_list.append(rmse)
    

2024-11-26 22:04:07 [INFO]: No given device, using default device: cuda
2024-11-26 22:04:07 [INFO]: saving_path not given. Model files and tensorboard file will not be saved.
2024-11-26 22:04:07 [INFO]: Model initialized successfully with the number of trainable parameters: 597,780
2024-11-26 22:04:08 [INFO]: epoch 0: training loss 0.8778
2024-11-26 22:04:08 [INFO]: epoch 1: training loss 0.7553
2024-11-26 22:04:08 [INFO]: epoch 2: training loss 0.6656
2024-11-26 22:04:08 [INFO]: epoch 3: training loss 0.5323
2024-11-26 22:04:08 [INFO]: epoch 4: training loss 0.5394
2024-11-26 22:04:08 [INFO]: epoch 5: training loss 0.4927
2024-11-26 22:04:08 [INFO]: epoch 6: training loss 0.4734
2024-11-26 22:04:08 [INFO]: epoch 7: training loss 0.4020
2024-11-26 22:04:08 [INFO]: epoch 8: training loss 0.3823
2024-11-26 22:04:08 [INFO]: epoch 9: training loss 0.3892
2024-11-26 22:04:08 [INFO]: epoch 10: training loss 0.4543
2024-11-26 22:04:08 [INFO]: epoch 11: training loss 0.4388
2024-11-26 22:04:08

2024-11-26 22:04:10 [INFO]: epoch 134: training loss 0.0958
2024-11-26 22:04:10 [INFO]: epoch 135: training loss 0.0961
2024-11-26 22:04:10 [INFO]: epoch 136: training loss 0.1025
2024-11-26 22:04:10 [INFO]: epoch 137: training loss 0.0993
2024-11-26 22:04:10 [INFO]: epoch 138: training loss 0.1085
2024-11-26 22:04:10 [INFO]: epoch 139: training loss 0.1242
2024-11-26 22:04:10 [INFO]: epoch 140: training loss 0.1116
2024-11-26 22:04:10 [INFO]: epoch 141: training loss 0.0960
2024-11-26 22:04:10 [INFO]: epoch 142: training loss 0.1132
2024-11-26 22:04:10 [INFO]: epoch 143: training loss 0.1046
2024-11-26 22:04:10 [INFO]: epoch 144: training loss 0.0858
2024-11-26 22:04:10 [INFO]: epoch 145: training loss 0.1287
2024-11-26 22:04:10 [INFO]: epoch 146: training loss 0.1189
2024-11-26 22:04:10 [INFO]: epoch 147: training loss 0.0988
2024-11-26 22:04:10 [INFO]: epoch 148: training loss 0.0925
2024-11-26 22:04:10 [INFO]: epoch 149: training loss 0.1114
2024-11-26 22:04:10 [INFO]: epoch 150: t

2024-11-26 22:04:12 [INFO]: epoch 271: training loss 0.0539
2024-11-26 22:04:12 [INFO]: epoch 272: training loss 0.0560
2024-11-26 22:04:12 [INFO]: epoch 273: training loss 0.0599
2024-11-26 22:04:12 [INFO]: epoch 274: training loss 0.0728
2024-11-26 22:04:12 [INFO]: epoch 275: training loss 0.0589
2024-11-26 22:04:12 [INFO]: epoch 276: training loss 0.0537
2024-11-26 22:04:12 [INFO]: epoch 277: training loss 0.0565
2024-11-26 22:04:12 [INFO]: epoch 278: training loss 0.0598
2024-11-26 22:04:12 [INFO]: epoch 279: training loss 0.0564
2024-11-26 22:04:12 [INFO]: epoch 280: training loss 0.0580
2024-11-26 22:04:12 [INFO]: epoch 281: training loss 0.0583
2024-11-26 22:04:12 [INFO]: epoch 282: training loss 0.0568
2024-11-26 22:04:12 [INFO]: epoch 283: training loss 0.0576
2024-11-26 22:04:12 [INFO]: epoch 284: training loss 0.0496
2024-11-26 22:04:12 [INFO]: epoch 285: training loss 0.0648
2024-11-26 22:04:12 [INFO]: epoch 286: training loss 0.0583
2024-11-26 22:04:12 [INFO]: epoch 287: t

2024-11-26 22:04:14 [INFO]: epoch 408: training loss 0.0388
2024-11-26 22:04:14 [INFO]: epoch 409: training loss 0.0394
2024-11-26 22:04:14 [INFO]: epoch 410: training loss 0.0367
2024-11-26 22:04:14 [INFO]: epoch 411: training loss 0.0396
2024-11-26 22:04:14 [INFO]: epoch 412: training loss 0.0381
2024-11-26 22:04:14 [INFO]: epoch 413: training loss 0.0394
2024-11-26 22:04:14 [INFO]: epoch 414: training loss 0.0399
2024-11-26 22:04:14 [INFO]: epoch 415: training loss 0.0475
2024-11-26 22:04:14 [INFO]: epoch 416: training loss 0.0406
2024-11-26 22:04:14 [INFO]: epoch 417: training loss 0.0380
2024-11-26 22:04:14 [INFO]: epoch 418: training loss 0.0414
2024-11-26 22:04:14 [INFO]: epoch 419: training loss 0.0455
2024-11-26 22:04:14 [INFO]: epoch 420: training loss 0.0468
2024-11-26 22:04:14 [INFO]: epoch 421: training loss 0.0395
2024-11-26 22:04:14 [INFO]: epoch 422: training loss 0.0404
2024-11-26 22:04:14 [INFO]: epoch 423: training loss 0.0437
2024-11-26 22:04:14 [INFO]: epoch 424: t

2024-11-26 22:04:16 [INFO]: epoch 545: training loss 0.0414
2024-11-26 22:04:16 [INFO]: epoch 546: training loss 0.0475
2024-11-26 22:04:16 [INFO]: epoch 547: training loss 0.0357
2024-11-26 22:04:16 [INFO]: epoch 548: training loss 0.0394
2024-11-26 22:04:16 [INFO]: epoch 549: training loss 0.0441
2024-11-26 22:04:16 [INFO]: epoch 550: training loss 0.0421
2024-11-26 22:04:16 [INFO]: epoch 551: training loss 0.0403
2024-11-26 22:04:16 [INFO]: epoch 552: training loss 0.0470
2024-11-26 22:04:16 [INFO]: epoch 553: training loss 0.0411
2024-11-26 22:04:16 [INFO]: epoch 554: training loss 0.0382
2024-11-26 22:04:16 [INFO]: epoch 555: training loss 0.0438
2024-11-26 22:04:16 [INFO]: epoch 556: training loss 0.0425
2024-11-26 22:04:16 [INFO]: epoch 557: training loss 0.0332
2024-11-26 22:04:16 [INFO]: epoch 558: training loss 0.0442
2024-11-26 22:04:16 [INFO]: epoch 559: training loss 0.0415
2024-11-26 22:04:16 [INFO]: epoch 560: training loss 0.0345
2024-11-26 22:04:16 [INFO]: epoch 561: t

2024-11-26 22:04:17 [INFO]: epoch 682: training loss 0.0360
2024-11-26 22:04:17 [INFO]: epoch 683: training loss 0.0292
2024-11-26 22:04:17 [INFO]: epoch 684: training loss 0.0385
2024-11-26 22:04:17 [INFO]: epoch 685: training loss 0.0366
2024-11-26 22:04:17 [INFO]: epoch 686: training loss 0.0392
2024-11-26 22:04:17 [INFO]: epoch 687: training loss 0.0312
2024-11-26 22:04:17 [INFO]: epoch 688: training loss 0.0306
2024-11-26 22:04:17 [INFO]: epoch 689: training loss 0.0304
2024-11-26 22:04:18 [INFO]: epoch 690: training loss 0.0320
2024-11-26 22:04:18 [INFO]: epoch 691: training loss 0.0334
2024-11-26 22:04:18 [INFO]: epoch 692: training loss 0.0303
2024-11-26 22:04:18 [INFO]: epoch 693: training loss 0.0313
2024-11-26 22:04:18 [INFO]: epoch 694: training loss 0.0338
2024-11-26 22:04:18 [INFO]: epoch 695: training loss 0.0338
2024-11-26 22:04:18 [INFO]: epoch 696: training loss 0.0381
2024-11-26 22:04:18 [INFO]: epoch 697: training loss 0.0321
2024-11-26 22:04:18 [INFO]: epoch 698: t

2024-11-26 22:04:19 [INFO]: epoch 819: training loss 0.0256
2024-11-26 22:04:19 [INFO]: epoch 820: training loss 0.0263
2024-11-26 22:04:19 [INFO]: epoch 821: training loss 0.0290
2024-11-26 22:04:19 [INFO]: epoch 822: training loss 0.0301
2024-11-26 22:04:19 [INFO]: epoch 823: training loss 0.0291
2024-11-26 22:04:19 [INFO]: epoch 824: training loss 0.0261
2024-11-26 22:04:19 [INFO]: epoch 825: training loss 0.0268
2024-11-26 22:04:19 [INFO]: epoch 826: training loss 0.0308
2024-11-26 22:04:19 [INFO]: epoch 827: training loss 0.0339
2024-11-26 22:04:19 [INFO]: epoch 828: training loss 0.0264
2024-11-26 22:04:19 [INFO]: epoch 829: training loss 0.0284
2024-11-26 22:04:19 [INFO]: epoch 830: training loss 0.0320
2024-11-26 22:04:19 [INFO]: epoch 831: training loss 0.0356
2024-11-26 22:04:19 [INFO]: epoch 832: training loss 0.0283
2024-11-26 22:04:19 [INFO]: epoch 833: training loss 0.0326
2024-11-26 22:04:19 [INFO]: epoch 834: training loss 0.0308
2024-11-26 22:04:19 [INFO]: epoch 835: t

2024-11-26 22:04:21 [INFO]: epoch 956: training loss 0.0318
2024-11-26 22:04:21 [INFO]: epoch 957: training loss 0.0249
2024-11-26 22:04:21 [INFO]: epoch 958: training loss 0.0245
2024-11-26 22:04:21 [INFO]: epoch 959: training loss 0.0311
2024-11-26 22:04:21 [INFO]: epoch 960: training loss 0.0228
2024-11-26 22:04:21 [INFO]: epoch 961: training loss 0.0270
2024-11-26 22:04:21 [INFO]: epoch 962: training loss 0.0252
2024-11-26 22:04:21 [INFO]: epoch 963: training loss 0.0261
2024-11-26 22:04:21 [INFO]: epoch 964: training loss 0.0324
2024-11-26 22:04:21 [INFO]: epoch 965: training loss 0.0305
2024-11-26 22:04:21 [INFO]: epoch 966: training loss 0.0267
2024-11-26 22:04:21 [INFO]: epoch 967: training loss 0.0250
2024-11-26 22:04:21 [INFO]: epoch 968: training loss 0.0268
2024-11-26 22:04:21 [INFO]: epoch 969: training loss 0.0271
2024-11-26 22:04:21 [INFO]: epoch 970: training loss 0.0226
2024-11-26 22:04:21 [INFO]: epoch 971: training loss 0.0292
2024-11-26 22:04:21 [INFO]: epoch 972: t

tensor(0.0132, device='cuda:0') tensor(0.0003, device='cuda:0') tensor(0.0166, device='cuda:0')


In [22]:
rmse = cal_rmse(imputation, X_ori, indicating_mask)
print(rmse)

tensor(0.0166, device='cuda:0')


In [23]:
inverted_x = imputation.permute(0,2,1) 
with torch.no_grad():        
    for t in range(0, T-1, 1):                         
        diffusion_steps = (t * torch.ones((1, 1))).cuda()          
        epsilon_theta = net((inverted_x, cond, mask, diffusion_steps))  
        inverted_x = ( torch.sqrt(Alpha_bar[t+1]) * (inverted_x - torch.sqrt(1-Alpha_bar[t]) * epsilon_theta) ) / torch.sqrt(Alpha_bar[t]) + \
                                                                                                     torch.sqrt(1-Alpha_bar[t+1])* epsilon_theta

In [24]:
initial_data = Variable( inverted_x, requires_grad=True)
optimizer_o = torch.optim.Adam(  [initial_data],lr=0.001)  
mse_criterion = nn.MSELoss().cuda()
lista = []

import gc
gc.collect()
torch.cuda.empty_cache()  

torch.manual_seed(5)
torch.cuda.manual_seed_all(5)

for i in range(0,100):

    xT = initial_data
    
    for t in range(T-1, 0, -1):                        
        diffusion_steps = (t * torch.ones((1, 1))).cuda()         
        epsilon_theta = net((xT, cond, mask, diffusion_steps))  
        xT = ( torch.sqrt(Alpha_bar[t-1]) * (xT - torch.sqrt(1-Alpha_bar[t]) * epsilon_theta) ) / torch.sqrt(Alpha_bar[t]) + \
                                                                                       torch.sqrt(1-Alpha_bar[t-1])* epsilon_theta  
    mse = mse_criterion( xT.permute(0,2,1) * missing_mask, X_ori * missing_mask )  
    print( torch.sqrt(mse) )
    optimizer_o.zero_grad()
    torch.sqrt(mse).backward() 
    optimizer_o.step()
    

tensor(0.0020, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0016, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0012, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0011, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0012, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0013, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0013, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0013, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0012, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0011, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0010, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0010, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0010, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0011, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0011, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0011, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0010, device='cuda:0', grad_fn=<SqrtBackward0>)
tensor(0.0010, device='cuda:0',

In [27]:
rmse = cal_rmse( xT.permute(0,2,1), X_ori, indicating_mask)   #  计算缺失位置的损失
print(rmse)

tensor(0.0164, device='cuda:0', grad_fn=<SqrtBackward0>)
