Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About linking strategy for action cube #21

Closed
dagongji10 opened this issue Mar 2, 2020 · 13 comments
Closed

About linking strategy for action cube #21

dagongji10 opened this issue Mar 2, 2020 · 13 comments

Comments

@dagongji10
Copy link

Thanks for sharing the nice work!
I modify the script 'train.py' and use it as a demo for inference. Now I can get frame-level result, but I still have some problem:

  1. Where is the linking strategy that your paper introduced in chapter3.3? I can't find it in code.
  2. What does the action-cube mean? What is it used for? I see it as the start and end time of the action instance, is that right?
@abhigoku10
Copy link

@dagongji10 cna you share your inference code ??

@dagongji10
Copy link
Author

dagongji10 commented Mar 2, 2020

@abhigoku10

The code comes from train.py, hope it can be useful for you.

from __future__ import print_function
import sys, os, time
import random
import math
import cv2 as cv
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torchvision import datasets, transforms

import dataset
from opts import parse_opts
from utils import *
from cfg import parse_cfg
from region_loss import RegionLoss

from model import YOWO


def bbox_iou(box1, box2, x1y1x2y2=True):
	if x1y1x2y2:
		mx = min(box1[0], box2[0])
		Mx = max(box1[2], box2[2])
		my = min(box1[1], box2[1])
		My = max(box1[3], box2[3])
		w1 = box1[2] - box1[0]
		h1 = box1[3] - box1[1]
		w2 = box2[2] - box2[0]
		h2 = box2[3] - box2[1]
	else:
		mx = min(float(box1[0] - box1[2] / 2.0), float(box2[0] - box2[2] / 2.0))
		Mx = max(float(box1[0] + box1[2] / 2.0), float(box2[0] + box2[2] / 2.0))
		my = min(float(box1[1] - box1[3] / 2.0), float(box2[1] - box2[3] / 2.0))
		My = max(float(box1[1] + box1[3] / 2.0), float(box2[1] + box2[3] / 2.0))
		w1 = box1[2]
		h1 = box1[3]
		w2 = box2[2]
		h2 = box2[3]
	uw = Mx - mx
	uh = My - my
	cw = w1 + w2 - uw
	ch = h1 + h2 - uh
	carea = 0
	if cw <= 0 or ch <= 0:
		return 0.0

	area1 = w1 * h1
	area2 = w2 * h2
	carea = cw * ch
	uarea = area1 + area2 - carea
	return carea / uarea


def nms(boxes, nms_thresh):
	if len(boxes) == 0:
		return boxes

	det_confs = torch.zeros(len(boxes))
	for i in range(len(boxes)):
		det_confs[i] = 1 - boxes[i][4]

	_, sortIds = torch.sort(det_confs)
	out_boxes = []
	for i in range(len(boxes)):
		box_i = boxes[sortIds[i]]
		if box_i[4] > 0:
			out_boxes.append(box_i)
			for j in range(i + 1, len(boxes)):
				box_j = boxes[sortIds[j]]
				if bbox_iou(box_i, box_j, x1y1x2y2=True) > nms_thresh:
					box_j[4] = 0
	return out_boxes


def get_config():
	opt = parse_opts()  # Training settings
	dataset_use = opt.dataset  # which dataset to use
	datacfg = opt.data_cfg  # path for dataset of training and validation, e.g: cfg/ucf24.data
	cfgfile = opt.cfg_file  # path for cfg file, e.g: cfg/ucf24.cfg
	assert dataset_use == 'ucf101-24' or dataset_use == 'jhmdb-21', 'invalid dataset'

	# loss parameters
	loss_options = parse_cfg(cfgfile)[1]
	region_loss = RegionLoss()
	anchors = loss_options['anchors'].split(',')
	region_loss.anchors = [float(i) for i in anchors]
	region_loss.num_classes = int(loss_options['classes'])
	region_loss.num_anchors = int(loss_options['num'])

	return opt, region_loss


def load_model(opt, pretrained_path):
	seed = int(time.time())
	use_cuda = True
	gpus = '0'
	torch.manual_seed(seed)
	if use_cuda:
		os.environ['CUDA_VISIBLE_DEVICES'] = gpus
	torch.cuda.manual_seed(seed)

	# Create model
	model = YOWO(opt)
	model = model.cuda()
	model = nn.DataParallel(model, device_ids=None)  # in multi-gpu case
	model.seen = 0

	checkpoint = torch.load(pretrained_path)
	epoch = checkpoint['epoch']
	fscore = checkpoint['fscore']
	model.load_state_dict(checkpoint['state_dict'])

	return model, epoch, fscore


def infer(model, data, region_loss):
	num_classes = region_loss.num_classes
	anchors = region_loss.anchors
	num_anchors = region_loss.num_anchors
	conf_thresh_valid = 0.005
	nms_thresh = 0.4

	model.eval()

	data = data.cuda()
	res = []
	with torch.no_grad():
		output = model(data).data
		all_boxes = get_region_boxes(output, conf_thresh_valid, num_classes, anchors, num_anchors, 0, 1)

		for i in range(output.size(0)):
			boxes = all_boxes[i]
			boxes = nms(boxes, nms_thresh)

			for box in boxes:
				x1 = round(float(box[0] - box[2] / 2.0) * 320.0)
				y1 = round(float(box[1] - box[3] / 2.0) * 240.0)
				x2 = round(float(box[0] + box[2] / 2.0) * 320.0)
				y2 = round(float(box[1] + box[3] / 2.0) * 240.0)
				det_conf = float(box[4])

				for j in range((len(box) - 5) // 2):
					cls_conf = float(box[5 + 2 * j].item())
					if type(box[6 + 2 * j]) == torch.Tensor:
						cls_id = int(box[6 + 2 * j].item())
					else:
						cls_id = int(box[6 + 2 * j])
					prob = det_conf * cls_conf
					res.append(str(int(box[6]) + 1) + ' ' + str(prob) + ' ' + str(x1) + ' ' + str(y1) + ' ' + str(x2) + ' ' + str(y2))
	return res


def pre_process_image(images, clip_duration, input_shape=(224, 224)):
	# resize to (224,224)
	clip = [img.resize(input_shape) for img in images]
	# numpy to tensor
	op_transforms = transforms.Compose([transforms.ToTensor()])
	clip = [op_transforms(img) for img in clip]
	# change dimension
	clip = torch.cat(clip, 0).view((clip_duration, -1) + input_shape).permute(1, 0, 2, 3)
	# expand dimmension to (batch_size, channel, duration, w, h)
	clip = clip.unsqueeze(0)

	return clip


def post_process(images, bboxs):
	jhmdb_cls = ['', 'brush_hair', 'catch', 'clap', 'climb_stairs', 'golf', 'jump', 'kick_ball', 'pick', 'pour', 'pullup',
				 'push', 'run', 'shoot_ball', 'shoot_bow', 'shoot_gun', 'sit', 'stand', 'swing_baseball', 'throw', 'walk', 'wave']
	conf_thresh = 0.1
	nms_thresh = 0.4

	proposals = []
	for i in range(len(bboxs)):
		line = bboxs[i]
		cls, score, x1, y1, x2, y2 = line.strip().split(' ')

		if float(score) < conf_thresh:
			continue
		proposals.append([int(int(x1) * 2.25), int(int(y1) * 2.25), int(int(x2) * 2.25), int(int(y2) * 2.25), float(score), int(cls)])

	proposals = nms(proposals, nms_thresh)

	image = cv.cvtColor(np.asarray(images[-1], dtype=np.uint8), cv.COLOR_RGB2BGR)
	for proposal in proposals:
		x1, y1, x2, y2, score, cls = proposal
		cv.rectangle(image, (x1, y1), (x2, y2), (255, 255, 255), 2, cv.LINE_4)

		text = '[{:.2f}] {}'.format(score, jhmdb_cls[cls])
		font_type = 5
		font_size = 1
		line_szie = 1
		textsize = cv.getTextSize(text, font_type, font_size, line_szie)
		y1 = y1 - 10
		p1 = (x1, y1 - textsize[0][1])
		p2 = (x1 + textsize[0][0], y1 + textsize[1])
		cv.rectangle(image, p1, p2, (180, 238, 180), -1)
		cv.putText(image, text, (x1, y1), font_type, font_size, (255, 255, 255), line_szie, 1)

	return image


if __name__ == '__main__':
	duration = 16
	num_sample = 8
	pretrained_path = 'model_zoo/JHMDB-21/yowo_jhmdb-21_16f_best_fmap_07668.pth'
	video_path = 'demo_data/after-cali.avi'

	fourcc = cv.VideoWriter_fourcc(*'XVID')
	out = cv.VideoWriter('output.avi', fourcc, 30.0, (720, 540))

	# load parameters
	opt, region_loss = get_config()
	# load model
	model, epoch, fscore = load_model(opt, pretrained_path)
	# read video
	video = cv.VideoCapture(video_path)

	stack = []
	n = 0
	t0 = time.time()
	while (True):
		ret, frame = video.read()
		if not ret:
			break
		n += 1

		frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
		frame = Image.fromarray(np.uint8(frame))
		stack.append(frame)

		if len(stack) == duration:
			# 1. preprocess images
			input_data = pre_process_image(stack, duration)
			# 2. YOWO detect action tube
			output_data = infer(model, input_data, region_loss)
			# 3. draw result to images
			result_img = post_process(stack, output_data)
			# 4. write to video
			out.write(result_img)

			for i in range(num_sample):
				stack.pop(0)

			t = time.time() - t0
			print('cost {:.2f}, {:.2f} FPS'.format(t, num_sample / t))
			t0 = time.time()

	out.release()
	video.release()

@wei-tim
Copy link
Owner

wei-tim commented Mar 3, 2020

@dagongji10

Thanks for your interest!

  1. This part can be found in video_map.py and details(linking strategy) are in eval_results.py
  2. Action tube is a concept introduced in the paper "Finding Action Tubes".(https://arxiv.org/abs/1411.6031) Your understanding is correct, it links the detections among frames and generates path for each reasonable action.

@dagongji10
Copy link
Author

dagongji10 commented Mar 6, 2020

@wei-tim Thanks for your reply! Now I can understand action tube, some more questions may need your help.
I made a video that each frame is 720*540. There are 3 people in the video, each of them do some action(walk, sit, stand, wave...) of JHMDB-21. I found that only people in the center of the screen can be detected. So I crop the frame where cannot be detected as a 320*240 input from 720*540 image, action detect is normal.
image
How can I refine YOWO's performance in screen edge action detection?

@wei-tim
Copy link
Owner

wei-tim commented Mar 10, 2020

@dagongji10
The training samples are 224*224, and in the dataset J-HMDB-21, the actors are 'relative' large, i.e. actors occupy large regions within the images. Another point is that this project is based on yolov2, which isn't so good at capturing tiny persons.

Hope this answer can help solve your problem ;-)

@dagongji10
Copy link
Author

@wei-tim
Thanks for your help. Now I can use YOWO more easily. I will keep watching on this nice work and look forward to the model on AVA-dataset.

@usamahjundia
Copy link

@abhigoku10

The code comes from train.py, hope it can be useful for you.

from __future__ import print_function
import sys, os, time
import random
import math
import cv2 as cv
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torchvision import datasets, transforms

import dataset
from opts import parse_opts
from utils import *
from cfg import parse_cfg
from region_loss import RegionLoss

from model import YOWO


def bbox_iou(box1, box2, x1y1x2y2=True):
	if x1y1x2y2:
		mx = min(box1[0], box2[0])
		Mx = max(box1[2], box2[2])
		my = min(box1[1], box2[1])
		My = max(box1[3], box2[3])
		w1 = box1[2] - box1[0]
		h1 = box1[3] - box1[1]
		w2 = box2[2] - box2[0]
		h2 = box2[3] - box2[1]
	else:
		mx = min(float(box1[0] - box1[2] / 2.0), float(box2[0] - box2[2] / 2.0))
		Mx = max(float(box1[0] + box1[2] / 2.0), float(box2[0] + box2[2] / 2.0))
		my = min(float(box1[1] - box1[3] / 2.0), float(box2[1] - box2[3] / 2.0))
		My = max(float(box1[1] + box1[3] / 2.0), float(box2[1] + box2[3] / 2.0))
		w1 = box1[2]
		h1 = box1[3]
		w2 = box2[2]
		h2 = box2[3]
	uw = Mx - mx
	uh = My - my
	cw = w1 + w2 - uw
	ch = h1 + h2 - uh
	carea = 0
	if cw <= 0 or ch <= 0:
		return 0.0

	area1 = w1 * h1
	area2 = w2 * h2
	carea = cw * ch
	uarea = area1 + area2 - carea
	return carea / uarea


def nms(boxes, nms_thresh):
	if len(boxes) == 0:
		return boxes

	det_confs = torch.zeros(len(boxes))
	for i in range(len(boxes)):
		det_confs[i] = 1 - boxes[i][4]

	_, sortIds = torch.sort(det_confs)
	out_boxes = []
	for i in range(len(boxes)):
		box_i = boxes[sortIds[i]]
		if box_i[4] > 0:
			out_boxes.append(box_i)
			for j in range(i + 1, len(boxes)):
				box_j = boxes[sortIds[j]]
				if bbox_iou(box_i, box_j, x1y1x2y2=True) > nms_thresh:
					box_j[4] = 0
	return out_boxes


def get_config():
	opt = parse_opts()  # Training settings
	dataset_use = opt.dataset  # which dataset to use
	datacfg = opt.data_cfg  # path for dataset of training and validation, e.g: cfg/ucf24.data
	cfgfile = opt.cfg_file  # path for cfg file, e.g: cfg/ucf24.cfg
	assert dataset_use == 'ucf101-24' or dataset_use == 'jhmdb-21', 'invalid dataset'

	# loss parameters
	loss_options = parse_cfg(cfgfile)[1]
	region_loss = RegionLoss()
	anchors = loss_options['anchors'].split(',')
	region_loss.anchors = [float(i) for i in anchors]
	region_loss.num_classes = int(loss_options['classes'])
	region_loss.num_anchors = int(loss_options['num'])

	return opt, region_loss


def load_model(opt, pretrained_path):
	seed = int(time.time())
	use_cuda = True
	gpus = '0'
	torch.manual_seed(seed)
	if use_cuda:
		os.environ['CUDA_VISIBLE_DEVICES'] = gpus
	torch.cuda.manual_seed(seed)

	# Create model
	model = YOWO(opt)
	model = model.cuda()
	model = nn.DataParallel(model, device_ids=None)  # in multi-gpu case
	model.seen = 0

	checkpoint = torch.load(pretrained_path)
	epoch = checkpoint['epoch']
	fscore = checkpoint['fscore']
	model.load_state_dict(checkpoint['state_dict'])

	return model, epoch, fscore


def infer(model, data, region_loss):
	num_classes = region_loss.num_classes
	anchors = region_loss.anchors
	num_anchors = region_loss.num_anchors
	conf_thresh_valid = 0.005
	nms_thresh = 0.4

	model.eval()

	data = data.cuda()
	res = []
	with torch.no_grad():
		output = model(data).data
		all_boxes = get_region_boxes(output, conf_thresh_valid, num_classes, anchors, num_anchors, 0, 1)

		for i in range(output.size(0)):
			boxes = all_boxes[i]
			boxes = nms(boxes, nms_thresh)

			for box in boxes:
				x1 = round(float(box[0] - box[2] / 2.0) * 320.0)
				y1 = round(float(box[1] - box[3] / 2.0) * 240.0)
				x2 = round(float(box[0] + box[2] / 2.0) * 320.0)
				y2 = round(float(box[1] + box[3] / 2.0) * 240.0)
				det_conf = float(box[4])

				for j in range((len(box) - 5) // 2):
					cls_conf = float(box[5 + 2 * j].item())
					if type(box[6 + 2 * j]) == torch.Tensor:
						cls_id = int(box[6 + 2 * j].item())
					else:
						cls_id = int(box[6 + 2 * j])
					prob = det_conf * cls_conf
					res.append(str(int(box[6]) + 1) + ' ' + str(prob) + ' ' + str(x1) + ' ' + str(y1) + ' ' + str(x2) + ' ' + str(y2))
	return res


def pre_process_image(images, clip_duration, input_shape=(224, 224)):
	# resize to (224,224)
	clip = [img.resize(input_shape) for img in images]
	# numpy to tensor
	op_transforms = transforms.Compose([transforms.ToTensor()])
	clip = [op_transforms(img) for img in clip]
	# change dimension
	clip = torch.cat(clip, 0).view((clip_duration, -1) + input_shape).permute(1, 0, 2, 3)
	# expand dimmension to (batch_size, channel, duration, w, h)
	clip = clip.unsqueeze(0)

	return clip


def post_process(images, bboxs):
	jhmdb_cls = ['', 'brush_hair', 'catch', 'clap', 'climb_stairs', 'golf', 'jump', 'kick_ball', 'pick', 'pour', 'pullup',
				 'push', 'run', 'shoot_ball', 'shoot_bow', 'shoot_gun', 'sit', 'stand', 'swing_baseball', 'throw', 'walk', 'wave']
	conf_thresh = 0.1
	nms_thresh = 0.4

	proposals = []
	for i in range(len(bboxs)):
		line = bboxs[i]
		cls, score, x1, y1, x2, y2 = line.strip().split(' ')

		if float(score) < conf_thresh:
			continue
		proposals.append([int(int(x1) * 2.25), int(int(y1) * 2.25), int(int(x2) * 2.25), int(int(y2) * 2.25), float(score), int(cls)])

	proposals = nms(proposals, nms_thresh)

	image = cv.cvtColor(np.asarray(images[-1], dtype=np.uint8), cv.COLOR_RGB2BGR)
	for proposal in proposals:
		x1, y1, x2, y2, score, cls = proposal
		cv.rectangle(image, (x1, y1), (x2, y2), (255, 255, 255), 2, cv.LINE_4)

		text = '[{:.2f}] {}'.format(score, jhmdb_cls[cls])
		font_type = 5
		font_size = 1
		line_szie = 1
		textsize = cv.getTextSize(text, font_type, font_size, line_szie)
		y1 = y1 - 10
		p1 = (x1, y1 - textsize[0][1])
		p2 = (x1 + textsize[0][0], y1 + textsize[1])
		cv.rectangle(image, p1, p2, (180, 238, 180), -1)
		cv.putText(image, text, (x1, y1), font_type, font_size, (255, 255, 255), line_szie, 1)

	return image


if __name__ == '__main__':
	duration = 16
	num_sample = 8
	pretrained_path = 'model_zoo/JHMDB-21/yowo_jhmdb-21_16f_best_fmap_07668.pth'
	video_path = 'demo_data/after-cali.avi'

	fourcc = cv.VideoWriter_fourcc(*'XVID')
	out = cv.VideoWriter('output.avi', fourcc, 30.0, (720, 540))

	# load parameters
	opt, region_loss = get_config()
	# load model
	model, epoch, fscore = load_model(opt, pretrained_path)
	# read video
	video = cv.VideoCapture(video_path)

	stack = []
	n = 0
	t0 = time.time()
	while (True):
		ret, frame = video.read()
		if not ret:
			break
		n += 1

		frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
		frame = Image.fromarray(np.uint8(frame))
		stack.append(frame)

		if len(stack) == duration:
			# 1. preprocess images
			input_data = pre_process_image(stack, duration)
			# 2. YOWO detect action tube
			output_data = infer(model, input_data, region_loss)
			# 3. draw result to images
			result_img = post_process(stack, output_data)
			# 4. write to video
			out.write(result_img)

			for i in range(num_sample):
				stack.pop(0)

			t = time.time() - t0
			print('cost {:.2f}, {:.2f} FPS'.format(t, num_sample / t))
			t0 = time.time()

	out.release()
	video.release()

Hi @dagongji10, im currently implementing a custom, vectorized postprocessing to alleviate some runtime cost and met with noisy results. Have you experienced the same with doing inference on a random video sampled from youtube?

Also i noticed 2 things in your code :

  1. You did NMS twice in infer and in postprocessing
  2. In postprocessing, you multiplied the coords by 2.25

What is the rationale behind both, if i may ask?

@dagongji10
Copy link
Author

@usamahjundia
YOWO can output much bbox with score, many bbox are overlapping, you can filt most of them according to score by NMS.

  1. The first NMS in infer is author set to filt most wrong detect. The second NMS in postprocessing is my personal needs.
  2. YOWO's input image size is default as 320*240, but my input is 720*540. So I multiplied 2.25 to resize bbox.

@usamahjundia
Copy link

@usamahjundia
YOWO can output much bbox with score, many bbox are overlapping, you can filt most of them according to score by NMS.

  1. The first NMS in infer is author set to filt most wrong detect. The second NMS in postprocessing is my personal needs.
  2. YOWO's input image size is default as 320240, but my input is 720540. So I multiplied 2.25 to resize bbox.

Thanks @dagongji10 ! i actually have managed to solve the issue myself, and now my postprocessing is now consistent with the results of the paper implementation. I referred to both your code and the paper's for comparison with my implementation and found a silly bug on my part, but now it is all fixed and fine.

If i may ask, have you managed to fiddle with the linking? does it still work well on just 2 frames? (for real-time cases)

@dagongji10
Copy link
Author

@usamahjundia I cann't get the speed effect as the paper said. So, I detected once every 8 frames, it cost 0.25s each detection.

@mqabbari
Copy link

@wei-tim First of all, thanks for the great work.

My question is also related to the action tube. I could understand how it works on video input, but how can I create action tube for a live camera source input? In other words, If I detect each 16 frames and only create bboxes for the last frame, how to set bounding boxes for past 15 frames?

@dagongji10 Thanks for the inference code. I think you are missing to output also the frames in between detections not to lose frames from the original video. Is my understanding correct?

I appreciate if you could also answer the above question regarding the action tube.

@cbiras
Copy link

cbiras commented Mar 16, 2021

Hi @mqabbari . Did you find how to draw bboxes for all the frames in the 16frame clip?

@lix4
Copy link

lix4 commented May 1, 2021

@wei-tim First of all, thanks for the great work.

My question is also related to the action tube. I could understand how it works on video input, but how can I create action tube for a live camera source input? In other words, If I detect each 16 frames and only create bboxes for the last frame, how to set bounding boxes for past 15 frames?

@dagongji10 Thanks for the inference code. I think you are missing to output also the frames in between detections not to lose frames from the original video. Is my understanding correct?

I appreciate if you could also answer the above question regarding the action tube.

I am also working to applying this model on real-time inference. I have the same question. Right now I just make my algorithm start to infer at frame 16.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants