### Goal
This code example shows you how to fruther customize the devtorch trainer - given you a little more more power over the training process to potentially accelerate training and make training more stable.

In [2]:
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

import devtorch
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
class ANNClassifier(devtorch.DevModel):
    
    def __init__(self, n_in, n_hidden, n_out):
        super().__init__()
        self.layer1 = nn.Linear(n_in, n_hidden, bias=False)
        self.layer2 = nn.Linear(n_hidden, n_out, bias=False)
        self.init_weight(self.layer1.weight, "glorot_uniform")
        self.init_weight(self.layer2.weight, "glorot_uniform")
    
    def forward(self, x):
        x = F.leaky_relu(self.layer1(x.flatten(1, 3)))
        return F.leaky_relu(self.layer2(x))

In [6]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST("../../data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST("../../data", train=False, download=True, transform=transform)

def loss(output, target, model):
    return F.cross_entropy(output, target.long())


# Full on trainer 
model = ANNClassifier(784, 4000, 10)
n_epochs=100
batch_size=128
lr=0.001
optimizer_func=torch.optim.Adam  # Can swap out for any other torch optimizer (see https://pytorch.org/docs/stable/optim.html)
scheduler_func=torch.optim.lr_scheduler.ExponentialLR  # Can specify a LR scheduler (see https://pytorch.org/docs/stable/optim.html)
device="cuda"  # Make GPU go brrrr or switch out for "cpu"
dtype=torch.float  # Changing the dtype to half precision (torch.half) can speedup training

# Grad clipping can stabilize training by preventing gradients from blowing up
# A good online-tutorial should cover these different types (https://machinelearningmastery.com/how-to-avoid-exploding-gradients-in-neural-networks-with-gradient-clipping/)
# Options are: "GRAD_VALUE_CLIP_PRE", "GRAD_VALUE_CLIP_POST" or "GRAD_NORM_CLIP"
grad_clip_type="GRAD_VALUE_CLIP_POST"
grad_clip_value=0.05  # Usually determined by trial-and-error

save_type="SAVE_DICT"  # Preferred default, you can also save as "SAVE_OBJECT"
# The name of the folder where all model params, logs and hyperparams are stored
# You will also need to define a root and call trainer.train(save=True)
id="my_awesome_model" 

# You can pass arguments to the torch optimizer, scheduler and data loader using this dicts
optimizer_kwargs={}
scheduler_kwargs={"gamma": 0.9}
loader_kwargs={}

trainer = devtorch.get_trainer(loss, 
                               model=model, 
                               train_dataset=train_dataset, 
                               n_epochs=10, 
                               batch_size=128, 
                               lr=0.001, 
                               optimizer_func=optimizer_func,
                               scheduler_func=scheduler_func,
                               device=device,
                               dtype=dtype,
                               grad_clip_type=grad_clip_type,
                               grad_clip_value=grad_clip_value,
                               scheduler_kwargs=scheduler_kwargs)
                               
trainer.train(save=False)

INFO:trainer:Completed epoch 0 with loss 149.45636756252497 in 7.6585s
INFO:trainer:Completed epoch 1 with loss 39.71987604862079 in 7.6597s
INFO:trainer:Completed epoch 2 with loss 22.427124134730548 in 7.6438s
INFO:trainer:Completed epoch 3 with loss 14.296442218270386 in 7.6427s
INFO:trainer:Completed epoch 4 with loss 9.480451079107297 in 7.6288s
INFO:trainer:Completed epoch 5 with loss 5.96110432043497 in 7.6378s
INFO:trainer:Completed epoch 6 with loss 3.871650189132197 in 7.6436s
INFO:trainer:Completed epoch 7 with loss 3.1468983609738643 in 7.6309s
INFO:trainer:Completed epoch 8 with loss 2.915643396354426 in 7.6351s
INFO:trainer:Completed epoch 9 with loss 1.5679306890942826 in 7.6324s


In [7]:
def eval_metric(output, target):
    return (torch.max(output, 1)[1] == target).sum().cpu().item()

scores = devtorch.compute_metric(model, test_dataset, eval_metric, batch_size=256)
print(f"Accuracy = {torch.Tensor(scores).sum()/len(test_dataset)}")

Accuracy = 0.984499990940094


**Exercise**: You could try fruther tweak the training arguments to improve the test score.