In [1]:
import datetime
import os
import random
from argparse import ArgumentParser
from pathlib import Path
from tqdm.autonotebook import tqdm

import lightning as L
import numpy as np
import torch
from odection import (
    SSD,
    CocoDataset,
    Loss,
    ResNet,
    SSDTransformer,
    collate_fn,
)
from odection.utils import Encoder, coco_classes, generate_dboxes
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter

  from tqdm.autonotebook import tqdm


In [13]:
L.fabric.utilities.seed.seed_everything(42)

Global seed set to 42


42

In [None]:
BS = 128
LR = 3e-4
# LR = LR * (BS / 32)
EPOCHS = 100
NUM_WORKERS = 4
MULTISTEP = [43, 54]
NMS_THRESHOLD = 0.5
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0005

DATA_PATH = Path("data/")
LOG_PATH = Path("logs/tensorboard/ssd")
SAVE_PATH = Path("models")
CHECKPOINT_PATH = SAVE_PATH / "ssd-checkpoint.pth"

LOG_PATH.mkdir(parents=True, exist_ok=True)
SAVE_PATH.mkdir(parents=True, exist_ok=True)

In [4]:
str_date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
writer = SummaryWriter(f"{LOG_PATH}/{str_date}")

In [5]:
# torch.set_float32_matmul_precision("high")
torch.set_float32_matmul_precision("medium")
fabric = L.Fabric(
    accelerator="auto",
    devices="auto",
    strategy="dp",
    precision="16-mixed",
    # loggers=SummaryWriter(f"{log_path}/{str_date}")
)
fabric.launch()

Using 16-bit Automatic Mixed Precision (AMP)


In [6]:
dboxes = generate_dboxes(model="ssd")

train_params = {
    "batch_size": BS,
    "shuffle": True,
    "drop_last": False,
    "num_workers": NUM_WORKERS,
    "collate_fn": collate_fn,
}
valid_params = {
    "batch_size": BS,
    "shuffle": False,
    "drop_last": False,
    "num_workers": NUM_WORKERS,
    "collate_fn": collate_fn,
}
train_set = CocoDataset(
    DATA_PATH,
    2017,
    "train",
    SSDTransformer(dboxes, (300, 300), val=False),
)
test_set = CocoDataset(
    DATA_PATH,
    2017,
    "val",
    SSDTransformer(dboxes, (300, 300), val=True),
)
train_loader = fabric.setup_dataloaders(
    DataLoader(
        train_set,
        **train_params,
    )
)
test_loader = fabric.setup_dataloaders(
    DataLoader(
        test_set,
        **valid_params,
    )
)

loading annotations into memory...
Done (t=6.09s)
creating index...
index created!
loading annotations into memory...
Done (t=0.18s)
creating index...
index created!


In [7]:
encoder = Encoder(dboxes)
model = SSD(backbone=ResNet(), num_classes=len(coco_classes))
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LR,
)
model, optimizer = fabric.setup(model, optimizer)
scheduler = MultiStepLR(optimizer=optimizer, milestones=MULTISTEP, gamma=0.1)
criterion = Loss(dboxes, device=fabric.device)

if CHECKPOINT_PATH.is_file():
    checkpoint = torch.load(CHECKPOINT_PATH)
    first_epoch = checkpoint["epoch"] + 1
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    scheduler.load_state_dict(checkpoint["scheduler"])
else:
    first_epoch = 0

for i_epoch, epoch in enumerate(range(first_epoch, EPOCHS)):
    model.train()
    num_iter_per_epoch = len(train_loader)
    progress_bar = tqdm(train_loader)
    for i, (img, _, _, gloc, glabel) in enumerate(progress_bar):
        ploc, plabel = model(img)
        ploc, plabel = ploc.float(), plabel.float()
        gloc = gloc.transpose(1, 2).contiguous()
        loss = criterion(ploc, plabel, gloc, glabel)
        progress_bar.set_description(f"Epoch: {epoch + 1}. Loss: {loss.item():.5f}")
        writer.add_scalar("Train/Loss", loss.item(), epoch * num_iter_per_epoch + i)
        # loss.backward()
        fabric.backward(loss)
        optimizer.step()
        optimizer.zero_grad()
    scheduler.step()

    # evaluate(
    #     model,
    #     test_loader,
    #     epoch,
    #     writer,
    #     encoder,
    #     nms_threshold,
    # )

    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
    }
    torch.save(checkpoint, CHECKPOINT_PATH)



  0%|          | 0/925 [00:00<?, ?it/s]

  0%|          | 0/925 [00:00<?, ?it/s]

  0%|          | 0/925 [00:00<?, ?it/s]

  0%|          | 0/925 [00:00<?, ?it/s]

KeyboardInterrupt: 