In [381]:
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy


class RBFlayer(nn.Module):
    def __init__(self, timelag):
        super(RBFlayer, self).__init__()
        self.timelag = timelag

    def init_weight_cause(self):
        return nn.Parameter(torch.rand(self.timelag))

    def init_weight_target(self):
        return nn.Parameter(torch.rand(self.timelag))

    def cluster(self, x):
        # simple cluster means and stds list about time series data parameters
        self.clusters, self.stds = torch.mean(x), torch.std(x)

        return nn.Parameter(torch.tensor((self.clusters, self.stds)))

    def rbf(self, x, cluster, std):
        return (x - cluster)*(x-cluster)/(std*std)


    def forward(self, cause, target):
        cause_cluster = self.cluster(cause)
        target_cluster = self.cluster(target)
        
        cause = self.init_weight_cause() * torch.tensor([self.rbf(cause[i], cause_cluster[0], cause_cluster[1]) for i in range(len(cause))])
        target = self.init_weight_target() * torch.tensor([self.rbf(target[j], target_cluster[0], target_cluster[1]) for j in range(len(target))])
        
        return cause, target


class RBFnet(nn.Module):
    def __init__(self, input_size , output_size, timelag):
        super(RBFnet,self).__init__()

        self.input_size = input_size      # number of data
        self.output_size = output_size
        self.timelag = timelag

        self.linear = nn.ModuleList([nn.Linear(self.timelag*2,1) for _ in range(self.input_size)])
        self.relu = nn.ReLU()
        self.networks = nn.ModuleList([RBFlayer(self.timelag) for _ in range(self.input_size)])

    def cause_target(self, cause, target):
        x = torch.cat((cause, target), 0)

        return x

    # forward 쪽 다시 건들기
    def forward(self, causes, targets):
        out_list = []
        for i in range(self.input_size):
            cause, target = self.networks[i](causes[i], targets[i])
            cause, target = self.relu(cause), self.relu(target)
            pred = torch.cat((cause, target),0)
            pred = self.linear[i](pred)
            out_list.append(pred)

        return out_list

In [382]:
def data_split(X, cause, target, timelag):
    input_cause = []
    input_target = []
    Y = []
    for i in range(len(X) - (timelag + 1)):
        input_cause.append(X[cause].values[i: i + timelag])
        input_target.append(X[target].values[i: i + timelag])
        Y.append([X[target][i + timelag + 1]])

    return torch.tensor(input_cause).float(), torch.tensor(input_target).float(), torch.tensor(Y).float()

In [383]:
def restore_parameters(model, best_model):
    '''Move parameter values from best_model to model.'''
    for params, best_params in zip(model.parameters(), best_model.parameters()):
        params.data = best_params



def train_rbf(model,input_causes, input_targets ,Y, lr, epochs, lookback = 5):

    # input_causes, input_targets : X
    # Y : Y

    loss_fn = nn.MSELoss(reduction='mean')
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    train_loss_list = []

    best_it = None
    best_model = None
    best_loss = np.inf

    for epoch in range(epochs):
        pred = model(input_causes, input_targets)
        loss = sum([loss_fn(pred[i], Y[i]) for i in range(len(Y))])
        print("epoch {} loss {} :".format(epoch, loss/len(Y)))
        loss.backward()
        optimizer.step()
        model.zero_grad()
        
        mean_loss = loss/len(Y)

        if mean_loss < best_loss:
            best_loss = mean_loss
            best_it = epoch
            best_model = deepcopy(model)
        elif (epoch - best_it) == lookback:
            if verbose:
                print('Stopping early')
            break

    restore_parameters(model, best_model)

    return train_loss_list

In [384]:
import pandas as pd

In [385]:
df = pd.read_csv('C:/Users/chanyoung/Desktop/Neural-GC-master/lorenz_96_10_10_1000.csv')
X2d = df[['a','b']]

In [386]:
input_cause, input_target, Y = data_split(X2d, 'a', 'b', 10)

In [387]:
model = RBFnet(input_cause.size()[0], 1, input_cause.size()[1])

In [388]:
train_rbf(model, input_cause, input_target, Y, 0.001, 1000)

epoch 0 loss 27.640209197998047 :
epoch 1 loss 27.50505828857422 :
epoch 2 loss 27.378297805786133 :
epoch 3 loss 27.467397689819336 :
epoch 4 loss 27.31884002685547 :
epoch 5 loss 27.232345581054688 :
epoch 6 loss 27.203283309936523 :
epoch 7 loss 27.1650447845459 :
epoch 8 loss 27.01943588256836 :
epoch 9 loss 26.95881462097168 :
epoch 10 loss 26.78447151184082 :
epoch 11 loss 26.738237380981445 :
epoch 12 loss 26.78936004638672 :
epoch 13 loss 26.627731323242188 :
epoch 14 loss 26.484527587890625 :
epoch 15 loss 26.598133087158203 :
epoch 16 loss 26.355073928833008 :
epoch 17 loss 26.329360961914062 :
epoch 18 loss 26.236576080322266 :
epoch 19 loss 26.133935928344727 :
epoch 20 loss 26.136205673217773 :
epoch 21 loss 26.030014038085938 :
epoch 22 loss 25.983556747436523 :
epoch 23 loss 25.965269088745117 :
epoch 24 loss 25.784732818603516 :
epoch 25 loss 25.841960906982422 :
epoch 26 loss 25.704875946044922 :
epoch 27 loss 25.602378845214844 :
epoch 28 loss 25.539955139160156 :
epo

KeyboardInterrupt: 