DEVICE = torch.device("cuda:0" if torch.cuda.is_available() and args.use_cuda else "cpu") if args.use_cuda and torch.cuda.is_available(): torch.backends.cudnn.benchmark = True logging.info("Use Cuda.") def train(loader, net, criterion, optimizer, device, debug_steps=100, epoch=-1): net.train(True) running_loss = 0.0 running_regression_loss = 0.0 running_classification_loss = 0.0 for i, data in enumerate(loader): images, boxes, labels = data images = images.to(device) boxes = boxes.to(device) labels = labels.to(device) optimizer.zero_grad() confidence, locations = net(images) regression_loss, classification_loss = criterion(confidence, locations, labels, boxes) # TODO CHANGE BOXES loss = regression_loss + classification_loss loss.backward() optimizer.step() running_loss += loss.item() running_regression_loss += regression_loss.item() running_classification_loss += classification_loss.item() if i and i % debug_steps == 0: avg_loss = running_loss / debug_steps avg_reg_loss = running_regression_loss / debug_steps avg_clf_loss = running_classification_loss / debug_steps # logging.info( # f"Epoch: {epoch}, Step: {i}, " + # f"Average Loss: {avg_loss:.4f}, " + # f"Average Regression Loss {avg_reg_loss:.4f}, " + # f"Average Classification Loss: {avg_clf_loss:.4f}" # ) running_loss = 0.0 running_regression_loss = 0.0 running_classification_loss = 0.0 return avg_loss def test(loader, net, criterion, device): net.eval() running_loss = 0.0 running_regression_loss = 0.0 running_classification_loss = 0.0 num = 0 for _, data in enumerate(loader): images, boxes, labels = data images = images.to(device) boxes = boxes.to(device) labels = labels.to(device) num += 1 with torch.no_grad(): confidence, locations = net(images) regression_loss, classification_loss = criterion(confidence, locations, labels, boxes) loss = regression_loss + classification_loss running_loss += loss.item() running_regression_loss += regression_loss.item() running_classification_loss += classification_loss.item() return running_loss / num, running_regression_loss / num, running_classification_loss / num def mymain(): #if __name__ == '__main__': timer = Timer() liveloss = PlotLosses() logs = {} logging.info(args) if args.net == 'vgg16-ssd': create_net = create_vgg_ssd config = vgg_ssd_config elif args.net == 'mb1-ssd': create_net = create_mobilenetv1_ssd config = mobilenetv1_ssd_config elif args.net == 'mb1-ssd-lite': create_net = create_mobilenetv1_ssd_lite config = mobilenetv1_ssd_config elif args.net == 'sq-ssd-lite': create_net = create_squeezenet_ssd_lite config = squeezenet_ssd_config elif args.net == 'mb2-ssd-lite': create_net = lambda num: create_mobilenetv2_ssd_lite(num, width_mult=args.mb2_width_mult) config = mobilenetv1_ssd_config else: logging.fatal("The net type is wrong.") parser.print_help(sys.stderr) sys.exit(1) train_transform = TrainAugmentation(config.image_size, config.image_mean, config.image_std) target_transform = MatchPrior(config.priors, config.center_variance, config.size_variance, 0.5) test_transform = TestTransform(config.image_size, config.image_mean, config.image_std) logging.info("Prepare training datasets.") datasets = [] for dataset_path in args.datasets: if args.dataset_type == 'voc': dataset = VOCDataset(dataset_path, transform=train_transform, target_transform=target_transform) label_file = os.path.join(args.checkpoint_folder, "voc-model-labels.txt") store_labels(label_file, dataset.class_names) num_classes = len(dataset.class_names) elif args.dataset_type == 'open_images': dataset = OpenImagesDataset(dataset_path, transform=train_transform, target_transform=target_transform, dataset_type="train", balance_data=args.balance_data) label_file = os.path.join(args.checkpoint_folder, "open-images-model-labels.txt") store_labels(label_file, dataset.class_names) logging.info(dataset) num_classes = len(dataset.class_names) else: raise ValueError(f"Dataset tpye {args.dataset_type} is not supported.") datasets.append(dataset) logging.info(f"Stored labels into file {label_file}.") train_dataset = ConcatDataset(datasets) logging.info("Train dataset size: {}".format(len(train_dataset))) train_loader = DataLoader(train_dataset, args.batch_size, num_workers=args.num_workers, shuffle=True) logging.info("Prepare Validation datasets.") if args.dataset_type == "voc": val_dataset = VOCDataset(args.validation_dataset, transform=test_transform, target_transform=target_transform, is_test=True) elif args.dataset_type == 'open_images': val_dataset = OpenImagesDataset(dataset_path, transform=test_transform, target_transform=target_transform, dataset_type="test") logging.info(val_dataset) logging.info("validation dataset size: {}".format(len(val_dataset))) val_loader = DataLoader(val_dataset, args.batch_size, num_workers=args.num_workers, shuffle=False) logging.info("Build network.") net = create_net(num_classes) min_loss = -10000.0 last_epoch = -1 base_net_lr = args.base_net_lr if args.base_net_lr is not None else args.lr extra_layers_lr = args.extra_layers_lr if args.extra_layers_lr is not None else args.lr if args.freeze_base_net: logging.info("Freeze base net.") freeze_net_layers(net.base_net) params = itertools.chain(net.source_layer_add_ons.parameters(), net.extras.parameters(), net.regression_headers.parameters(), net.classification_headers.parameters()) params = [ {'params': itertools.chain( net.source_layer_add_ons.parameters(), net.extras.parameters() ), 'lr': extra_layers_lr}, {'params': itertools.chain( net.regression_headers.parameters(), net.classification_headers.parameters() )} ] elif args.freeze_net: freeze_net_layers(net.base_net) freeze_net_layers(net.source_layer_add_ons) freeze_net_layers(net.extras) params = itertools.chain(net.regression_headers.parameters(), net.classification_headers.parameters()) logging.info("Freeze all the layers except prediction heads.") else: params = [ {'params': net.base_net.parameters(), 'lr': base_net_lr}, {'params': itertools.chain( net.source_layer_add_ons.parameters(), net.extras.parameters() ), 'lr': extra_layers_lr}, {'params': itertools.chain( net.regression_headers.parameters(), net.classification_headers.parameters() )} ] timer.start("Load Model") if args.resume: logging.info(f"Resume from the model {args.resume}") net.load(args.resume) elif args.base_net: logging.info(f"Init from base net {args.base_net}") net.init_from_base_net(args.base_net) elif args.pretrained_ssd: logging.info(f"Init from pretrained ssd {args.pretrained_ssd}") net.init_from_pretrained_ssd(args.pretrained_ssd) logging.info(f'Took {timer.end("Load Model"):.2f} seconds to load the model.') net.to(DEVICE) criterion = MultiboxLoss(config.priors, iou_threshold=0.5, neg_pos_ratio=3, center_variance=0.1, size_variance=0.2, device=DEVICE) optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) logging.info(f"Learning rate: {args.lr}, Base net learning rate: {base_net_lr}, " + f"Extra Layers learning rate: {extra_layers_lr}.") if args.scheduler == 'multi-step': logging.info("Uses MultiStepLR scheduler.") milestones = [int(v.strip()) for v in args.milestones.split(",")] scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1, last_epoch=last_epoch) elif args.scheduler == 'cosine': logging.info("Uses CosineAnnealingLR scheduler.") scheduler = CosineAnnealingLR(optimizer, args.t_max, last_epoch=last_epoch) else: logging.fatal(f"Unsupported Scheduler: {args.scheduler}.") parser.print_help(sys.stderr) sys.exit(1) logging.info(f"Start training from epoch {last_epoch + 1}.") for epoch in range(last_epoch + 1, args.num_epochs): scheduler.step() avg_loss = train(train_loader, net, criterion, optimizer, device=DEVICE, debug_steps=args.debug_steps, epoch=epoch) logs['train loss']=avg_loss if epoch % args.validation_epochs == 0 or epoch == args.num_epochs - 1: val_loss, val_regression_loss, val_classification_loss = test(val_loader, net, criterion, DEVICE) # logging.info( # f"Epoch: {epoch}, " + # f"Validation Loss: {val_loss:.4f}, " + # f"Validation Regression Loss {val_regression_loss:.4f}, " + # f"Validation Classification Loss: {val_classification_loss:.4f}" # ) model_path = os.path.join(args.checkpoint_folder, f"{args.net}-Epoch-{epoch}-Loss-{val_loss}.pth") net.save(model_path) logging.info(f"Saved model {model_path}") liveloss.update(logs) liveloss.draw() mymain()