In [None]:
import numpy as np

import torch as th
import hypergrad as hg

TARGET_DEVICE = th.device('cuda') if th.cuda.is_available() else th.device('cpu')

Code for training a simple equilibrium network with an "RNN-style" dynamics on a subset of Mnist data.

For more details refer to Section 4.2 of the paper
_On the iteration complexity of hypergradient computations_

In [None]:
# --------------------------------------------
# UTILS
# --------------------------------------------

def to_numpy(tensor):
    if isinstance(tensor, list):
        return [to_numpy(v) for v in tensor]
    else:
        return tensor.detach().to(th.device('cpu')).numpy()


def set_requires_grad(lst): [l.requires_grad_(True) for l in lst]


def acc(preds, targets):
    """Computes the accuracy"""
    return preds.argmax(dim=1).eq(targets).float().mean()


class NamedLists(list):
    def __init__(self, lst, names) -> None:
        super().__init__(lst)
        assert len(lst) == len(names)
        self.names = names

    def __getitem__(self, i):
        if isinstance(i, str):
            return self.__getattribute__(i)
        else:
            return super().__getitem__(i)


class TVT(NamedLists):  # train val & test
    def __init__(self, lst) -> None:
        super().__init__(lst, ['train', 'val', 'test'])
        self.train, self.val, self.test = lst


class DT(NamedLists):  # data & targets
    def __init__(self, lst) -> None:
        super().__init__(lst, ['data', 'targets'])
        self.data, self.targets = lst


class LA(NamedLists):  # loss and accuracy
    def __init__(self, lst):
        super().__init__(lst, ['loss', 'acc'])
        self.loss, self.acc = lst


def load_mnist(seed=0, num_train=50000, num_valid=10000):
    """Load MNIST dataset with given number of training and validation examples"""
    from torchvision import datasets
    rnd = np.random.RandomState(seed)
    mnist_train = datasets.MNIST('../data', download=True, train=True)
    train_indices = rnd.permutation(list(range(60000)))
    dta, targets = mnist_train.data, mnist_train.targets

    # print(train_indices)
    tr_inds = train_indices[:num_train]
    mnist_tr1 = DT([dta[tr_inds], targets[tr_inds]])

    val_inds = train_indices[num_train:num_train + num_valid]
    mnist_valid = DT([dta[val_inds], targets[val_inds]])

    mnist_test = datasets.MNIST('../data', download=True, train=False)

    def _process_dataset(dts):
        dt, tgt = np.array(dts.data.numpy(), dtype=np.float32), dts.targets.numpy()
        return DT([th.from_numpy(
            np.reshape(dt / 255., (-1, 28 * 28))).to(TARGET_DEVICE),
                   th.from_numpy(tgt).to(TARGET_DEVICE)])

    return TVT([_process_dataset(dtt) for dtt in [mnist_tr1, mnist_valid, mnist_test]])


In [None]:
i_sig = 0.01  # initialization
dw = 200  # dimensionality of the hidden state

lr = 0.5

th.manual_seed(0)
data = load_mnist(0, num_train=5000, num_valid=5000)
num_exp, dim_x = data.train.data.shape

In [None]:
do_projection = True

T = K = 20  # number of iterations; T for forward iterations, K for backward;

In [None]:
# choose between
# rm (reverse-mode iterative differentiation),
# fp (fixed point implicit differentiation) and
# cg (conjugate gradient implicit differentiation)

hg_mode = 'rm'


In [None]:

def matrix_projection_on_spectral_ball(a, radius=0.99, project=True):
    A = a.detach()
    if A.is_cuda: A = A.cpu()
    A = A.numpy()
    U, S, V = np.linalg.svd(A)
    if project:
        S1 = np.minimum(S, radius)
        a = U @ np.diag(S1) @ V
    else:
        a = A
    return th.from_numpy(a).type(th.FloatTensor).to(TARGET_DEVICE).requires_grad_(True), S


In [None]:
initial_states = TVT([th.zeros(d.data.shape[0], dw, device=TARGET_DEVICE) for d in data])

if hg == 'rm': set_requires_grad(initial_states)  # necessary only for reverse-mode with unrolling

# define model's parameters
parameters = [
    i_sig * th.randn(dw, dw, device=TARGET_DEVICE),
    i_sig * th.randn(dim_x, dw, device=TARGET_DEVICE),
    i_sig * th.randn(dw, device=TARGET_DEVICE),
    i_sig * th.randn(dw, 10, device=TARGET_DEVICE),
    th.zeros(10, device=TARGET_DEVICE)
]
set_requires_grad(parameters)


In [None]:
def get_fully_connected_dynamics(x):
    def fully_connected_dynamics(state_list, params):
        # RNNs like dynamics (the fp_map of the bi-level problem)
        A, B, c = params[:3]
        state = state_list[0]
        return [th.tanh(state @ A + x @ B + c)]

    return fully_connected_dynamics


get_dynamics = get_fully_connected_dynamics  # change this line for changing type of dynamics

# obtain one dynamics per set (training, validation and test) which is a callable
tvt_dynamics = TVT([get_dynamics(dt.data) for dt in data])

In [None]:

def linear(state_list, params):
    return state_list[0] @ params[-2] + params[-1]


def get_loss(targets):
    def loss(state_list, params):
        # cross entropy loss (the outer loss of the bi-level problem)
        outputs = linear(state_list, params)
        criterion = th.nn.CrossEntropyLoss()
        return th.mean(criterion(outputs, targets))

    return loss


# obtain one loss per dataset (note: the losses remain callable as well as the dynamics!).
tvt_losses = TVT([get_loss(dt.targets) for dt in data])


In [None]:
# forward pass
def get_forward(initial_state, dynamics):
    def forward():
        states = [[initial_state]]
        for _ in range(T):
            states.append(dynamics(states[-1], parameters))
        return states

    return forward


# one per dataset
tvt_forward = TVT([get_forward(s, dyna) for s, dyna in zip(initial_states, tvt_dynamics)])


def metric_after_fw(forward, metric):
    def _f():
        states = forward()
        return metric(states[-1], parameters)

    return _f


def accuracy(targets):
    def _f(states, params):
        return acc(linear(states, params), targets)

    return _f


# obtain callables for loss and accuracy for each set (after executing the model's dynamics)
tvt_metrics = TVT([
    LA([metric_after_fw(fww, lss), metric_after_fw(fww, accuracy(dt.targets))])
    for fww, lss, dt in zip(tvt_forward, tvt_losses, data)
])

In [None]:
# optimizer
opt = th.optim.SGD(parameters, lr, momentum=0.9)

In [None]:
# training!

for t in range(1000):
    opt.zero_grad()
    states = tvt_forward.train()

    # compute the hypergradient (with different methods)
    if hg_mode == 'fp':
        hg.fixed_point(states[-1], parameters, K, tvt_dynamics.train, tvt_losses.train)
    elif hg_mode == 'cg':
        hg.CG_normaleq(states[-1], parameters, K, tvt_dynamics.train, tvt_losses.train)
    elif hg_mode == 'rm':
        hg.reverse_unroll(states[-1], parameters, tvt_losses.train)
    else:
        raise NotImplementedError('{} not available!'.format(hg_mode))

    opt.step()

    try:  # perform projection
        A_proj, svl = matrix_projection_on_spectral_ball(parameters[0], project=do_projection)
        parameters[0].data = A_proj.data
    except (ValueError, np.linalg.LinAlgError) as e:
        print('there were nans most probably: aborting all')
        break

    if t % 20 == 0:
        valid_acc = to_numpy(tvt_metrics.val.acc())
        hgs = to_numpy([l.grad for l in parameters])

        print('Validation accuracy at iteration {}:'.format(t), valid_acc)  # update early stopping

