<!--BOOK_INFORMATION-->
<img align="left" style="width:80px;height:98px;padding-right:20px;" src="https://raw.githubusercontent.com/joe-papa/pytorch-book/main/files/pytorch-book-cover.jpg">

This notebook contains an excerpt from the [PyTorch Pocket Reference](http://pytorchbook.com) book by [Joe Papa](http://joepapa.ai); content is available [on GitHub](https://github.com/joe-papa/pytorch-book).

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/joe-papa/pytorch-book/blob/main/05_Customizing_PyTorch.ipynb)

# Chapter 5 - Customizing PyTorch

In [None]:
import torch 

def linear(input, weight, bias=None):

    if input.dim() == 2 and bias is not None:
        # fused op is marginally faster
        ret = torch.addmm(bias, input, weight.t())
    else:
        output = input.matmul(weight.t())
        if bias is not None:
            output += bias
        ret = output
    return ret

In [None]:
import torch.nn as nn
from torch import Tensor

class Linear(nn.Module):
    r"""Applies a linear transformation to the 
      incoming data: :math:`y = xA^T + b`
      Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to ``False``, the layer will 
          not learn an additive bias.
            Default: ``True``
    Attributes:
        weight: the learnable weights of the 
          module of shape
        bias:   the learnable bias of the 
          module of shape
    Examples::
        >>> m = nn.Linear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 30])
    """

    def __init__(self, in_features, 
                 out_features, bias): # <1>
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(
            torch.Tensor(out_features, 
                         in_features))
        if bias:
            self.bias = Parameter(
                torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, 
                              a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = \
              init._calculate_fan_in_and_fan_out(
                  self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input: Tensor) -> Tensor: # <2>
        return F.linear(input, 
                        self.weight, 
                        self.bias) # <3>

In [None]:
def complex_linear(in_r, in_i, w_r, w_i, b_i, b_r):
    out_r = (in_r.matmul(w_r.t()) 
              - in_i.matmul(w_i.t()) + b_r)
    out_i = (in_r.matmul(w_i.t()) 
              - in_i.matmul(w_r.t()) + b_i)

    return out_r, out_i

In [None]:
class ComplexLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight_r = \
          Parameter(torch.randn(out_features, 
                                in_features))
        self.weight_i = \
          Parameter(torch.randn(out_features, 
                                in_features))
        self.bias_r = Parameter(
                        torch.randn(out_features))
        self.bias_i = Parameter(
                        torch.randn(out_features))

    def forward(self, in_r, in_i):
        return F.complex_linear(in_r, in_i,
                 self.weight_r, self.weight_i,
                 self.bias_r, self.bias_i)

In [None]:
class ComplexLinearSimple(nn.Module):
    def __init__(self, in_features, out_features):
        super(ComplexLinearSimple, self).__init__()
        self.fc_r = Linear(in_features,
                           out_features)
        self.fc_i = Linear(in_features,
                           out_features)

    def forward(self,in_r, in_i):
        return (self.fc_r(in_r) - self.fc_i(in_i), 
               self.fc_r(in_i)+self.fc_i(in_r))

In [None]:
def my_relu(input, thresh=0.0):
    return torch.where(
              input > thresh, 
              input, 
              torch.zeros_like(input))

In [None]:
class MyReLU(nn.Module):
  def __init__(self, thresh = 0.0):
      super(MyReLU, self).__init__()
      self.thresh = thresh

  def forward(self, input):
      return my_relu(input, self.thresh)

In [None]:
import torch.nn.functional as F # <1>

class SimpleNet(nn.Module):
  def __init__(self, D_in, H, D_out):
    super(SimpleNet, self).__init__()
    self.fc1 = nn.Linear(D_in, H)
    self.fc2 = nn.Linear(H, D_out)

  def forward(self, x):
    x = F.relu(self.fc1(x)) # <2>
    return self.fc2(x)

In [None]:
class SimpleNet(nn.Module):
  def __init__(self, D_in, H, D_out):
    super(SimpleNet, self).__init__()
    self.net = nn.Sequential( # <1>
        nn.Linear(D_in, H),
        nn.ReLU(), # <2>
        nn.Linear(H, D_out)
    )

  def forward(self, x):
    return self.net(x)

In [None]:
# to test code above
model = SimpleNet(10,20,6)
model(torch.rand((10)))

tensor([ 0.1010,  0.0398, -0.0249, -0.0080, -0.2398,  0.0778],
       grad_fn=<AddBackward0>)

In [None]:
def complex_relu(in_r, in_i): # <1>
    return (F.relu(in_r), F.relu(in_i))

class ComplexReLU(nn.Module): # <2>
  def __init__(self):
      super(ComplexReLU, self).__init__()
      
  def forward(self, in_r, in_i):
      return complex_relu(in_r, in_i)

In [None]:
# to test code above
model = ComplexReLU()
r = torch.Tensor([-.5, .5, 2])
i = torch.Tensor([0, 4, -2])
print(F.relu(r))
model(r,i)

tensor([0.0000, 0.5000, 2.0000])


(tensor([0.0000, 0.5000, 2.0000]), tensor([0., 4., 0.]))

## Custom Model Architectures

In [None]:
class AlexNet(nn.Module):

    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, 
                      stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, 
                      padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, 
                      padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, 
                      padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, 
                      padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [None]:
from torch.hub import load_state_dict_from_url
model_urls = {
    'alexnet': 
    'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}

def alexnet(pretrained=False, 
            progress=True, **kwargs):
    model = AlexNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(
              model_urls['alexnet'],
              progress=progress)
        model.load_state_dict(state_dict)
    return model

In [None]:
# to test code above
model = alexnet(pretrained=True)

Downloading: "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-4df8aa71.pth


HBox(children=(FloatProgress(value=0.0, max=244418560.0), HTML(value='')))




## Custom Loss Functions

In [None]:
# dummy variable to get code below to run
outputs = torch.rand((10,10), requires_grad=True)
targets = torch.rand((10,10))

In [None]:
loss_fcn = nn.MSELoss()
loss = loss_fcn(outputs, targets)
loss.backward()

In [None]:
def mse_loss(input, target):
    return ((input-target)**2).mean()

class MSELoss(nn.Module):
    def __init__(self):
        super(MSELoss, self).__init__()

    def forward(self, input, target):
        return F.mse_loss(input, target)

In [None]:
# to test code above
criterion = MSELoss()
loss = criterion(outputs, targets)

In [None]:
def complex_mse_loss(input_r, input_i, 
                     target_r, target_i):
  return (((input_r-target_r)**2).mean(), 
          ((input_i-target_i)**2).mean())

class ComplexMSELoss(nn.Module):
    def __init__(self, real_only=False):
        super(ComplexMSELoss, self).__init__()
        self.real_only = real_only

    def forward(self, input_r, input_i, 
                target_r, target_i):
        if (self.real_only):
          return F.mse_loss(input_r, target_r)
        else:
          return complex_mse_loss(
              input_r, input_i, 
              target_r, target_i)

In [None]:
# to test code above
criterion = ComplexMSELoss()
loss = criterion(outputs, outputs, targets, targets)

## Custom Optimizers

In [None]:
from torch import optim

optimizer = optim.SGD(model.parameters(), 
                      lr=0.01, momentum=0.9)

In [None]:
optim.SGD([
        {'params': 
          model.features.parameters()},
        {'params': 
          model.classifier.parameters(), 
          'lr': 1e-3}
    ], lr=1e-2, momentum=0.9)

SGD (
Parameter Group 0
    dampening: 0
    lr: 0.01
    momentum: 0.9
    nesterov: False
    weight_decay: 0

Parameter Group 1
    dampening: 0
    lr: 0.001
    momentum: 0.9
    nesterov: False
    weight_decay: 0
)

In [None]:
from collections import defaultdict

class Optimizer(object):

    def __init__(self, params, defaults):
        self.defaults = defaults
        self.state = defaultdict(dict) # <1>
        self.param_groups = [] # <2>

        param_groups = list(params)
        if len(param_groups) == 0:
            raise ValueError(
                """optimizer got an 
                empty parameter list""")
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]

        for param_group in param_groups:
            self.add_param_group(param_group)

    def __getstate__(self):
        return {
            'defaults': self.defaults,
            'state': self.state,
            'param_groups': self.param_groups,
        }

    def __setstate__(self, state):
        self.__dict__.update(state)

    def zero_grad(self): # <3>
        r"""Clears the gradients of all 
        optimized :class:`torch.Tensor` s."""
        
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    p.grad.detach_()
                    p.grad.zero_()

    def step(self, closure): # <4>
        raise NotImplementedError

In [None]:
from torch.optim import Optimizer

class SimpleSGD(Optimizer):

    def __init__(self, params, lr='required'):
        if lr is not 'required' and lr < 0.0:
          raise ValueError(
            "Invalid learning rate: {}".format(lr))

        defaults = dict(lr=lr)
        super(SimpleSGD, self).__init__(
            params, defaults)

    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad
                p.add_(d_p, alpha=-group['lr'])

        return

In [None]:
optimizer = SimpleSGD(model.parameters(), 
                      lr=0.001)

In [None]:
optimizer = SimpleSGD([
                {'params': 
                 model.features.parameters()},
                {'params': 
                 model.classifier.parameters(), 
                 'lr': 1e-3}
            ], lr=1e-2)

## Custom Training Loops

In [None]:
# Dummy values to get code to run in the next cells
from torch.utils.data import DataLoader

n_epochs = 1
model = nn.Linear(10,10)
dataset = [(torch.rand(10),torch.rand(10))]*20
train_dataloader = DataLoader(dataset)

val_dataloader = DataLoader(dataset)
test_dataloader = DataLoader(dataset)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),lr=0.001)

In [None]:
for epoch in range(n_epochs):

    # Training
    for data in train_dataloader:
        input, targets = data
        optimizer.zero_grad()
        output = model(input)
        train_loss = criterion(output, targets)
        train_loss.backward()
        optimizer.step()

    # Validation
    with torch.no_grad():
      for input, targets in val_dataloader:
          output = model(input)
          val_loss = criterion(output, targets)

# Test
with torch.no_grad():
  for input, targets in test_dataloader:
      output = model(input)
      test_loss = criterion(output, targets)

In [None]:
for epoch in range(n_epochs):
    total_train_loss = 0.0 # <1>
    total_val_loss = 0.0  # <1>

    if (epoch == epoch//2):
      optimizer = optim.SGD(model.parameters(),
                            lr=0.001) # <3>
    # Training
    model.train() # <2>
    for data in train_dataloader:
        input, targets = data
        optimizer.zero_grad()
        output = model(input)
        train_loss = criterion(output, targets)
        train_loss.backward()
        optimizer.step()
        total_train_loss += train_loss # <1>

    # Validation
    model.eval() # <2>
    with torch.no_grad():
      for input, targets in val_dataloader:
          output = model(input)
          val_loss = criterion(output, targets)
          total_val_loss += val_loss # <1>

    print("""Epoch: {} 
          Train Loss: {} 
          Val Loss {}""".format( 
         epoch, total_train_loss, 
         total_val_loss)) # <1>

# Test
model.eval()
with torch.no_grad():
  for input, targets in test_dataloader:
      output = model(input)
      test_loss = criterion(output, targets)

Epoch: 0 
          Train Loss: 9.062652587890625 
          Val Loss 8.805281639099121
