In [1]:
import numpy as np

import torch as th
import hypergrad as hg

TARGET_DEVICE = th.device('cuda') if th.cuda.is_available() else th.device('cpu')

Code for training a simple equilibrium network with an "RNN-style" dynamics on a subset of Mnist data.

For more details refer to Section 4.2 of the paper 
_On the iteration complexity of hypergradient computations_

In [2]:
# --------------------------------------------
# UTILS
# --------------------------------------------

def store(tensor):
    if isinstance(tensor, list):
        return [store(v) for v in tensor]
    else:
        return tensor.detach().to(th.device('cpu')).numpy()
    

def set_requires_grad(lst): [l.requires_grad_(True) for l in lst]


def acc(preds, targets):
    """Computes the accuracy"""
    return preds.argmax(dim=1).eq(targets).float().mean()

    

class NamedLists(list):
    def __init__(self, lst, names) -> None:
        super().__init__(lst)
        assert len(lst) == len(names)
        self.names = names

    def __getitem__(self, i):
        if isinstance(i, str):
            return self.__getattribute__(i)
        else:
            return super().__getitem__(i)


class TVT(NamedLists):  # train val & test
    def __init__(self, lst) -> None:
        super().__init__(lst, ['train', 'val', 'test'])
        self.train, self.val, self.test = lst


class DT(NamedLists):  # data & targets
    def __init__(self, lst) -> None:
        super().__init__(lst, ['data', 'targets'])
        self.data, self.targets = lst


class LA(NamedLists):  # loss and accuracy
    def __init__(self, lst):
        super().__init__(lst, ['loss', 'acc'])
        self.loss, self.acc = lst


def load_mnist(seed=0, num_train=50000, num_valid=10000):
    """Load MNIST dataset with given number of training and validation examples"""
    from torchvision import datasets
    rnd = np.random.RandomState(seed)
    mnist_train = datasets.MNIST('../data', download=True, train=True)
    train_indices = rnd.permutation(list(range(60000)))
    dta, targets = mnist_train.data, mnist_train.targets

    # print(train_indices)
    tr_inds = train_indices[:num_train]
    mnist_tr1 = DT([dta[tr_inds], targets[tr_inds]])

    val_inds = train_indices[num_train:num_train + num_valid]
    mnist_valid = DT([dta[val_inds], targets[val_inds]])

    mnist_test = datasets.MNIST('../data', download=True, train=False)

    def _process_dataset(dts):
        dt, tgt = np.array(dts.data.numpy(), dtype=np.float32), dts.targets.numpy()
        return DT([th.from_numpy(
            np.reshape(dt / 255., (-1, 28 * 28))).to(TARGET_DEVICE),
                   th.from_numpy(tgt).to(TARGET_DEVICE)])

    return TVT([_process_dataset(dtt) for dtt in [mnist_tr1, mnist_valid, mnist_test]])

In [3]:
i_sig = 0.01  # initialization 
dw = 200# dimensionality of the hidden state 

lr = 0.5


th.manual_seed(0)
data = load_mnist(0, num_train=5000, num_valid=5000)
num_exp, dim_x = data.train.data.shape




In [4]:
do_projection = True  

T = K = 20  # number of iterations; T for forward iterations, K for backward;

In [5]:
# choose between 
# rm (reverse-mode iterative differentiation),
# fp (fixed point implicit differentiation) and
# cg (conjugate gradient implicit differentiation)

hg_mode = 'rm'  


In [6]:
# functions that define the dynamics (see eq )
def phi_simple_RNN_like(x):
    def _phi(w, lmd):
        A, B, c = lmd[:3]
        ww = w[0]
        return [th.tanh(ww @ A + x @ B + c)]

    return _phi

phi = phi_simple_RNN_like  # change this line for changing type of dynamics 

# obtain one dynamics per set (training, validation and test) which is a callable
PHIs = TVT([phi(dt.data) for dt in data])


def matrix_projection_on_spectral_ball(a, radius=0.999, project=True):
    A = a.detach()
    if A.is_cuda: A = A.cpu()
    A = A.numpy()
    U, S, V = np.linalg.svd(A)
    if project:
        S1 = np.minimum(S, radius)
        a = U @ np.diag(S1) @ V
    else:
        a = A
    return th.from_numpy(a).type(th.FloatTensor).to(TARGET_DEVICE).requires_grad_(True), S

In [7]:
# define initial state
w0s = TVT([th.zeros(d.data.shape[0], dw, device=TARGET_DEVICE) for d in data])

if hg == 'rm': set_requires_grad(w0s)  # necessary only for reverse-mode ITD

# define model's parameters
lmbd = [
    i_sig * th.randn(dw, dw, device=TARGET_DEVICE),
    i_sig * th.randn(dim_x, dw, device=TARGET_DEVICE),
    i_sig * th.randn(dw, device=TARGET_DEVICE),
    i_sig * th.randn(dw, 10, device=TARGET_DEVICE),
    th.zeros(10, device=TARGET_DEVICE)
]
set_requires_grad(lmbd)

In [8]:
# define the linear model for computing the output
def out_lin_mod(ww, lmd):
    D, e1 = lmd[-2], lmd[-1]
    return ww[0] @ D + e1

# cross entropy losses 
def build_loss(tgts):
    def loss(ww, lmd):
        ce = th.nn.CrossEntropyLoss()
        outputs = out_lin_mod(ww, lmd)
        return th.mean(ce(outputs, tgts))

    return loss

# obtain one loss per dataset (note: the losses remain callable as well as the dynamics!).   
lss = TVT([build_loss(dt.targets) for dt in data])

In [9]:
# forward pass
def fw(w0, the_phi):
    def _f():
        _vals = [[w0]]
        for k in range(T):
            _vals.append(the_phi(_vals[-1], lmbd))
        return _vals

    return _f

# one per dataset
fws = TVT([fw(ww, ph) for ww, ph in zip(w0s, PHIs)])


def stat_after_fw(fww, which):
    def _f():
        _vls = fww()
        return which(_vls[-1], lmbd)
    return _f

def accuracy(tgts):
    def _f(w, lmd):
        return acc(out_lin_mod(w, lmd), tgts)
    return _f

# obtrain callables for loss and accuracy for each training set (after executing the model's dyanmics)
lss_acs_after_fw = TVT([
    LA([stat_after_fw(fww, lss), stat_after_fw(fww, accuracy(dt.targets))])
    for fww, lss, dt in zip(fws, lss, data)
])

In [10]:
# optimizer
opt = th.optim.SGD(lmbd, lr, momentum=0.9)

In [12]:
# training!

for t in range(1000):
    vals = fws.train()  # compute the w_T

    # compute the hypergradient (with different methods)
    if hg_mode == 'fp':
        hg.fixed_point(vals[-1], lmbd, K, PHIs.train, lss.train)
    elif hg_mode == 'cg':
        hg.CG_normaleq(vals[-1], lmbd, K, PHIs.train, lss.train)
    elif hg_mode == 'rm':
        hg.reverse_unroll(vals[-1], lmbd, lss.train)
    else:
        raise NotImplementedError('{} not available!'.format(hg_mode))

    opt.step()

    try:  # perform projection
        A_proj, svl = matrix_projection_on_spectral_ball(lmbd[0], project=do_projection)
        lmbd[0].data = A_proj.data
    except (ValueError, np.linalg.LinAlgError) as e:
        print('there were nans most probably: aborting all')
        break

    if t % 20 == 0:
        valid_acc = store(lss_acs_after_fw.val.acc())
        hgs = store([l.grad for l in lmbd])

        print('Validation accuracy at iteration {}:'.format(t), valid_acc)  # update early stopping


Validation accuracy at iteration 0: 0.5084
Validation accuracy at iteration 20: 0.5118
Validation accuracy at iteration 40: 0.111


KeyboardInterrupt: 