In [95]:
import numpy as np 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader

import matplotlib.pyplot as plt 
from tqdm.notebook import tqdm, trange

In [96]:
bs = 128

transform = transforms.Compose([
    transforms.ToTensor(),
])

mnist_train = datasets.MNIST(root="./data", train=True, 
                             download=True, transform=transform)
images, labels = mnist_train.data, mnist_train.targets

labels_idx = torch.where(torch.isin(labels, torch.tensor([1,7])))
images = images[labels_idx] / 255.
labels = labels[labels_idx]

ds = TensorDataset(images, labels)
train_loader = DataLoader(ds, batch_size=bs)

# mnist_test = datasets.MNIST(root="./data", train=False, 
#                             download=True, transform=transform)
# test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=bs)

In [334]:
class LogisticRegression(nn.Module):
    def __init__(self, f_in, f_out):
        super().__init__()
        self.fc = nn.Linear(f_in, f_out)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.fc(x)
        x = self.sigmoid(x)
        return x

In [335]:
model = LogisticRegression(28*28, 1)

In [341]:
optimizer = SGD(model.parameters(), lr=1e-3)
criterion = nn.BCELoss()

epochs = 50

model.train()
for epoch in (pbar := trange(epochs)):
    batch_losses = []
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()

        data = data.reshape(-1, 28*28)
        target = torch.where(target == 7, 1., 0.).reshape(-1,1)
        output = model(data)
        
        loss = F.binary_cross_entropy(output, target)
        # loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        
        # output = model(data.reshape(-1,28*28))
        # target_preds = torch.where(output < 0.5, 1, 7).reshape(-1)
        # acc = (target == target_preds).sum() / len(output)
        
        batch_losses.append(loss.item())
        
    epoch_loss = np.mean(batch_losses)
    pbar.set_description(f"Epoch: {epoch}, Average loss: {epoch_loss:.4f}")

  0%|          | 0/50 [00:00<?, ?it/s]

In [342]:
model

LogisticRegression(
  (fc): Linear(in_features=784, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [343]:
is_model_functional = False

def _set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        _set_attr(getattr(obj, names[0]), names[1:], val)

def _del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        _del_attr(getattr(obj, names[0]), names[1:])
        
def _param_names(model):
    param_names = tuple(name for name, _ in _model_params(model))
    return param_names
    
def _param_shapes(model):
    param_shapes = tuple(p.shape for _, p in _model_params(model))
    return param_shapes

def _model_params(model, with_names=True):
    assert not is_model_functional
    return tuple((name, p) if with_names else p for name, p in model.named_parameters() if p.requires_grad)

def _model_make_functional(model):
    assert not is_model_functional
    params = tuple(p.detach().requires_grad_() for p in _model_params(model, False))
    
    for name in model._param_names:
        # print("_del_attr", name.split("."))
        _del_attr(model, name.split("."))
    
    if_model_functional = True
    return params
    
def _flatten_params_like(params_like):
    vec = []
    for p in params_like:
        vec.append(p.view(-1))
    return torch.cat(vec)

def _reshape_like_params(vec):
    pointer = 0
    split_tensors = []
    for dim in model._param_shapes:
        num_param = dim.numel()
        split_tensors.append(vec[pointer:pointer+num_param].view(dim))
        pointer += num_param
    return tuple(split_tensors)

def _model_reinsert_params(model, params, register=False):
    for name, p in zip(model._param_names, params):
        # print("_set_attr", name.split("."))
        _set_attr(model, name.split("."), torch.nn.Parameter(p) if register else p)
        
    is_model_functional = not register

In [344]:
# params = _model_make_functional(model)
# flat_params = _flatten_params_like(params)

# _model_reinsert_params(model, _reshape_like_params(flat_params))

In [345]:
# after training...
model._param_shapes = _param_shapes(model)
model._param_names = _param_names(model)

params = _model_make_functional(model)
flat_params = _flatten_params_like(params)

d = flat_params.shape[0]
damp = 0.2
hess = 0
for batch_idx, (data, target) in enumerate(train_loader):
    
    def f(theta_):
        _model_reinsert_params(model, _reshape_like_params(theta_))
        output = model(data.reshape(-1, 28*28))
        # target = torch.where(target == 7, 1., 0.).reshape(-1,1)

        loss = F.binary_cross_entropy(output, torch.where(target == 7, 1., 0.).reshape(-1,1))
        return loss
    
    hess_batch = torch.autograd.functional.hessian(f, flat_params).detach()
    hess += hess_batch * len(data)
    
with torch.no_grad():
    _model_reinsert_params(model, _reshape_like_params(flat_params), register=True)
    hess /= len(train_loader)
    hess += damp * torch.eye(d) # TODO: why do we need this?
    
    check_eigvals = True
    if check_eigvals:
        eigvals = np.linalg.eigvalsh(hess.cpu().numpy())
        print(f"hessian min eigval {np.min(eigvals).item()}")
        print(f"hessian max eigval {np.max(eigvals).item()}")
        if not np.all(eigvals >= 0):
            raise ValueError()
            
    inverse_hess = torch.inverse(hess)

hessian min eigval 0.20000000298023224
hessian max eigval 123.18155670166016


In [346]:
inverse_hess

tensor([[5.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 5.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 5.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 5.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 5.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 2.3519]])

In [353]:
model.fc

Linear(in_features=784, out_features=1, bias=True)

In [350]:
def _loss_grad_loader_wrapper(model):
    # assume train_loss = test_loss
    params = _model_params(model)
    print(params)
    flat_params = _flatten_params_like(params)
    return flat_params

# def _loss_grad(idxs):
#     grad = 0
#     for grad_batch, batch_size
#         grad += grad_batch * batch_size
#     return grad / len(idxs)

In [351]:
_loss_grad_loader_wrapper(model)

(('fc.weight', Parameter containing:
tensor([[-1.5153e-02, -2.9503e-02,  2.3125e-02, -2.8672e-02,  2.9786e-02,
          1.5562e-02,  3.3335e-02,  2.3777e-03, -3.5537e-02, -2.0688e-02,
         -2.7439e-02, -1.5559e-02, -1.8334e-02, -2.7009e-03, -1.8452e-02,
          1.4241e-02,  2.0550e-02, -1.4958e-02, -2.9243e-02,  1.8480e-03,
         -1.8899e-02,  3.4272e-02,  2.0476e-02, -8.5623e-03,  3.5247e-02,
          1.4442e-02,  5.1014e-03,  1.4080e-02, -2.5379e-02, -2.9842e-03,
          4.2720e-03,  1.4757e-03, -6.0699e-03, -2.7218e-02,  2.9360e-02,
          1.9579e-02,  2.1962e-02, -5.8839e-03, -6.1890e-04, -3.5157e-03,
         -1.2685e-03, -2.3056e-02, -3.5493e-03, -1.6992e-03,  6.7121e-03,
         -2.0090e-02, -3.0533e-02,  1.2062e-02, -8.5477e-03,  4.0748e-03,
         -7.1618e-03,  2.7513e-02,  2.1717e-02, -1.7721e-02,  2.4279e-02,
         -3.2521e-02, -1.7433e-02,  2.2365e-02,  1.9440e-02, -2.6831e-02,
         -2.5402e-03, -3.0336e-02,  2.1283e-02, -2.6553e-02,  2.7480e-02,
 

AttributeError: 'tuple' object has no attribute 'view'

In [362]:
bs = 128

transform = transforms.Compose([
    transforms.ToTensor(),
])

mnist_train = datasets.MNIST(root="./data", train=True, 
                             download=True, transform=transform)
images, labels = mnist_train.data, mnist_train.targets

# mnist_train_7 = torch.utils.data.Subset(mnist_train, indices=torch.where(labels==7)[0])
mnist_train_7 = torch.utils.data.Subset(mnist_train, indices=[0,1,2,3,4])
mnist_train_7 = torch.utils.data.DataLoader(mnist_train_7, batch_size=bs, shuffle=True)

In [363]:
mnist_train_7.data

AttributeError: 'DataLoader' object has no attribute 'data'

In [370]:
for x, y in mnist_train_7:
    print(x)
    break

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
    

In [367]:
mnist_train_7.dataset

AttributeError: 'Subset' object has no attribute 'data'

In [373]:
a = torch.randn(4)
print(a)

torch.argsort(a)

tensor([-0.2087,  0.7829, -0.4156, -0.0695])


tensor([2, 0, 3, 1])