# Import

In [None]:
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning.pytorch as pl
import torchvision.transforms as transforms
import torchvision
from lightning.pytorch.tuner import Tuner
from utils import get_logger

# create logger
logger = get_logger()

logger.info("Setup Complete")

# seeting seeds for reproducibility
pl.seed_everything(42, workers=True)

# Learning Rate Finder

## Input

In [None]:
from model import TestNet
from datamodule import TestDataModule

# Model to use
model=TestNet()
# Trainer to use
trainer=pl.Trainer(
        accelerator="cpu",
        max_epochs=100,
        logger=False,
    )
# Data to use
datamodule=TestDataModule(32)

## Assessing & outputting
It's recommended to not pick the learning rate that achieves the lowest loss, but instead something in the middle of the sharpest downward slope (red point). It doesn't give you the perfect learning rate. It's only a help. Use it as a ballpark estimate where you should start to look for an actual learning rate in your code.

In [None]:
tuner = Tuner(trainer)
lr_finder = tuner.lr_find(model, datamodule=datamodule)
fig = lr_finder.plot(suggest=True)
fig.show()
logger.info("Learning rate suggested: "+str(lr_finder.suggestion()))