In [1]:
# !pip3 install torch torchvision tqdm matplotlib numpy

In [2]:
%load_ext autoreload

In [3]:
%autoreload 2

import torch
from torch import nn
from torchvision import transforms
from tqdm.auto import tqdm

from carvana import Carvana
from unet import Unet
import numpy as np
from wandb_logger import WanDBWriter
from trainer import train_block1

In [4]:
train_dataset = Carvana(
    root=".",
    transform=transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=128,
    shuffle=True,
    pin_memory=True,
    num_workers=14
)

In [5]:
from dataclasses import dataclass


@dataclass
class Config:
    wandb_project: str = 'Fast Pipelines'
    
    use_amp = True
    use_torch_scaler = False
    use_empty_scaler = False
    use_constant_scaler = False
    use_dynamic_scaler = True

    
config = Config()
logger = WanDBWriter(config)

[34m[1mwandb[0m: Currently logged in as: [33mtimothyxp[0m (use `wandb login --relogin` to force relogin)


In [6]:
import torch, torch.nn as nn
from utils import count_zero_grads


class GradScalerEmpty:
    def __init__(self):
        self.scalar = 2. ** 16
        
    def scale(self, loss):
        return loss
    
    def unscale_(self, optimizer):
        pass
    
    def update(self):
        pass
    
    def step(self, optimizer):
        optimizer.step()
        
        
class GradScalerCustom:
    def __init__(self, init_scale=2. ** 16, growth_factor: float = 2.0, dynamic=False):
        self.scale_rate = init_scale
        self.dynamic = dynamic
        self.growth_factor = growth_factor
        self.optimizer = None
        
    def scale(self, loss):
        return loss * self.scale_rate
    
    @torch.no_grad()
    def unscale_(self, optimizer):
        for param_group in optimizer.param_groups:
            for param in param_group['params']:
                param.grad = torch.where(
                    torch.isinf(param.grad),
                    torch.zeros_like(param.grad),
                    param.grad / self.scale_rate
                )

    def update(self):
        if not self.dynamic:
            return
        
        if self.optimizer is not None:
            if count_zero_grads(optimizer) > 0.001:
                self.scale_rate *= self.growth_factor
    
    def step(self, optimizer):
        optimizer.step()
        self.optimizer = optimizer

In [7]:
device = "cuda:1"
model = Unet().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-4
)
if config.use_torch_scaler:
    scaler = torch.cuda.amp.GradScaler()
elif config.use_empty_scaler:
    scaler = GradScalerEmpty()
elif config.use_constant_scaler:
    scaler = GradScalerCustom(dynamic=False)
elif config.use_dynamic_scaler:
    scaler = GradScalerCustom(init_scale=1., dynamic=True)

num_epochs = 5
for epoch in range(0, num_epochs):
    train_block1(train_loader, model, criterion, optimizer, epoch,
          num_epochs, device=device, scaler=scaler, config=config, wandb_logger=logger)

Loss: 0.6098Accuracy: 0.9490Epoch acc: 91.3414: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:21<00:00,  1.88it/s]
Loss: 0.5957Accuracy: 0.9745Epoch acc: 96.4878: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:18<00:00,  2.15it/s]
Loss: 0.5904Accuracy: 0.9805Epoch acc: 97.9182: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:18<00:00,  2.13it/s]
Loss: 0.5865Accuracy: 0.9850Epoch acc: 98.3601: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:18<00:00,  2.11it/s]
Loss: 0.5839Accuracy: 0.9867Epoch acc: 98.5816: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 

In [8]:
optimizer.param_groups[0]['params'][0].grad.dtype

torch.float32