In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [46]:
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy



class RBFlayer(nn.Module):
    def __init__(self, timelag):
        super(RBFlayer, self).__init__()

        self.timelag = timelag

        device = "cuda" if torch.cuda.is_available() else "cpu"
        torch.cuda.manual_seed(0)

        self.init_weight_cause = nn.Parameter(torch.rand(self.timelag, device=device))
        self.init_weight_target = nn.Parameter(torch.rand(self.timelag, device=device))
        self.cause_clt  = nn.Parameter(torch.rand(1,device=device))
        self.cause_std = nn.Parameter(torch.rand(1,device=device))
        self.target_clt = nn.Parameter(torch.rand(1,device=device))
        self.target_std = nn.Parameter(torch.rand(1, device=device))

    def init_clt(self):
        return nn.Parameter(torch.rand(1,device=device))

    def init_std(self):
        return nn.Parameter(torch.rand(1, device=device))

    def rbf(self, x, cluster, std):
        return (x - cluster) * (x - cluster) / (std * std)

    def forward(self, cause, target):
        cause_c, cause_s = self.cause_clt[0], self.cause_std[0]
        target_c, target_s = self.target_clt[0], self.target_std[0]
        
        cause = self.init_weight_cause * torch.tensor(
            [self.rbf(cause[i], cause_c, cause_s) for i in range(len(cause))], device = device)
        target = self.init_weight_target * torch.tensor(
            [self.rbf(target[j], target_c, target_s) for j in range(len(target))], device = device)

        return cause, target


class RBFnet(nn.Module):
    def __init__(self, input_size , output_size, timelag):
        super(RBFnet,self).__init__()
    
        self.input_size = input_size      # number of data
        self.output_size = output_size
        self.timelag = timelag

        self.linear = nn.ModuleList([nn.Linear(self.timelag*2,1) for _ in range(self.input_size)])
        self.relu = nn.ReLU()
        self.networks = nn.ModuleList([RBFlayer(self.timelag) for _ in range(self.input_size)])

    def cause_target(self, cause, target):
        x = torch.cat((cause, target), 0)

        return x

    def GC(self, threshold=True):
        '''
        Extract learned Granger causality.
        Args:
          threshold: return norm of weights, or whether norm is nonzero.
        Returns:
          GC: (p x p) matrix. Entry (i, j) indicates whether variable j is
            Granger causal of variable i.
        '''
        GC = [torch.norm(net.init_weight_cause, dim=0)
              for net in self.networks]
        GC = torch.stack(GC)
        if threshold:
            return (GC > 0).int()
        else:
            return GC


    def forward(self, causes, targets):
        out_list = []
        for i in range(self.input_size):
            cause, target = self.networks[i](causes[i], targets[i])
            cause, target = self.relu(cause), self.relu(target)
            pred = torch.cat((cause, target),0)
            pred = self.linear[i](pred)
            out_list.append(pred)

        return out_list


def restore_parameters(model, best_model):
    '''Move parameter values from best_model to model.'''
    for params, best_params in zip(model.parameters(), best_model.parameters()):
        params.data = best_params


def train_rbf(model, input_causes, input_targets, Y, lr, epochs, lookback=5,device = device):
    # input_causes, input_targets : X
    # Y : Y
    model.to(device)
    loss_fn = nn.MSELoss(reduction='mean')
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    train_loss_list = []

    best_it = None
    best_model = None
    best_loss = np.inf

    for epoch in range(epochs):
        pred = model(input_causes, input_targets)
        loss = sum([loss_fn(pred[i], Y[i]) for i in range(len(Y))])
        print("epoch {} loss {} :".format(epoch, loss / len(Y)))
        loss.backward()
        optimizer.step()
        model.zero_grad()

        mean_loss = loss / len(Y)

        if mean_loss < best_loss:
            best_loss = mean_loss
            best_it = epoch
            best_model = deepcopy(model)
        elif (epoch - best_it) == lookback:
            if verbose:
                print('Stopping early')
            break

    restore_parameters(model, best_model)

    return train_loss_list



def data_split(X, cause, target, timelag, device = device):
    input_cause = []
    input_target = []
    Y = []

    for i in range(len(X) - (timelag + 1)):
        input_cause.append(X[cause].values[i: i + timelag])
        input_target.append(X[target].values[i: i + timelag])
        Y.append([X[target][i + timelag + 1]])

    return torch.tensor(input_cause, device=device).float(), torch.tensor(input_target,device=device).float(), torch.tensor(Y, device=device).float()


In [47]:
import pandas as pd
df = pd.read_csv('C:/Users/chanyoung/Desktop/Neural-GC-master/lorenz_96_10_10_1000.csv')
X2d = df[['a','b']]
torch.manual_seed(1234)
input_cause, input_target, Y = data_split(X2d, 'a', 'b', 10)

In [48]:
model1 = RBFnet(input_cause.size()[0], 1, input_cause.size()[1])
train_rbf(model1, input_cause, input_target, Y, 0.01, 300, device)

epoch 0 loss 1984.463623046875 :
epoch 1 loss 1305.23583984375 :
epoch 2 loss 831.1533203125 :
epoch 3 loss 528.258056640625 :
epoch 4 loss 344.78375244140625 :
epoch 5 loss 236.38897705078125 :
epoch 6 loss 176.8124542236328 :
epoch 7 loss 149.3320770263672 :
epoch 8 loss 141.1910858154297 :
epoch 9 loss 141.9173126220703 :
epoch 10 loss 144.8268585205078 :
epoch 11 loss 146.95358276367188 :
epoch 12 loss 147.05538940429688 :
epoch 13 loss 143.9816436767578 :
epoch 14 loss 137.0172882080078 :
epoch 15 loss 126.33810424804688 :
epoch 16 loss 113.26525115966797 :
epoch 17 loss 99.3742904663086 :
epoch 18 loss 85.97523498535156 :
epoch 19 loss 73.65156555175781 :
epoch 20 loss 62.84405517578125 :
epoch 21 loss 53.8270263671875 :
epoch 22 loss 46.70965576171875 :
epoch 23 loss 41.34923553466797 :
epoch 24 loss 37.36167907714844 :
epoch 25 loss 34.29926681518555 :
epoch 26 loss 31.757936477661133 :
epoch 27 loss 29.37957763671875 :
epoch 28 loss 26.96114158630371 :
epoch 29 loss 24.4794788

epoch 221 loss 9.047792264027521e-05 :
epoch 222 loss 7.948128768475726e-05 :
epoch 223 loss 7.677499525016174e-05 :
epoch 224 loss 8.86929192347452e-05 :
epoch 225 loss 9.160939953289926e-05 :
epoch 226 loss 7.673873915337026e-05 :
epoch 227 loss 8.428399451076984e-05 :
epoch 228 loss 8.960432751337066e-05 :
epoch 229 loss 0.000101084602647461 :
epoch 230 loss 9.046193008543923e-05 :
epoch 231 loss 9.382428834214807e-05 :
epoch 232 loss 9.94754591374658e-05 :
epoch 233 loss 8.243831689469516e-05 :
epoch 234 loss 8.581011206842959e-05 :
epoch 235 loss 8.24955859570764e-05 :
epoch 236 loss 8.266515214927495e-05 :
epoch 237 loss 9.991026308853179e-05 :
epoch 238 loss 9.121901530306786e-05 :
epoch 239 loss 9.145677177002653e-05 :
epoch 240 loss 9.478822175879031e-05 :
epoch 241 loss 0.0001143175977631472 :
epoch 242 loss 9.322172263637185e-05 :
epoch 243 loss 9.216415492119268e-05 :
epoch 244 loss 0.00011507054296089336 :
epoch 245 loss 0.00010681723506422713 :
epoch 246 loss 0.0001031416

[]

In [68]:
torch.mean(model1.networks[1].init_weight_cause)

tensor(0.5548, device='cuda:0', grad_fn=<MeanBackward0>)