In [1]:
"""
use this to run a simple grid search
"""

import datetime
import os
import time

import presets
import torch
import torch.utils.data
import torchvision
import torchvision.models.detection
import torchvision.models.detection.mask_rcnn
import utils
from coco_utils import get_coco, get_coco_kp
from engine import evaluate, train_one_epoch
from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler
from torchvision.transforms import InterpolationMode
from transforms import SimpleCopyPaste

print('is_available: ', torch.cuda.is_available())
print('device_count: ', torch.cuda.device_count())
print('current_device: ', torch.cuda.current_device())
print('current_device: ', torch.cuda.device(0))
print('get_device_name: ', torch.cuda.get_device_name(0))


is_available:  True
device_count:  1
current_device:  0
current_device:  <torch.cuda.device object at 0x000001816BACC7C0>
get_device_name:  NVIDIA GeForce RTX 3090


In [None]:
def get_dataset(name, image_set, transform, data_path):
    paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)}
    p, ds_fn, num_classes = paths[name]

    ds = ds_fn(p, image_set=image_set, transforms=transform)
    return ds, num_classes

In [None]:

output_dir_path = r'C:\Users\endle\Desktop\pytorch-retinanet\outputdir'
data_dir_path = r"C:\Users\endle\Desktop\pytorch-retinanet\data"
dataset_type = 'coco'
model = "retinanet_resnet50_fpn"
device_type = "cuda"
batch_size = 8
epochs = 3
workers = 1
optimizer = "sgd"
norm_weight_decay = 0.9
momentum = 0.9
lr = 0.0005 #0.001
weight_decay = 1e-4
lr_step_size = 8


device = torch.device("cuda")
torch.use_deterministic_algorithms(True)

# Data loading code
print("Loading data")

ds_train, num_t = get_dataset()

dataset, num_classes = get_dataset(dataset_type, "train", get_transform(True, args), data_dir_path)
dataset_test, _ = get_dataset(dataset_type, "val", get_transform(False, args), data_dir_path)

print("Creating data loaders")
if args.distributed:
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
else:
    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

if args.aspect_ratio_group_factor >= 0:
    group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
    train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
else:
    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)

train_collate_fn = utils.collate_fn
if args.use_copypaste:
    if args.data_augmentation != "lsj":
        raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies")

    train_collate_fn = copypaste_collate_fn

data_loader = torch.utils.data.DataLoader(
    dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=train_collate_fn
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)

print("Creating model")
kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
if args.data_augmentation in ["multiscale", "lsj"]:
    kwargs["_skip_resize"] = True
if "rcnn" in args.model:
    if args.rpn_score_thresh is not None:
        kwargs["rpn_score_thresh"] = args.rpn_score_thresh
# model = torchvision.models.get_model(
#     args.model, weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs
# )

model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True, num_classes=num_classes, **kwargs)

model.to(device)
if args.distributed and args.sync_bn:
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

model_without_ddp = model
if args.distributed:
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
    model_without_ddp = model.module

if args.norm_weight_decay is None:
    parameters = [p for p in model.parameters() if p.requires_grad]
else:
    param_groups = torchvision.ops._utils.split_normalization_params(model)
    wd_groups = [args.norm_weight_decay, args.weight_decay]
    parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]

opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
    optimizer = torch.optim.SGD(
        parameters,
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
        nesterov="nesterov" in opt_name,
    )
elif opt_name == "adamw":
    optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
    raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.")

scaler = torch.cuda.amp.GradScaler() if args.amp else None

args.lr_scheduler = args.lr_scheduler.lower()
if args.lr_scheduler == "multisteplr":
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
elif args.lr_scheduler == "cosineannealinglr":
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
else:
    raise RuntimeError(
        f"Invalid lr scheduler '{args.lr_scheduler}'. Only MultiStepLR and CosineAnnealingLR are supported."
    )

print('what is argsresume: ', args.resume)

if args.test_only:
    torch.backends.cudnn.deterministic = True
    evaluate(model, data_loader_test, device=device)
    return

print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
    if args.distributed:
        train_sampler.set_epoch(epoch)

    #print('look at data_loader: ', data_loader) # <torch.utils.data.dataloader.DataLoader object at 0x00000188DE551940>
    #print('len of dataloader: ', len(data_loader))

    train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler)
    
    lr_scheduler.step()
    if args.output_dir:
        checkpoint = {
            "model": model_without_ddp.state_dict(),
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict(),
            "args": args,
            "epoch": epoch,
        }
        if args.amp:
            checkpoint["scaler"] = scaler.state_dict()
        utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
        utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))

    # evaluate after every epoch
    evaluate(model, data_loader_test, device=device)

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Training time {total_time_str}")