# Getting started with TNT

In [41]:
import logging
from types import SimpleNamespace
from typing import List, Tuple, Optional, Union, Any, Literal

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

import torchvision.transforms as T
from torchvision.datasets import FashionMNIST

from torcheval.metrics import MulticlassAccuracy, Mean

from torchtnt.framework import init_fit_state, State, fit, AutoUnit
from torchtnt.utils import get_timer_summary, init_from_env, seed
from torchtnt.utils.device import copy_data_to_device

from wandb_logger import WandbLogger

import timm

_logger: logging.Logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

Batch = Tuple[torch.Tensor, torch.Tensor]

In [29]:
def prepare_model(model_name:str, input_dim: int, device: torch.device) -> nn.Module:
    model = timm.create_model(model_name, 
                              pretrained=False, 
                              num_classes=10, 
                              in_chans=input_dim)
    return model.to(device)

In [30]:
def prepare_dataloaders(data_path:str, batch_size: int, num_workers: int) -> DataLoader:
    """Instantiate DataLoader"""
    train_tfms = T.Compose([
        T.RandomCrop(28, padding=1), 
        T.RandomHorizontalFlip(),
        T.ToTensor(),
    ])

    val_tfms = T.Compose([
        T.ToTensor(),
    ])

    tfms = {"train": train_tfms, "valid":val_tfms}
    train_ds = FashionMNIST(data_path, download=True, transform=tfms["train"])
    valid_ds = FashionMNIST(data_path, download=True, train=False, transform=tfms["valid"])
    
    train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, 
                               pin_memory=True, num_workers=num_workers)
    valid_dataloader = DataLoader(valid_ds, batch_size=batch_size*2, shuffle=False, 
                               num_workers=num_workers)
    return train_dataloader, valid_dataloader

In [50]:
class MyUnit(AutoUnit[Batch]):
    def __init__(
        self,
        *,
        module: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
        device: Optional[torch.device],
        log_frequency_steps: int = 10,
        precision: Optional[Union[str, torch.dtype]] = None,
        gradient_accumulation_steps: int = 1,
        detect_anomaly: bool = False,
        clip_grad_norm: Optional[float] = None,
        clip_grad_value: Optional[float] = None,
        use_wandb=False,
    ):
        super().__init__(
            module=module,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            device=device,
            log_frequency_steps=log_frequency_steps,
            precision=precision,
            gradient_accumulation_steps=gradient_accumulation_steps,
            detect_anomaly=detect_anomaly,
            clip_grad_norm=clip_grad_norm,
            clip_grad_value=clip_grad_value,
        )
        self.train_accuracy = MulticlassAccuracy(num_classes=10).to(device)
        self.valid_accuracy = MulticlassAccuracy(num_classes=10).to(device)
        self.train_loss = 0
        self.valid_loss = Mean()
        
        self.use_wandb = use_wandb

    # pyre-fixme[3]: See T137070928
    def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
        inputs, targets = data
        outputs = self.module(inputs)
        outputs = torch.squeeze(outputs)
        loss = torch.nn.functional.cross_entropy(outputs, targets)

        return loss, outputs

    def update_metrics(
        self,
        state: State,
        data: Batch,
        loss: torch.Tensor,
        outputs: Any,
    ) -> None:
        self.loss = loss
        _, targets = data
        self.train_accuracy.update(outputs, targets)

    def log(self, d):
        if self.use_wandb:
            wandb.log(d)
            
    def log_metrics(
        self, state: State, step: int, interval: Literal["step", "epoch"]
    ) -> None:
        self.log({"train_loss": self.loss.item()})

        accuracy = self.train_accuracy.compute()
        self.log({"train_accuracy": accuracy})

    def on_train_epoch_end(self, state: State) -> None:
        super().on_train_epoch_end(state)
        # reset the metric every epoch
        self.train_accuracy.reset()
        self.valid_accuracy.reset()
        self.valid_loss.reset()

## Train

In [56]:
config = SimpleNamespace(
    seed=42,
    model_name="resnet10t",
    path=".",
    input_dim=1,
    lr=1e-3,
    epochs=3,
    batch_size=512,
    num_workers=8,
)

In [57]:
seed(config.seed)

# device and process group initialization
device = init_from_env()

train_dl, valid_dl = prepare_dataloaders(config.path, config.batch_size, config.num_workers)

module = prepare_model(config.model_name, config.input_dim, device)
optimizer = AdamW(module.parameters(), lr=config.lr)
lr_scheduler = OneCycleLR(optimizer, max_lr=config.lr, total_steps=config.epochs*len(train_dl))
train_accuracy = MulticlassAccuracy(num_classes=10, device=device)

In [68]:
my_unit = MyUnit(
    module=module,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    device=device,
    log_frequency_steps=10,
    use_wandb=False
)

In [69]:
state = init_fit_state(
        train_dataloader=train_dl,
        eval_dataloader=valid_dl,
        max_epochs=config.epochs,
)

In [70]:
fit(state, my_unit)

INFO:torchtnt.framework.fit:Started fit with max_epochs=3 max_steps=None max_train_steps_per_epoch=None max_eval_steps_per_epoch=None evaluate_every_n_steps=None evaluate_every_n_epochs=1 
INFO:torchtnt.framework.train:Started train epoch
INFO:torchtnt.framework.evaluate:Started evaluate with max_steps_per_epoch=None
INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:Started train epoch
INFO:torchtnt.framework.evaluate:Started evaluate with max_steps_per_epoch=None
INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:Started train epoch
INFO:torchtnt.framework.evaluate:Started evaluate with max_steps_per_epoch=None
INFO:torchtnt.framework.train:Ended train epoch


In [63]:
import wandb
run = wandb.init(project="tnt", entity="capecape", config=config)

0,1
train_accuracy,▁▅▆▆▇▇▇███
train_loss,█▇▆▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▂▁▁▂▁▁▁▁▁▁▁▁▁

0,1
train_accuracy,0.86003
train_loss,0.41163


In [64]:
fit(state, my_unit)

INFO:torchtnt.framework.fit:Started fit with max_epochs=3 max_steps=None max_train_steps_per_epoch=None max_eval_steps_per_epoch=None evaluate_every_n_steps=None evaluate_every_n_epochs=1 
INFO:torchtnt.framework.train:Started train epoch
INFO:torchtnt.framework.evaluate:Started evaluate with max_steps_per_epoch=None
INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:Started train epoch
INFO:torchtnt.framework.evaluate:Started evaluate with max_steps_per_epoch=None
INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:Started train epoch
INFO:torchtnt.framework.evaluate:Started evaluate with max_steps_per_epoch=None
INFO:torchtnt.framework.train:Ended train epoch


In [65]:
wandb.finish()

0,1
train_accuracy,▁▁▁▁▁▁▂▂▂▂▂▂▄▄▄▄▄▄▄▄▄▄▄▄▅▅▇▇▆▆▆▆▆▆▇▇▇▇▇█
train_loss,▇██▇▇▇▆▄▆▄▆▆▄▅▄▆▄▄▄▄▃▄▄▄▃▂▃▃▃▃▃▁▃▄▂▂▃▂▁▁

0,1
train_accuracy,0.832
train_loss,0.4602


In [10]:
import wandb
run = wandb.init(project="tnt", entity="capecape", config=config)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcapecape[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
train(state, my_unit)

INFO:torchtnt.framework.train:Started train with max_epochs=10, max_steps=None, max_steps_per_epoch=None
INFO:torchtnt.framework.train:Started train epoch
INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:Started train epoch
INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:Started train epoch
INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:Started train epoch
INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:Started train epoch
INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:Started train epoch
INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:Started train epoch
INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:Started train epoch
INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:Started train epoch
INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt