In [23]:
import numpy as np
import torch
import matplotlib
from matplotlib import pyplot as plt
from grnewt import compute_Hg, nesterov_lrs, fullbatch_gradient
from grnewt import partition as build_partition

In [2]:
# Build dummy regression dataset

size_in = 5
size_out = 4
batch_size = 10

data_in = torch.randn(batch_size, size_in)
data_tar = torch.randn(batch_size, size_out)

dataset = torch.utils.data.TensorDataset(data_in, data_tar)
data_loader = torch.utils.data.DataLoader(dataset, batch_size)

In [3]:
# Define simple model

size_hidden = 6
act_function_cl = torch.nn.Tanh

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.hidden_layer = torch.nn.Linear(size_in, size_hidden)
        self.activation = act_function_cl()
        self.out_layer = torch.nn.Linear(size_hidden, size_out)

    def forward(self, x):
        x = self.hidden_layer(x)
        x = self.activation(x)
        return self.out_layer(x)

In [4]:
# Build model
model = Model()

# List of parameters
tup_params = list(model.parameters())

# Partition of the parameters
param_groups, name_groups = build_partition.canonical(model) # canonical, trivial, wb

# List of sizes of each subset
group_sizes = [len(pgroup['params']) for pgroup in param_groups]

# List of starting index and ending index of each subset
group_indices = [0] + list(np.cumsum(group_sizes))

# Show the partition
print(f'Partition with {len(name_groups)} subset(s).')
for idx, ngroup in enumerate(name_groups):
    print(f'Subset #{idx} (size = {len(ngroup)}):')
    for name in ngroup:
        print(f'    {name}')

Partition with 4 subset(s).
Subset #0 (size = 1):
    hidden_layer.weight
Subset #1 (size = 1):
    hidden_layer.bias
Subset #2 (size = 1):
    out_layer.weight
Subset #3 (size = 1):
    out_layer.bias


In [5]:
# Build losses

#loss_fn = torch.nn.MSELoss()    # order3
loss_fn = lambda x, y: (x - y).pow(2).mean().sqrt()
full_loss = lambda x, y: loss_fn(model(x), y)

In [10]:
gradient = fullbatch_gradient(model, loss_fn, tup_params, data_loader, batch_size)
direction = tuple(-grad for grad in gradient)

In [19]:
H, g, order3 = compute_Hg(tup_params, full_loss, data_in, data_tar, direction, 
           param_groups = param_groups, group_sizes = group_sizes, group_indices = group_indices)

In [21]:
damping_int = 1.
lrs, r_root, r_converged = nesterov_lrs(H, g, order3, damping_int = damping_int)

In [22]:
lrs

tensor([ 1.8101, 21.4407, -9.0042, -4.3080])