In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:

import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy

device = "cuda" if torch.cuda.is_available() else "cpu"

class RBFlayer(nn.Module):
    def __init__(self, timelag):
        super(RBFlayer, self).__init__()

        self.timelag = timelag

        device = "cuda" if torch.cuda.is_available() else "cpu"
        torch.cuda.manual_seed(0)

        self.init_weight_cause = nn.Parameter(torch.rand(self.timelag, device=device))
        self.init_weight_target = nn.Parameter(torch.rand(self.timelag, device=device))
        self.cause_clt  = nn.Parameter(torch.rand(1,device=device))
        self.cause_std = nn.Parameter(torch.rand(1,device=device))
        self.target_clt = nn.Parameter(torch.rand(1,device=device))
        self.target_std = nn.Parameter(torch.rand(1, device=device))

    def init_clt(self):
        return nn.Parameter(torch.rand(1,device=device))

    def init_std(self):
        return nn.Parameter(torch.rand(1, device=device))

    def rbf(self, x, cluster, std):
        return (x - cluster) * (x - cluster) / (std * std)

    def forward(self, cause, target):
        cause_c, cause_s = self.cause_clt[0], self.cause_std[0]
        target_c, target_s = self.target_clt[0], self.target_std[0]
        
        cause = self.init_weight_cause * torch.tensor(
            [self.rbf(cause[i], cause_c, cause_s) for i in range(len(cause))], device = device)
        target = self.init_weight_target * torch.tensor(
            [self.rbf(target[j], target_c, target_s) for j in range(len(target))], device = device)

        return cause, target


class RBFnet(nn.Module):
    def __init__(self, input_size , output_size, timelag):
        super(RBFnet,self).__init__()
    
        self.input_size = input_size      # number of data
        self.output_size = output_size
        self.timelag = timelag

        self.linear = nn.ModuleList([nn.Linear(self.timelag*2,1) for _ in range(self.input_size)])
        self.relu = nn.ReLU()
        self.networks = nn.ModuleList([RBFlayer(self.timelag) for _ in range(self.input_size)])

    def cause_target(self, cause, target):
        x = torch.cat((cause, target), 0)

        return x

    def GC(self, threshold=True):
        '''
        Extract learned Granger causality.
        Args:
          threshold: return norm of weights, or whether norm is nonzero.
        Returns:
          GC: (p x p) matrix. Entry (i, j) indicates whether variable j is
            Granger causal of variable i.
        '''
        GC = [torch.norm(net.init_weight_cause, dim=0)
              for net in self.networks]
        GC = torch.stack(GC)
        if threshold:
            return (GC > 0).int()
        else:
            return GC


    def forward(self, causes, targets):
        out_list = []
        for i in range(self.input_size):
            cause, target = self.networks[i](causes[i], targets[i])
            cause, target = self.relu(cause), self.relu(target)
            pred = torch.cat((cause, target),0)
            pred = self.linear[i](pred)
            out_list.append(pred)

        return out_list


def restore_parameters(model, best_model):
    '''Move parameter values from best_model to model.'''
    for params, best_params in zip(model.parameters(), best_model.parameters()):
        params.data = best_params


def train_rbf(model, input_causes, input_targets, Y, lr, epochs, lookback=5,device = device):
    # input_causes, input_targets : X
    # Y : Y
    model.to(device)
    loss_fn = nn.MSELoss(reduction='mean')
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    train_loss_list = []

    best_it = None
    best_model = None
    best_loss = np.inf

    for epoch in range(epochs):
        pred = model(input_causes, input_targets)
        loss = sum([loss_fn(pred[i], Y[i]) for i in range(len(Y))])
        print("epoch {} loss {} :".format(epoch, loss / len(Y)))
        loss.backward()
        optimizer.step()
        model.zero_grad()

        mean_loss = loss / len(Y)

        if mean_loss < best_loss:
            best_loss = mean_loss
            best_it = epoch
            best_model = deepcopy(model)
        elif (epoch - best_it) == lookback:
            if verbose:
                print('Stopping early')
            break

    restore_parameters(model, best_model)

    return train_loss_list



def data_split(X, cause, target, timelag, device = device):
    input_cause = []
    input_target = []
    Y = []

    for i in range(len(X) - (timelag + 1)):
        input_cause.append(X[cause].values[i: i + timelag])
        input_target.append(X[target].values[i: i + timelag])
        Y.append([X[target][i + timelag + 1]])

    return torch.tensor(input_cause, device=device).float(), torch.tensor(input_target,device=device).float(), torch.tensor(Y, device=device).float()


In [5]:
import pandas as pd
df = pd.read_csv('C:/Users/chanyoung/Desktop/Neural-GC-master/lorenz_96_10_10_1000.csv')
X2d = df[['e','f']]
torch.manual_seed(1234)
input_cause, input_target, Y = data_split(X2d, 'e', 'f', 10)

In [6]:
model1 = RBFnet(input_cause.size()[0], 1, input_cause.size()[1])
train_rbf(model1, input_cause, input_target, Y, 0.01, 300, device)

epoch 0 loss 2303.729248046875 :
epoch 1 loss 1529.8668212890625 :
epoch 2 loss 990.0054931640625 :
epoch 3 loss 641.9430541992188 :
epoch 4 loss 426.7265319824219 :
epoch 5 loss 296.246337890625 :
epoch 6 loss 222.8704071044922 :
epoch 7 loss 187.57601928710938 :
epoch 8 loss 174.10560607910156 :
epoch 9 loss 169.48468017578125 :
epoch 10 loss 166.3971710205078 :
epoch 11 loss 162.55313110351562 :
epoch 12 loss 157.82945251464844 :
epoch 13 loss 151.90773010253906 :
epoch 14 loss 144.04678344726562 :
epoch 15 loss 134.2952880859375 :
epoch 16 loss 123.69828796386719 :
epoch 17 loss 113.07479858398438 :
epoch 18 loss 102.78964233398438 :
epoch 19 loss 92.58745574951172 :
epoch 20 loss 82.51065826416016 :
epoch 21 loss 72.85395050048828 :
epoch 22 loss 63.89337921142578 :
epoch 23 loss 55.802040100097656 :
epoch 24 loss 48.69577407836914 :
epoch 25 loss 42.499778747558594 :
epoch 26 loss 37.18221664428711 :
epoch 27 loss 32.57992172241211 :
epoch 28 loss 28.694046020507812 :
epoch 29 lo

epoch 221 loss 0.00013160282105673105 :
epoch 222 loss 0.000115273091068957 :
epoch 223 loss 0.000131659529870376 :
epoch 224 loss 0.00013304151070769876 :
epoch 225 loss 0.00012003486335743219 :
epoch 226 loss 0.00014918956730980426 :
epoch 227 loss 0.00012520876771304756 :
epoch 228 loss 0.00012996034638490528 :
epoch 229 loss 0.00014677568105980754 :
epoch 230 loss 0.0001454580924473703 :
epoch 231 loss 0.00010869632387766615 :
epoch 232 loss 0.00012246060941834003 :
epoch 233 loss 0.0001510038273409009 :
epoch 234 loss 0.0001408054813509807 :
epoch 235 loss 0.00013785867486149073 :
epoch 236 loss 0.00012910035729873925 :
epoch 237 loss 0.00012549044913612306 :
epoch 238 loss 0.00013645130093209445 :
epoch 239 loss 0.00011782851652242243 :
epoch 240 loss 0.0001397442538291216 :
epoch 241 loss 0.0001525851257611066 :
epoch 242 loss 0.00011073617497459054 :
epoch 243 loss 0.00014074519276618958 :
epoch 244 loss 0.00013142618990968913 :
epoch 245 loss 0.00012675017933361232 :
epoch 246

[]

In [10]:
[torch.norm(net.init_weight_cause, dim=0)
              for net in model1.networks]

[tensor(2.0684, device='cuda:0', grad_fn=<CopyBackwards>),
 tensor(1.9792, device='cuda:0', grad_fn=<CopyBackwards>),
 tensor(2.0982, device='cuda:0', grad_fn=<CopyBackwards>),
 tensor(1.9732, device='cuda:0', grad_fn=<CopyBackwards>),
 tensor(2.0933, device='cuda:0', grad_fn=<CopyBackwards>),
 tensor(2.0757, device='cuda:0', grad_fn=<CopyBackwards>),
 tensor(2.0494, device='cuda:0', grad_fn=<CopyBackwards>),
 tensor(1.9651, device='cuda:0', grad_fn=<CopyBackwards>),
 tensor(2.1263, device='cuda:0', grad_fn=<CopyBackwards>),
 tensor(2.0744, device='cuda:0', grad_fn=<CopyBackwards>),
 tensor(2.0202, device='cuda:0', grad_fn=<CopyBackwards>),
 tensor(2.1177, device='cuda:0', grad_fn=<CopyBackwards>),
 tensor(2.0877, device='cuda:0', grad_fn=<CopyBackwards>),
 tensor(2.0278, device='cuda:0', grad_fn=<CopyBackwards>),
 tensor(2.0869, device='cuda:0', grad_fn=<CopyBackwards>),
 tensor(1.9017, device='cuda:0', grad_fn=<CopyBackwards>),
 tensor(2.1139, device='cuda:0', grad_fn=<CopyBackwards>

In [16]:
torch.mean(model1.networks[1].init_weight_cause)

tensor(0.5541, device='cuda:0', grad_fn=<MeanBackward0>)