Skip to content

Commit

Permalink
Update Eval
Browse files Browse the repository at this point in the history
  • Loading branch information
toandaominh1997 committed Jan 3, 2020
1 parent a586ffd commit 8a9ec2f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
2 changes: 1 addition & 1 deletion datasets/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_augumentation(phase, width=512, height=512, min_area=0., min_visibility=
albu.HorizontalFlip(p=0.5),
albu.VerticalFlip(p=0.5),
])
if(phase == 'test'):
if(phase == 'test' or phase=='valid'):
list_transforms.extend([
albu.Resize(height=height, width=width)
])
Expand Down
31 changes: 18 additions & 13 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from torch.utils.data import DataLoader
from models.efficientdet import EfficientDet
from utils import EFFICIENTDET


parser = argparse.ArgumentParser(
Expand All @@ -21,7 +22,7 @@
help='Choose model for training')
parser.add_argument('-t', '--threshold', default=0.5,
type=float, help='Visualization threshold')
parser.add_argument('--weights', default='./weights/checkpoint_efficientdet-d0_154.pth', type=str,
parser.add_argument('--weights', default='./weights/checkpoint_VOC_efficientdet-d1_20.pth', type=str,
help='Checkpoint state_dict file to resume training from')
parser.add_argument('--batch_size', default=32, type=int,
help='Batch size for training')
Expand Down Expand Up @@ -49,28 +50,32 @@ def prepare_device(device):
return device, list_ids


if(args.weights is not None):
resume_path = str(args.weights)
print("Loading checkpoint: {} ...".format(resume_path))
checkpoint = torch.load(
args.weights, map_location=lambda storage, loc: storage)
args.num_class = checkpoint['num_class']
args.network = checkpoint['network']
model = EfficientDet(num_classes=args.num_class, network=args.network, is_training=False)
model.load_state_dict(checkpoint['state_dict'])
device, device_ids = prepare_device(args.device)

if(args.dataset == 'VOC'):
valid_dataset = VOCDetection(root=args.dataset_root,
transform=get_augumentation(phase='valid'))
transform=get_augumentation(phase='valid', width=EFFICIENTDET[args.network]['input_size'], height=EFFICIENTDET[args.network]['input_size']))
elif(args.dataset == 'COCO'):
valid_dataset = COCODetection(root=args.dataset_root,
transform=get_augumentation(phase='valid'))
transform=get_augumentation(phase='valid', width=EFFICIENTDET[args.network]['input_size'], height=EFFICIENTDET[args.network]['input_size']))

valid_dataloader = DataLoader(valid_dataset,
batch_size=1,
num_workers=args.num_worker,
shuffle=False,
collate_fn=detection_collate,
pin_memory=False)
if(args.weights is not None):
resume_path = str(args.weights)
print("Loading checkpoint: {} ...".format(resume_path))
checkpoint = torch.load(
args.weights, map_location=lambda storage, loc: storage)
num_class = checkpoint['num_class']
network = checkpoint['network']
model = EfficientDet(num_classes=num_class, network=network, is_training=False)
model.load_state_dict(checkpoint['state_dict'])
device, device_ids = prepare_device(args.device)


model = model.to(device)
if(len(device_ids) > 1):
model = torch.nn.DataParallel(model, device_ids=device_ids)
Expand Down

0 comments on commit 8a9ec2f

Please sign in to comment.