-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
73 lines (52 loc) · 2.35 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import hydra
import torch
import torch.distributed as dist
from hydra.utils import instantiate
from timm.utils import accuracy
from data import create_dataloader
from logger import MetricLogger, SmoothedValue
from utils import init_distributed_mode
torch.backends.cudnn.benchmark = True
@torch.no_grad()
def evaluate(dataloader, model, n_iter):
criterion = torch.nn.CrossEntropyLoss().cuda()
model.eval()
evalloagger = MetricLogger(delimiter=' ')
evalloagger.add_meter('eval_loss', SmoothedValue(window_size=1, fmt='{global_avg:.4f}'))
evalloagger.add_meter('eval_acc1', SmoothedValue(window_size=1, fmt='{value:.3f}'))
header = 'Val:'
for data in evalloagger.log_every(dataloader, n_iter, 10, header):
images = data[0].cuda(non_blocking=True)
labels = data[1].cuda(non_blocking=True)
with torch.cuda.amp.autocast():
outputs = model(images)
eval_loss = criterion(outputs, labels)
acc1, _ = accuracy(outputs, labels, topk=(1, 5))
torch.cuda.synchronize()
batch_size = images.size(0)
evalloagger.update(eval_loss=eval_loss.item())
evalloagger.update(eval_acc1=acc1.item(), n=batch_size)
evalloagger.synchronize_between_processes()
print('* Acc@1: {top1.global_avg:.3f} Eval loss: {losses.global_avg:.3f}'.format(top1=evalloagger.eval_acc1, losses=evalloagger.eval_loss))
@hydra.main(config_path='./configs', config_name='test')
def main(cfg):
# Initialize torch.distributed using MPI
init_distributed_mode(cfg.dist)
# Create Dataloader
world_size = dist.get_world_size()
total_batch_size = cfg.data.loader.batch_size * world_size
valloader = create_dataloader(cfg.data, is_training=False)
n_val_iter = cfg.data.baseinfo.val_imgs // total_batch_size
# model
model = instantiate(cfg.model.arch, num_classes=cfg.data.baseinfo.num_classes)
print(f'Model[{cfg.model.arch.model_name}] was created')
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[cfg.dist.local_rank])
model_without_ddp = model.module
# Load trained weights
checkpoint = torch.load(cfg.ckpt, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
print(f'Checkpoint was loaded from {cfg.ckpt}\n')
evaluate(valloader, model, n_val_iter)
if __name__ == '__main__':
main()