In [None]:
import os
import logging
from typing import List
import hydra
from omegaconf import DictConfig, OmegaConf
import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler
from torch.amp import autocast
from accelerate import Accelerator
import wandb
from ffcv.loader import Loader, OrderOption
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.transforms import RandomHorizontalFlip, Cutout, RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze
from ffcv.pipeline.operation import Operation
import torchvision

from models import create_model_from_config

class Trainer:
    def __init__(self, cfg: DictConfig):
        self.cfg = cfg
        self.setup_accelerator()
        self.setup_wandb()
        self.setup_dataloaders()
        
    def setup_accelerator(self):
        self.accelerator = Accelerator(
            mixed_precision='fp16' if self.cfg.training.mixed_precision else 'no',
            gradient_accumulation_steps=1,
        )
        self.device = self.accelerator.device
        
    def setup_wandb(self):
        if self.accelerator.is_main_process:
            wandb.init(
                project=self.cfg.project.name,
                tags=self.cfg.project.tags,
                notes=self.cfg.project.notes,
                config=OmegaConf.to_container(self.cfg, resolve=True),
            )
    
    def setup_dataloaders(self):
        # Create pipelines for FFCV
        label_pipeline: List[Operation] = [
            IntDecoder(),
            ToTensor(),
            Squeeze()
        ]
        
        image_pipeline_train: List[Operation] = [
            SimpleRGBImageDecoder(),
            RandomHorizontalFlip(),
            RandomTranslate(padding=2, fill=tuple(map(int, self.cfg.dataset.mean))),
            Cutout(4, tuple(map(int, self.cfg.dataset.mean))),
            ToTensor(),
            ToTorchImage(),
            Convert(torch.float16 if self.cfg.training.mixed_precision else torch.float32),
            torchvision.transforms.Normalize(self.cfg.dataset.mean, self.cfg.dataset.std),
        ]
        
        image_pipeline_val = [
            SimpleRGBImageDecoder(),
            ToTensor(),
            ToTorchImage(),
            Convert(torch.float16 if self.cfg.training.mixed_precision else torch.float32),
            torchvision.transforms.Normalize(self.cfg.dataset.mean, self.cfg.dataset.std),
        ]
        
        self.train_loader = Loader(
            self.cfg.dataset.train_dataset,
            batch_size=self.cfg.training.batch_size,
            num_workers=self.cfg.training.num_workers,
            order=OrderOption.RANDOM,
            drop_last=True,
            pipelines={
                'image': image_pipeline_train,
                'label': label_pipeline
            }
        )
        
        self.val_loader = Loader(
            self.cfg.dataset.val_dataset,
            batch_size=self.cfg.training.batch_size,
            num_workers=self.cfg.training.num_workers,
            order=OrderOption.SEQUENTIAL,
            drop_last=False,
            pipelines={
                'image': image_pipeline_val,
                'label': label_pipeline
            }
        )

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'models'