In [None]:
import torch
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm
import datetime

from models import Yolov1
from loss import YoloLoss
from dataset import CreateDataset, Compose
from utils import (intersection_over_union, non_max_suppression, mean_average_precision, 
                    plot_image, get_bboxes, convert_cellboxes, cellboxes_to_boxes, 
                    save_checkpoint, load_checkpoint)

In [None]:
labels = ['pipe', 'corner', 'flange', 'anode']

In [None]:
mean = 0.4732
std = 0.1271

transform = Compose([transforms.Resize((448, 448)), 
                    transforms.Grayscale(),
                    transforms.ToTensor(),
                    transforms.Normalize(mean, std),
                    ]) 

In [None]:
seed = 123
torch.manual_seed(seed)
DEVICE = "cuda" if torch.cuda.is_available else "cpu"

# Hyperparameters 
IMG_SIZE = [1080, 1920]
LEARNING_RATE = 2e-5
BATCH_SIZE = 16 
WEIGHT_DECAY = 0
EPOCHS = 10
NUM_WORKERS = 2
PIN_MEMORY = True
LOAD_MODEL = False
LOAD_MODEL_FILE = "model/all/default.pth"
IMG_DIR = "datasets/images"

FILE_DIR =  "datasets/info_all.json"
S=7
B=2
C=4

In [None]:
def train_fn(train_loader, model, optimizer, loss_fn, step, epoch, num_epochs):
    loop = tqdm(train_loader, leave=True)
    mean_loss = []

    for batch_idx, (x, y) in enumerate(loop):
        x, y = x.to(DEVICE), y.to(DEVICE)
        out = model(x)
        loss = loss_fn(out, y)
        mean_loss.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        
        optimizer.step()

        writer.add_scalar('Training Loss', loss, global_step=step)
        step += 1

        # update progress bar
        loop.set_description(f"Epoch [{epoch}/{num_epochs}]")
        loop.set_postfix(loss=loss.item())

    print(f"Mean loss was {sum(mean_loss)/len(mean_loss)}")
    scheduler.step(sum(mean_loss)/len(mean_loss))

    return step

In [None]:
model = Yolov1(config="test2", in_channels=1, split_size=S, num_boxes=B, num_classes=C).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
loss_fn = YoloLoss(split_size=S, num_boxes=B, num_classes=C)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=2, verbose=True)

writer = SummaryWriter(f'logs/all/'+ datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

if LOAD_MODEL:
    load_checkpoint(torch.load(LOAD_MODEL_FILE), model, optimizer)

dataset = CreateDataset(file_dir=FILE_DIR, img_dir=IMG_DIR, img_size=IMG_SIZE, 
                        split_size=S, num_boxes=B, num_classes=C, 
                        box_format="coco", bb_ratio=False, offset=1,
                        transform=transform)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(0.9*len(dataset)), len(dataset)-int(0.9*len(dataset))], generator=torch.Generator().manual_seed(42))
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, shuffle=True, drop_last=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, shuffle=True, drop_last=True)

In [None]:
step = 0
for epoch in range(EPOCHS):
    for x, y in train_loader:

        pred_boxes, target_boxes = get_bboxes(train_loader, model, iou_threshold=0.5, threshold=0.7, S=S, C=C)
        mean_avg_prec = mean_average_precision(pred_boxes, target_boxes, iou_threshold=0.5, box_format="midpoint")
        step = train_fn(train_loader, model, optimizer, loss_fn, step, epoch, EPOCHS)

        print(f"Train mAP: {mean_avg_prec}")
        writer.add_scalar('mean average precision', mean_avg_prec, global_step=step)

In [None]:
print("saving model")
save_model = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
torch.save(save_model, "model/all/default.pth")

In [None]:
for x, y in val_loader:
    x = x.to(DEVICE)
    for idx in range(BATCH_SIZE):
        bboxes = cellboxes_to_boxes(model(x), S=S, C=C)
        bboxes = non_max_suppression(bboxes[idx], iou_threshold=0.5, threshold=0.4, box_format="midpoint")
        plot_image(x[idx].permute(1,2,0).to("cpu"), bboxes, labels)
    break