In [1]:
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
class RBFlayer(nn.Module):
    def __init__(self, timelag):
        super(RBFlayer, self).__init__()

        self.timelag = timelag

        # device = "cuda" if torch.cuda.is_available() else "cpu"
        torch.cuda.manual_seed(0)

        self.init_weight = nn.Parameter(torch.rand(self.timelag))
        self.rbf_clt = self.init_clt()
        self.rbf_std = self.init_std()
        
        self.b = nn.Parameter(torch.rand(1))

    def init_clt(self):
        return nn.Parameter(torch.rand(1))

    def init_std(self):
        return nn.Parameter(torch.rand(1))
    
    def rbf(self, x, cluster, std):
        return torch.exp(-(x - cluster) * (x - cluster) / 2 * (std * std))
    
    def rbf_gradient(self, x, clt, std):
        return (-1 * (x - clt) * (x - clt) / (std * std)) * (torch.exp(-(x - clt) * (x - clt) / 2 * (std * std)))
    
    
    def forward(self, x):        
        for i in range(len(x)):
            if i == 0:
                a = self.rbf(x[i], self.rbf_clt, self.rbf_std)
            else:
                a = torch.cat([a, self.rbf(x[i], self.rbf_clt, self.rbf_std)], dim=0)
        cause = self.init_weight * a
        
        
        return cause

In [3]:
def restore_parameters(model, best_model):
    '''Move parameter values from best_model to model.'''
    for params, best_params in zip(model.parameters(), best_model.parameters()):
        params.data = best_params

In [4]:
def train_RBFlayer(model, input_, lr, epochs, lookback = 5, device = device):
    model.to(device)
    loss_fn = nn.MSELoss(reduction='mean')
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)
    
    train_loss_list = []
    
    best_it = None
    best_model = None
    best_loss = np.inf
    target = []
    for j in range(len(input_) - 2):
        target.append((input_[j+2] - input_[j])/2)
    
    loss_list = []
    cause_list = []
    for epoch in range(epochs):
        cause = model(input_)
        cause_list.append(cause)
        grad = []
        
        
        for i in range(len(cause) - 2):
            grad.append((cause[i+2] - cause[i])/2)
        
        loss1 = sum([loss_fn(grad[i], target[i]) for i in range(len(grad))])
        loss2 = sum([loss_fn(cause[i], input_[i]) for i in range(len(input_))])
        
        loss = loss1 + loss2
        
        loss.backward()
        optimizer.step()
        model.zero_grad()
        
        loss_list.append(loss)
        mean_loss = loss / len(grad)
        train_loss_list.append(mean_loss)
        
        if mean_loss < best_loss:
            best_loss = mean_loss
            best_it = epoch
            best_model = deepcopy(model)
            
        elif (epoch - best_it) == lookback:
            if verbose:
                print('Stopping early')
            break
    print("epoch {} cause loss {} :".format(epoch, loss / len(input_)))
    print('gradient loss :', loss1/len(grad))
    print('value loss :', loss2/len(input_))
                
    best_cause = cause_list[best_it]    
    restore_parameters(model, best_model)

    return best_model, loss_list, best_cause

In [5]:
def data_split(X, cause, target, timelag, device = device):
    input_cause = []
    input_target = []
    Y = []

    for i in range(len(X) - (timelag + 1)):
        input_cause.append(X[cause].values[i: i + timelag])
        input_target.append(X[target].values[i: i + timelag])
        Y.append([X[target][i + timelag + 1]])

    return torch.tensor(input_cause, device=device).float(), torch.tensor(input_target,device=device).float(), torch.tensor(Y, device=device).float()


In [6]:
import pandas as pd
df = pd.read_csv('C:/Users/chanyoung/Desktop/Neural-GC-master/lorenz_96_10_10_1000.csv')
X2d = df[['a','b']]
torch.manual_seed(1234)
input_cause, input_target, Y = data_split(X2d, 'a', 'b', 100)

In [8]:
input_cause1 = input_cause[:20]

In [7]:
input_cause2 = input_cause[100:130]

In [9]:
input_cause.size()

torch.Size([20, 100])

In [14]:
import time

cause_list = []
for i in range(len(input_cause)):
    print(i,"번째 time series")
    start = time.time()
    model = RBFlayer(100)
    best_model, loss_list, best_cause = train_RBFlayer(model, input_cause[i], 0.01, 1000, device)
    cause_list.append(best_cause.cpu().detach().numpy())
    print("time :", time.time() - start)
    print('-------------------------------------------------------------------------------------------')
    
import pickle

filePath = './value_grad1_epcoh1000.txt'
with open(filePath, 'wb') as lf:
    pickle.dump(cause_list, lf)

0 번째 time series
epoch 999 cause loss 0.41716113686561584 :
gradient loss : tensor(0.0458, device='cuda:0', grad_fn=<DivBackward0>)
value loss : tensor(0.3722, device='cuda:0', grad_fn=<DivBackward0>)
time : 122.76029992103577
-------------------------------------------------------------------------------------------
1 번째 time series
epoch 999 cause loss 0.5941880941390991 :
gradient loss : tensor(0.0655, device='cuda:0', grad_fn=<DivBackward0>)
value loss : tensor(0.5300, device='cuda:0', grad_fn=<DivBackward0>)
time : 122.80376935005188
-------------------------------------------------------------------------------------------
2 번째 time series
epoch 999 cause loss 0.429005891084671 :
gradient loss : tensor(0.0446, device='cuda:0', grad_fn=<DivBackward0>)
value loss : tensor(0.3853, device='cuda:0', grad_fn=<DivBackward0>)
time : 122.98517227172852
-------------------------------------------------------------------------------------------
3 번째 time series
epoch 999 cause loss 0.572124

In [8]:
import time

cause_list = []
for i in range(len(input_cause2)):
    print(i,"번째 time series")
    start = time.time()
    model = RBFlayer(100)
    best_model, loss_list, best_cause = train_RBFlayer(model, input_cause2[i], 0.01, 1000, device)
    cause_list.append(best_cause.cpu().detach().numpy())
    print("time :", time.time() - start)
    print('-------------------------------------------------------------------------------------------')
    
import pickle

filePath = './value_grad2_epcoh1000.txt'
with open(filePath, 'wb') as lf:
    pickle.dump(cause_list, lf)

0 번째 time series
epoch 999 cause loss 0.3864403963088989 :
gradient loss : tensor(0.0732, device='cuda:0', grad_fn=<DivBackward0>)
value loss : tensor(0.3147, device='cuda:0', grad_fn=<DivBackward0>)
time : 123.18133425712585
-------------------------------------------------------------------------------------------
1 번째 time series
epoch 999 cause loss 0.3593389093875885 :
gradient loss : tensor(0.0705, device='cuda:0', grad_fn=<DivBackward0>)
value loss : tensor(0.2903, device='cuda:0', grad_fn=<DivBackward0>)
time : 123.03879451751709
-------------------------------------------------------------------------------------------
2 번째 time series
epoch 999 cause loss 0.4457787871360779 :
gradient loss : tensor(0.0843, device='cuda:0', grad_fn=<DivBackward0>)
value loss : tensor(0.3631, device='cuda:0', grad_fn=<DivBackward0>)
time : 124.99190902709961
-------------------------------------------------------------------------------------------
3 번째 time series
epoch 999 cause loss 0.443057

epoch 999 cause loss 0.6165334582328796 :
gradient loss : tensor(0.1173, device='cuda:0', grad_fn=<DivBackward0>)
value loss : tensor(0.5016, device='cuda:0', grad_fn=<DivBackward0>)
time : 119.29041838645935
-------------------------------------------------------------------------------------------
27 번째 time series
epoch 999 cause loss 0.6860111355781555 :
gradient loss : tensor(0.1352, device='cuda:0', grad_fn=<DivBackward0>)
value loss : tensor(0.5535, device='cuda:0', grad_fn=<DivBackward0>)
time : 123.04266238212585
-------------------------------------------------------------------------------------------
28 번째 time series
epoch 999 cause loss 0.35542264580726624 :
gradient loss : tensor(0.0687, device='cuda:0', grad_fn=<DivBackward0>)
value loss : tensor(0.2881, device='cuda:0', grad_fn=<DivBackward0>)
time : 123.30172634124756
-------------------------------------------------------------------------------------------
29 번째 time series
epoch 999 cause loss 0.5332555174827576 :
