Skip to content
This repository has been archived by the owner on Sep 29, 2023. It is now read-only.

Commit

Permalink
add CompareTools
Browse files Browse the repository at this point in the history
  • Loading branch information
RainyBlueSky committed May 29, 2018
1 parent c8fb55f commit b613f9a
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 38 deletions.
13 changes: 7 additions & 6 deletions layer/sst.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

class SST(nn.Module):
#new: combine two vgg_net
def __init__(self, phase, base, extras, selector, final_net):
def __init__(self, phase, base, extras, selector, final_net, use_gpu=config['cuda']):
super(SST, self).__init__()
self.phase = phase

Expand All @@ -49,7 +49,7 @@ def __init__(self, phase, base, extras, selector, final_net):
self.false_objects_column = None
self.false_objects_row = None
self.false_constant = config['false_constant']
self.use_gpu = config['cuda']
self.use_gpu = use_gpu

def forward(self, x_pre, x_next, l_pre , l_next, valid_pre=None, valid_next=None):
'''
Expand Down Expand Up @@ -101,14 +101,14 @@ def forward_feature_extracter(self, x, l):
s = list()

x = self.forward_vgg(x, self.vgg, s)
x = self.forward_extras(x, self.extras,s)
x = self.forward_extras(x, self.extras, s)
x = self.forward_selector_stacker1(s, l, self.selector)

return x

def get_similarity(self, image1, detection1, image2, detection2):
feature1 = self.forward_feature_extracter(image1, detection1)
feature2 = self.forward(image2, detection2)
feature2 = self.forward_feature_extracter(image2, detection2)
return self.forward_stacker_features(feature1, feature2, False)


Expand Down Expand Up @@ -360,7 +360,7 @@ def selector(vgg, extra_layers, batch_normal=True):

return vgg, extra_layers, selector_layers

def build_sst(phase, size=900):
def build_sst(phase, size=900, use_gpu=config['cuda']):
'''
create the SSJ Tracker Object
:return: ssj tracker object
Expand All @@ -382,5 +382,6 @@ def build_sst(phase, size=900):
vgg(base[str(size)], 3),
add_extras(extras[str(size)], 1024)
),
add_final(final[str(size)])
add_final(final[str(size)]),
use_gpu
)
69 changes: 37 additions & 32 deletions tools/compare_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
''')

parser = argparse.ArgumentParser(description='Single Shot Joint Tracker Train')
parser.add_argument('--image1', default="/home/ssm/ssj/dataset/MOT17/train/MOT17-11-FRCNN/img1/000001.jpg", help='Previous Image')
parser.add_argument('--image2', default="/home/ssm/ssj/dataset/MOT17/train/MOT17-11-FRCNN/img1/000029.jpg", help='Current Image')
parser.add_argument('--model_path', default="/home/ssm/ssj/weights/MOT17/weights0326-I50k-M80-G30-Continue0509v1.pth", help='sst net model path')
parser.add_argument('--cuda', default=True, help="use gpu or not")
parser.add_argument('--image1', default="C:/Users/00097307/Dropbox/tracking/images/000001.jpg", help='Previous Image')
parser.add_argument('--image2', default="C:/Users/00097307/Dropbox/tracking/images/000030.jpg", help='Current Image')
parser.add_argument('--model_path', default="C:/Users/00097307/Dropbox/tracking/pretrained/sst300_0712_9960.pth", help='sst net model path')
parser.add_argument('--cuda', default=False, help="use gpu or not")

args = parser.parse_args()

Expand All @@ -34,14 +34,15 @@ class CompareTools:
org_img = None
sst = None
cuda = False
resize_rate = 0.4

save_objects = {'rect' : [], 'text': []}

@staticmethod
def init(img1_path, img2_path, model_path, cuda):
print('start init >>>>>>>>>>>>>>')
if not os.path.exists(img1_path) or not os.path.exists(img2_path) or not os.path.exists(model_path):
raise ValueError("input parameter nto right")
raise ValueError("input parameter not right")

CompareTools.cuda = cuda

Expand All @@ -51,20 +52,20 @@ def init(img1_path, img2_path, model_path, cuda):
CompareTools.img2 = cv2.imread(img2_path)
CompareTools.img1_convert = CompareTools.convert_image(CompareTools.img1, CompareTools.cuda)
CompareTools.img2_convert = CompareTools.convert_image(CompareTools.img2, CompareTools.cuda)
CompareTools.img = np.concatenate([CompareTools.img1, CompareTools.img2], axis=1)
CompareTools.img = np.concatenate([CompareTools.img1, CompareTools.img2], axis=0)
CompareTools.img_org = np.copy(CompareTools.img)

print('load model...')
# load net
CompareTools.sst = build_sst('test', 900)
CompareTools.sst = build_sst('test', 900, CompareTools.cuda)
if cuda:
cudnn.benchmark = True
CompareTools.sst.load_state_dict(
torch.load(model_path)
)
CompareTools.sst = CompareTools.sst.cuda()
else:
CompareTools.sst.load_state_dict(torch.load(model_path, map_location='cpu'))
CompareTools.sst.load_state_dict(torch.load(model_path))

print('finish init <<<<<<<<<<<<')
@staticmethod
Expand All @@ -85,7 +86,8 @@ def convert_image(image, cuda):

@staticmethod
def convert_boxes(boxes):
center = (2 * boxes[:, 0:2] + boxes[:, 2:4]) - 1.0
boxes = np.array(boxes)
center = (boxes[:, 0:2] + boxes[:, 2:4]) - 1.0
center = torch.from_numpy(center.astype(float)).float()
center.unsqueeze_(0)
center.unsqueeze_(2)
Expand All @@ -99,53 +101,57 @@ def convert_boxes(boxes):
def select_object(event, x, y, flag, param):
global ix, iy, drawing

color = tuple((np.random.rand(3) * 255).astype(int).tolist())

if event == cv2.EVENT_LBUTTONDOWN:
drawing = True
ix, iy = x, y

elif event == cv2.EVENT_MOUSEMOVE:
if drawing == True:
cv2.rectangle(CompareTools.img, (ix, iy), (x, y), (0, 255, 0), -1)
cv2.rectangle(CompareTools.img, (ix, iy), (x, y), color, 2)

elif event == cv2.EVENT_LBUTTONUP:
drawing = False
cv2.rectangle(CompareTools.img, (ix, iy), (x, y), (0, 255, 0), -1)
cv2.rectangle(CompareTools.img, (ix, iy), (x, y), color, 2)

CompareTools.save_objects['rect'] += [(
(ix, iy),
(x,y),
str(len(CompareTools.save_objects['rect']))
(x, y),
str(len(CompareTools.save_objects['rect'])),
tuple((np.random.rand(3) * 255).astype(int).tolist())
)]

@staticmethod
def draw(img):
boxes = CompareTools.save_objects['rect']
for b in boxes:
start = b[0]
end = b[1]
start = (int(b[0][0] / CompareTools.resize_rate), int(b[0][1] / CompareTools.resize_rate))
end = (int(b[1][0] / CompareTools.resize_rate), int(b[1][1] / CompareTools.resize_rate))
text = b[2]
cv2.rectangle(img, start, end, (0, 255, 0), -1)
cv2.putText(img, text, start, cv2.CV_FONT_HERSHEY_SIMPLEX, 2, 255)
cv2.rectangle(img, start, end, b[3], 3)
cv2.putText(img, text, start, cv2.FONT_HERSHEY_SIMPLEX, 1, b[3], 2)
return img

@staticmethod
def get_similarity():
h, w, _ = CompareTools.img1.shape
boxes = [list(b[0])+list(b[1]) for b in CompareTools.save_objects['rect']]

def convert_box(x):
x[0] /= float(w)
x[2] /= float(w)
x[1] /= float(h)
x[3] /= float(h)
x[0] /= float(w) * CompareTools.resize_rate
x[2] /= float(w) * CompareTools.resize_rate
x[1] /= float(h) * CompareTools.resize_rate
x[3] /= float(h) * CompareTools.resize_rate
return x

boxes = list(map(convert_box, boxes))
boxes1 = [b for b in boxes if b[0] >= 1]
boxes2 = [b for b in boxes if b[0] < 1]
boxes1 = [b for b in boxes if b[1] < 1]
boxes2 = [[b[0], b[1]-1, b[2], b[3]-1] for b in boxes if b[1] >= 1]

boxes1 = CompareTools.convert_boxes(boxes1)
boxes2 = CompareTools.convert_boxes(boxes2)

return CompareTools.get_similarity(CompareTools.img1_convert, boxes1, CompareTools.img2_convert, boxes2)

return CompareTools.sst.get_similarity(CompareTools.img1_convert, boxes1, CompareTools.img2_convert, boxes2)

@staticmethod
def run():
Expand All @@ -155,30 +161,29 @@ def run():
'D': delete all the rectangles
'c': calculate the similarity matrix
''')
h, w, _ = CompareTools.img.shape
# start draw
title = "images(left is previous, right is current)"
cv2.namedWindow(title)
cv2.setMouseCallback(title, CompareTools.select_object)
cv2.imshow(title, CompareTools.img_org)

while(1):
CompareTools.img = CompareTools.draw(np.copy(CompareTools.img_org))
CompareTools.img = cv2.resize(CompareTools.img, (int(CompareTools.resize_rate*w), int(CompareTools.resize_rate*h)))
cv2.imshow(title, CompareTools.img)
key = cv2.waitKey(20)

if key == 'd':
if key == 100: # 'd'
CompareTools.save_objects['rect'] = CompareTools.save_objects['rect'][:-1]
print('delete the latest box!')
elif key == 'D':
elif key == 68: # 'D'
CompareTools.save_objects['rect'] = []
print('delete all boxes')
elif key == 'c':
elif key == 99: # 'c'
print('start calculate the similarity!')
s = CompareTools.get_similarity()
print(s)
if key & 0xFF == 27:
break
print(key)


if __name__ == '__main__':
Expand Down

0 comments on commit b613f9a

Please sign in to comment.