In [43]:
import time
from typing import List, Dict

import autograd.numpy as np

import torch
from torch import Tensor
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.datasets
import torch.nn as nn
import torch.optim as optim


Train and test functions for our network / resnet18

In [44]:
def train(model: nn.Module, loss_fn: nn.modules.loss._Loss, optimizer: torch.optim.Optimizer, train_loader: torch.utils.data.DataLoader, epoch: int=0)-> List:

    train_loss = []
    model.train()
    for batch_idx, (images, targets) in enumerate(train_loader ):
        images = images.to(device)
        targets = targets.to(device)
        #model.to(device)
        optimizer.zero_grad()

        # forward + backward + optimiz
        t = torch.linspace(0, 64, steps=64)
        outputs = model(images,t)
        loss = loss_fn(outputs, targets)
        optimizer.step()
        train_loss.append(loss)
        if batch_idx % int(len(train_loader.dataset)/(10*len(images))) == 0 :
          print(f'Epoch {epoch}: [{batch_idx*len(images)}/{len(train_loader.dataset)}] Loss: {loss.item():.3f}')

    return train_loss



def test(model: nn.Module, loss_fn: nn.modules.loss._Loss, test_loader: torch.utils.data.DataLoader,epoch: int=0)-> Dict:

  model.eval()

  test_loss = 0
  correct = 0
  predictions = torch.Tensor()
  predictions = predictions.to(device)
  with torch.no_grad():
    for images, targets in test_loader:
      images = images.to(device)
      targets = targets.to(device)
      output = model(images)
      test_loss += loss_fn(output, targets).item()
      pred = output.data.max(1, keepdim=True)[1]
      predictions = torch.cat((predictions,pred),0)
      correct += pred.eq(targets.data.view_as(pred)).sum()


    test_loss /= len(test_loader)
    accuracy = correct / len(test_loader.dataset)
    total_num = len(test_loader.dataset)
    test_stat = {
      "loss": test_loss,
      "accuracy": accuracy,
      "prediction": predictions,
    }
    print(f"Test result on epoch {epoch}: total sample: {total_num}, Avg loss: {test_stat['loss']:.3f}, Acc: {100*test_stat['accuracy']:.3f}%")
    return test_stat

Data loader for Fashion mnist

In [45]:
# Define transformations for the dataset
Batch_size = 64


transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])

# Load Fashion MNIST dataset
trainset = torchvision.datasets.FashionMNIST('data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=Batch_size, shuffle=True)

testset = torchvision.datasets.FashionMNIST('data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=Batch_size, shuffle=False)

Defeintions of classes and methods for Resdiaul ODE netowrk with 1 ode layer


In [46]:
def grad_f(z, t, p, a, module):
  '''
  Calculates the gradient of the function f with respect to z, t and p (parameters theta).

  ---
  Inputs:
    z: np.array
      Hidden state
    t: float
      Time
    p: ##type##
      Parameters, theta
    a: ##type##
      Adjoint, a
    module: function
      Function f
  ---
  Outputs:
    adfdz: ##type##
      Derivative of f wrt z
    adfdt: ##type##
      Derivative of f wrt t
    adfdp: ##type##
      Derivative of f wrt theta
    f: ##type##
      Function f evaluated at given z and t
  '''
  with torch.set_grad_enabled(True):
    # ensure that we can find gradients using autograd
    z.requires_grad_(True)
    t.requires_grad_(True)
    p.requires_grad_(True)

    # calculate output f; inputs are z, t
    f = module(z, t)
    # p = module.parameters() # in case we need the parameters after function calculation ?
    # torch autograd grad with grad_outputs computes jacobian product
    adfdz = torch.autograd.grad(f, z, grad_outputs=(a), allow_unused=True)
    adfdt = torch.autograd.grad(f, t, grad_outputs=(a), allow_unused=True)
    adfdp = torch.autograd.grad(f, p, grad_outputs=(a), allow_unused=True)

  return adfdz, adfdt, adfdp, f

def aug_dynamics(z, t, theta, a, module):
  '''
  Defines dynamics of augmented state.
  ---
  Inputs:
    z: np.array
      Hidden state
    t: float
      Time
    theta: ##type##
      Dynamic parameters
    a: ##type##
      Adjoint, a
    delL_deltheta: ##type##
      Derivative of loss wrt theta
    delL_delt: ##type##
      Derivative of loss wrt t
    module: function
      Function f

  ---

  Returns:
    delz_delt: np.array
      Time derivative of state, z

    dela_delt: ##[type]##
      Time derivative of adjoint, a

    deldelL_deltheta_delt: ##[type]##
      Time derivative of loss gradient wrt dynamic parameters, delL_deltheta

  '''
  with torch.enable_grad():
      z = z.detach().requires_grad_(True)
      f = module(z, t, theta)  # Evaluate the ODE function at z, t, and theta

      # Compute gradients of f with respect to z and theta
      df_dz, df_dtheta = torch.autograd.grad(f, (z, theta), grad_outputs=a, create_graph=True)

      # Time derivative of state, z
      delz_delt = df_dz

      # Time derivative of adjoint, a
      dela_delt = -a.T @ df_dz  # Vector-Jacobian product

      # Time derivative of loss gradient wrt dynamic parameters, delL_deltheta
      deldelL_deltheta_delt = -a.T @ df_dtheta  # Vector-Jacobian product
  return delz_delt, dela_delt, deldelL_deltheta_delt




def ode_solve_augstate(s1, t0, t1, dynamics, z, theta, module):
  '''Solves the ODE for the augmented state dynamics.'''

  # Initializing parameters
  h = .5 # step size
  t = np.linspace(t0, t1, int((abs(t1 - t0))/h))

  # Initial variables at t1
  s = s1 # start with s1
  a = s[1] #adjoint is the second element of s


  # Integrating for each timestep
  for t in range(len(t)):

    # Euler method
    s = s - h * dynamics(z, t, theta, a, module)

    #Update z and a
    z = s[0] #hidden state is the first element of s
    a = s[1] #adjoint is the second element of s

  return s1



# we have to create a custom forward and backwards pass
def gradient_loss(theta, t0, t1, zt1, delL_delzt1, module):
  '''
  Reverse-mode derivative of an ODE initial value problem.
  Returns gradients of the loss.
  ---
  Inputs:
    theta: np.array
      Dynamic parameters

    t0: float
      Start time

    t1: float
      Stop time

    zt0:
      initial state

    zt1: np.array
      Final state

    delL_delzt1: ##[type]##
      Loss gradient at stop time

  ---

  Returns:
    delL_delzt0: ##[type]##
      Loss gradient at start time

    delL_deltheta: ##[type]##
      Loss gradient wrt dynamic parameters

    delL_delt0: ##[type]##
      Loss gradient wrt initial time

    delL_delt1: ##[type]##
      Loss gradient wrt stop time

  '''
  print("test b.1")
  # Calculate f(z(t1), t1, theta)
  ft1 = module(zt1, t1, theta)

  # Calculate gradient of loss wrt t1
  delL_delt1 = delL_delzt1.T@ft1

  # Define initial augmented state
  s1 = np.array([zt1, delL_delzt1, torch.zeros(theta.shape[0]), -delL_delt1]) #s1 = [zt1, delL_delzt1, delL_deltheta1, -delL_delt1]

  # Solve reverse-time ODE
  s0 = ode_solve_augstate(s1, t0, t1, aug_dynamics, zt1, theta, module)
  s0 = s0
  print("test b.2")
  #s0 = [zt0, delL_delzt0, delL_deltheta0, -delL_delt0]
  zt0 = s0[0]
  delL_delzt0 = s0[1]
  delL_deltheta = s0[2]
  delL_delt0 = -s0[3]

  # Return gradients
  return delL_delzt0, delL_deltheta, delL_delt0, delL_delt1



class ODEForwardBackward(torch.autograd.Function):
    @staticmethod
    def forward(ctx, z0, t, p, func):
      '''
      Finds the z values by solving the ODE
      ---
      Inputs:
        z0: initial state
        t: time vector
        p: parameters
        func: f(z(t), t, theta)
      ---
      Returns:
        z: filled z matrix
      '''
      # because we have our own backwards method, we do not need to track any gradients:
      # for now lets assume z is a vector ( tesnor ), no batch
      with torch.enable_grad():
        t_len = t.size(0)
        new_shape = (t_len, *z0.size())
        z = torch.zeros([t_len] + list(z0.size()), device=z0.device)
        z[0] = z0
        for i in range(t_len -1):
          # func is a nn.Module. it should apply forward when calling func(--)
          res = ode_solve(z0, t[i], t[i+1], func)
          z[i] = res
      # save the function and t, z, p for backwards pass
      ctx.func = func
      ctx.save_for_backward(z,t ) # !!! save the paramteters in time step you want !!!!
      ctx.params = p

      return  z[1]

    @staticmethod
    # pytorch AUTOMATICALLY gives us the loss gradient over the entire function.
    # therefore, from that gradient, we must return dldz0, dldt, and dldp
    # this is outlined in the appendix
    def backward(ctx, loss_grad):
      '''
      Finds our gradients wrt our inputs in forward pass
      ---
      Inputs:
        loss_grad: total loss gradient over z
      ---
      Returns:
        grad_z0: Loss gradient wrt z0
        grad_t: loss gradient wrt t
        grad_p: loss gradient wrt parameters
      '''

      #adjoint is dldz
      a = loss_grad
      # how can we get the gradients wrt our inputs?
      # dAaug/dt = -[a∂f/∂z, a∂f/∂θ, a∂f/∂t], integrate both sides
      with torch.enable_grad():
      # get our saved tensors
        func = ctx.func
        z, t, p = ctx.saved_tensors
        adfdz, adfdt, adfdp, out = grad_f(z, t, p, a, func)

        if adfdz[0] is None:
          adfdz = torch.zeros(z.size())
        if adfdt[0] is None:
          adfdt = torch.zeros(t.size())
        if adfdp[0] is None:
          adfdp = torch.zeros(p.size())

      # we can now solve for our gradients by using augmented dynamics
      # can we do it one by one too?
        for i in range(len(t) - 1, 1, -1):
          grad_z0 = adjoint_solve(a, t[i], t[i-1], -adfdz[-1])
          grad_t = adjoint_solve(t, t[i], t[i-1], -adfdt[-1])
          grad_p = adjoint_solve(p, t[i], t[i-1], -adfdp[-1])

      return grad_z0, grad_t, grad_p


#basic ODE solver ( maybe replace with RPK45 ) curretnly eurler method
def ode_solve(z0, t0, t1, f):
  h = .5 # step size
  t = np.linspace(t0, t1, int((abs(t1 - t0))/h))
  z = z0
  for i in range(len(t)-1):
    z = z + h * f(z, t[i])
  return z

def adjoint_solve(atf, t0, t1, grad_f):
  # solve the ode backwards
  h = 0.05
  steps = int(abs(t1 - t0) / h)
  t = np.linspace(t0, t1, steps + 1)
  z = atf
  for i in range(1, len(t)):
    time = t[i]
    z = z + h * grad_f
  return z


class ResidualBlock_ODE(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.norm = nn.BatchNorm2d(64)

    def forward(self, x, theta):
        if x.shape[0] == 2:
          # Sometimes the extra time channel makes it into this step which is a no no
          x = x[1]
        #basic reisdaul layer to be ODE magicify
        out = self.norm(F.relu(self.conv1(x)))
        return out

class ResODE(nn.Module):
    def __init__(self, ode):
        super().__init__()
        self.ode = ode

    def forward(self, z0, t=torch.tensor([0., 1.])):
        t = t.to(z0)

        # the magic awful break
        z = ODEForwardBackward.apply(z0, t, self.parameters , self.ode)
        return z

class ODE_model(nn.Module):
    def __init__(self, ode_step):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.batchnorm = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.ODESTEP = ode_step
        self.fc = nn.Linear(64*7 * 7, 10)
        for param in self.parameters():
            param.requires_grad = True

    def forward(self, x,t):
        # 2 down sampling layers
        x = self.pool(F.relu(self.batchnorm(self.conv1(x))))
        x = self.pool(F.relu(self.batchnorm(self.conv2(x))))
        # Resdiaul ODE layer
        x = self.ODESTEP(x)
        # fully connected layers
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x




In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

max_epoch = 3


#our model
t = torch.tensor([0., 1.])

Ode_model= ODE_model(ResODE(ResidualBlock_ODE()))
Ode_model = Ode_model.to(device)
#RESNET Time

Res = models.resnet18()
ODE_optim = torch.optim.SGD(Ode_model.parameters(), lr= 0.6, momentum=0.9)
optimizerRES = optim.SGD(Res.parameters(), lr=0.01, momentum=0.8)
criterion = nn.CrossEntropyLoss()


#train(Ode_net,criterion,optimizer,trainloader,epoch)

Ode_model = Ode_model.to(device)
for epoch in range(max_epoch):
  train(Ode_model,criterion,ODE_optim,trainloader,epoch)
test(Ode_model, criterion, testloader, epoch)

# # Need to remove the 3 channel input (RGB) to 1 channel greyscale
# Res.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)


# start = time.time()

# for epoch in range(max_epoch):
#   train(Res,criterion,optimizerRES,trainloader,epoch)
# test(Res, criterion, testloader, epoch)
# end = time.time()
# print(f'Finished Training after {end-start} s ')


Epoch 0: [0/60000] Loss: 2.304
Epoch 0: [5952/60000] Loss: 2.303
Epoch 0: [11904/60000] Loss: 2.303
Epoch 0: [17856/60000] Loss: 2.302
Epoch 0: [23808/60000] Loss: 2.304
Epoch 0: [29760/60000] Loss: 2.304
Epoch 0: [35712/60000] Loss: 2.301
Epoch 0: [41664/60000] Loss: 2.303
Epoch 0: [47616/60000] Loss: 2.303
