In [1]:
import csv
import os

import numpy as np
import pandas as pd
from epiweeks import Week

from data_downloader import GenerateTrainingData
from utils import date_today, gravity_law_commute_dist

os.environ['NUMEXPR_MAX_THREADS'] = '16'
os.environ['NUMEXPR_NUM_THREADS'] = '8'

import pickle
import matplotlib.pyplot as plt
import dgl
import torch
from torch import nn
import torch.nn.functional as F
from model import STAN

import sklearn
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error

Using backend: pytorch


In [2]:
GenerateTrainingData().download_jhu_data('2020-05-01', '2020-12-01')

Finish download


Unnamed: 0,state,latitude,longitude,fips,date_today,confirmed,deaths,recovered,active,hospitalization,new_cases
0,Alabama,32.3182,-86.9023,1,2020-05-04,8203,8,0.0,7814.0,0.0,226
0,Alabama,32.3182,-86.9023,1,2020-05-05,8520,17,0.0,8122.0,0.0,317
0,Alabama,32.3182,-86.9023,1,2020-05-06,8769,28,0.0,8348.0,0.0,249
0,Alabama,32.3182,-86.9023,1,2020-05-07,9115,26,0.0,8677.0,0.0,346
0,Alabama,32.3182,-86.9023,1,2020-05-08,9437,14,0.0,9002.0,0.0,322
...,...,...,...,...,...,...,...,...,...,...,...
57,Wyoming,42.7560,-107.3025,56,2020-11-27,31773,0,21700.0,9858.0,4.0,1012
57,Wyoming,42.7560,-107.3025,56,2020-11-28,31928,0,22798.0,8915.0,4.0,155
57,Wyoming,42.7560,-107.3025,56,2020-11-29,32489,0,23022.0,9252.0,2.0,561
57,Wyoming,42.7560,-107.3025,56,2020-11-30,33305,0,24478.0,8612.0,0.0,816


In [3]:
#Merge population data with downloaded data
raw_data = pickle.load(open('./data/state_covid_data.pickle','rb'))
pop_data = pd.read_csv('./uszips.csv')
pop_data = pop_data.groupby('state_name').agg({'population':'sum', 'density':'mean', 'lat':'mean', 'lng':'mean'}).reset_index()
raw_data = pd.merge(raw_data, pop_data, how='inner', left_on='state', right_on='state_name')

In [4]:
raw_data.tail(3)

Unnamed: 0,state,latitude,longitude,fips,date_today,confirmed,deaths,recovered,active,hospitalization,new_cases,state_name,population,density,lat,lng
11021,Wyoming,42.756,-107.3025,56,2020-11-29,32489,0,23022.0,9252.0,2.0,561,Wyoming,582091,90.858989,42.932801,-107.378368
11022,Wyoming,42.756,-107.3025,56,2020-11-30,33305,0,24478.0,8612.0,0.0,816,Wyoming,582091,90.858989,42.932801,-107.378368
11023,Wyoming,42.756,-107.3025,56,2020-12-01,33805,24,26003.0,7563.0,2.0,500,Wyoming,582091,90.858989,42.932801,-107.378368


In [5]:
# Generate location similarity
loc_list = list(raw_data['state'].unique())
loc_dist_map = {}

for each_loc in loc_list:
    loc_dist_map[each_loc] = {}
    for each_loc2 in loc_list:
        lat1 = raw_data[raw_data['state']==each_loc]['latitude'].unique()[0]
        lng1 = raw_data[raw_data['state']==each_loc]['longitude'].unique()[0]
        pop1 = raw_data[raw_data['state']==each_loc]['population'].unique()[0]
        
        lat2 = raw_data[raw_data['state']==each_loc2]['latitude'].unique()[0]
        lng2 = raw_data[raw_data['state']==each_loc2]['longitude'].unique()[0]
        pop2 = raw_data[raw_data['state']==each_loc2]['population'].unique()[0]
        
        loc_dist_map[each_loc][each_loc2] = gravity_law_commute_dist(lat1, lng1, pop1, lat2, lng2, pop2, r=0.5)

In [6]:
#Generate Graph
dist_threshold = 18

for each_loc in loc_dist_map:
    loc_dist_map[each_loc] = {k: v for k, v in sorted(loc_dist_map[each_loc].items(), key=lambda item: item[1], reverse=True)}
    
adj_map = {}
for each_loc in loc_dist_map:
    adj_map[each_loc] = []
    for i, each_loc2 in enumerate(loc_dist_map[each_loc]):
        if loc_dist_map[each_loc][each_loc2] > dist_threshold:
            if i <= 3:
                adj_map[each_loc].append(each_loc2)
            else:
                break
        else:
            if i <= 1:
                adj_map[each_loc].append(each_loc2)
            else:
                break

rows = []
cols = []
for each_loc in adj_map:
    for each_loc2 in adj_map[each_loc]:
        rows.append(loc_list.index(each_loc))
        cols.append(loc_list.index(each_loc2))

In [7]:
g = dgl.graph((rows, cols))

In [8]:
#Preprocess features

active_cases = []
confirmed_cases = []
new_cases = []
death_cases = []
static_feat = []

for i, each_loc in enumerate(loc_list):
    active_cases.append(raw_data[raw_data['state'] == each_loc]['active'])
    confirmed_cases.append(raw_data[raw_data['state'] == each_loc]['confirmed'])
    new_cases.append(raw_data[raw_data['state'] == each_loc]['new_cases'])
    death_cases.append(raw_data[raw_data['state'] == each_loc]['deaths'])
    static_feat.append(np.array(raw_data[raw_data['state'] == each_loc][['population','density','lng','lat']]))
    
active_cases = np.array(active_cases)
confirmed_cases = np.array(confirmed_cases)
death_cases = np.array(death_cases)
new_cases = np.array(new_cases)
static_feat = np.array(static_feat)[:, 0, :]
recovered_cases = confirmed_cases - active_cases - death_cases
susceptible_cases = np.expand_dims(static_feat[:, 0], -1) - active_cases - recovered_cases

# Batch_feat: new_cases(dI), dR, dS
#dI = np.array(new_cases)
dI = np.concatenate((np.zeros((active_cases.shape[0],1), dtype=np.float32), np.diff(active_cases)), axis=-1)
dR = np.concatenate((np.zeros((recovered_cases.shape[0],1), dtype=np.float32), np.diff(recovered_cases)), axis=-1)
dS = np.concatenate((np.zeros((susceptible_cases.shape[0],1), dtype=np.float32), np.diff(susceptible_cases)), axis=-1)

In [9]:
static_feat.shape

(52, 4)

In [10]:
#Build normalizer
normalizer = {'S':{}, 'I':{}, 'R':{}, 'dS':{}, 'dI':{}, 'dR':{}}

for i, each_loc in enumerate(loc_list):
    normalizer['S'][each_loc] = (np.mean(susceptible_cases[i]), np.std(susceptible_cases[i]))
    normalizer['I'][each_loc] = (np.mean(active_cases[i]), np.std(active_cases[i]))
    normalizer['R'][each_loc] = (np.mean(recovered_cases[i]), np.std(recovered_cases[i]))
    normalizer['dI'][each_loc] = (np.mean(dI[i]), np.std(dI[i]))
    normalizer['dR'][each_loc] = (np.mean(dR[i]), np.std(dR[i]))
    normalizer['dS'][each_loc] = (np.mean(dS[i]), np.std(dS[i]))

In [11]:
def prepare_data(data, sum_I, sum_R, history_window=5, pred_window=15, slide_step=5):
    # Data shape n_loc, timestep, n_feat
    # Reshape to n_loc, t, history_window*n_feat
    n_loc = data.shape[0]
    timestep = data.shape[1]
    n_feat = data.shape[2]
    
    x = []
    y_I = []
    y_R = []
    last_I = []
    last_R = []
    concat_I = []
    concat_R = []
    for i in range(0, timestep, slide_step):
        if i+history_window+pred_window-1 >= timestep or i+history_window >= timestep:
            break
        x.append(data[:, i:i+history_window, :].reshape((n_loc, history_window*n_feat)))
        
        concat_I.append(data[:, i+history_window-1, 0])
        concat_R.append(data[:, i+history_window-1, 1])
        last_I.append(sum_I[:, i+history_window-1])
        last_R.append(sum_R[:, i+history_window-1])

        y_I.append(data[:, i+history_window:i+history_window+pred_window, 0])
        y_R.append(data[:, i+history_window:i+history_window+pred_window, 1])

    x = np.array(x, dtype=np.float32).transpose((1, 0, 2))
    last_I = np.array(last_I, dtype=np.float32).transpose((1, 0))
    last_R = np.array(last_R, dtype=np.float32).transpose((1, 0))
    concat_I = np.array(concat_I, dtype=np.float32).transpose((1, 0))
    concat_R = np.array(concat_R, dtype=np.float32).transpose((1, 0))
    y_I = np.array(y_I, dtype=np.float32).transpose((1, 0, 2))
    y_R = np.array(y_R, dtype=np.float32).transpose((1, 0, 2))
    return x, last_I, last_R, concat_I, concat_R, y_I, y_R

In [12]:
valid_window = 25
test_window = 25

history_window=6
pred_window=15
slide_step=5

dynamic_feat = np.concatenate((np.expand_dims(dI, axis=-1), np.expand_dims(dR, axis=-1), np.expand_dims(dS, axis=-1)), axis=-1)
    
#Normalize
for i, each_loc in enumerate(loc_list):
    dynamic_feat[i, :, 0] = (dynamic_feat[i, :, 0] - normalizer['dI'][each_loc][0]) / normalizer['dI'][each_loc][1]
    dynamic_feat[i, :, 1] = (dynamic_feat[i, :, 1] - normalizer['dR'][each_loc][0]) / normalizer['dR'][each_loc][1]
    dynamic_feat[i, :, 2] = (dynamic_feat[i, :, 2] - normalizer['dS'][each_loc][0]) / normalizer['dS'][each_loc][1]

dI_mean = []
dI_std = []
dR_mean = []
dR_std = []

for i, each_loc in enumerate(loc_list):
    dI_mean.append(normalizer['dI'][each_loc][0])
    dR_mean.append(normalizer['dR'][each_loc][0])
    dI_std.append(normalizer['dI'][each_loc][1])
    dR_std.append(normalizer['dR'][each_loc][1])

dI_mean = np.array(dI_mean)
dI_std = np.array(dI_std)
dR_mean = np.array(dR_mean)
dR_std = np.array(dR_std)

#Split train-test
train_feat = dynamic_feat[:, :-valid_window-test_window, :]
val_feat = dynamic_feat[:, -valid_window-test_window:-test_window, :]
test_feat = dynamic_feat[:, -test_window:, :]

train_x, train_I, train_R, train_cI, train_cR, train_yI, train_yR = prepare_data(train_feat, active_cases[:, :-valid_window-test_window], recovered_cases[:, :-valid_window-test_window], history_window, pred_window, slide_step)
val_x, val_I, val_R, val_cI, val_cR, val_yI, val_yR = prepare_data(val_feat, active_cases[:, -valid_window-test_window:-test_window], recovered_cases[:, -valid_window-test_window:-test_window], history_window, pred_window, slide_step)
test_x, test_I, test_R, test_cI, test_cR, test_yI, test_yR = prepare_data(test_feat, active_cases[:, -test_window:], recovered_cases[:, -test_window:], history_window, pred_window, slide_step)

In [13]:
dynamic_feat.shape

(52, 212, 3)

In [14]:
#Build STAN model

in_dim = 3*history_window
hidden_dim1 = 32
hidden_dim2 = 32
gru_dim = 32
num_heads = 1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

g = g.to(device)
model = STAN(g, in_dim, hidden_dim1, hidden_dim2, gru_dim, num_heads, pred_window, device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
criterion = nn.MSELoss()

In [15]:
model

STAN(
  (layer1): MultiHeadGATLayer(
    (heads): ModuleList(
      (0): GATLayer(
        (fc): Linear(in_features=18, out_features=32, bias=True)
        (attn_fc): Linear(in_features=64, out_features=1, bias=True)
      )
    )
  )
  (layer2): MultiHeadGATLayer(
    (heads): ModuleList(
      (0): GATLayer(
        (fc): Linear(in_features=32, out_features=32, bias=True)
        (attn_fc): Linear(in_features=64, out_features=1, bias=True)
      )
    )
  )
  (gru): GRUCell(32, 32)
  (nn_res_I): Linear(in_features=34, out_features=15, bias=True)
  (nn_res_R): Linear(in_features=34, out_features=15, bias=True)
  (nn_res_sir): Linear(in_features=34, out_features=2, bias=True)
)

In [16]:
train_x = torch.tensor(train_x).to(device)
train_I = torch.tensor(train_I).to(device)
train_R = torch.tensor(train_R).to(device)
train_cI = torch.tensor(train_cI).to(device)
train_cR = torch.tensor(train_cR).to(device)
train_yI = torch.tensor(train_yI).to(device)
train_yR = torch.tensor(train_yR).to(device)

val_x = torch.tensor(val_x).to(device)
val_I = torch.tensor(val_I).to(device)
val_R = torch.tensor(val_R).to(device)
val_cI = torch.tensor(val_cI).to(device)
val_cR = torch.tensor(val_cR).to(device)
val_yI = torch.tensor(val_yI).to(device)
val_yR = torch.tensor(val_yR).to(device)

test_x = torch.tensor(test_x).to(device)
test_I = torch.tensor(test_I).to(device)
test_R = torch.tensor(test_R).to(device)
test_cI = torch.tensor(test_cI).to(device)
test_cR = torch.tensor(test_cR).to(device)
test_yI = torch.tensor(test_yI).to(device)
test_yR = torch.tensor(test_yR).to(device)

dI_mean = torch.tensor(dI_mean, dtype=torch.float32).to(device).reshape((dI_mean.shape[0], 1, 1))
dI_std = torch.tensor(dI_std, dtype=torch.float32).to(device).reshape((dI_mean.shape[0], 1, 1))
dR_mean = torch.tensor(dR_mean, dtype=torch.float32).to(device).reshape((dI_mean.shape[0], 1, 1))
dR_std = torch.tensor(dR_std, dtype=torch.float32).to(device).reshape((dI_mean.shape[0], 1, 1))

N = torch.tensor(static_feat[:, 0], dtype=torch.float32).to(device).unsqueeze(-1)

In [17]:
def get_real_y(data, history_window=5, pred_window=15, slide_step=5):
    # Data shape n_loc, timestep, n_feat
    # Reshape to n_loc, t, history_window*n_feat
    n_loc = data.shape[0]
    timestep = data.shape[1]
    
    y = []
    for i in range(0, timestep, slide_step):
        if i+history_window+pred_window-1 >= timestep or i+history_window >= timestep:
            break
        y.append(data[:, i+history_window:i+history_window+pred_window])
    y = np.array(y, dtype=np.float32).transpose((1, 0, 2))
    return y

In [18]:
I_true = get_real_y(active_cases[:], history_window, pred_window, slide_step)

In [19]:
history_window

6

In [21]:
previous_day_mse = 0
predicted_mse = 0

for i in range(52):
    in_dim = 3*history_window
    hidden_dim1 = 32
    hidden_dim2 = 32
    gru_dim = 32
    num_heads = 1
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    g = g.to(device)
    model = STAN(g, in_dim, hidden_dim1, hidden_dim2, gru_dim, num_heads, pred_window, device).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    criterion = nn.MSELoss()
    
    all_loss = []
    file_name = './save/stan'
    min_loss = 1e10

    loc_name = loc_list[i]
    cur_loc = i
    
    for epoch in range(50):
        
        model.train()
        optimizer.zero_grad()

        active_pred, recovered_pred, phy_active, phy_recover, _ = model(train_x, train_cI[cur_loc], train_cR[cur_loc], N[cur_loc], train_I[cur_loc], train_R[cur_loc])
        phy_active = (phy_active - dI_mean[cur_loc]) / dI_std[cur_loc]
        phy_recover = (phy_recover - dR_mean[cur_loc]) / dR_std[cur_loc]
        loss = criterion(active_pred.squeeze(), train_yI[cur_loc])+criterion(recovered_pred.squeeze(), train_yR[cur_loc])+0.1*criterion(phy_active.squeeze(), train_yI[cur_loc])+0.1*criterion(phy_recover.squeeze(), train_yR[cur_loc])

        loss.backward()
        optimizer.step()
        all_loss.append(loss.item())

        model.eval()
        print(train_x.shape)
        print(train_cI[cur_loc].shape)
        print(train_cR[cur_loc].shape)
        print(N[cur_loc])
        _, _, _, _, prev_h = model(train_x, train_cI[cur_loc], train_cR[cur_loc], N[cur_loc], train_I[cur_loc], train_R[cur_loc])
        val_active_pred, val_recovered_pred, val_phy_active, val_phy_recover, _ = model(val_x, val_cI[cur_loc], val_cR[cur_loc], N[cur_loc], val_I[cur_loc], val_R[cur_loc], prev_h)

        val_phy_active = (val_phy_active - dI_mean[cur_loc]) / dI_std[cur_loc]
        val_loss = criterion(val_active_pred.squeeze(), val_yI[cur_loc]) + 0.1*criterion(val_phy_active.squeeze(), val_yI[cur_loc])
        if val_loss < min_loss:    
            state = {
                'state': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(state, file_name)
            min_loss = val_loss
            print('-----Save best model-----')

        print('Epoch %d, Loss %.2f, Val loss %.2f'%(epoch, all_loss[-1], val_loss.item()))
        
    file_name = './save/stan'
    checkpoint = torch.load(file_name)
    model.load_state_dict(checkpoint['state'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    model.eval()


    prev_x = torch.cat((train_x, val_x), dim=1)
    prev_I = torch.cat((train_I, val_I), dim=1)
    prev_R = torch.cat((train_R, val_R), dim=1)
    prev_cI = torch.cat((train_cI, val_cI), dim=1)
    prev_cR = torch.cat((train_cR, val_cR), dim=1)
    prev_active_pred, _, prev_phyactive_pred, _, h = model(prev_x, prev_cI[cur_loc], prev_cR[cur_loc], N[cur_loc], prev_I[cur_loc], prev_R[cur_loc])

    test_pred_active, test_pred_recovered, test_pred_phy_active, test_pred_phy_recover, _ = model(test_x, test_cI[cur_loc], test_cR[cur_loc], N[cur_loc], test_I[cur_loc], test_R[cur_loc], h)

    pred_I = []

    for i in range(test_pred_active.size(1)):
        cur_pred = (test_pred_active[0, i, :].detach().cpu().numpy() * dI_std[cur_loc].reshape(1, 1).detach().cpu().numpy()) + dI_mean[cur_loc].reshape(1, 1).detach().cpu().numpy()
        #cur_pred = test_pred_phy_active[0, i, :].detach().cpu().numpy()
        cur_pred = (cur_pred + test_pred_phy_active[0, i, :].detach().cpu().numpy()) / 2
        cur_pred = np.cumsum(cur_pred)
        cur_pred = cur_pred + test_I[cur_loc, i].detach().cpu().item()
        pred_I.append(cur_pred)
    pred_I = np.array(pred_I)
    pred_I = pred_I
    
    ground_truth = np.array(list(I_true[cur_loc, -1, :]))
    previous_day = np.array(list([I_true[cur_loc, -2, 4]]) + list(I_true[cur_loc, -1, :-1]))
    predicted = np.array(list(pred_I[-1, :]))
    
    plt.plot(list(range(15)), ground_truth,c='r', label='Ground truth')
    plt.plot(list(range(15)), previous_day,c='g', label='Previous day')
    plt.plot(list(range(15)), predicted,c='b', label='STAN')
    plt.title(loc_name)
    plt.legend()
    plt.savefig('images/'+str(title)+'.png')
    plt.cla()
    
    predicted_mse += ((ground_truth - previous_day) **2).mean()
    previous_day_mse += ((ground_truth - predicted) **2).mean()

In [28]:
prev_x

### GATSIR

In [21]:
class GAT(nn.Module):
    def __init__(self, g, input_size, hidden_size, output_size, gcn_nlayers, num_heads=5):
        super(GAT, self).__init__()
        self.g = g
        self.gcn_layers = nn.ModuleList()
        self.gcn_layers.append(dglnn.conv.GATConv(input_size, hidden_size, num_heads=num_heads,
                                                  residual=True, activation=F.relu))
        for i in range(gcn_nlayers):
            self.gcn_layers.append(dglnn.conv.GATConv(num_heads * hidden_size, hidden_size, num_heads=num_heads,
                                                      residual=True, activation=F.relu))

        self.linear_layers = nn.Sequential(
            nn.Linear(num_heads * hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.ReLU()
        )

    def forward(self, x):
        h = x
        for layer in self.gcn_layers:
            h = layer(self.g, h).flatten(1)
        return self.linear_layers(h)

In [22]:
from dgl.nn import pytorch as dglnn
import torch.nn.functional as F
import torch.nn as nn

class GAT_SIR(nn.Module):
    def __init__(self, g, input_size, hidden_size, output_size, gcn_nlayers, num_heads=5):
        super(GAT_SIR, self).__init__()
        self.g = g
        self.gcn_layers = nn.ModuleList()
        self.gcn_layers.append(dglnn.conv.GATConv(input_size, hidden_size, num_heads=num_heads,
                                                  residual=True, activation=F.relu))
        for i in range(gcn_nlayers):
            self.gcn_layers.append(dglnn.conv.GATConv(num_heads * hidden_size, hidden_size, num_heads=num_heads,
                                                      residual=True, activation=F.relu))

        self.linear_layers = nn.Sequential(
            nn.Linear(num_heads * hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Sigmoid()
        )

    def forward(self, x):
        h = x
        for layer in self.gcn_layers:
            h = layer(self.g, h).flatten(1)
        return self.linear_layers(h)

In [23]:
hidden_size = 100
learning_rate = 0.0005
num_epochs = 500
batch_size = 256
output_size = 2

device = th.device('cuda:0')
net_gat = GAT_SIR(g, train_past_history[0].shape[1] +train_past_mobility[0].shape[1]+ attrs.shape[1], hidden_size, output_size, 1)

# Move model to GPU
net_gat = net_gat.to(device)

criteria = nn.MSELoss()

# Move data to GPU
pop = pop.to(device)
attrs = attrs.to(device)
attrs[th.isnan(attrs)] = 0
for i in range(len(train_past_cases)):
    train_past_cases[i] = train_past_cases[i].to(device)
    train_past_deaths[i] = train_past_deaths[i].to(device)
    train_past_combined[i] = train_past_combined[i].to(device)
    train_past_history[i] = train_past_history[i].to(device)
    train_past_mobility[i] = train_past_mobility[i].to(device)
    train_labels_cases[i] = train_labels_cases[i].to(device)
    train_labels_deaths[i] = train_labels_deaths[i].to(device)
    if len(train_labels_cases[i].shape) == 1:
        train_labels_cases[i] = train_labels_cases[i].unsqueeze(1)
    if len(train_labels_deaths[i].shape) == 1:
        train_labels_deaths[i] = train_labels_deaths[i].unsqueeze(1)
        
for i in range(len(valid_past_cases)):
    valid_past_cases[i] = valid_past_cases[i].to(device)
    valid_past_deaths[i] = valid_past_deaths[i].to(device)
    valid_past_history[i] = valid_past_history[i].to(device)
    valid_past_combined[i] = valid_past_combined[i].to(device)
    valid_past_mobility[i] = valid_past_mobility[i].to(device)
    valid_labels_cases[i] = valid_labels_cases[i].to(device)
    valid_labels_deaths[i] = valid_labels_deaths[i].to(device)
    if len(valid_labels_cases[i].shape) == 1:
         valid_labels_cases[i] = valid_labels_cases[i].unsqueeze(1)
    if len(test_labels_deaths[i].shape) == 1:
        valid_labels_deaths[i] = valid_labels_deaths[i].unsqueeze(1)

for i in range(len(test_past_cases)):
    test_past_cases[i] = test_past_cases[i].to(device)
    test_past_deaths[i] = test_past_deaths[i].to(device)
    test_past_history[i] = test_past_history[i].to(device)
    test_past_combined[i] = test_past_combined[i].to(device)
    test_past_mobility[i] = test_past_mobility[i].to(device)
    test_labels_cases[i] = test_labels_cases[i].to(device)
    test_labels_deaths[i] = test_labels_deaths[i].to(device)
    if len(test_labels_cases[i].shape) == 1:
        test_labels_cases[i] = test_labels_cases[i].unsqueeze(1)
    if len(test_labels_deaths[i].shape) == 1:
        test_labels_deaths[i] = test_labels_deaths[i].unsqueeze(1)        

attr_temp = attrs
attr_mean = th.mean(attr_temp, dim=1, keepdim=True)
attrs_norm = attr_temp - attr_mean
attr_std = th.std(attr_temp, dim=1, keepdim=True)
attrs_norm = attrs_norm / attr_std

gat_optimizer = th.optim.Adam(net_gat.parameters(), lr=learning_rate)

last_loss = 0
for epoch in range(num_epochs):
    sample_idxs = th.randperm(len(train_past_deaths))
    losses = []   
    real_losses = []
    for idx in sample_idxs:
        labels_cases = train_labels_cases[idx]
        labels_deaths = train_labels_deaths[idx]
        past_combined = th.cat([th.log(th.add(train_past_history[idx],1)), train_past_mobility[idx]], dim=1)
        batch = th.cat([past_combined, attrs_norm], dim=1)
        gat_optimizer.zero_grad() 
        vals = net_gat(batch)
        I = train_past_cases[idx][:,-1].view(3142,1)
        D = train_past_deaths[idx][:,-1].view(3142,1)
        I_lst1, D_lst1= sir_1d_output(vals, I, D)    
        I_lst2, D_lst2= sir_1d_output(vals, I_lst1, D_lst1) 
        I_lst3, D_lst3= sir_1d_output(vals, I_lst2, D_lst2)
        I_lst4, D_lst4= sir_1d_output(vals, I_lst3, D_lst3)
        I_lst5, D_lst5= sir_1d_output(vals, I_lst4, D_lst4) 
        I_lst6, D_lst6= sir_1d_output(vals, I_lst5, D_lst5) 
        I_lst7, D_lst7= sir_1d_output(vals, I_lst6, D_lst6) 
        loss1 = my_msle(D_lst1, labels_deaths[:,0].view(3142,1))
        loss2 = my_msle(D_lst2, labels_deaths[:,1].view(3142,1))
        loss3 = my_msle(D_lst3, labels_deaths[:,2].view(3142,1))
        loss4 = my_msle(D_lst4, labels_deaths[:,3].view(3142,1))
        loss5 = my_msle(D_lst5, labels_deaths[:,4].view(3142,1))
        loss6 = my_msle(D_lst6, labels_deaths[:,5].view(3142,1))
        loss7 = my_msle(D_lst7, labels_deaths[:,6].view(3142,1))
        loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7
        loss.backward(retain_graph=True)
        gat_optimizer.step()
        losses.append(loss.detach().cpu().numpy())
        real_losses.append(real_loss.detach().cpu().numpy())

    with th.no_grad():
        eval_mses = []
        eval_msles = []
        for idx in range(len(valid_past_deaths)):
            eval_labels_cases = valid_labels_cases[idx]
            eval_labels_deaths = valid_labels_deaths[idx]
            eval_past_combined = th.cat([th.log(th.add(valid_past_history[idx],1)), valid_past_mobility[idx]], dim=1)
            eval_batch = th.cat([eval_past_combined, attrs_norm], dim=1)
            eval_vals = net_gat(eval_batch)
            eval_I = valid_past_cases[idx][:,-1].view(3142,1)
            eval_D = valid_past_deaths[idx][:,-1].view(3142,1)
            eval_I_lst1, eval_D_lst1= sir_1d_output(eval_vals, eval_I, eval_D)    
            eval_I_lst2, eval_D_lst2= sir_1d_output(eval_vals, eval_I_lst1, eval_D_lst1) 
            eval_I_lst3, eval_D_lst3= sir_1d_output(eval_vals, eval_I_lst2, eval_D_lst2)
            eval_I_lst4, eval_D_lst4= sir_1d_output(eval_vals, eval_I_lst3, eval_D_lst3)
            eval_I_lst5, eval_D_lst5= sir_1d_output(eval_vals, eval_I_lst4, eval_D_lst4) 
            eval_I_lst6, eval_D_lst6= sir_1d_output(eval_vals, eval_I_lst5, eval_D_lst5) 
            eval_I_lst7, eval_D_lst7= sir_1d_output(eval_vals, eval_I_lst6, eval_D_lst6) 
            eval_loss1 = criteria(eval_D_lst1, eval_labels_deaths[:,0].view(3142,1))
            eval_loss2 = criteria(eval_D_lst2, eval_labels_deaths[:,1].view(3142,1))
            eval_loss3 = criteria(eval_D_lst3, eval_labels_deaths[:,2].view(3142,1))
            eval_loss4 = criteria(eval_D_lst4, eval_labels_deaths[:,3].view(3142,1))
            eval_loss5 = criteria(eval_D_lst5, eval_labels_deaths[:,4].view(3142,1))
            eval_loss6 = criteria(eval_D_lst6, eval_labels_deaths[:,5].view(3142,1))
            eval_loss7 = criteria(eval_D_lst7, eval_labels_deaths[:,6].view(3142,1))
            eval_loss = eval_loss1 + eval_loss2 + eval_loss3 + eval_loss4 + eval_loss5 + eval_loss6 + eval_loss7
            eval_mses.append(eval_loss.cpu().numpy())
            
            eval_msle1 = my_msle(eval_D_lst1, eval_labels_deaths[:,0].view(3142,1))
            eval_msle2 = my_msle(eval_D_lst2, eval_labels_deaths[:,1].view(3142,1))
            eval_msle3 = my_msle(eval_D_lst3, eval_labels_deaths[:,2].view(3142,1))
            eval_msle4 = my_msle(eval_D_lst4, eval_labels_deaths[:,3].view(3142,1))
            eval_msle5 = my_msle(eval_D_lst5, eval_labels_deaths[:,4].view(3142,1))
            eval_msle6 = my_msle(eval_D_lst6, eval_labels_deaths[:,5].view(3142,1))
            eval_msle7 = my_msle(eval_D_lst7, eval_labels_deaths[:,6].view(3142,1))
            eval_msle = eval_msle1 + eval_msle2 + eval_msle3 + eval_msle4 + eval_msle5 + eval_msle6 + eval_msle7
            eval_msles.append(eval_msle.cpu().numpy())

        test_err1 = []
        test_err2 = []
        test_err3 = []
        test_err4 = []
        test_err5 = []
        test_err6 = []
        test_err7 = []
        test_msles = []
        test_mses = []
        test_maes = []
        for idx in range(len(test_past_deaths)):
            test_labels_d = test_labels_deaths[idx]
            test_past_combined = th.cat([th.log(th.add(test_past_history[idx],1)), test_past_mobility[idx]], dim=1)
            test_batch = th.cat([test_past_combined, attrs_norm], dim=1)
            test_vals = net_gat(test_batch)
            test_I = test_past_cases[idx][:,-1].view(3142,1)
            test_D = test_past_deaths[idx][:,-1].view(3142,1)
            test_I_lst1, test_D_lst1= sir_1d_output(test_vals, test_I, test_D)    
            test_I_lst2, test_D_lst2= sir_1d_output(test_vals, test_I_lst1, test_D_lst1) 
            test_I_lst3, test_D_lst3= sir_1d_output(test_vals, test_I_lst2, test_D_lst2)
            test_I_lst4, test_D_lst4= sir_1d_output(test_vals, test_I_lst3, test_D_lst3)
            test_I_lst5, test_D_lst5= sir_1d_output(test_vals, test_I_lst4, test_D_lst4) 
            test_I_lst6, test_D_lst6= sir_1d_output(test_vals, test_I_lst5, test_D_lst5) 
            test_I_lst7, test_D_lst7= sir_1d_output(test_vals, test_I_lst6, test_D_lst6) 
            test_loss1 = criteria(test_D_lst1, test_labels_d[:,0].view(3142,1))
            test_loss2 = criteria(test_D_lst2, test_labels_d[:,1].view(3142,1))
            test_loss3 = criteria(test_D_lst3, test_labels_d[:,2].view(3142,1))
            test_loss4 = criteria(test_D_lst4, test_labels_d[:,3].view(3142,1))
            test_loss5 = criteria(test_D_lst5, test_labels_d[:,4].view(3142,1))
            test_loss6 = criteria(test_D_lst6, test_labels_d[:,5].view(3142,1))
            test_loss7 = criteria(test_D_lst7, test_labels_d[:,6].view(3142,1))
            test_loss = test_loss1 + test_loss2 + test_loss3 + test_loss4 + test_loss5 + test_loss6 + test_loss7
            test_mses.append(test_loss.cpu().numpy())
            test_err1.append(test_loss1.cpu().numpy())
            test_err2.append(test_loss2.cpu().numpy())
            test_err3.append(test_loss3.cpu().numpy())
            test_err4.append(test_loss4.cpu().numpy())
            test_err5.append(test_loss5.cpu().numpy())
            test_err6.append(test_loss6.cpu().numpy())
            test_err7.append(test_loss7.cpu().numpy())
            
            test_msle1 = my_msle(test_D_lst1, test_labels_d[:,0].view(3142,1))
            test_msle2 = my_msle(test_D_lst2, test_labels_d[:,1].view(3142,1))
            test_msle3 = my_msle(test_D_lst3, test_labels_d[:,2].view(3142,1))
            test_msle4 = my_msle(test_D_lst4, test_labels_d[:,3].view(3142,1))
            test_msle5 = my_msle(test_D_lst5, test_labels_d[:,4].view(3142,1))
            test_msle6 = my_msle(test_D_lst6, test_labels_d[:,5].view(3142,1))
            test_msle7 = my_msle(test_D_lst7, test_labels_d[:,6].view(3142,1))
            test_msle = test_msle1 + test_msle2 + test_msle3 + test_msle4 + test_msle5 + test_msle6 + test_msle7
            test_msles.append(test_msle.cpu().numpy())
            
            test_mae1 = my_mae(test_D_lst1, test_labels_d[:,0].view(3142,1))
            test_mae2 = my_mae(test_D_lst2, test_labels_d[:,1].view(3142,1))
            test_mae3 = my_mae(test_D_lst3, test_labels_d[:,2].view(3142,1))
            test_mae4 = my_mae(test_D_lst4, test_labels_d[:,3].view(3142,1))
            test_mae5 = my_mae(test_D_lst5, test_labels_d[:,4].view(3142,1))
            test_mae6 = my_mae(test_D_lst6, test_labels_d[:,5].view(3142,1))
            test_mae7 = my_mae(test_D_lst7, test_labels_d[:,6].view(3142,1))
            test_mae = test_mae1 + test_mae2 + test_mae3 + test_mae4 + test_mae5 + test_mae6 + test_mae7
            test_maes.append(test_mae.cpu().numpy())
            
        print('epoch={}, loss={:.5f}, validation msle = {:.5f}, validation mse = {:.5f}, test msle={:.5f}, test mse = {:.3f}, test mae = {:.3f}'.format(epoch, np.mean(losses), np.mean(eval_msles),np.mean(eval_mses),np.mean(test_msles), np.mean(test_mses), np.mean(test_maes)))
        print('day1={:.2f},day2={:.2f},day3={:.2f},day4={:.2f},day5={:.2f},day6={:.2f},day7={:.2f}'.format(np.mean(test_err1), np.mean(test_err2),np.mean(test_err3),np.mean(test_err4),np.mean(test_err5),np.mean(test_err6),np.mean(test_err7)))

NameError: name 'th' is not defined