#Referenced from https://github.com/NVIDIA-AI-IOT/trt_pose

First, let's load the JSON file which describes the human pose task.  This is in COCO format, it is the category descriptor pulled from the annotations file.  We modify the COCO category slightly, to add a neck keypoint.  We will use this task description JSON to create a topology tensor, which is an intermediate data structure that describes the part linkages, as well as which channels in the part affinity field each linkage corresponds to.

In [1]:
import json
import trt_pose.coco

with open('human_pose.json', 'r') as f:
    human_pose = json.load(f)

topology = trt_pose.coco.coco_category_to_topology(human_pose)

Next, we'll load our model.  Each model takes at least two parameters, *cmap_channels* and *paf_channels* corresponding to the number of heatmap channels
and part affinity field channels.  The number of part affinity field channels is 2x the number of links, because each link has a channel corresponding to the
x and y direction of the vector field for each link.

In [2]:
import trt_pose.models

num_parts = len(human_pose['keypoints'])
num_links = len(human_pose['skeleton'])

model = trt_pose.models.resnet18_baseline_att(num_parts, 2 * num_links).cuda().eval()

Next, let's load the model weights.  You will need to download these according to the table in the README.

In [3]:
import torch

MODEL_WEIGHTS = 'resnet18_baseline_att_224x224_A_epoch_249.pth'

model.load_state_dict(torch.load(MODEL_WEIGHTS))

In order to optimize with TensorRT using the python library *torch2trt* we'll also need to create some example data.  The dimensions
of this data should match the dimensions that the network was trained with.  Since we're using the resnet18 variant that was trained on
an input resolution of 224x224, we set the width and height to these dimensions.

In [4]:
WIDTH = 224
HEIGHT = 224

data = torch.zeros((1, 3, HEIGHT, WIDTH)).cuda()

Next, we'll use [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt) to optimize the model.  We'll enable fp16_mode to allow optimizations to use reduced half precision.

The optimized model may be saved so that we do not need to perform optimization again, we can just load the model.  Please note that TensorRT has device specific optimizations, so you can only use an optimized model on similar platforms.

In [5]:
import torch2trt

model_trt = torch2trt.torch2trt(model, [data], fp16_mode=True, max_workspace_size=1<<25)

In [6]:
OPTIMIZED_MODEL = 'resnet18_baseline_att_224x224_A_epoch_249_trt.pth'

torch.save(model_trt.state_dict(), OPTIMIZED_MODEL)

We could then load the saved model using *torch2trt* as follows.

In [7]:
from torch2trt import TRTModule

model_trt = TRTModule()
model_trt.load_state_dict(torch.load(OPTIMIZED_MODEL))

<All keys matched successfully>

We can benchmark the model in FPS with the following code

In [8]:
import time

t0 = time.time()
torch.cuda.current_stream().synchronize()
for i in range(50):
    y = model_trt(data)
torch.cuda.current_stream().synchronize()
t1 = time.time()

print(50.0 / (t1 - t0))

114.4670517618843


Next, let's define a function that will preprocess the image, which is originally in BGR8 / HWC format.

In [9]:
import cv2
import torchvision.transforms as transforms
import PIL.Image

mean = torch.Tensor([0.485, 0.456, 0.406]).cuda()
std = torch.Tensor([0.229, 0.224, 0.225]).cuda()
device = torch.device('cuda')

def preprocess(image):
    global device
    device = torch.device('cuda')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = PIL.Image.fromarray(image)
    image = transforms.functional.to_tensor(image).to(device)
    image.sub_(mean[:, None, None]).div_(std[:, None, None])
    return image[None, ...]

Next, we'll define two callable classes that will be used to parse the objects from the neural network, as well as draw the parsed objects on an image.

In [10]:
from trt_pose.draw_objects import DrawObjects
from trt_pose.parse_objects import ParseObjects

parse_objects = ParseObjects(topology)
draw_objects = DrawObjects(topology)

Assuming you're using NVIDIA Jetson, you can use the [jetcam](https://github.com/NVIDIA-AI-IOT/jetcam) package to create an easy to use camera that will produce images in BGR8/HWC format.

If you're not on Jetson, you may need to adapt the code below.

In [11]:
from jetcam.usb_camera import USBCamera
# from jetcam.csi_camera import CSICamera
from jetcam.utils import bgr8_to_jpeg

camera = USBCamera(width=WIDTH, height=HEIGHT, capture_fps=30)
# camera = CSICamera(width=WIDTH, height=HEIGHT, capture_fps=30)

camera.running = True

Connect to the joycontrol tcp server.

Skip the following cell if the joycontrol tcp server is not running yet.

In [12]:
import asyncio
loop = asyncio.get_event_loop()
tcpReader, tcpWriter = await asyncio.open_connection('127.0.0.1', 8080, loop=loop)

Handler for sending joycontrol command manually.

In [14]:
def on_button_clicked(b):
    global tcpWriter
    tcpWriter.write(b.tooltip.encode())

In [15]:
#toggle button state
def on_button_long_clicked(b):
    global tcpWriter
    if b.button_style == '':
        tcpWriter.write(b.tooltip.encode())
        b.button_style = 'success'
    else:
        tcpWriter.write(b.tooltip.upper().encode())
        b.button_style = ''


ipywidgets for user interactions.

Check the "Control with data" checkbox for sending the game controls from processTrtData to the joycontrol tcp server.

Next, we'll create a widget which will be used to display the camera feed with visualizations.

In [16]:
import ipywidgets
from IPython.display import display

image_w = ipywidgets.Image(format='jpeg')
display(image_w)

controllerCheckBox_w = ipywidgets.Checkbox(
    value=False,
    description='Control with data',
    disabled=False,
    indent=False
)
display(ipywidgets.HBox([controllerCheckBox_w]))

dataText_w = ipywidgets.Text("Value", description="Data", layout=ipywidgets.Layout(width='100%', height='30px'))
display(dataText_w)

upButton = ipywidgets.Button(tooltip='u',icon='arrow-up')
upButton.on_click(on_button_clicked)
downButton = ipywidgets.Button(tooltip='d',icon='arrow-down')
downButton.on_click(on_button_clicked)
leftButton = ipywidgets.Button(tooltip='l',icon='arrow-left')
leftButton.on_click(on_button_clicked)
rightButton = ipywidgets.Button(tooltip='r',icon='arrow-right')
rightButton.on_click(on_button_clicked)
aButton = ipywidgets.Button(description="A", tooltip='a')
aButton.on_click(on_button_clicked)
bButton = ipywidgets.Button(description="B", tooltip='b')
bButton.on_click(on_button_clicked)
xButton = ipywidgets.Button(description="X", tooltip='x')
xButton.on_click(on_button_clicked)
yButton = ipywidgets.Button(description="Y", tooltip='y')
yButton.on_click(on_button_clicked)
lButton = ipywidgets.Button(description="L", tooltip='L')
lButton.on_click(on_button_clicked)
rButton = ipywidgets.Button(description="R", tooltip='R')
rButton.on_click(on_button_clicked)
lrButton = ipywidgets.Button(description="L+R", tooltip='2')
lrButton.on_click(on_button_clicked)

upButtonLong = ipywidgets.Button(tooltip='m',icon='arrow-up')
upButtonLong.on_click(on_button_long_clicked)
downButtonLong = ipywidgets.Button(tooltip='n',icon='arrow-down')
downButtonLong.on_click(on_button_long_clicked)
leftButtonLong = ipywidgets.Button(tooltip='o',icon='arrow-left')
leftButtonLong.on_click(on_button_long_clicked)
rightButtonLong = ipywidgets.Button(tooltip='p',icon='arrow-right')
rightButtonLong.on_click(on_button_long_clicked)
aButtonLong = ipywidgets.Button(description="a", tooltip='e')
aButtonLong.on_click(on_button_long_clicked)
bButtonLong = ipywidgets.Button(description="b", tooltip='f')
bButtonLong.on_click(on_button_long_clicked)
xButtonLong = ipywidgets.Button(description="x", tooltip='g')
xButtonLong.on_click(on_button_long_clicked)
yButtonLong = ipywidgets.Button(description="y", tooltip='h')
yButtonLong.on_click(on_button_long_clicked)

targetingModeButton = ipywidgets.Button(description="T", tooltip='?')

display(ipywidgets.HBox([upButton, downButton, leftButton, rightButton, aButton, bButton, xButton, yButton, lButton, rButton, lrButton]))
display(ipywidgets.HBox([upButtonLong, downButtonLong, leftButtonLong, rightButtonLong, aButtonLong, bButtonLong, xButtonLong, yButtonLong, lButton, rButton, targetingModeButton]))


Image(value=b'', format='jpeg')

HBox(children=(Checkbox(value=False, description='Control with data', indent=False),))

Text(value='Value', description='Data', layout=Layout(height='30px', width='100%'))

HBox(children=(Button(icon='arrow-up', style=ButtonStyle(), tooltip='u'), Button(icon='arrow-down', style=Butt…

HBox(children=(Button(icon='arrow-up', style=ButtonStyle(), tooltip='m'), Button(icon='arrow-down', style=Butt…

Logging data in batch. For testing.

In [17]:
from collections import deque
logDataSize = 100
logDataQueue = deque(maxlen=logDataSize)
logDataCounter = 0

def logData(d):
    global logDataQueue
    global logDataCounter
    global dataText_w
    logDataQueue.append(d)
    logDataCounter = logDataCounter + 1
    if logDataCounter == logDataSize:
        logDataCounter = 0
        dataText_w.value = str(logDataQueue)

Major functions for transforming the trt data into game actions. 

In [18]:
import numpy as np

#flatten the normalized_peaks from torch.Size([1, 18, 100, 2]) to numpy([1, 36 + 1])
#18 peaks * 2 coordinates (y,x) + 1 timestamp
def getFlattenPeaks(object_counts, objects, normalized_peaks, timestamp):
    count = int(object_counts[0])
    flattenPeaks = np.zeros((count,objects[0][0].shape[0]*2 + 1)) #i.e. (1,36 + 1)
    for i in range(count):
        obj = objects[0][i]
        C = obj.shape[0] #18
        flattenPeaks[i][C*2] = timestamp #adding timestamp to the end
        for j in range(C):
            k = int(obj[j]) #obj index, i.e. k=0 for the 1st obj detected
            if k >= 0:
                    peak = normalized_peaks[0][j][k]
                    flattenPeaks[i][j*2] = peak[0]
                    flattenPeaks[i][j*2+1] = peak[1]
    return flattenPeaks

In [19]:
class MyPose:
    poseIndexNose = 0
    poseIndexLeftEye = 1
    poseIndexRightEye = 2
    poseIndexLeftEar = 3
    poseIndexRightEar = 4
    poseIndexLeftShoulder = 5
    poseIndexRightShoulder = 6
    poseIndexLeftElbow = 7
    poseIndexRightElbow = 8
    poseIndexLeftWrist = 9
    poseIndexRightWrist = 10
    poseIndexLeftHip = 11
    poseIndexRightHip = 12
    poseIndexLeftKnee = 13
    poseIndexRightKnee = 14
    poseIndexLeftAnkle = 15
    poseIndexRightAnkle = 16
    poseIndexNeck = 17
    poseIndexTimeStamp = 18
    flattenIndexTimeStamp = poseIndexTimeStamp * 2

    def xIndexOf(poseIndex):
        return poseIndex * 2 + 1
    
    def yIndexOf(poseIndex):
        return poseIndex * 2
    

In [20]:
class MyStatistic:
    absMaxVal = 0
    absMinVal = 0
    absDVSpeed = 0
    startingDVSign = 0
    isDVInOneSign = 0
    isAbsValIncreasing = 0
    avgVal = 0
    
    def __init__(self, absMaxVal, absDVSpeed, startingDVSign, isDVInOneSign, isAbsValIncreasing, avgVal, absMinVal):
        self.absMaxVal = absMaxVal
        self.absMinVal = absMinVal
        self.absDVSpeed = absDVSpeed
        self.startingDVSign = startingDVSign
        self.isDVInOneSign = isDVInOneSign
        self.isAbsValIncreasing = isAbsValIncreasing
        self.avgVal = avgVal


In [21]:
class MyStatisticOne:
    maxVal = 0
    minVal = 0
    avgVal = 0
    dVSpeed = 0
    
    def __init__(self, maxVal, minVal, avgVal, dVSpeed):
        self.maxVal = maxVal
        self.minVal = minVal
        self.avgVal = avgVal
        self.dVSpeed = dVSpeed


In [22]:
bodyHeight = 0

def getBodyHeight(flattenPeakQueue):
    obj = flattenPeakQueue[0]
    return obj[MyPose.yIndexOf(MyPose.poseIndexLeftAnkle)] - obj[MyPose.yIndexOf(MyPose.poseIndexNeck)]


In [23]:
def getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, flattenIndexA, flattenIndexB):
    absMaxVal = 0
    absMinVal = 0
    lastVal = 0
    lastT = 0
    dVSum = 0
    dTSum = 0
    startingDVSign = 0
    isDVInOneSign = True
    lastAbsVal = 0
    isAbsValIncreasing = True
    valSum = 0
    numOfElements = 0
    for i, obj in enumerate(flattenPeakQueue):
        aVal = obj[flattenIndexA]
        bVal = obj[flattenIndexB]
        t = obj[MyPose.flattenIndexTimeStamp]
        val = bVal - aVal
        valSum = valSum + val
        numOfElements = numOfElements + 1
        absVal = abs(val)
        if i == 0:
            absMaxVal = absVal
            absMinVal = absVal
        else:
            if absVal > absMaxVal: absMaxVal = absVal
            if absVal < absMinVal: absMinVal = absVal
            dV = val - lastVal
            dVSign = np.sign(dV)
            if startingDVSign == 0:
                startingDVSign = dVSign
            else:
                if dVSign != 0 and startingDVSign != dVSign: isDVInOneSign = False
            if absVal < lastAbsVal: isAbsValIncreasing = False
            dT = t - lastT
            dVSum = dVSum + dV
            dTSum = dTSum + dT
        lastVal = val
        lastT = t
        lastAbsVal = absVal
    if dTSum != 0:
        dVSpeed = dVSum / dTSum
    else:
        dVSpeed = 0
    avgVal = valSum / numOfElements
    return MyStatistic(absMaxVal, abs(dVSpeed), startingDVSign, isDVInOneSign, isAbsValIncreasing, avgVal, absMinVal)


In [24]:
def getStatisticsOfPeakInQueue(flattenPeakQueue, flattenIndex):
    #global dataText_w
    minVal = 0
    maxVal = 0
    valSum = 0
    numOfElements = 0
    dVSum = 0
    dTSum = 0
    lastVal = 0
    lastT = 0
    for i, obj in enumerate(flattenPeakQueue):
        t = obj[MyPose.flattenIndexTimeStamp]
        val = obj[flattenIndex]
        valSum = valSum + val
        numOfElements = numOfElements + 1
        if i == 0:
            minVal = val
            maxVal = val
        else:
            if val > maxVal:
                maxVal = val
            if val < minVal:
                minVal = val            
            dV = val - lastVal
            dT = t - lastT
            dVSum = dVSum + dV
            dTSum = dTSum + dT
        lastVal = val
        lastT = t
    dVSpeed = dVSum / dTSum
    avgVal = valSum / numOfElements
    return MyStatisticOne(maxVal, minVal, avgVal, dVSpeed)


In [25]:
isRunningSpeedScalingFactor=0.4

def isRunningWithAnkle(flattenPeakQueue, hipIndex, ankleIndex):
    global bodyHeight
    global isRunningSpeedScalingFactor
    s = getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, hipIndex, ankleIndex)
    return s.absMaxVal > bodyHeight*0.1 and s.absDVSpeed > bodyHeight*isRunningSpeedScalingFactor

def getRunningDirWithKnee(flattenPeakQueue):
    sR = getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, MyPose.xIndexOf(MyPose.poseIndexRightHip), MyPose.xIndexOf(MyPose.poseIndexRightKnee))
    sL = getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, MyPose.xIndexOf(MyPose.poseIndexLeftHip), MyPose.xIndexOf(MyPose.poseIndexLeftKnee))
    if sR.absMaxVal > sL.absMaxVal:
        s = sR
    else:
        s = sL
    if s.avgVal < 0:
        return "p"
    else:
        return "o"
    return "?"

def isRunning(flattenPeakQueue):
    return isRunningWithAnkle(flattenPeakQueue, MyPose.yIndexOf(MyPose.poseIndexRightHip), MyPose.yIndexOf(MyPose.poseIndexRightAnkle)) or isRunningWithAnkle(flattenPeakQueue, MyPose.yIndexOf(MyPose.poseIndexLeftHip), MyPose.yIndexOf(MyPose.poseIndexLeftAnkle))


In [26]:
def isPunchingWithWrist(flattenPeakQueue, shoulderIndex, wristIndex):
    global bodyHeight
    s = getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, shoulderIndex, wristIndex)
    return s.absMaxVal > bodyHeight*0.25 and s.absDVSpeed > bodyHeight*0.1 and s.isDVInOneSign and s.isAbsValIncreasing

def isPunching(flattenPeakQueue):
    return isPunchingWithWrist(flattenPeakQueue, MyPose.xIndexOf(MyPose.poseIndexRightShoulder), MyPose.xIndexOf(MyPose.poseIndexRightWrist)) or isPunchingWithWrist(flattenPeakQueue, MyPose.xIndexOf(MyPose.poseIndexLeftShoulder), MyPose.xIndexOf(MyPose.poseIndexLeftWrist))


In [27]:
def isJumping(flattenPeakQueue):
    global bodyHeight
    s = getStatisticsOfPeakInQueue(flattenPeakQueue, MyPose.yIndexOf(MyPose.poseIndexNeck))
    maxDeltaVal = s.maxVal - s.minVal
    return maxDeltaVal > bodyHeight * 0.1 and s.dVSpeed < -bodyHeight*0.1


In [28]:
isExtendingBothWristsScalingFactor = 0.5

#check if both wrist is extended
#i.e. the deltaX of left and right wrist is over threshold
def isExtendingBothWrists(flattenPeakQueue):
    global bodyHeight
    global isExtendingBothWristsScalingFactor
    s = getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, MyPose.xIndexOf(MyPose.poseIndexRightWrist), MyPose.xIndexOf(MyPose.poseIndexLeftWrist))
    return s.absMaxVal > bodyHeight * isExtendingBothWristsScalingFactor

def getFlyingDirWithWrists(flattenPeakQueue):
    global bodyHeight
    sR = getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, MyPose.yIndexOf(MyPose.poseIndexRightWrist), MyPose.yIndexOf(MyPose.poseIndexRightShoulder))
    sL = getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, MyPose.yIndexOf(MyPose.poseIndexLeftWrist), MyPose.yIndexOf(MyPose.poseIndexLeftShoulder))
    if sR.absMaxVal < bodyHeight*0.1 and sL.absMaxVal < bodyHeight*0.1:
        return "0"
    elif sR.avgVal < -bodyHeight*0.1 and sL.avgVal > bodyHeight*0.1:
        return "p"
    elif sR.avgVal > bodyHeight*0.1 and sL.avgVal < -bodyHeight*0.1:
        return "o"
    return "?"

def isFlying(flattenPeakQueue):
    return isExtendingBothWrists(flattenPeakQueue)


In [29]:
def isRightWristAboveShoulder(flattenPeakQueue):
    global bodyHeight
    sR = getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, MyPose.yIndexOf(MyPose.poseIndexRightWrist), MyPose.yIndexOf(MyPose.poseIndexRightShoulder))
    return sR.avgVal > bodyHeight*0.2

def isLeftWristAboveShoulder(flattenPeakQueue):
    global bodyHeight
    sL = getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, MyPose.yIndexOf(MyPose.poseIndexLeftWrist), MyPose.yIndexOf(MyPose.poseIndexLeftShoulder))
    return sL.avgVal > bodyHeight*0.2

def isBothWristsAboveShoulder(flattenPeakQueue):
    global bodyHeight
    sR = getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, MyPose.yIndexOf(MyPose.poseIndexRightWrist), MyPose.yIndexOf(MyPose.poseIndexRightShoulder))
    sL = getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, MyPose.yIndexOf(MyPose.poseIndexLeftWrist), MyPose.yIndexOf(MyPose.poseIndexLeftShoulder))
    return sR.avgVal > bodyHeight*0.3 and sL.avgVal > bodyHeight*0.3

def isBothWristsBelowKnee(flattenPeakQueue):
    global bodyHeight
    sR = getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, MyPose.yIndexOf(MyPose.poseIndexRightKnee), MyPose.yIndexOf(MyPose.poseIndexRightWrist))
    sL = getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, MyPose.yIndexOf(MyPose.poseIndexLeftKnee), MyPose.yIndexOf(MyPose.poseIndexLeftWrist))
    return sR.avgVal > bodyHeight*0.1 and sL.avgVal > bodyHeight*0.1

def isAnklesWideOpen(flattenPeakQueue):
    global bodyHeight
    s = getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, MyPose.xIndexOf(MyPose.poseIndexLeftAnkle), MyPose.xIndexOf(MyPose.poseIndexRightAnkle))
    return s.absMinVal > bodyHeight * 0.5

def getUpDownFrom(flattenPeakQueue):
    #if isBothWristsAboveShoulder(flattenPeakQueue):
    if isRightWristAboveShoulder(flattenPeakQueue):
        return "u"
    #if isBothWristsBelowKnee(flattenPeakQueue):
    if isAnklesWideOpen(flattenPeakQueue):
        return "d"
    return "?"


In [30]:
def getTargetingDirFromRightWrist(flattenPeakQueue):
    global bodyHeight
    sX = getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, MyPose.xIndexOf(MyPose.poseIndexRightWrist), MyPose.xIndexOf(MyPose.poseIndexRightShoulder))
    sY = getStatisticsBetweenTwoPeaksInQueue(flattenPeakQueue, MyPose.yIndexOf(MyPose.poseIndexRightWrist), MyPose.yIndexOf(MyPose.poseIndexRightShoulder))

    if sY.absMaxVal > bodyHeight * 0.3 and sY.avgVal > 0:
        return "u"
    elif sY.absMaxVal > bodyHeight * 0.3 and sY.avgVal < 0:
        return "d"
    elif sX.absMaxVal > bodyHeight * 0.05 and sX.avgVal < 0:
        return "l"
    elif sX.absMaxVal > bodyHeight * 0.2 and sX.avgVal > 0:
        return "r"
    
    return "?"

def isTargeting(flattenPeakQueue):
    return isLeftWristAboveShoulder(flattenPeakQueue)


In [31]:
def isFlattennPeakQueueValid(flattenPeakQueue):
    global bodyHeight
    obj = flattenPeakQueue[0]
    val0 = obj[MyPose.xIndexOf(MyPose.poseIndexNeck)]
    val1 = obj[MyPose.xIndexOf(MyPose.poseIndexRightWrist)]
    val2 = obj[MyPose.xIndexOf(MyPose.poseIndexLeftAnkle)]
    return val0 != 0 and val1 != 0 and val2 != 0 and bodyHeight != 0

Major game action logic.

In [32]:
shortDirKeyList = ["r", "l", "u", "d"]
longDirKeyList = ["M", "N", "O", "P", "0"]
allDirKeyList = ["r", "l", "u", "d", "M", "N", "O", "P", "0"]
longPressKeyList = ["E", "F", "G", "H", "M", "N", "O", "P", "0"]

def isLongPressKey(controlKey):
    global longPressKeyList
    return controlKey.upper() in longPressKeyList

def isLongDirKey(controkKey):
    global longDirKeyList
    return controlKey.upper() in longDirKeyList

def isShortDirKey(controkKey):
    global shortDirKeyList
    return controkKey in shortDirKeyList

def isShootingKey(controlKey):
    return controlKey == "x"

In [33]:
#direction control with rxpy
import rx
from rx import operators as op
from rx.scheduler.eventloop import AsyncIOScheduler
aio_scheduler = AsyncIOScheduler(loop=loop)

dirControlObserver = 0
def dirControlObservable(observer, scheduler):
    global dirControlObserver
    dirControlObserver = observer
    
rx.create(dirControlObservable).pipe(
    op.filter(lambda val: val != "?")
    #op.debounce(0.25)
    #op.throttle_with_timeout(0.2)
).subscribe(
    #on_next = lambda val: publishControlKey(val), scheduler=aio_scheduler
    on_next = lambda val: publishControlKey(val),
)

longControlObserver = 0
def longControlObservable(observer, scheduler):
    global longControlObserver
    longControlObserver = observer

rx.create(longControlObservable).pipe(
    op.filter(lambda val: val != "?"),
    op.distinct_until_changed()
).subscribe(
    on_next = lambda val: publishControlKey(val),
)

shootingControlObserver = 0
def shootingControlObservable(observer, scheduler):
    global shootingControlObserver
    shootingControlObserver = observer

dummyCnt = 0
def onNextLog(val):
    global dataText_w
    global dummyCnt
    dummyCnt = dummyCnt + 1
    dataText_w.value = str(val) + ' ' + str(dummyCnt)

rx.create(shootingControlObservable).pipe(
    op.filter(lambda val: val != "?"),
    #op.throttle_with_timeout(1.0, scheduler=aio_scheduler)
).subscribe(
    #on_next = lambda val: publishControlKey(val), scheduler=aio_scheduler
    on_next = lambda val: publishControlKey(val),
)


<rx.disposable.disposable.Disposable at 0x7f1eeb4828>

In [34]:
isTargetingMode = False
isFlyingMode = False
isRunningMode = False

def getGameControlsFrom(flattenPeakQueue):
    global isTargetingMode
    global isFlyingMode
    global isRunningMode
    output = []
    if not isFlattennPeakQueueValid(flattenPeakQueue):
        return output
    
    if isTargetingMode:
        if isTargeting(flattenPeakQueue):
            lastElement = flattenPeakQueue[-1]
            output.append(getTargetingDirFromRightWrist([lastElement]))
        else:
            isTargetingMode = False
            output.append("x")
        return output

    if isFlyingMode:
        if isFlying(flattenPeakQueue):
            output.append(getFlyingDirWithWrists(flattenPeakQueue))
        else:
            isFlyingMode = False
            output.append("0") #clear direction at the end of flying
            output.append("E")
        return output

    if isPunching(flattenPeakQueue):
        output.append("b")

    if isRunningMode:
        if isRunning(flattenPeakQueue):
            output.append(getRunningDirWithKnee(flattenPeakQueue))
        else:
            isRunningMode = False
            output.append("0")
        return output
    
    if isTargeting(flattenPeakQueue):
        isTargetingMode = True
        output.append("x")
        return output
        
    if isFlying(flattenPeakQueue):
        isFlyingMode = True
        output.append("e")
        return output

    if isRunning(flattenPeakQueue):
        isRunningMode = True
        return output

    #if isJumping(flattenPeakQueue):
    #    output.append("a")

    output.append(getUpDownFrom(flattenPeakQueue))

    return output

In [35]:
lastControlKey = "a"

buttonKeyMapDict = {
    "a": aButton,
    "b": bButton,
    "x": xButton,
    "y": yButton,
    "u": upButton,
    "d": downButton,
    "l": leftButton,
    "r": rightButton,
    "L": lButton,
    "R": rButton,
    "E": aButtonLong,
    "F": bButtonLong,
    "G": xButtonLong,
    "H": yButtonLong,
    "M": upButtonLong,
    "N": downButtonLong,
    "O": leftButtonLong,
    "P": rightButtonLong,
    "T": targetingModeButton
}

def displayAction(controlKey):
    global lastControlKey
    global buttonKeyMapDict
    global isTargetingMode
    if isLongPressKey(controlKey):
        if controlKey == "0":
            buttonKeyMapDict["M"].button_style = ''
            buttonKeyMapDict["N"].button_style = ''
            buttonKeyMapDict["O"].button_style = ''
            buttonKeyMapDict["P"].button_style = ''
        elif controlKey.isupper():
            buttonKeyMapDict[controlKey].button_style = ''
        else:
            buttonKeyMapDict[controlKey.upper()].button_style = 'success'
    elif isShortDirKey(controlKey):
        buttonKeyMapDict["u"].button_style = ''
        buttonKeyMapDict["d"].button_style = ''
        buttonKeyMapDict["l"].button_style = ''
        buttonKeyMapDict["r"].button_style = ''
        buttonKeyMapDict[controlKey].button_style = 'success'
    else:
        buttonKeyMapDict[lastControlKey].button_style = ''
        buttonKeyMapDict[controlKey].button_style = 'success'
        lastControlKey = controlKey
        if isTargetingMode:
            buttonKeyMapDict["T"].button_style = 'success'
        else:
            buttonKeyMapDict["T"].button_style = ''            


In [36]:
def publishControlKey(controlKey):
    global tcpWriter
    tcpWriter.write(controlKey.encode())


In [37]:
from collections import deque
N = 5
flattenPeakQueue = deque(maxlen=N)
lastDirKeyAt = 0

def processTrtData(object_counts, objects, normalized_peaks, timestamp):
    global N
    global flattenPeakQueue
    #global dataText_w
    global controllerCheckBox_w
    global tcpWriter
    global bodyHeight
    global dirControlObserver
    global longControlObserver
    global shootingControlObserver
    global lastDirKeyAt
    flattenPeaks = getFlattenPeaks(object_counts, objects, normalized_peaks, timestamp)
    if len(flattenPeaks) > 0:
        flattenPeak = flattenPeaks[0]
        flattenPeakQueue.append(flattenPeak)
        if len(flattenPeakQueue) == N:
            bodyHeight = getBodyHeight(flattenPeakQueue)
            controlKeys = getGameControlsFrom(flattenPeakQueue)
            if controllerCheckBox_w.value:
                for k in controlKeys:
                    if isLongPressKey(k):
                        longControlObserver.on_next(k)
                    elif isShortDirKey(k):
                        t1 = time.time()
                        dt = t1 - lastDirKeyAt
                        if dt > 0.15:
                            lastDirKeyAt = t1 
                            dirControlObserver.on_next(k)
                    elif isShootingKey(k):
                        shootingControlObserver.on_next(k)
                    elif k != "?":
                        publishControlKey(k)                    
            for k in controlKeys:
                if k != "?":
                    displayAction(k)


For measuring the response time of the main processing loop.

In [38]:
timeSum = 0
timeCount = 0
timeCountN = 1000

def recordTime(deltaTime):
    global timeSum
    global timeCount
    global dataText_w
    global timeCountN
    timeSum = timeSum + deltaTime
    timeCount = timeCount + 1
    if timeCount >= timeCountN:
        fps = timeCountN / timeSum
        timeSum = 0
        timeCount = 0
        dataText_w.value = "fps {}".format(fps)

Finally, we'll define the main execution loop.  This will perform the following steps

1.  Preprocess the camera image
2.  Execute the neural network
3.  Parse the objects from the neural network output
4.  Draw the objects onto the camera image
5.  Convert the image to JPEG format and stream to the display widget

In [39]:
def execute(change):
    global saveImageCheckBox_w
    global saveImageDelayCounter
    t0 = time.time()
    image = change['new']
    data = preprocess(image)
    cmap, paf = model_trt(data)
    cmap, paf = cmap.detach().cpu(), paf.detach().cpu()
    counts, objects, peaks = parse_objects(cmap, paf)#, cmap_threshold=0.15, link_threshold=0.15)
    processTrtData(counts, objects, peaks, t0)
    draw_objects(image, counts, objects, peaks)
    image_w.value = bgr8_to_jpeg(image[:, ::-1, :])
    t1 = time.time()
    recordTime(t1-t0)

If we call the cell below it will execute the function once on the current camera frame.

In [40]:
execute({'new': camera.value})

Call the cell below to attach the execution function to the camera's internal value.  This will cause the execute function to be called whenever a new camera frame is received.

In [41]:
camera.observe(execute, names='value')

Call the cell below to unattach the camera frame callbacks.

In [None]:
camera.unobserve_all()