# A Demo of the Kriging Convolutional Network

## 1 Data loading and visualization
In this notebook, we will show how to run KCN on the bird dataset from the KCN paper. This section will load the data and visualize them on a map.

In [1]:
import pickle
import numpy as np
import pandas as pd
from geographiclib.geodesic import Geodesic
import os
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import data
import argument
import geopandas as gpd
import geoplot as gplt
import os
import torch
import experiment
%matplotlib inline
%load_ext autoreload
%autoreload 2
%autosave 180

Autosaving every 180 seconds


## get data


In [3]:
X_raw=np.load('all_t.npy',allow_pickle=True)
X_raw.shape

(34141, 7, 24)

In [4]:
X_raw=X_raw[:,2:,:]

In [5]:
stations = pd.read_csv("Tehiku_coordinate.csv")

In [6]:
stations=stations.T
stations=stations.iloc[1:,1:]
stations = stations.rename(columns={1: 'lon', 2: 'lat'})

In [7]:
# X_raw=np.load('all.npy')
# stations = pd.read_csv("location_of_Shenzhen.csv")
# stations = pd.read_csv("/home/dma312/Spatial_interpolation/SSIN/data/Station_info.csv")
lat=stations["lat"]
lon=stations["lon"]
lat=list(lat)
lon=list(lon)
x=X_raw.tolist()
for i in range(len(x)):
    x[i].append(lat)
    x[i].append(lon)
X=np.array(x)
for i in range(len(x)):
    X[i]=X[i].astype(np.float32)

In [8]:
X.shape

(34141, 7, 24)

In [9]:
X.shape

(34141, 7, 24)

In [10]:
xx=X[0]
xx.shape

(7, 24)

0  1  2  3  4  
5 6

In [11]:
xx=X[0]
columns_with_missing_data = np.any(np.isnan(xx), axis=0)
missing_columns = np.where(columns_with_missing_data)[0]
result = np.delete(xx, missing_columns, axis=1)
train_array = [i for i in range(int(result.shape[1]*0.8))]
test_array = [i for i in range(int(result.shape[1]*0.8),result.shape[1])]
x_train=result[[0,1,2,3,5,6]]
# x_train=x_train[:,[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]]
x_train=x_train[:,train_array]
x_train=np.transpose(x_train)
y_train=result[4]
# y_train=y_train[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]]
y_train=y_train[train_array]
y_train = y_train[:, None]
x_test=result[[0,1,2,3,5,6]]
# x_test=x_test[:,[31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56]]
x_test=x_test[:,test_array]
x_test=np.transpose(x_test)
y_test=result[4]
# y_test=y_test[[31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56]]
y_test=y_test[test_array]
y_test = y_test[:, None]

In [12]:
args = argument.parse_opt()
args.model = 'kcn_sage'
args.dataset = "t_data"
args.validation_size=4

In [13]:
# args.dataset = "bird_count"
#load_ndbc_data(X_train,Y_train,X_test,Y_test,args)
trainset, testset = data.load_t_data(x_train,y_train,x_test,y_test,args)

Using the default test set from the data


In [14]:
testset.coords

tensor([[173.1239],
        [173.1324],
        [173.1338],
        [173.1223],
        [173.1163]])

### 分界点

In [15]:
err,err1 = experiment.run_kcn(x_train,y_train,x_test,y_test,args)
print('Model: {}, test error: {}\n'.format(args.model, err))

Using the default test set from the data
The t_data dataset has 18 training instances and 5 test instances.
Length scale is set to 0.0046539306640625


100%|██████████| 1/1 [00:01<00:00,  1.41s/it]


Epoch: 0, train error: 0.14376601576805115, validation error: 0.2641015946865082


100%|██████████| 1/1 [00:00<00:00, 111.05it/s]


Epoch: 1, train error: 0.12571638822555542, validation error: 0.24166281521320343


100%|██████████| 1/1 [00:00<00:00, 99.99it/s]


Epoch: 2, train error: 0.11063648760318756, validation error: 0.21897980570793152


100%|██████████| 1/1 [00:00<00:00, 100.02it/s]


Epoch: 3, train error: 0.10059290379285812, validation error: 0.20884394645690918


100%|██████████| 1/1 [00:00<00:00, 90.90it/s]


Epoch: 4, train error: 0.08798076212406158, validation error: 0.2509421706199646


100%|██████████| 1/1 [00:00<00:00, 124.99it/s]


Epoch: 5, train error: 0.09477738291025162, validation error: 0.268224835395813


100%|██████████| 1/1 [00:00<00:00, 142.80it/s]


Epoch: 6, train error: 0.07782776653766632, validation error: 0.1741212159395218


100%|██████████| 1/1 [00:00<00:00, 133.22it/s]


Epoch: 7, train error: 0.08007638901472092, validation error: 0.16284361481666565


100%|██████████| 1/1 [00:00<00:00, 71.43it/s]


Epoch: 8, train error: 0.07630448043346405, validation error: 0.16213230788707733


100%|██████████| 1/1 [00:00<00:00, 142.74it/s]


Epoch: 9, train error: 0.06426939368247986, validation error: 0.15436872839927673


100%|██████████| 1/1 [00:00<00:00, 142.71it/s]


Epoch: 10, train error: 0.07064256072044373, validation error: 0.14915713667869568


100%|██████████| 1/1 [00:00<00:00, 111.12it/s]


Epoch: 11, train error: 0.06296203285455704, validation error: 0.1538366824388504


100%|██████████| 1/1 [00:00<00:00, 142.85it/s]


Epoch: 12, train error: 0.08326944708824158, validation error: 0.19422295689582825


100%|██████████| 1/1 [00:00<00:00, 124.99it/s]


Epoch: 13, train error: 0.06864151358604431, validation error: 0.142757385969162


100%|██████████| 1/1 [00:00<00:00, 99.99it/s]


Epoch: 14, train error: 0.05572962760925293, validation error: 0.1338346004486084


100%|██████████| 1/1 [00:00<00:00, 66.67it/s]


Epoch: 15, train error: 0.07013662159442902, validation error: 0.14390775561332703


100%|██████████| 1/1 [00:00<00:00, 100.02it/s]


Epoch: 16, train error: 0.06412647664546967, validation error: 0.12911494076251984


100%|██████████| 1/1 [00:00<00:00, 125.02it/s]


Epoch: 17, train error: 0.06865471601486206, validation error: 0.18873830139636993


100%|██████████| 1/1 [00:00<00:00, 124.95it/s]


Epoch: 18, train error: 0.06176281347870827, validation error: 0.12326797842979431


100%|██████████| 1/1 [00:00<00:00, 79.93it/s]

Epoch: 19, train error: 0.07235080003738403, validation error: 0.13273084163665771



100%|██████████| 1/1 [00:00<00:00, 142.85it/s]


Epoch: 20, train error: 0.06269484013319016, validation error: 0.12392818182706833


100%|██████████| 1/1 [00:00<00:00, 124.81it/s]


Epoch: 21, train error: 0.06459904462099075, validation error: 0.10405139625072479


100%|██████████| 1/1 [00:00<00:00, 125.00it/s]


Epoch: 22, train error: 0.05363856628537178, validation error: 0.1853996217250824


100%|██████████| 1/1 [00:00<00:00, 125.01it/s]

Epoch: 23, train error: 0.07580909132957458, validation error: 0.11375778168439865



100%|██████████| 1/1 [00:00<00:00, 110.99it/s]


Epoch: 24, train error: 0.07484669983386993, validation error: 0.16516922414302826


100%|██████████| 1/1 [00:00<00:00, 133.19it/s]


Epoch: 25, train error: 0.07001838088035583, validation error: 0.12890015542507172


100%|██████████| 1/1 [00:00<00:00, 124.99it/s]


Epoch: 26, train error: 0.06482144445180893, validation error: 0.12160670757293701


100%|██████████| 1/1 [00:00<00:00, 142.88it/s]


Epoch: 27, train error: 0.06675852835178375, validation error: 0.12566912174224854


100%|██████████| 1/1 [00:00<00:00, 111.15it/s]


Epoch: 28, train error: 0.06002921238541603, validation error: 0.14166590571403503


100%|██████████| 1/1 [00:00<00:00, 124.89it/s]


Epoch: 29, train error: 0.06022525206208229, validation error: 0.16234193742275238


100%|██████████| 1/1 [00:00<00:00, 66.62it/s]


Epoch: 30, train error: 0.06601700186729431, validation error: 0.13341395556926727

Early stopping at epoch 30
well done
Model: kcn_sage, test error: tensor([[0.3163],
        [0.3147],
        [0.3136],
        [0.3164],
        [0.3159]], device='cuda:0', grad_fn=<MmBackward0>)



In [10]:
loss_func=torch.nn.MSELoss(reduction="mean")

In [11]:
test_error = loss_func(err,err1)
test_error = torch.sqrt(test_error).item()
mae=torch.nn.L1Loss(reduction='mean')
test_error1=mae(err, err1).item()

### 探索多时间点怎么合并

In [34]:
# xx=X[1]
# columns_with_missing_data = np.any(np.isnan(xx), axis=0)
# missing_columns = np.where(columns_with_missing_data)[0]
# result = np.delete(xx, missing_columns, axis=1)
# train_array = [i for i in range(int(result.shape[1]*0.8))]
# test_array = [i for i in range(int(result.shape[1]*0.8),result.shape[1])]
# x_train=result[[1,2,3,4,5,6,7]]
# # x_train=x_train[:,[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]]
# x_train=x_train[:,train_array]
# x_train=np.transpose(x_train)
# y_train=result[0]
# # y_train=y_train[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]]
# y_train=y_train[train_array]
# y_train = y_train[:, None]
# x_test=result[[1,2,3,4,5,6,7]]
# # x_test=x_test[:,[31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56]]
# x_test=x_test[:,test_array]
# x_test=np.transpose(x_test)
# y_test=result[0]
# # y_test=y_test[[31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56]]
# y_test=y_test[test_array]
# y_test = y_test[:, None]

In [35]:
# trainset, testset = data.load_shenzhen_data(x_train,y_train,x_test,y_test,args)

Using the default test set from the data


In [36]:
# err_,err1_ = experiment.run_kcn(x_train,y_train,x_test,y_test,args)

Using the default test set from the data
The shenzhen dataset has 8 training instances and 2 test instances.
Length scale is set to 0.1319656367690134


100%|██████████| 1/1 [00:00<00:00, 83.34it/s]


Epoch: 0, train error: 0.16008472442626953, validation error: 0.3570149540901184


100%|██████████| 1/1 [00:00<00:00, 83.33it/s]


Epoch: 1, train error: 0.1499241590499878, validation error: 0.20207837224006653


100%|██████████| 1/1 [00:00<00:00, 83.31it/s]


Epoch: 2, train error: 0.15931293368339539, validation error: 0.24741914868354797


100%|██████████| 1/1 [00:00<00:00, 86.87it/s]


Epoch: 3, train error: 0.15046222507953644, validation error: 0.21700280904769897


100%|██████████| 1/1 [00:00<00:00, 58.82it/s]


Epoch: 4, train error: 0.15475526452064514, validation error: 0.23863467574119568


100%|██████████| 1/1 [00:00<00:00, 90.91it/s]


Epoch: 5, train error: 0.13675718009471893, validation error: 0.21358159184455872


100%|██████████| 1/1 [00:00<00:00, 125.00it/s]


Epoch: 6, train error: 0.0891076922416687, validation error: 0.23296692967414856


100%|██████████| 1/1 [00:00<00:00, 166.69it/s]


Epoch: 7, train error: 0.06890174746513367, validation error: 0.23645728826522827


100%|██████████| 1/1 [00:00<00:00, 166.52it/s]


Epoch: 8, train error: 0.07243487238883972, validation error: 0.44552603363990784


100%|██████████| 1/1 [00:00<00:00, 181.49it/s]


Epoch: 9, train error: 0.05065607279539108, validation error: 0.13535884022712708


100%|██████████| 1/1 [00:00<00:00, 166.72it/s]


Epoch: 10, train error: 0.05985722690820694, validation error: 0.25318267941474915


100%|██████████| 1/1 [00:00<00:00, 166.67it/s]


Epoch: 11, train error: 0.03721706569194794, validation error: 0.1624194085597992


100%|██████████| 1/1 [00:00<00:00, 76.92it/s]


Epoch: 12, train error: 0.048990558832883835, validation error: 0.26880621910095215


100%|██████████| 1/1 [00:00<00:00, 62.51it/s]


Epoch: 13, train error: 0.07921621948480606, validation error: 0.2725543677806854


100%|██████████| 1/1 [00:00<00:00, 124.87it/s]


Epoch: 14, train error: 0.07748941332101822, validation error: 0.21962617337703705


100%|██████████| 1/1 [00:00<00:00, 142.87it/s]


Epoch: 15, train error: 0.047063447535037994, validation error: 0.2811576724052429


100%|██████████| 1/1 [00:00<00:00, 124.88it/s]


Epoch: 16, train error: 0.05565207451581955, validation error: 0.11948040872812271


100%|██████████| 1/1 [00:00<00:00, 166.69it/s]


Epoch: 17, train error: 0.06410031020641327, validation error: 0.4532088041305542


100%|██████████| 1/1 [00:00<00:00, 142.87it/s]


Epoch: 18, train error: 0.05317217856645584, validation error: 0.29026708006858826


100%|██████████| 1/1 [00:00<00:00, 166.66it/s]


Epoch: 19, train error: 0.09417768567800522, validation error: 0.33071160316467285


100%|██████████| 1/1 [00:00<00:00, 166.67it/s]


Epoch: 20, train error: 0.05610799789428711, validation error: 0.3198240399360657


100%|██████████| 1/1 [00:00<00:00, 142.83it/s]

Epoch: 21, train error: 0.05816706269979477, validation error: 0.30599284172058105

Early stopping at epoch 21
well done





In [37]:
# err_1=torch.cat((err,err_),dim=0)

##  test结果

In [40]:
for i in range(1,len(X)):
    xx=X[i]
    # columns_with_missing_data = np.any(np.isnan(xx), axis=0)
    # missing_columns = np.where(columns_with_missing_data)[0]
    # result = np.delete(xx, missing_columns, axis=1)
    # x_train=result[[1,2,3,4,5,6,7]]
    # x_train=x_train[:,[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]]
    # x_train=np.transpose(x_train)
    # y_train=result[0]
    # y_train=y_train[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]]
    # y_train = y_train[:, None]
    # x_test=result[[1,2,3,4,5,6,7]]
    # x_test=x_test[:,[31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56]]
    # x_test=np.transpose(x_test)
    # y_test=result[0]
    # y_test=y_test[[31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56]]
    # y_test = y_test[:, None]
    columns_with_missing_data = np.any(np.isnan(xx), axis=0)
    missing_columns = np.where(columns_with_missing_data)[0]
    result = np.delete(xx, missing_columns, axis=1)
    train_array = [i for i in range(int(result.shape[1]*0.8))]
    test_array = [i for i in range(int(result.shape[1]*0.8),result.shape[1])]
    x_train=result[[0,1,2,3,5,6]]
    # x_train=x_train[:,[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]]
    x_train=x_train[:,train_array]
    x_train=np.transpose(x_train)
    y_train=result[4]
    # y_train=y_train[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]]
    y_train=y_train[train_array]
    y_train = y_train[:, None]
    x_test=result[[0,1,2,3,5,6]]
    # x_test=x_test[:,[31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56]]
    x_test=x_test[:,test_array]
    x_test=np.transpose(x_test)
    y_test=result[4]
    # y_test=y_test[[31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56]]
    y_test=y_test[test_array]
    y_test = y_test[:, None]
    args = argument.parse_opt()
    args.model = 'kcn_sage'
    args.dataset = "t_data"
    args.validation_size=4
    # trainset, testset = data.load_ndbc_data(x_train,y_train,x_test,y_test,args)
    err_,err1_ = experiment.run_kcn(x_train,y_train,x_test,y_test,args)
    err=torch.cat((err,err_),dim=0)
    err1=torch.cat((err1,err1_),dim=0)

Using the default test set from the data
The t_data dataset has 18 training instances and 5 test instances.
Length scale is set to 0.0046539306640625


100%|██████████| 1/1 [00:00<00:00, 28.56it/s]


Epoch: 0, train error: 0.2084318995475769, validation error: 0.37731146812438965


100%|██████████| 1/1 [00:00<00:00, 27.77it/s]


Epoch: 1, train error: 0.19389766454696655, validation error: 0.36084169149398804


100%|██████████| 1/1 [00:00<00:00, 28.57it/s]


Epoch: 2, train error: 0.19018517434597015, validation error: 0.32865142822265625


100%|██████████| 1/1 [00:00<00:00, 41.67it/s]


Epoch: 3, train error: 0.18589498102664948, validation error: 0.34724414348602295


100%|██████████| 1/1 [00:00<00:00, 29.84it/s]


Epoch: 4, train error: 0.1773330271244049, validation error: 0.3414277136325836


100%|██████████| 1/1 [00:00<00:00, 66.66it/s]


Epoch: 5, train error: 0.1729830652475357, validation error: 0.3200293183326721


100%|██████████| 1/1 [00:00<00:00, 124.93it/s]


Epoch: 6, train error: 0.1673320084810257, validation error: 0.32107260823249817


100%|██████████| 1/1 [00:00<00:00, 117.49it/s]


Epoch: 7, train error: 0.16597387194633484, validation error: 0.3217805027961731


100%|██████████| 1/1 [00:00<00:00, 142.70it/s]


Epoch: 8, train error: 0.1587730348110199, validation error: 0.3225417137145996


100%|██████████| 1/1 [00:00<00:00, 111.11it/s]


Epoch: 9, train error: 0.16072332859039307, validation error: 0.3178783357143402


100%|██████████| 1/1 [00:00<00:00, 125.00it/s]


Epoch: 10, train error: 0.15783582627773285, validation error: 0.31834444403648376


100%|██████████| 1/1 [00:00<00:00, 111.11it/s]


Epoch: 11, train error: 0.1576351523399353, validation error: 0.3158969581127167


100%|██████████| 1/1 [00:00<00:00, 166.68it/s]


Epoch: 12, train error: 0.157884880900383, validation error: 0.3146369755268097


100%|██████████| 1/1 [00:00<00:00, 125.02it/s]


Epoch: 13, train error: 0.1587303876876831, validation error: 0.3143956661224365


100%|██████████| 1/1 [00:00<00:00, 125.00it/s]


Epoch: 14, train error: 0.15667961537837982, validation error: 0.3108440935611725


100%|██████████| 1/1 [00:00<00:00, 181.43it/s]


Epoch: 15, train error: 0.15567204356193542, validation error: 0.31354039907455444


100%|██████████| 1/1 [00:00<00:00, 142.67it/s]


Epoch: 16, train error: 0.15339715778827667, validation error: 0.3106345236301422


100%|██████████| 1/1 [00:00<00:00, 142.87it/s]


Epoch: 17, train error: 0.15364143252372742, validation error: 0.3049579858779907


100%|██████████| 1/1 [00:00<00:00, 142.92it/s]


Epoch: 18, train error: 0.1523398756980896, validation error: 0.30914050340652466


100%|██████████| 1/1 [00:00<00:00, 142.92it/s]


Epoch: 19, train error: 0.1511790007352829, validation error: 0.3050669729709625


100%|██████████| 1/1 [00:00<00:00, 142.86it/s]


Epoch: 20, train error: 0.15058906376361847, validation error: 0.30293071269989014


100%|██████████| 1/1 [00:00<00:00, 124.98it/s]


Epoch: 21, train error: 0.1505717784166336, validation error: 0.29815128445625305


100%|██████████| 1/1 [00:00<00:00, 142.75it/s]


Epoch: 22, train error: 0.14935782551765442, validation error: 0.2968007028102875


100%|██████████| 1/1 [00:00<00:00, 125.02it/s]


Epoch: 23, train error: 0.14416442811489105, validation error: 0.29221364855766296


100%|██████████| 1/1 [00:00<00:00, 142.87it/s]


Epoch: 24, train error: 0.14626745879650116, validation error: 0.28596949577331543


100%|██████████| 1/1 [00:00<00:00, 76.93it/s]


Epoch: 25, train error: 0.14003406465053558, validation error: 0.2863178849220276


100%|██████████| 1/1 [00:00<00:00, 90.90it/s]


Epoch: 26, train error: 0.13925454020500183, validation error: 0.2823161780834198


100%|██████████| 1/1 [00:00<00:00, 111.15it/s]


Epoch: 27, train error: 0.13636468350887299, validation error: 0.29805096983909607


100%|██████████| 1/1 [00:00<00:00, 124.92it/s]


Epoch: 28, train error: 0.13227711617946625, validation error: 0.258978933095932


100%|██████████| 1/1 [00:00<00:00, 166.67it/s]


Epoch: 29, train error: 0.12902021408081055, validation error: 0.2521899938583374


100%|██████████| 1/1 [00:00<00:00, 166.65it/s]


Epoch: 30, train error: 0.12856057286262512, validation error: 0.2645750343799591


100%|██████████| 1/1 [00:00<00:00, 142.85it/s]


Epoch: 31, train error: 0.11635857820510864, validation error: 0.25712254643440247


100%|██████████| 1/1 [00:00<00:00, 124.88it/s]


Epoch: 32, train error: 0.11950750648975372, validation error: 0.2513967454433441


100%|██████████| 1/1 [00:00<00:00, 111.09it/s]

Epoch: 33, train error: 0.1105675920844078, validation error: 0.22353366017341614



100%|██████████| 1/1 [00:00<00:00, 100.01it/s]


Epoch: 34, train error: 0.1164623573422432, validation error: 0.2416875809431076


100%|██████████| 1/1 [00:00<00:00, 124.98it/s]


Epoch: 35, train error: 0.10968822240829468, validation error: 0.20865057408809662


100%|██████████| 1/1 [00:00<00:00, 117.62it/s]


Epoch: 36, train error: 0.10268065333366394, validation error: 0.2175963670015335


100%|██████████| 1/1 [00:00<00:00, 142.88it/s]


Epoch: 37, train error: 0.0946812853217125, validation error: 0.22051770985126495


100%|██████████| 1/1 [00:00<00:00, 124.97it/s]


Epoch: 38, train error: 0.09602305293083191, validation error: 0.24074506759643555


100%|██████████| 1/1 [00:00<00:00, 142.83it/s]


Epoch: 39, train error: 0.08856703341007233, validation error: 0.20653891563415527


100%|██████████| 1/1 [00:00<00:00, 125.00it/s]


Epoch: 40, train error: 0.08705911040306091, validation error: 0.21481019258499146


100%|██████████| 1/1 [00:00<00:00, 100.00it/s]


Epoch: 41, train error: 0.10743537545204163, validation error: 0.19166859984397888


100%|██████████| 1/1 [00:00<00:00, 166.65it/s]


Epoch: 42, train error: 0.09801294654607773, validation error: 0.16506990790367126


100%|██████████| 1/1 [00:00<00:00, 142.75it/s]


Epoch: 43, train error: 0.0758679136633873, validation error: 0.17743396759033203


100%|██████████| 1/1 [00:00<00:00, 132.99it/s]


Epoch: 44, train error: 0.08225987106561661, validation error: 0.1829298436641693


100%|██████████| 1/1 [00:00<00:00, 142.86it/s]


Epoch: 45, train error: 0.07029583305120468, validation error: 0.1578584611415863


100%|██████████| 1/1 [00:00<00:00, 142.81it/s]


Epoch: 46, train error: 0.060872405767440796, validation error: 0.14457425475120544


100%|██████████| 1/1 [00:00<00:00, 142.83it/s]


Epoch: 47, train error: 0.058840516954660416, validation error: 0.1582060158252716


100%|██████████| 1/1 [00:00<00:00, 76.91it/s]


Epoch: 48, train error: 0.06311637163162231, validation error: 0.14405520260334015


100%|██████████| 1/1 [00:00<00:00, 83.33it/s]


Epoch: 49, train error: 0.06268465518951416, validation error: 0.1548343300819397


100%|██████████| 1/1 [00:00<00:00, 111.14it/s]


Epoch: 50, train error: 0.061079803854227066, validation error: 0.13986361026763916


100%|██████████| 1/1 [00:00<00:00, 124.98it/s]


Epoch: 51, train error: 0.05670483037829399, validation error: 0.17984254658222198


100%|██████████| 1/1 [00:00<00:00, 125.00it/s]


Epoch: 52, train error: 0.0819154679775238, validation error: 0.12037539482116699


100%|██████████| 1/1 [00:00<00:00, 166.63it/s]


Epoch: 53, train error: 0.09948417544364929, validation error: 0.13529209792613983


100%|██████████| 1/1 [00:00<00:00, 124.99it/s]


Epoch: 54, train error: 0.07427565008401871, validation error: 0.1134343221783638


100%|██████████| 1/1 [00:00<00:00, 125.18it/s]


Epoch: 55, train error: 0.08317866921424866, validation error: 0.11620558798313141


100%|██████████| 1/1 [00:00<00:00, 142.88it/s]


Epoch: 56, train error: 0.09931828081607819, validation error: 0.13775615394115448


100%|██████████| 1/1 [00:00<00:00, 142.74it/s]


Epoch: 57, train error: 0.06489233672618866, validation error: 0.1152774840593338


100%|██████████| 1/1 [00:00<00:00, 142.81it/s]


Epoch: 58, train error: 0.06419380754232407, validation error: 0.13045388460159302


100%|██████████| 1/1 [00:00<00:00, 124.97it/s]


Epoch: 59, train error: 0.06209234893321991, validation error: 0.133784681558609


100%|██████████| 1/1 [00:00<00:00, 142.84it/s]

Epoch: 60, train error: 0.06579867005348206, validation error: 0.10949383676052094



100%|██████████| 1/1 [00:00<00:00, 142.80it/s]


Epoch: 61, train error: 0.07355524599552155, validation error: 0.16373606026172638


100%|██████████| 1/1 [00:00<00:00, 142.84it/s]


Epoch: 62, train error: 0.06437113136053085, validation error: 0.13127776980400085


100%|██████████| 1/1 [00:00<00:00, 142.88it/s]


Epoch: 63, train error: 0.062239810824394226, validation error: 0.13054093718528748


100%|██████████| 1/1 [00:00<00:00, 76.94it/s]


Epoch: 64, train error: 0.05223112180829048, validation error: 0.12887410819530487


100%|██████████| 1/1 [00:00<00:00, 76.89it/s]


Epoch: 65, train error: 0.05800241976976395, validation error: 0.13087700307369232


100%|██████████| 1/1 [00:00<00:00, 142.87it/s]


Epoch: 66, train error: 0.06256839632987976, validation error: 0.13219264149665833


100%|██████████| 1/1 [00:00<00:00, 124.98it/s]


Epoch: 67, train error: 0.06569831818342209, validation error: 0.1324567049741745


100%|██████████| 1/1 [00:00<00:00, 90.92it/s]


Epoch: 68, train error: 0.06596504151821136, validation error: 0.13574565947055817


100%|██████████| 1/1 [00:00<00:00, 142.85it/s]


Epoch: 69, train error: 0.060388725250959396, validation error: 0.11428578943014145


100%|██████████| 1/1 [00:00<00:00, 142.84it/s]


Epoch: 70, train error: 0.08394981175661087, validation error: 0.1338452249765396


100%|██████████| 1/1 [00:00<00:00, 124.99it/s]


Epoch: 71, train error: 0.07287438213825226, validation error: 0.13375553488731384


100%|██████████| 1/1 [00:00<00:00, 142.85it/s]


Epoch: 72, train error: 0.06324230134487152, validation error: 0.12552306056022644


100%|██████████| 1/1 [00:00<00:00, 133.04it/s]


Epoch: 73, train error: 0.06859540939331055, validation error: 0.13588277995586395

Early stopping at epoch 73
well done
Using the default test set from the data
The t_data dataset has 18 training instances and 5 test instances.
Length scale is set to 0.0046539306640625


100%|██████████| 1/1 [00:00<00:00, 142.74it/s]


Epoch: 0, train error: 0.08750800043344498, validation error: 0.18255183100700378


100%|██████████| 1/1 [00:00<00:00, 125.01it/s]


Epoch: 1, train error: 0.06181592494249344, validation error: 0.15028657019138336


100%|██████████| 1/1 [00:00<00:00, 166.67it/s]


Epoch: 2, train error: 0.07107515633106232, validation error: 0.16405604779720306


100%|██████████| 1/1 [00:00<00:00, 111.11it/s]


Epoch: 3, train error: 0.07543571293354034, validation error: 0.1341608315706253


100%|██████████| 1/1 [00:00<00:00, 166.64it/s]


Epoch: 4, train error: 0.06527414917945862, validation error: 0.17743022739887238


100%|██████████| 1/1 [00:00<00:00, 124.72it/s]


Epoch: 5, train error: 0.05219247564673424, validation error: 0.14154459536075592


100%|██████████| 1/1 [00:00<00:00, 133.11it/s]


Epoch: 6, train error: 0.06113499402999878, validation error: 0.10083775222301483


100%|██████████| 1/1 [00:00<00:00, 166.68it/s]


Epoch: 7, train error: 0.056267816573381424, validation error: 0.20245863497257233


100%|██████████| 1/1 [00:00<00:00, 142.79it/s]


Epoch: 8, train error: 0.06395263969898224, validation error: 0.11327295005321503


100%|██████████| 1/1 [00:00<00:00, 166.65it/s]


Epoch: 9, train error: 0.058201633393764496, validation error: 0.13555744290351868


100%|██████████| 1/1 [00:00<00:00, 125.02it/s]


Epoch: 10, train error: 0.052247051149606705, validation error: 0.14825236797332764


100%|██████████| 1/1 [00:00<00:00, 111.11it/s]


Epoch: 11, train error: 0.05185684189200401, validation error: 0.1634427309036255


100%|██████████| 1/1 [00:00<00:00, 124.98it/s]


Epoch: 12, train error: 0.056514281779527664, validation error: 0.16389364004135132


100%|██████████| 1/1 [00:00<00:00, 166.55it/s]


Epoch: 13, train error: 0.04723675549030304, validation error: 0.11831419914960861


100%|██████████| 1/1 [00:00<00:00, 133.30it/s]


Epoch: 14, train error: 0.06344802677631378, validation error: 0.10633005201816559


100%|██████████| 1/1 [00:00<00:00, 125.01it/s]

Epoch: 15, train error: 0.06657876819372177, validation error: 0.1357678472995758



100%|██████████| 1/1 [00:00<00:00, 124.80it/s]


Epoch: 16, train error: 0.04814622551202774, validation error: 0.1109289899468422


100%|██████████| 1/1 [00:00<00:00, 142.86it/s]


Epoch: 17, train error: 0.047543007880449295, validation error: 0.10589964687824249


100%|██████████| 1/1 [00:00<00:00, 125.10it/s]


Epoch: 18, train error: 0.0356709286570549, validation error: 0.14931458234786987


100%|██████████| 1/1 [00:00<00:00, 166.73it/s]


Epoch: 19, train error: 0.048476044088602066, validation error: 0.10725773870944977


100%|██████████| 1/1 [00:00<00:00, 111.11it/s]


Epoch: 20, train error: 0.05238541588187218, validation error: 0.11784426867961884


100%|██████████| 1/1 [00:00<00:00, 166.52it/s]


Epoch: 21, train error: 0.0510527677834034, validation error: 0.15621641278266907


100%|██████████| 1/1 [00:00<00:00, 153.73it/s]


Epoch: 22, train error: 0.048479147255420685, validation error: 0.11751069128513336


100%|██████████| 1/1 [00:00<00:00, 125.00it/s]


Epoch: 23, train error: 0.04074955731630325, validation error: 0.07558367401361465


100%|██████████| 1/1 [00:00<00:00, 111.24it/s]


Epoch: 24, train error: 0.04226970300078392, validation error: 0.13232438266277313


100%|██████████| 1/1 [00:00<00:00, 142.63it/s]


Epoch: 25, train error: 0.058858923614025116, validation error: 0.10852497816085815


100%|██████████| 1/1 [00:00<00:00, 166.72it/s]


Epoch: 26, train error: 0.032965973019599915, validation error: 0.0895887166261673


100%|██████████| 1/1 [00:00<00:00, 111.06it/s]


Epoch: 27, train error: 0.036510080099105835, validation error: 0.1792927384376526


100%|██████████| 1/1 [00:00<00:00, 124.98it/s]


Epoch: 28, train error: 0.03071792796254158, validation error: 0.1340283751487732

Early stopping at epoch 28
well done
Using the default test set from the data
The t_data dataset has 18 training instances and 5 test instances.
Length scale is set to 0.0046539306640625


100%|██████████| 1/1 [00:00<00:00, 90.92it/s]


Epoch: 0, train error: 0.5070742964744568, validation error: 0.7268956899642944


100%|██████████| 1/1 [00:00<00:00, 71.44it/s]

Epoch: 1, train error: 0.43196800351142883, validation error: 0.6377021074295044



100%|██████████| 1/1 [00:00<00:00, 142.67it/s]


Epoch: 2, train error: 0.42308714985847473, validation error: 0.6293940544128418


100%|██████████| 1/1 [00:00<00:00, 166.65it/s]


Epoch: 3, train error: 0.37784555554389954, validation error: 0.5857270956039429


100%|██████████| 1/1 [00:00<00:00, 99.97it/s]


Epoch: 4, train error: 0.3525611460208893, validation error: 0.5436990261077881


100%|██████████| 1/1 [00:00<00:00, 133.12it/s]


Epoch: 5, train error: 0.3274170458316803, validation error: 0.499416708946228


100%|██████████| 1/1 [00:00<00:00, 111.13it/s]


Epoch: 6, train error: 0.30890151858329773, validation error: 0.49786022305488586


100%|██████████| 1/1 [00:00<00:00, 142.85it/s]


Epoch: 7, train error: 0.276294082403183, validation error: 0.376359760761261


100%|██████████| 1/1 [00:00<00:00, 142.85it/s]


Epoch: 8, train error: 0.258078396320343, validation error: 0.42766448855400085


100%|██████████| 1/1 [00:00<00:00, 142.82it/s]


Epoch: 9, train error: 0.23502330482006073, validation error: 0.3918536305427551


100%|██████████| 1/1 [00:00<00:00, 142.83it/s]


Epoch: 10, train error: 0.23022761940956116, validation error: 0.3902027904987335


100%|██████████| 1/1 [00:00<00:00, 142.86it/s]


Epoch: 11, train error: 0.19187264144420624, validation error: 0.3770028352737427


100%|██████████| 1/1 [00:00<00:00, 99.96it/s]


Epoch: 12, train error: 0.2059686779975891, validation error: 0.37076056003570557


100%|██████████| 1/1 [00:00<00:00, 100.01it/s]


Epoch: 13, train error: 0.19397276639938354, validation error: 0.351646363735199


100%|██████████| 1/1 [00:00<00:00, 125.04it/s]


Epoch: 14, train error: 0.18893447518348694, validation error: 0.3430408537387848


100%|██████████| 1/1 [00:00<00:00, 142.88it/s]


Epoch: 15, train error: 0.17377278208732605, validation error: 0.3079794645309448


100%|██████████| 1/1 [00:00<00:00, 142.81it/s]


Epoch: 16, train error: 0.17694059014320374, validation error: 0.32142341136932373


100%|██████████| 1/1 [00:00<00:00, 111.25it/s]


Epoch: 17, train error: 0.172624409198761, validation error: 0.31232285499572754


100%|██████████| 1/1 [00:00<00:00, 142.83it/s]


Epoch: 18, train error: 0.16991637647151947, validation error: 0.2935725450515747


100%|██████████| 1/1 [00:00<00:00, 142.72it/s]


Epoch: 19, train error: 0.1665431708097458, validation error: 0.2883151173591614


100%|██████████| 1/1 [00:00<00:00, 143.00it/s]


Epoch: 20, train error: 0.15975791215896606, validation error: 0.2873709499835968


100%|██████████| 1/1 [00:00<00:00, 124.97it/s]


Epoch: 21, train error: 0.152793288230896, validation error: 0.284396767616272


100%|██████████| 1/1 [00:00<00:00, 100.01it/s]


Epoch: 22, train error: 0.1371733397245407, validation error: 0.2742289900779724


100%|██████████| 1/1 [00:00<00:00, 111.10it/s]


Epoch: 23, train error: 0.14002510905265808, validation error: 0.2721410095691681


100%|██████████| 1/1 [00:00<00:00, 100.01it/s]


Epoch: 24, train error: 0.14632150530815125, validation error: 0.25834909081459045


100%|██████████| 1/1 [00:00<00:00, 111.09it/s]


Epoch: 25, train error: 0.13109712302684784, validation error: 0.2782185673713684


100%|██████████| 1/1 [00:00<00:00, 124.89it/s]

Epoch: 26, train error: 0.12518921494483948, validation error: 0.2604619264602661



100%|██████████| 1/1 [00:00<00:00, 153.41it/s]


Epoch: 27, train error: 0.1313788890838623, validation error: 0.26311880350112915


100%|██████████| 1/1 [00:00<00:00, 166.67it/s]


Epoch: 28, train error: 0.11987990140914917, validation error: 0.23950716853141785


100%|██████████| 1/1 [00:00<00:00, 166.45it/s]


Epoch: 29, train error: 0.11488155275583267, validation error: 0.2231806516647339


100%|██████████| 1/1 [00:00<00:00, 125.02it/s]


Epoch: 30, train error: 0.10715574771165848, validation error: 0.22896510362625122


100%|██████████| 1/1 [00:00<00:00, 142.86it/s]


Epoch: 31, train error: 0.10512708127498627, validation error: 0.27685537934303284


100%|██████████| 1/1 [00:00<00:00, 125.00it/s]


Epoch: 32, train error: 0.10635816305875778, validation error: 0.22250959277153015


100%|██████████| 1/1 [00:00<00:00, 111.09it/s]


Epoch: 33, train error: 0.09730391949415207, validation error: 0.20584914088249207


100%|██████████| 1/1 [00:00<00:00, 90.93it/s]


Epoch: 34, train error: 0.11193317174911499, validation error: 0.21228213608264923


100%|██████████| 1/1 [00:00<00:00, 100.00it/s]


Epoch: 35, train error: 0.10271499305963516, validation error: 0.20768439769744873


100%|██████████| 1/1 [00:00<00:00, 90.92it/s]


Epoch: 36, train error: 0.09367328137159348, validation error: 0.2038383036851883


100%|██████████| 1/1 [00:00<00:00, 124.99it/s]


Epoch: 37, train error: 0.09351536631584167, validation error: 0.2648191452026367


100%|██████████| 1/1 [00:00<00:00, 142.85it/s]


Epoch: 38, train error: 0.0907578319311142, validation error: 0.18956458568572998


100%|██████████| 1/1 [00:00<00:00, 142.87it/s]


Epoch: 39, train error: 0.08273449540138245, validation error: 0.18685051798820496


100%|██████████| 1/1 [00:00<00:00, 111.05it/s]


Epoch: 40, train error: 0.09678284078836441, validation error: 0.18170218169689178


100%|██████████| 1/1 [00:00<00:00, 133.14it/s]


Epoch: 41, train error: 0.10217741876840591, validation error: 0.1946040540933609


100%|██████████| 1/1 [00:00<00:00, 142.87it/s]


Epoch: 42, train error: 0.08890310674905777, validation error: 0.17450734972953796


100%|██████████| 1/1 [00:00<00:00, 99.97it/s]


Epoch: 43, train error: 0.10258162766695023, validation error: 0.17382007837295532


100%|██████████| 1/1 [00:00<00:00, 100.01it/s]


Epoch: 44, train error: 0.0748843178153038, validation error: 0.16961155831813812


100%|██████████| 1/1 [00:00<00:00, 166.66it/s]

Epoch: 45, train error: 0.07614638656377792, validation error: 0.17884603142738342



100%|██████████| 1/1 [00:00<00:00, 125.02it/s]


Epoch: 46, train error: 0.08536919951438904, validation error: 0.18109533190727234


100%|██████████| 1/1 [00:00<00:00, 142.87it/s]


Epoch: 47, train error: 0.08823179453611374, validation error: 0.16128860414028168


100%|██████████| 1/1 [00:00<00:00, 124.98it/s]


Epoch: 48, train error: 0.07251191139221191, validation error: 0.1983519047498703


100%|██████████| 1/1 [00:00<00:00, 142.86it/s]


Epoch: 49, train error: 0.08268678188323975, validation error: 0.1517205536365509


100%|██████████| 1/1 [00:00<00:00, 142.88it/s]


Epoch: 50, train error: 0.08660269528627396, validation error: 0.15716132521629333


100%|██████████| 1/1 [00:00<00:00, 142.82it/s]


Epoch: 51, train error: 0.067258819937706, validation error: 0.15208733081817627


100%|██████████| 1/1 [00:00<00:00, 142.86it/s]


Epoch: 52, train error: 0.0709386095404625, validation error: 0.14574027061462402


100%|██████████| 1/1 [00:00<00:00, 142.84it/s]


Epoch: 53, train error: 0.06503716111183167, validation error: 0.25572100281715393


100%|██████████| 1/1 [00:00<00:00, 124.80it/s]


Epoch: 54, train error: 0.08457411080598831, validation error: 0.14649948477745056


100%|██████████| 1/1 [00:00<00:00, 124.87it/s]


Epoch: 55, train error: 0.12844812870025635, validation error: 0.16701622307300568

Early stopping at epoch 55
well done
Using the default test set from the data
The t_data dataset has 18 training instances and 5 test instances.
Length scale is set to 0.0046539306640625


100%|██████████| 1/1 [00:00<00:00, 142.88it/s]


Epoch: 0, train error: 0.1547141969203949, validation error: 0.26212772727012634


100%|██████████| 1/1 [00:00<00:00, 125.00it/s]


Epoch: 1, train error: 0.12161541730165482, validation error: 0.27153098583221436


100%|██████████| 1/1 [00:00<00:00, 142.64it/s]


Epoch: 2, train error: 0.12145522236824036, validation error: 0.24477367103099823


100%|██████████| 1/1 [00:00<00:00, 124.99it/s]


Epoch: 3, train error: 0.11381849646568298, validation error: 0.23098963499069214


100%|██████████| 1/1 [00:00<00:00, 100.01it/s]


Epoch: 4, train error: 0.09640004485845566, validation error: 0.23053672909736633


100%|██████████| 1/1 [00:00<00:00, 111.13it/s]


Epoch: 5, train error: 0.10414844006299973, validation error: 0.24934256076812744


100%|██████████| 1/1 [00:00<00:00, 100.09it/s]


Epoch: 6, train error: 0.08494512736797333, validation error: 0.18476757407188416


100%|██████████| 1/1 [00:00<00:00, 142.83it/s]


Epoch: 7, train error: 0.09680651873350143, validation error: 0.19292905926704407


100%|██████████| 1/1 [00:00<00:00, 111.13it/s]


Epoch: 8, train error: 0.07742416858673096, validation error: 0.18635772168636322


100%|██████████| 1/1 [00:00<00:00, 142.87it/s]


Epoch: 9, train error: 0.08021040260791779, validation error: 0.21262134611606598


100%|██████████| 1/1 [00:00<00:00, 111.13it/s]


Epoch: 10, train error: 0.05855107679963112, validation error: 0.1532398909330368


100%|██████████| 1/1 [00:00<00:00, 125.00it/s]


Epoch: 11, train error: 0.06699761748313904, validation error: 0.1397205889225006


100%|██████████| 1/1 [00:00<00:00, 86.85it/s]


Epoch: 12, train error: 0.06402711570262909, validation error: 0.1796552836894989


100%|██████████| 1/1 [00:00<00:00, 100.01it/s]


Epoch: 13, train error: 0.10536161810159683, validation error: 0.13753820955753326


100%|██████████| 1/1 [00:00<00:00, 125.01it/s]


Epoch: 14, train error: 0.06753547489643097, validation error: 0.1332227736711502


100%|██████████| 1/1 [00:00<00:00, 142.81it/s]


Epoch: 15, train error: 0.054844118654727936, validation error: 0.10460922122001648


100%|██████████| 1/1 [00:00<00:00, 100.00it/s]


Epoch: 16, train error: 0.0589950829744339, validation error: 0.1402449607849121


100%|██████████| 1/1 [00:00<00:00, 90.86it/s]


Epoch: 17, train error: 0.06925247609615326, validation error: 0.12425597012042999


100%|██████████| 1/1 [00:00<00:00, 111.05it/s]


Epoch: 18, train error: 0.06488791108131409, validation error: 0.10434825718402863


100%|██████████| 1/1 [00:00<00:00, 133.14it/s]


Epoch: 19, train error: 0.06351107358932495, validation error: 0.14166371524333954


100%|██████████| 1/1 [00:00<00:00, 125.01it/s]


Epoch: 20, train error: 0.06457220017910004, validation error: 0.12738199532032013


100%|██████████| 1/1 [00:00<00:00, 111.10it/s]


Epoch: 21, train error: 0.07744397968053818, validation error: 0.14884714782238007


100%|██████████| 1/1 [00:00<00:00, 125.01it/s]


Epoch: 22, train error: 0.0702144056558609, validation error: 0.2706342339515686

Early stopping at epoch 22
well done
Using the default test set from the data
The t_data dataset has 18 training instances and 5 test instances.
Length scale is set to 0.0046539306640625


100%|██████████| 1/1 [00:00<00:00, 91.01it/s]


Epoch: 0, train error: 0.2350178211927414, validation error: 0.36055582761764526


100%|██████████| 1/1 [00:00<00:00, 124.92it/s]


Epoch: 1, train error: 0.2151230275630951, validation error: 0.31849193572998047


100%|██████████| 1/1 [00:00<00:00, 133.19it/s]


Epoch: 2, train error: 0.17031329870224, validation error: 0.3006844222545624


100%|██████████| 1/1 [00:00<00:00, 100.00it/s]


Epoch: 3, train error: 0.1438046246767044, validation error: 0.28573697805404663


100%|██████████| 1/1 [00:00<00:00, 99.98it/s]


Epoch: 4, train error: 0.13743963837623596, validation error: 0.2669917345046997


100%|██████████| 1/1 [00:00<00:00, 99.97it/s]


Epoch: 5, train error: 0.12730537354946136, validation error: 0.25990673899650574


100%|██████████| 1/1 [00:00<00:00, 90.92it/s]


Epoch: 6, train error: 0.124836266040802, validation error: 0.2424270659685135


100%|██████████| 1/1 [00:00<00:00, 142.75it/s]


Epoch: 7, train error: 0.11599458009004593, validation error: 0.23513665795326233


100%|██████████| 1/1 [00:00<00:00, 86.88it/s]


Epoch: 8, train error: 0.10995818674564362, validation error: 0.22069096565246582


100%|██████████| 1/1 [00:00<00:00, 99.99it/s]


Epoch: 9, train error: 0.10640020668506622, validation error: 0.253146231174469


100%|██████████| 1/1 [00:00<00:00, 111.10it/s]


Epoch: 10, train error: 0.09850533306598663, validation error: 0.20546090602874756


100%|██████████| 1/1 [00:00<00:00, 125.01it/s]


Epoch: 11, train error: 0.08945891261100769, validation error: 0.2263573408126831


100%|██████████| 1/1 [00:00<00:00, 125.01it/s]


Epoch: 12, train error: 0.10645201057195663, validation error: 0.19409868121147156


100%|██████████| 1/1 [00:00<00:00, 124.86it/s]


Epoch: 13, train error: 0.10492878407239914, validation error: 0.18725578486919403


100%|██████████| 1/1 [00:00<00:00, 124.95it/s]


Epoch: 14, train error: 0.07863079756498337, validation error: 0.17632639408111572


100%|██████████| 1/1 [00:00<00:00, 133.27it/s]


Epoch: 15, train error: 0.09075973182916641, validation error: 0.17576411366462708


100%|██████████| 1/1 [00:00<00:00, 124.99it/s]


Epoch: 16, train error: 0.07286768406629562, validation error: 0.17963586747646332


100%|██████████| 1/1 [00:00<00:00, 111.10it/s]


Epoch: 17, train error: 0.0710226446390152, validation error: 0.16638396680355072


100%|██████████| 1/1 [00:00<00:00, 99.99it/s]


Epoch: 18, train error: 0.0693829208612442, validation error: 0.16195048391819


100%|██████████| 1/1 [00:00<00:00, 166.66it/s]


Epoch: 19, train error: 0.06666044145822525, validation error: 0.1611129343509674


100%|██████████| 1/1 [00:00<00:00, 124.99it/s]


Epoch: 20, train error: 0.06749998778104782, validation error: 0.17718607187271118


100%|██████████| 1/1 [00:00<00:00, 142.89it/s]


Epoch: 21, train error: 0.08419124037027359, validation error: 0.15288561582565308


100%|██████████| 1/1 [00:00<00:00, 125.00it/s]


Epoch: 22, train error: 0.064925417304039, validation error: 0.13939779996871948


100%|██████████| 1/1 [00:00<00:00, 62.49it/s]


Epoch: 23, train error: 0.08341149240732193, validation error: 0.13911309838294983


100%|██████████| 1/1 [00:00<00:00, 62.51it/s]


Epoch: 24, train error: 0.06299947202205658, validation error: 0.1413303166627884


100%|██████████| 1/1 [00:00<00:00, 76.93it/s]


Epoch: 25, train error: 0.08747952431440353, validation error: 0.1749313324689865


100%|██████████| 1/1 [00:00<00:00, 111.09it/s]


Epoch: 26, train error: 0.07267922163009644, validation error: 0.1399839073419571


100%|██████████| 1/1 [00:00<00:00, 117.56it/s]


Epoch: 27, train error: 0.07903167605400085, validation error: 0.13407769799232483


100%|██████████| 1/1 [00:00<00:00, 100.01it/s]


Epoch: 28, train error: 0.05467795580625534, validation error: 0.11564312875270844


100%|██████████| 1/1 [00:00<00:00, 76.92it/s]


Epoch: 29, train error: 0.062400657683610916, validation error: 0.13464051485061646


100%|██████████| 1/1 [00:00<00:00, 142.89it/s]


Epoch: 30, train error: 0.0801919549703598, validation error: 0.14391271770000458


100%|██████████| 1/1 [00:00<00:00, 142.71it/s]


Epoch: 31, train error: 0.05953177809715271, validation error: 0.13555960357189178


100%|██████████| 1/1 [00:00<00:00, 124.99it/s]


Epoch: 32, train error: 0.06754162162542343, validation error: 0.14548663794994354


100%|██████████| 1/1 [00:00<00:00, 90.91it/s]


Epoch: 33, train error: 0.05771259218454361, validation error: 0.11558470875024796


100%|██████████| 1/1 [00:00<00:00, 166.64it/s]


Epoch: 34, train error: 0.08624463528394699, validation error: 0.13232667744159698


100%|██████████| 1/1 [00:00<00:00, 142.88it/s]


Epoch: 35, train error: 0.06408469378948212, validation error: 0.13108889758586884


100%|██████████| 1/1 [00:00<00:00, 100.01it/s]


Epoch: 36, train error: 0.06256572157144547, validation error: 0.14862026274204254


100%|██████████| 1/1 [00:00<00:00, 142.84it/s]


Epoch: 37, train error: 0.059872858226299286, validation error: 0.1341283917427063


100%|██████████| 1/1 [00:00<00:00, 100.00it/s]


Epoch: 38, train error: 0.0816131979227066, validation error: 0.19720472395420074

Early stopping at epoch 38
well done
Using the default test set from the data
The t_data dataset has 18 training instances and 5 test instances.
Length scale is set to 0.0046539306640625


100%|██████████| 1/1 [00:00<00:00, 117.44it/s]


Epoch: 0, train error: 0.21344663202762604, validation error: 0.36981600522994995


100%|██████████| 1/1 [00:00<00:00, 125.01it/s]


Epoch: 1, train error: 0.1870061606168747, validation error: 0.331590473651886


100%|██████████| 1/1 [00:00<00:00, 125.01it/s]


Epoch: 2, train error: 0.18808498978614807, validation error: 0.33042335510253906


100%|██████████| 1/1 [00:00<00:00, 125.01it/s]


Epoch: 3, train error: 0.1700606793165207, validation error: 0.31470954418182373


100%|██████████| 1/1 [00:00<00:00, 125.00it/s]


Epoch: 4, train error: 0.16203483939170837, validation error: 0.3021267056465149


100%|██████████| 1/1 [00:00<00:00, 124.99it/s]


Epoch: 5, train error: 0.15550456941127777, validation error: 0.3002071678638458


100%|██████████| 1/1 [00:00<00:00, 117.59it/s]


Epoch: 6, train error: 0.1471630185842514, validation error: 0.31186991930007935


100%|██████████| 1/1 [00:00<00:00, 90.91it/s]


Epoch: 7, train error: 0.14397649466991425, validation error: 0.29069623351097107


100%|██████████| 1/1 [00:00<00:00, 111.12it/s]


Epoch: 8, train error: 0.14576654136180878, validation error: 0.2875605523586273


100%|██████████| 1/1 [00:00<00:00, 124.88it/s]


Epoch: 9, train error: 0.13795730471611023, validation error: 0.27790865302085876


100%|██████████| 1/1 [00:00<00:00, 142.81it/s]


Epoch: 10, train error: 0.14313681423664093, validation error: 0.2816944122314453


100%|██████████| 1/1 [00:00<00:00, 142.91it/s]

Epoch: 11, train error: 0.13087210059165955, validation error: 0.2927841544151306



100%|██████████| 1/1 [00:00<00:00, 125.01it/s]


Epoch: 12, train error: 0.12849879264831543, validation error: 0.27385711669921875


100%|██████████| 1/1 [00:00<00:00, 111.13it/s]


Epoch: 13, train error: 0.12042711675167084, validation error: 0.2539862096309662


100%|██████████| 1/1 [00:00<00:00, 142.90it/s]


Epoch: 14, train error: 0.12092430889606476, validation error: 0.2599131762981415


100%|██████████| 1/1 [00:00<00:00, 111.12it/s]


Epoch: 15, train error: 0.11375658214092255, validation error: 0.2481643259525299


100%|██████████| 1/1 [00:00<00:00, 142.93it/s]


Epoch: 16, train error: 0.1296127885580063, validation error: 0.23933175206184387


100%|██████████| 1/1 [00:00<00:00, 142.88it/s]


Epoch: 17, train error: 0.11311449855566025, validation error: 0.23078148066997528


100%|██████████| 1/1 [00:00<00:00, 111.10it/s]


Epoch: 18, train error: 0.10317330807447433, validation error: 0.2277361899614334


100%|██████████| 1/1 [00:00<00:00, 111.11it/s]


Epoch: 19, train error: 0.10816773772239685, validation error: 0.21689003705978394


100%|██████████| 1/1 [00:00<00:00, 124.98it/s]


Epoch: 20, train error: 0.10350099205970764, validation error: 0.211081400513649


100%|██████████| 1/1 [00:00<00:00, 143.09it/s]

Epoch: 21, train error: 0.09650438278913498, validation error: 0.20505942404270172



100%|██████████| 1/1 [00:00<00:00, 117.50it/s]


Epoch: 22, train error: 0.09291071444749832, validation error: 0.21583354473114014


100%|██████████| 1/1 [00:00<00:00, 142.84it/s]


Epoch: 23, train error: 0.08625245839357376, validation error: 0.20223966240882874


100%|██████████| 1/1 [00:00<00:00, 111.12it/s]


Epoch: 24, train error: 0.10020225495100021, validation error: 0.2013118714094162


100%|██████████| 1/1 [00:00<00:00, 124.89it/s]


Epoch: 25, train error: 0.08061301708221436, validation error: 0.19695322215557098


100%|██████████| 1/1 [00:00<00:00, 111.10it/s]

Epoch: 26, train error: 0.0792708620429039, validation error: 0.227297842502594



100%|██████████| 1/1 [00:00<00:00, 111.04it/s]


Epoch: 27, train error: 0.08698560297489166, validation error: 0.171286940574646


100%|██████████| 1/1 [00:00<00:00, 133.14it/s]


Epoch: 28, train error: 0.07996606826782227, validation error: 0.17408666014671326


100%|██████████| 1/1 [00:00<00:00, 124.99it/s]


Epoch: 29, train error: 0.0896938368678093, validation error: 0.16978846490383148


100%|██████████| 1/1 [00:00<00:00, 142.87it/s]


Epoch: 30, train error: 0.06288652867078781, validation error: 0.15645447373390198


100%|██████████| 1/1 [00:00<00:00, 142.89it/s]


Epoch: 31, train error: 0.06238916888833046, validation error: 0.1698499470949173


100%|██████████| 1/1 [00:00<00:00, 125.01it/s]


Epoch: 32, train error: 0.060330793261528015, validation error: 0.26935890316963196


100%|██████████| 1/1 [00:00<00:00, 99.89it/s]


Epoch: 33, train error: 0.08331810683012009, validation error: 0.22606895864009857

Early stopping at epoch 33
well done
Using the default test set from the data
The t_data dataset has 18 training instances and 5 test instances.
Length scale is set to 0.0046539306640625


100%|██████████| 1/1 [00:00<00:00, 133.16it/s]


Epoch: 0, train error: 0.16472789645195007, validation error: 0.29760223627090454


100%|██████████| 1/1 [00:00<00:00, 124.99it/s]


Epoch: 1, train error: 0.1573771834373474, validation error: 0.29275253415107727


100%|██████████| 1/1 [00:00<00:00, 111.18it/s]


Epoch: 2, train error: 0.15649299323558807, validation error: 0.2850402593612671


100%|██████████| 1/1 [00:00<00:00, 125.01it/s]


Epoch: 3, train error: 0.15921267867088318, validation error: 0.2848922312259674


100%|██████████| 1/1 [00:00<00:00, 111.12it/s]


Epoch: 4, train error: 0.13851232826709747, validation error: 0.26615867018699646


100%|██████████| 1/1 [00:00<00:00, 99.99it/s]

Epoch: 5, train error: 0.13700705766677856, validation error: 0.26884013414382935



100%|██████████| 1/1 [00:00<00:00, 124.90it/s]


Epoch: 6, train error: 0.13008791208267212, validation error: 0.2549518346786499


100%|██████████| 1/1 [00:00<00:00, 132.85it/s]


Epoch: 7, train error: 0.12483391165733337, validation error: 0.23798447847366333


100%|██████████| 1/1 [00:00<00:00, 125.13it/s]


Epoch: 8, train error: 0.11729519069194794, validation error: 0.23383915424346924


100%|██████████| 1/1 [00:00<00:00, 142.88it/s]


Epoch: 9, train error: 0.11175493896007538, validation error: 0.22637496888637543


100%|██████████| 1/1 [00:00<00:00, 166.71it/s]


Epoch: 10, train error: 0.10159424692392349, validation error: 0.2550400197505951


100%|██████████| 1/1 [00:00<00:00, 124.98it/s]


Epoch: 11, train error: 0.09675363451242447, validation error: 0.18906480073928833


100%|██████████| 1/1 [00:00<00:00, 166.48it/s]


Epoch: 12, train error: 0.0935480073094368, validation error: 0.1852511763572693


100%|██████████| 1/1 [00:00<00:00, 166.64it/s]


Epoch: 13, train error: 0.08311589062213898, validation error: 0.17344191670417786


100%|██████████| 1/1 [00:00<00:00, 111.10it/s]


Epoch: 14, train error: 0.08652277290821075, validation error: 0.1687992364168167


100%|██████████| 1/1 [00:00<00:00, 142.74it/s]


Epoch: 15, train error: 0.1009039580821991, validation error: 0.17435450851917267


100%|██████████| 1/1 [00:00<00:00, 166.69it/s]


Epoch: 16, train error: 0.07928303629159927, validation error: 0.17608237266540527


100%|██████████| 1/1 [00:00<00:00, 200.05it/s]


Epoch: 17, train error: 0.07433299720287323, validation error: 0.2008170634508133


100%|██████████| 1/1 [00:00<00:00, 199.66it/s]


Epoch: 18, train error: 0.0749240294098854, validation error: 0.1648516058921814


100%|██████████| 1/1 [00:00<00:00, 142.61it/s]


Epoch: 19, train error: 0.06854212284088135, validation error: 0.1422528177499771


100%|██████████| 1/1 [00:00<00:00, 199.80it/s]


Epoch: 20, train error: 0.08757422864437103, validation error: 0.17135047912597656


100%|██████████| 1/1 [00:00<00:00, 142.69it/s]


Epoch: 21, train error: 0.0655045285820961, validation error: 0.14550529420375824


100%|██████████| 1/1 [00:00<00:00, 199.79it/s]


Epoch: 22, train error: 0.0746610090136528, validation error: 0.15867826342582703


100%|██████████| 1/1 [00:00<00:00, 199.58it/s]


Epoch: 23, train error: 0.06776577979326248, validation error: 0.18484897911548615


100%|██████████| 1/1 [00:00<00:00, 181.34it/s]


Epoch: 24, train error: 0.06349625438451767, validation error: 0.15010125935077667


100%|██████████| 1/1 [00:00<00:00, 181.34it/s]


Epoch: 25, train error: 0.09173872321844101, validation error: 0.12913183867931366


100%|██████████| 1/1 [00:00<00:00, 153.00it/s]


Epoch: 26, train error: 0.07546310871839523, validation error: 0.12824974954128265


100%|██████████| 1/1 [00:00<00:00, 153.45it/s]


Epoch: 27, train error: 0.059400491416454315, validation error: 0.1489991545677185


100%|██████████| 1/1 [00:00<00:00, 117.53it/s]


Epoch: 28, train error: 0.06726405769586563, validation error: 0.12883539497852325


100%|██████████| 1/1 [00:00<00:00, 104.88it/s]


Epoch: 29, train error: 0.06662720441818237, validation error: 0.12336324155330658


100%|██████████| 1/1 [00:00<00:00, 95.00it/s]


Epoch: 30, train error: 0.06281890720129013, validation error: 0.1282402127981186


100%|██████████| 1/1 [00:00<00:00, 99.70it/s]


Epoch: 31, train error: 0.06481549143791199, validation error: 0.1123085469007492


100%|██████████| 1/1 [00:00<00:00, 117.33it/s]


Epoch: 32, train error: 0.057553481310606, validation error: 0.12183484435081482


100%|██████████| 1/1 [00:00<00:00, 110.94it/s]

In [None]:
loss_func=torch.nn.MSELoss(reduction="mean")

In [None]:
test_error = loss_func(err,err1)
test_error = torch.sqrt(test_error).item()
mae=torch.nn.L1Loss(reduction='mean')
test_error1=mae(err, err1).item()

### 分界点

In [11]:
args.model = 'kcn_sage'
err = experiment.run_kcn(args)
print('Model: {}, test error: {}\n'.format(args.model, err))

The bird_count dataset has 53623 training instances and 53623 test instances.
Length scale is set to 0.0077334818661507305


100%|██████████| 760/760 [00:23<00:00, 32.18it/s] 


Epoch: 0, train error: 0.5084476398304105, validation error: 0.5692890286445618


100%|██████████| 760/760 [00:05<00:00, 142.06it/s]


Epoch: 1, train error: 0.48033421052325714, validation error: 0.529566764831543


100%|██████████| 760/760 [00:05<00:00, 141.71it/s]


Epoch: 2, train error: 0.46672908845847766, validation error: 0.5387781858444214


100%|██████████| 760/760 [00:05<00:00, 137.51it/s]


Epoch: 3, train error: 0.4671565787768678, validation error: 0.5296153426170349


100%|██████████| 760/760 [00:05<00:00, 138.33it/s]


Epoch: 4, train error: 0.46231354863235824, validation error: 0.5157771110534668


100%|██████████| 760/760 [00:05<00:00, 143.46it/s]


Epoch: 5, train error: 0.4619791304986728, validation error: 0.5345360636711121


100%|██████████| 760/760 [00:05<00:00, 141.44it/s]


Epoch: 6, train error: 0.4587763080042542, validation error: 0.5191469788551331


100%|██████████| 760/760 [00:05<00:00, 137.57it/s]


Epoch: 7, train error: 0.45509065244542923, validation error: 0.529217541217804


100%|██████████| 760/760 [00:05<00:00, 144.21it/s]


Epoch: 8, train error: 0.4535137154872676, validation error: 0.5076262354850769


100%|██████████| 760/760 [00:05<00:00, 145.51it/s]


Epoch: 9, train error: 0.45178459898666723, validation error: 0.5183165669441223


100%|██████████| 760/760 [00:05<00:00, 148.50it/s]


Epoch: 10, train error: 0.4512821324733331, validation error: 0.5233545899391174


100%|██████████| 760/760 [00:04<00:00, 164.25it/s]


Epoch: 11, train error: 0.4521951329208126, validation error: 0.528048038482666


100%|██████████| 760/760 [00:04<00:00, 182.35it/s]


Epoch: 12, train error: 0.4503923704872202, validation error: 0.5127835869789124


100%|██████████| 760/760 [00:04<00:00, 179.47it/s]


Epoch: 13, train error: 0.4526019830991955, validation error: 0.5146411061286926


100%|██████████| 760/760 [00:04<00:00, 166.39it/s]


Epoch: 14, train error: 0.4487560413312167, validation error: 0.5174217224121094


100%|██████████| 760/760 [00:05<00:00, 150.01it/s]


Epoch: 15, train error: 0.44878184503728624, validation error: 0.5186971426010132


100%|██████████| 760/760 [00:04<00:00, 185.63it/s]


Epoch: 16, train error: 0.44719123895721213, validation error: 0.5109809637069702


100%|██████████| 760/760 [00:04<00:00, 177.54it/s]


Epoch: 17, train error: 0.446163627171987, validation error: 0.5187528133392334


100%|██████████| 760/760 [00:04<00:00, 158.41it/s]


Epoch: 18, train error: 0.4473458596774818, validation error: 0.515716016292572


100%|██████████| 760/760 [00:05<00:00, 142.90it/s]


Epoch: 19, train error: 0.4451551899068842, validation error: 0.5190639495849609


100%|██████████| 760/760 [00:05<00:00, 147.01it/s]


Epoch: 20, train error: 0.4453455313599031, validation error: 0.5246556997299194


100%|██████████| 760/760 [00:04<00:00, 171.73it/s]


Epoch: 21, train error: 0.44922255450173426, validation error: 0.5105816125869751


100%|██████████| 760/760 [00:04<00:00, 188.71it/s]


Epoch: 22, train error: 0.4457699354615455, validation error: 0.5126059651374817


100%|██████████| 760/760 [00:04<00:00, 187.21it/s]


Epoch: 23, train error: 0.44512024700739666, validation error: 0.5053388476371765


100%|██████████| 760/760 [00:04<00:00, 178.83it/s]


Epoch: 24, train error: 0.4458492526367895, validation error: 0.5276523232460022


100%|██████████| 760/760 [00:04<00:00, 175.95it/s]


Epoch: 25, train error: 0.44771861556525294, validation error: 0.5215075612068176


100%|██████████| 760/760 [00:04<00:00, 185.84it/s]


Epoch: 26, train error: 0.4473644845187664, validation error: 0.5016493201255798


100%|██████████| 760/760 [00:04<00:00, 183.10it/s]


Epoch: 27, train error: 0.44300437793625813, validation error: 0.5219486951828003


100%|██████████| 760/760 [00:04<00:00, 177.13it/s]


Epoch: 28, train error: 0.4450833098325682, validation error: 0.5263857841491699


100%|██████████| 760/760 [00:04<00:00, 173.77it/s]


Epoch: 29, train error: 0.44489568774956034, validation error: 0.5291870832443237

Early stopping at epoch 29
Test error is 0.4571419060230255
Model: kcn_sage, test error: 0.4571419060230255



In [12]:
args.model = 'kcn_gat'
err = experiment.run_kcn(args)
print('Model: {}, test error: {}\n'.format(args.model, err))

The bird_count dataset has 53623 training instances and 53623 test instances.
Length scale is set to 0.0077334818661507305


100%|██████████| 760/760 [00:06<00:00, 122.74it/s]


Epoch: 0, train error: 0.5226912329757684, validation error: 0.5775631070137024


100%|██████████| 760/760 [00:06<00:00, 124.00it/s]


Epoch: 1, train error: 0.503716991858949, validation error: 0.5576544404029846


100%|██████████| 760/760 [00:06<00:00, 120.49it/s]


Epoch: 2, train error: 0.4845468202262725, validation error: 0.5315265655517578


100%|██████████| 760/760 [00:05<00:00, 130.31it/s]


Epoch: 3, train error: 0.4619069013107372, validation error: 0.5013281106948853


100%|██████████| 760/760 [00:05<00:00, 130.94it/s]


Epoch: 4, train error: 0.4568927689234873, validation error: 0.4956985116004944


100%|██████████| 760/760 [00:06<00:00, 125.80it/s]


Epoch: 5, train error: 0.4511670098419448, validation error: 0.4986773431301117


100%|██████████| 760/760 [00:06<00:00, 124.59it/s]


Epoch: 6, train error: 0.45175081673568407, validation error: 0.49998167157173157


100%|██████████| 760/760 [00:05<00:00, 129.72it/s]


Epoch: 7, train error: 0.4521404561929797, validation error: 0.5165917873382568


100%|██████████| 760/760 [00:05<00:00, 127.69it/s]


Epoch: 8, train error: 0.4480192456737553, validation error: 0.509813666343689


100%|██████████| 760/760 [00:05<00:00, 127.00it/s]


Epoch: 9, train error: 0.44893726038403414, validation error: 0.5018283724784851


100%|██████████| 760/760 [00:06<00:00, 126.42it/s]


Epoch: 10, train error: 0.4489624211613677, validation error: 0.514763355255127


100%|██████████| 760/760 [00:05<00:00, 127.49it/s]


Epoch: 11, train error: 0.44864917587194786, validation error: 0.4941472113132477


100%|██████████| 760/760 [00:06<00:00, 125.69it/s]


Epoch: 12, train error: 0.44761378665601737, validation error: 0.5156534314155579


100%|██████████| 760/760 [00:06<00:00, 126.11it/s]


Epoch: 13, train error: 0.4504722493308547, validation error: 0.49642133712768555


100%|██████████| 760/760 [00:06<00:00, 125.25it/s]


Epoch: 14, train error: 0.44772941883604384, validation error: 0.5256237983703613


100%|██████████| 760/760 [00:06<00:00, 122.99it/s]


Epoch: 15, train error: 0.4463864963501692, validation error: 0.5133916139602661


100%|██████████| 760/760 [00:06<00:00, 122.23it/s]


Epoch: 16, train error: 0.44635045291659864, validation error: 0.5094889998435974


100%|██████████| 760/760 [00:06<00:00, 117.45it/s]


Epoch: 17, train error: 0.4448921820038537, validation error: 0.5086689591407776


100%|██████████| 760/760 [00:06<00:00, 112.23it/s]


Epoch: 18, train error: 0.4458668149247962, validation error: 0.5158348679542542


100%|██████████| 760/760 [00:06<00:00, 124.86it/s]


Epoch: 19, train error: 0.4426966260677498, validation error: 0.5017677545547485


100%|██████████| 760/760 [00:06<00:00, 123.08it/s]


Epoch: 20, train error: 0.4414260424771591, validation error: 0.5601100325584412


100%|██████████| 760/760 [00:06<00:00, 120.50it/s]


Epoch: 21, train error: 0.4462428587831949, validation error: 0.5013173222541809

Early stopping at epoch 21
Test error is 0.4536028802394867
Model: kcn_gat, test error: 0.4536028802394867



In [33]:
datafile = os.path.join(args.data_path, args.dataset + ".npz")

In [34]:
data = np.load(datafile)

In [35]:
Y_train = data['Ytrain'].astype(np.float32)

In [31]:
X_train = np.ndarray.astype(data['Xtrain'], np.float32)
Y_train = data['Ytrain'].astype(np.float32)
Y_train = Y_train[:, None]
X_test = np.ndarray.astype(data['Xtest'], np.float32)
Y_test = data['Ytest'].astype(np.float32)
Y_test = Y_test[:, None]

In [32]:
X_train.shape

(53623, 21)

In [None]:
data = np.load(datafile)
X_train = np.ndarray.astype(data['Xtrain'], np.float32)
Y_train = data['Ytrain'].astype(np.float32)
Y_train = Y_train[:, None]
X_test = np.ndarray.astype(data['Xtest'], np.float32)
Y_test = data['Ytest'].astype(np.float32)
Y_test = Y_test[:, None]

In [23]:
trainset.coords

tensor([[ 33.6317, -84.3828],
        [ 43.2835, -72.8161],
        [ 40.7715, -74.3768],
        ...,
        [ 41.5547, -83.8526],
        [ 27.7913, -97.3992],
        [ 44.5384, -77.6191]])

In [25]:
X_train[:, 0:2].shape

(53623, 2)

In [26]:
x=[1,2,3,4,5,6,7,8,9,10]