<a href="https://colab.research.google.com/github/ridhimagarg/PyTorchBook/blob/main/Chapter2_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import datetime
import torch
import torch.optim as optim
import torch.nn as nn
import torch.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('fivethirtyeight')

In [5]:
def make_val_step(model, loss_fn):

  def perform_val_step(x, y):

    model.eval()

    yhat = model(x)

    loss = loss_fn(yhat, y)

    return loss.item()

  return perform_val_step

In [2]:
class StepByStep(object):
  def __init__(self, model, loss_fn, optimizer):
    self.model = model
    self.loss_fn = loss_fn
    self.optimizer = optimizer
    self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    self.model.to(self.device)

    self.train_loader = None
    self.val_loader = None
    self.writer = None

    self.losses = []
    self.val_losses = []
    self.total_epochs = 0

    self.train_step = self._make_train_step()
    self.val_step = self._make_val_step()

    self.mini_batch = self._mini_batch()

  def to(self, device):
    self.device = device
    self.model.to(self.device)

  def set_loaders(self, train_loader, val_loader=None):

    self.train_loader = train_loader
    self.val_loader = val_loader

  def set_tensorboard(self, name, folder='runs'):

    suffix = datetime.datetime.now('%Y%m%d%H%M%S')
    self.writer = SummaryWriter('{}/{}_{}'.format(folder, name, suffix))

  def make_train_step(self):

    def perform_train_step(x, y):

      self.model.train()
      yhat = self.model(x)

      loss = self.loss_fn(yhat, y)

      loss.backward()

      self.optimizer.step()
      self.optimizer.zero_grad()

      return loss.item()

    return perform_train_step

  def _make_val_step(self):
    # Builds function that performs a step in the validation loop
    def perform_val_step(x, y):
      # Sets model to EVAL mode
      self.model.eval()
      # Step 1 - Computes model's predicted output - forward pass
      yhat = self.model(x)
      # Step 2 - Computes the loss
      loss = self.loss_fn(yhat, y)
      # There is no need to compute Steps 3 and 4,
      # since we don't update parameters during evaluation
      return loss.item()
    return perform_val_step

  def _mini_batch(self, validation=False):

    if validation:
      data_loader = self.val_loader
      step = self.val_step

    else:
      data_loader = self.train_loader
      step = self.train_step

    if data_loader is None:
      return None

    mini_batch_losses = []
    for x_batch, y_batch in data_loader:
      x_batch = x_batch.to(self.device)
      y_batch = y_batch.to(self.device)

      mini_batch_loss = step(x_batch, y_batch)
      mini_batch_losses.append(mini_batch_loss)

    loss = np.mean(mini_batch_losses)

    return loss





    