# Training
> training loop

In [None]:
#| default_exp training

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

In [None]:
#| export
from slg_generative.models.vae import AutoEncoder
from slg_generative.data.datasets import FashionMnistDataset
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn as nn
from tqdm import tqdm
import torch

In [None]:
#| export

class Trainer:
    "Trainer for VAE models"

    def __init__(self,
        model:AutoEncoder, # Model
        dataloader:torch.utils.data.DataLoader, # Dataloader
        loss_func:torch.nn.modules.loss._Loss, # Loss function
        optimizer:torch.optim.Optimizer, # Optimizer
        n_epochs:int, # Number of training epochs
        device:str # Device
    ):
        self.model = model
        self.dataloader = dataloader
        self.loss_func = loss_func
        self.optimizer = optimizer
        self.n_epochs = n_epochs
        self.device = device

    def fit(self):
        # training loop
        for epoch in tqdm(range(self.n_epochs)):
            for batch_idx, (x,y) in enumerate(self.dataloader):
                x = x.to(self.device)
                self.optimizer.zero_grad()
                x_hat = self.model(x)
                loss = self.loss_func(x_hat, x)
                loss.backward()
                self.optimizer.step()
                if batch_idx % 100:
                    print('\r Train Epoch: {}/{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        epoch+1,
                        self.n_epochs,
                        batch_idx * len(x), 
                        len(self.dataloader.dataset),
                        100. * batch_idx / len(self.dataloader), 
                        loss.cpu().data.item()), 
                        end='')


In [None]:
# device
# device = 'mps' if torch.backends.mps.is_available() else 'cpu' # or 'cuda' for nvidia gpus
device = 'cpu'
# data
ds = FashionMnistDataset(csv_file="~/Data/fashion-mnist/fashion-mnist_train.csv")
dl = torch.utils.data.DataLoader(ds,batch_size=128,shuffle=True)
# model
autoencoder = AutoEncoder().to(device)
# training params
n_epochs = 5
# optim
opt = Adam(autoencoder.parameters(), lr=1e-3)
# # mean square error loss 
loss_func = nn.MSELoss()

In [None]:
trainer = Trainer(autoencoder, dl, loss_func, opt, n_epochs, device)
trainer.fit()

  0%|          | 0/5 [00:00<?, ?it/s]



 20%|██        | 1/5 [00:04<00:19,  4.86s/it]



 40%|████      | 2/5 [00:07<00:10,  3.64s/it]



 60%|██████    | 3/5 [00:10<00:06,  3.26s/it]



 80%|████████  | 4/5 [00:13<00:03,  3.08s/it]



100%|██████████| 5/5 [00:16<00:00,  3.20s/it]






In [None]:
#| hide
import nbdev; nbdev.nbdev_export()