In [5]:
import time
import torch
import torch.nn as nn
import numpy as np
from contextualized.modules import MLP
from contextualized.functions import identity_link, identity


n, x_dim, y_dim = 1000, 100, 20
X = torch.rand((n, x_dim)) * 2 - 1
Y = torch.rand((n, y_dim)) * 2 - 1

In [15]:
mse_loss = lambda y_true, y_pred: ((y_true - y_pred)**2).mean()

def train(model):
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    start = time.time()
    for _ in range(100):
        opt.zero_grad()
        loss = mse_loss(Y, model(X))
        loss.backward()
        opt.step()
    end = time.time()
    print(f"time: {end - start}")
    
train(MLP(x_dim, y_dim, 25, 2))

time: 0.13069415092468262


In [29]:
# Current implementation
class NGAM(nn.Module):
    """
    Neural generalized additive model
    """
    def __init__(self, input_dim, output_dim, width, layers, activation=nn.ReLU, link_fn=identity_link):
        super(NGAM, self).__init__()
        self.intput_dim = input_dim
        self.output_dim = output_dim
        self.nams = nn.ModuleList([MLP(1, output_dim, width, layers, activation=activation, link_fn=identity_link) for _ in range(input_dim)])
        self.link_fn = link_fn

    def forward(self, x):
        batch_size = x.shape[0]
        ret = torch.zeros((batch_size, self.output_dim))
        for i, nam in enumerate(self.nams):
            ret += nam(x[:, i].unsqueeze(-1))
        return self.link_fn(ret)
    
train(NGAM(x_dim, y_dim, 25, 1))

time: 3.776171922683716


In [26]:
# Current implementation
class NGAM(nn.Module):
    """
    Neural generalized additive model
    """
    def __init__(self, input_dim, output_dim, width, layers, activation=nn.ReLU, link_fn=identity_link):
        super(NGAM, self).__init__()
        self.intput_dim = input_dim
        self.output_dim = output_dim
        self.nams = nn.ModuleList([MLP(1, output_dim, width, layers, activation=activation, link_fn=identity_link) for _ in range(input_dim)])
        self.link_fn = link_fn
        self.register_buffer('ret_buffer', torch.zeros(input_dim, 1000, output_dim))

    def forward(self, x):
        batch_size = x.shape[0]
        ret = [] # self.nams[0](x[:, 0].unsqueeze(-1))
        for i, nam in enumerate(self.nams):
            ret += [nam(x[:, i].unsqueeze(-1)).unsqueeze(0)]
        return self.link_fn(torch.cat(ret, dim=0).sum(dim=0))
    
train(NGAM(x_dim, y_dim, 25, 1))

time: 5.214995861053467


In [32]:
class FastNGAM(nn.Module):
    """
    Fast training neural generalized additive model: requires extra memory
    """
    def __init__(self, input_dim, output_dim, width, layers, activation=nn.ReLU, link_fn=identity_link):
        super(NGAM, self).__init__()
        self.intput_dim = input_dim
        self.output_dim = output_dim
        self.weights1 = nn.parameter.Parameter(torch.rand() * 2e-2 - 1e-2, requires_grad=True)
        self.nams = nn.ModuleList([MLP(1, output_dim, width, layers, activation=activation, link_fn=identity_link) for _ in range(input_dim)])
        self.link_fn = link_fn
        self.register_buffer('ret_buffer', torch.zeros(input_dim, 1000, output_dim))

    def forward(self, x):
        ret = self.nams[0](x[:, 0].unsqueeze(-1))
        for i, nam in enumerate(self.nams[1:]):
            ret += nam(x[:, i].unsqueeze(-1))
        return self.link_fn(ret)
    
train(NGAM(x_dim, y_dim, 25, 1))

time: 3.8774161338806152
