In [3]:
%%writefile going_modular/engine.py

import torch
from tqdm.auto import tqdm
from collections import defaultdict

from typing import Tuple, Dict, List


def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
              loss_fn:torch.nn.Module,
              optimizer: torch.optim.Optimizer,
              device: torch.device) -> Tuple[float,float]:
    
    """Train a pytorch model for a single epoch
    
    args:
        model
        loss_fn
        optimizer
        device
        
    returns:
        (train_loss, train_accuracy)
    """
    model.train()
    
    train_loss, train_acc = 0,0
    
    for batch ,(X,y) in enumerate(dataloader):
        X,y = X.to(device), y.to(device)
        
        y_pred = model(X)
        
        y_pred_label = torch.softmax(y_pred,dim=1).argmax(dim=1)
        
        
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()
        train_acc += torch.eq(y_pred_label, y).sum().item() / len(y_pred_label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    train_loss /=  len(dataloader)
    train_acc /= len(dataloader)  
        
    return train_loss, train_acc 


def test_step(model: torch.nn.Module,
            dataloader:torch.utils.data.DataLoader,
            loss_fn: torch.nn.Module,
              device: torch.device
             ) -> tuple[float, float]:
    
    model.eval()
    
    test_loss, test_acc = 0,0
    
    with torch.inference_mode():
        for batch, (X,y) in enumerate(dataloader):
            X,y = X.to(device), y.to(device)
            
            test_pred_logits = model(X)
            y_pred_label = torch.softmax(test_pred_logits,dim=1).argmax(dim=1)
            
            test_loss += loss_fn(test_pred_logits, y).item()
            test_acc += torch.eq(y_pred_label, y).sum().item() / len(y_pred_label)
            
            
    test_loss = test_loss/len(dataloader)
    test_acc = test_acc/len(dataloader)

    return test_loss, test_acc
    
def train(model,
         train_dataloader,
         test_dataloader,
         loss_fn,
         optimizer,
         epochs,
         device):
    """train and test a model
    
    Args:
        model,
        train_dataloader,
        test_dataloader,
        loss_fn,
        optimizer,
        epochs,
        device
        
    Returns:
    A dictionary of training and testing loss as well as training and
    testing accuracy metrics. Each metric has a value in a list for 
    each epoch.
    In the form: {train_loss: [...],
                  train_acc: [...],
                  test_loss: [...],
                  test_acc: [...]} 
    """
    results = defaultdict(list)

    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model,
                                           train_dataloader,
                                           loss_fn,
                                           optimizer,
                                           device)
        test_loss, test_acc = test_step(model,
                 test_dataloader,
                 loss_fn,
                 device)

        print(f"epochs: {epoch}\ntrain_loss:{train_loss:.4f}\ntrain_acc: {train_acc:.4f}\ntest_loss: {test_loss:.4f}\ntest_acc: {test_acc:.4f}")

        results["train_loss"].append(train_loss)
        results["trian_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)

        return results
            
            
            
            
            

Overwriting going_modular/engine.py
