In [2]:
import torch as t
import utilsd2 as utils
import typing
from typing import Callable, Iterable

# Optimiser Investigations
A series of various code snippets related to certain optimisers

## Parameter Groups
Pytorch allows you to specify parameter groups, meaning that you can have different hyperparameters operating on different parameters.

(Note: this exercise took me a lot longer than anticipated, and I had to look at the solution for some inspiration as well. Really sneaky bug where I forgot that generators exhaust everything. I get the high level idea of this, though; just that the code is not reflecting it)

In [103]:
class SGD:

    def __init__(self, params, **kwargs):
        '''Implements SGD with momentum.

        Accepts parameters in groups, or an iterable.

        Like the PyTorch version, but assume nesterov=False, maximize=False, and dampening=0
            https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD
        kwargs can contain lr, momentum or weight_decay
        '''
        self.param_groups = []
        default_param_values = dict(momentum = 0., weight_decay = 0.)

        checkParams = set()

        for param_group in params:
            param_group = {**default_param_values, **kwargs, **param_group}
            param_group['params'] = list(param_group['params'])
            param_group['bs'] = [t.zeros_like(p) for p in param_group['params']]
            assert param_group['lr'] is not None, "lr must be specified"

            self.param_groups.append(param_group)

            for param in param_group['params']:
                assert param not in checkParams, "WHY?!"
                checkParams.add(param)


        
            
        self.t = 1

        
        # check if parameters appear in more than one group?




    @t.inference_mode()
    def step(self):

        pg: dict
        for i, pg in enumerate(self.param_groups):
            params = pg.get('params')
            assert params is not None, 'must have parameter'
            lr = pg['lr']          
            momentum = pg['momentum']
            weight_decay = pg['weight_decay']
                       
            for j, param in enumerate(params):
                g = param.grad
                if weight_decay != 0:
                    g = g + weight_decay * param
                if momentum != 0:
                    if self.t > 1:
                        b = momentum * pg.get('bs')[j] + g
                    else:
                        b = g
                    g = b
                    pg.get('bs')[j] = b
                param -= lr * g 
        self.t = self.t + 1

    def zero_grad(self) -> None:
        for param_group in self.param_groups:
            for param in param_group.get('params'):
                param.grad = None

utils.test_sgd_param_groups(SGD)


Testing configuration:  [{'params': 'base'}, {'params': 'classifier', 'lr': 0.001}]

Testing configuration:  [{'params': 'base'}, {'params': 'classifier'}]

Testing configuration:  [{'params': 'base', 'lr': 0.01, 'momentum': 0.95}, {'params': 'classifier', 'lr': 0.001}]

Testing that your function doesn't allow duplicates (this should raise an error): 
Got an error, as expected.

All tests in `test_sgd_param_groups` passed!


## Readings
I then proceeded to look at Deep Double Descent and the Lottery Ticket Hypothesis. Really cool stuff!