-
Notifications
You must be signed in to change notification settings - Fork 37
/
test.py
129 lines (106 loc) · 3.89 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# System libs
import os
import time
# Numerical libs
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
import numpy as np
# Our libs
from data.sist_line import SISTLine
import data.transforms as tf
from models.lsd_test import LSDTestModule
from utils import AverageMeter, graph2line, draw_lines, draw_jucntions
# tensorboard
from tensorboardX import SummaryWriter
import torchvision.utils as vutils
import fire
import cv2
class LSD(object):
def __init__(
self,
# exp params
exp_name="u50_block",
# arch params
backbone="resnet50",
backbone_kwargs={},
dim_embedding=256,
feature_spatial_scale=0.25,
max_junctions=512,
junction_pooling_threshold=0.2,
junc_pooling_size=15,
block_inference_size=64,
# data params
img_size=416,
gpus=[0,],
resume_epoch="latest",
# vis params
vis_junc_th=0.3,
vis_line_th=0.3
):
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(c) for c in gpus)
self.is_cuda = bool(gpus)
self.model = LSDTestModule(
backbone=backbone,
dim_embedding=dim_embedding,
backbone_kwargs=backbone_kwargs,
junction_pooling_threshold=junction_pooling_threshold,
max_junctions=max_junctions,
feature_spatial_scale=feature_spatial_scale,
junction_pooling_size=junc_pooling_size,
)
self.exp_name = exp_name
os.makedirs(os.path.join("log", exp_name), exist_ok=True)
os.makedirs(os.path.join("ckpt", exp_name), exist_ok=True)
self.writer = SummaryWriter(log_dir=os.path.join("log", exp_name))
# checkpoints
self.states = dict(
last_epoch=-1,
elapsed_time=0,
state_dict=None
)
if resume_epoch and os.path.isfile(os.path.join("ckpt", exp_name, f"train_states_{resume_epoch}.pth")):
states = torch.load(
os.path.join("ckpt", exp_name, f"train_states_{resume_epoch}.pth"))
print(f"resume traning from epoch {states['last_epoch']}")
self.model.load_state_dict(states["state_dict"])
self.states.update(states)
self.vis_junc_th = vis_junc_th
self.vis_line_th = vis_line_th
self.block_size = block_inference_size
self.max_junctions = max_junctions
self.img_size = img_size
def end(self):
self.writer.close()
return "command queue finished."
def test(self, path_to_image):
# main loop
torch.set_grad_enabled(False)
print(f"test for image: {path_to_image}", flush=True)
if self.is_cuda:
model = self.model.cuda().eval()
else:
model = self.model.eval()
img = cv2.imread(path_to_image)
img = cv2.resize(img, (self.img_size, self.img_size))
img_reverse = img[..., [2, 1, 0]]
img = torch.from_numpy(img_reverse).float().permute(2, 0, 1).unsqueeze(0)
if self.is_cuda:
img = img.cuda()
# measure elapsed time
junc_pred, heatmap_pred, adj_mtx_pred = model(img)
# visualize eval
img = img.cpu().numpy()
junctions_pred = junc_pred.cpu().numpy()
adj_mtx = adj_mtx_pred.cpu().numpy()
img_with_junc = draw_jucntions(img, junctions_pred)
img_with_junc = img_with_junc[0].numpy()[None]
img_with_junc = img_with_junc[:, ::-1, :, :]
lines_pred, score_pred = graph2line(junctions_pred, adj_mtx)
vis_line_pred = draw_lines(img_with_junc, lines_pred, score_pred)[0]
vis_line_pred = vis_line_pred.permute(1, 2, 0).numpy()
cv2.imshow("result", vis_line_pred)
if __name__ == "__main__":
fire.Fire(LSD)
# trainer = LSDTrainer().train(lr=1.)