<a href="https://colab.research.google.com/github/trian-gles/pytorch-guitar-amp-modelling/blob/main/GuitarLSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
from torch import nn
import time

### Pre emphasis filter

In [93]:
filt = torch.nn.Conv1d(1, 1, 2)
filt.bias = torch.nn.Parameter(torch.zeros(1))
filt.weight = torch.nn.Parameter(torch.tensor([[[-0.95, 1.]]]))

pre_filt = torch.nn.Sequential(
    torch.nn.ConstantPad1d((1, 0), 0),
    filt
)

### loss fns

In [94]:
def esr_loss(y, y_hat):
  return torch.sum(torch.square(y - y_hat)) / torch.sum(torch.square(y))

def dc_loss(y, y_hat):
  return torch.square(torch.mean(y - y_hat)) / torch.mean(torch.square(y))

def combined_loss(y, y_hat):
  y, y_hat = pre_filt(y), pre_filt(y_hat)
  return esr_loss(y, y_hat) + dc_loss(y, y_hat)

In [88]:
pred = torch.arange(16).reshape((4, 4)).float()
exp = torch.zeros(16).reshape((4, 4)).float()
combined_loss(pred, exp)

tensor(1.4098, grad_fn=<AddBackward0>)

### Training

#### Hidden state and cell

In [101]:
from torch.autograd.grad_mode import no_grad
# Taken from https://discuss.pytorch.org/t/implementing-truncated-backpropagation-through-time/15500/3
class TBPTT_CELL():
    '''
    truncated backprop runner for modules with both hidden and cell states (such as LSTM)
    '''
    def __init__(self, one_step_module, loss_module, k1, k2, no_grad_steps, optimizer):
        self.one_step_module = one_step_module
        self.loss_module = loss_module
        self.k1 = k1
        self.k2 = k2
        self.no_grad_steps = no_grad_steps
        self.retain_graph = True
        # You can also remove all the optimizer code here, and the
        # train function will just accumulate all the gradients in
        # one_step_module parameters
        self.optimizer = optimizer

    def train(self, input_sequence, init_states):
        hidden_state, cell_state = init_states
        
        self.one_step_module.idx = 0
        hidden_states = [(None, hidden_state)]
        cell_states = [(None, cell_state)]
        target_sequence = []
        output_sequence = []

        for j, (inp, target) in enumerate(input_sequence[self.no_grad_steps:]):
            if (j < self.no_grad_steps):
              with torch.no_grad():
                print(f"no grad step {j}")
                self.one_step_module.idx += 1

                output, new_states = self.one_step_module(inp, (hidden_state, cell_state))
                hidden_state, cell_state = new_states

                hidden_states = [(None, hidden_state)]
                cell_states = [(None, cell_state)]
            else:
              hidden_state = hidden_states[-1][1].detach()
              cell_state = cell_states[-1][1].detach()


              hidden_state.requires_grad=True
              cell_state.requires_grad=True

              output, new_states = self.one_step_module(inp, (hidden_state, cell_state))
              new_hidden, new_cell = new_states

              hidden_states.append((hidden_state, new_hidden))
              cell_states.append((cell_state, new_cell))
              target_sequence.append(target)
              output_sequence.append(output)

              while len(cell_states) > self.k2:
                  # Delete stuff that is too old
                  del hidden_states[0]
                  del cell_states[0]
                  del output_sequence[0]
                  del target_sequence[0]

              if (j+1)%self.k1 == 0:
                  loss = self.loss_module(torch.cat(output_sequence), torch.cat(target_sequence))

                  optimizer.zero_grad()
                  # backprop last module (keep graph only if they ever overlap)
                  start = time.time()
                  loss.backward(retain_graph=self.retain_graph)
                  for i in range(self.k2-1):
                      # if we get all the way back to the "init_state", stop
                      if cell_states[-i-2][0] is None:
                          break
                      curr_grad_hidden = hidden_states[-i-1][0].grad
                      curr_grad_cell = cell_states[-i-1][0].grad

                      hidden_states[-i-2][1].backward(curr_grad_hidden, retain_graph=self.retain_graph)
                      cell_states[-i-2][1].backward(curr_grad_cell)

                  print("bw: {}".format(time.time()-start))
                  optimizer.step()


no_grad_steps = 200
seq_len = 2000
k1 = 200
k2 = 200
input_size = 1
output_size = 1
batch_size = 20

lstm_blocks = 1

idx = 0

class MyMod(nn.Module):
    def __init__(self, hidden_size=5):
        super(MyMod, self).__init__()
        self.lstm = torch.nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)
        self.lin = torch.nn.Linear(hidden_size, output_size)
        self.idx = 0
    def forward(self, inp, state):
        out, states = self.lstm(inp, state)
        out = self.lin(out)
        hidden_state, cell_state = states

        if out.requires_grad:
          def get_pr(idx_val):
              def pr(*args):
                  print("doing backward {}".format(idx_val))
              return pr
          hidden_state.register_hook(get_pr(self.idx))
          #cell_state.register_hook(get_pr(idx))
          out.register_hook(get_pr(self.idx))
          print("doing fw {}".format(self.idx))
          self.idx += 1
          out += inp # residual connection
        return out, states

hidden_size = 96
one_step_module = MyMod(hidden_size)
loss_module = combined_loss
input_sequence = [(torch.rand(batch_size, 1, input_size), torch.rand(batch_size, 1, output_size))] * seq_len

optimizer = torch.optim.Adam(one_step_module.parameters())

runner = TBPTT_CELL(one_step_module, loss_module, k1, k2, no_grad_steps, optimizer)

runner.train(input_sequence, (torch.zeros(1, batch_size, hidden_size), torch.zeros(1, batch_size, hidden_size)))
print("done")

no grad step 0
no grad step 1
no grad step 2
no grad step 3
no grad step 4
no grad step 5
no grad step 6
no grad step 7
no grad step 8
no grad step 9
no grad step 10
no grad step 11
no grad step 12
no grad step 13
no grad step 14
no grad step 15
no grad step 16
no grad step 17
no grad step 18
no grad step 19
no grad step 20
no grad step 21
no grad step 22
no grad step 23
no grad step 24
no grad step 25
no grad step 26
no grad step 27
no grad step 28
no grad step 29
no grad step 30
no grad step 31
no grad step 32
no grad step 33
no grad step 34
no grad step 35
no grad step 36
no grad step 37
no grad step 38
no grad step 39
no grad step 40
no grad step 41
no grad step 42
no grad step 43
no grad step 44
no grad step 45
no grad step 46
no grad step 47
no grad step 48
no grad step 49
no grad step 50
no grad step 51
no grad step 52
no grad step 53
no grad step 54
no grad step 55
no grad step 56
no grad step 57
no grad step 58
no grad step 59
no grad step 60
no grad step 61
no grad step 62
no