In [1]:
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy

device = "cuda" if torch.cuda.is_available() else "cpu"

class RBFlayer(nn.Module):
    def __init__(self, timelag):
        super(RBFlayer, self).__init__()

        self.timelag = timelag

        device = "cuda" if torch.cuda.is_available() else "cpu"
        torch.cuda.manual_seed(0)

        self.init_weight_cause = nn.Parameter(torch.rand(self.timelag, device=device))
        self.init_weight_target = nn.Parameter(torch.rand(self.timelag, device=device))
        self.cause_clt = self.init_clt()
        self.cause_std = self.init_clt()
        self.target_clt = nn.Parameter(torch.rand(1, device=device))
        self.target_std = nn.Parameter(torch.rand(1, device=device))
        
        self.b = nn.Parameter(torch.rand(1,device = device))

    def init_clt(self):
        return nn.Parameter(torch.rand(1, device=device))

    def init_std(self):
        return nn.Parameter(torch.rand(1, device=device))

    def rbf(self, x, cluster, std):
        return torch.exp(-(x - cluster) * (x - cluster) / 2 * (std * std))
    
    def rbf_gradient(self, x, clt, std):
        return (-2 * (x - clt) / (std * std)) * (torch.exp(-(x - clt) * (x - clt) / 2 * (std * std)))
    
    def rbf_grad(self, x, type_ = 'cause'):
        
        # 1~ x.shape[0]-2th gradient list
        
        rbf_grad_list = []
        if type_ == "cause":
            for j in range(x.shape[0] - 2):
                rbf_grad_list.append(self.rbf_gradient(x[j+1], self.cause_clt, self.cause_std))
        
        else:
            for j in range(x.shape[0] - 2):
                rbf_grad_list.append(self.rbf_gradient(x[j+1], self.target_clt, self.target_std))
        
        return rbf_grad_list
    
    def rbf_num_grad(self, x, type_ = "cause"):
        
        rbf_grad_list = []
        if type_ == "cause":
            for j in range(x.shape[0] - 2):
                rbf_grad_list.append((self.rbf(x[j+2], self.cause_clt, self.cause_std) - self.rbf(x[j], self.cause_clt, self.cause_std))/ (x[j+2]-x[j]))
                
        else:
            for j in range(x.shape[0] - 2):
                rbf_grad_list.append((self.rbf(x[j+2], self.target_clt, self.target_std) - self.rbf(x[j], self.target_clt, self.target_std))/ (x[j+2]-x[j]))
                
        return rbf_grad_list


    def forward(self, cause, target):

        for i in range(len(cause)):
            if i == 0:
                a = self.rbf(cause[i], self.cause_clt, self.cause_std)
            else:
                a = torch.cat([a, self.rbf(cause[i], self.cause_clt, self.cause_std)], dim=0)
        cause = self.init_weight_cause * a

        for j in range(len(target)):
            if j == 0:
                b = self.rbf(target[j], self.target_clt, self.target_std)
            else:
                b = torch.cat([b, self.rbf(target[j], self.target_clt, self.target_std)], dim=0)
        target = self.init_weight_target * b
        
        pred = sum(cause) + sum(target) + self.b

        return cause, target, pred

In [None]:
def restore_parameters(model, best_model):
    '''Move parameter values from best_model to model.'''
    for params, best_params in zip(model.parameters(), best_model.parameters()):
        params.data = best_params

def train_rbf(model, input_causes, input_targets, Y, lr, epochs, lookback=5,device = device):
    # input_causes, input_targets : X
    # Y : Y
    model.to(device)
    loss_fn = nn.MSELoss(reduction='mean')
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    train_loss_list = []

    best_it = None
    best_model = None
    best_loss = np.inf
    time = torch.tensor([0,1,2,3,4,5,6,7,8,9], device = device)
    pred_list = []
    for epoch in range(epochs):
        
        
        # pred loss
        
        pred, cause_list, target_list = model(input_causes, input_targets)
        pred_list.append(pred)

        
        loss_ = sum([loss_fn(cause_list[i], input_cause[i]) for i in range(len(input_cause))])
        loss_2 = sum([loss_fn(target_list[i], input_target[i]) for i in range(len(input_cause))])

              
        # loss = loss_target

        loss = loss_ + loss_2 
        print("epoch {} cause loss {} :".format(epoch, loss_ / len(Y)))
        print("epoch {} target loss {} :".format(epoch, loss_2 / len(Y)))
        print("------------------------------------------------------")
        print()
        loss.backward()
        optimizer.step()
        model.zero_grad()

        
        mean_loss = loss / len(Y)
        train_loss_list.append(mean_loss)
        if mean_loss < best_loss:
            best_loss = mean_loss
            best_it = epoch
            best_model = deepcopy(model)
        elif (epoch - best_it) == lookback:
            if verbose:
                print('Stopping early')
            break

    restore_parameters(model, best_model)

    return train_loss_list , model, cause_list, target_list, pred_list