First, create the model. This must match the model used in the interactive training notebook.

In [None]:
import cv2
import torch
import torchvision

CATEGORIES = ['apex']

device = torch.device('cuda')
model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(512, 2 * len(CATEGORIES))
model = model.cuda().eval().half()

Next, load the saved model.  Enter the model path you used to save.

In [None]:
model.load_state_dict(torch.load('road_following_model.pth'))

Convert and optimize the model using ``torch2trt`` for faster inference with TensorRT.  Please see the [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt) readme for more details.

> This optimization process can take a couple minutes to complete.

In [None]:
from torch2trt import torch2trt

data = torch.zeros((1, 3, 224, 224)).cuda().half()

model_trt = torch2trt(model, [data], fp16_mode=True)

Save the optimized model using the cell below

In [None]:
torch.save(model_trt.state_dict(), 'road_following_model_trt.pth')

Load the optimized model by executing the cell below

In [None]:
import torch
from torch2trt import TRTModule

model_trt = TRTModule()
model_trt.load_state_dict(torch.load('road_following_model_trt.pth'))

Create the racecar class

In [None]:
from jetracer.nvidia_racecar import NvidiaRacecar

car = NvidiaRacecar()

Create the camera class.

In [None]:
from jetcam.csi_camera import CSICamera

camera = CSICamera(width=224, height=224, capture_fps=65)

Create Live Widget

In [None]:
# with live widget
import threading
import time
import numpy as np
from utils import preprocess
import torch.nn.functional as F
import matplotlib.pyplot as plt

import cv2
import ipywidgets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg
from jupyter_clickable_image_widget import ClickableImageWidget

Kp = 1.7
Kd = 6
car.steering_gain = 1.0
car.throttle = 0.75
car.throttle_gain=1.0
camera.running = True

# unobserve all callbacks from camera in case we are running this cell for second time
camera.unobserve_all()

# create image preview
prediction_widget = ipywidgets.Image(format='jpeg', width=camera.width, height=camera.height)
state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='stop')

def live(state_widget, model, camera, prediction_widget):
    last_x = 0
    while state_widget.value == 'live':
        start_time = time.time() # start time of the loop
        image = camera.value
        preprocessed = preprocess(image).half()
        output = model(preprocessed).detach().cpu().numpy().flatten()
        x = float(output[0])

        # with PD controller
        car.steering = - ( x * Kp + (x - last_x) * Kd)
        last_x = x

        steering_angle = x * 30.0
        steering_data.append(steering_angle) # Store steering angle data

        fps = 1.0 / (time.time() - start_time)
        fps_data.append(fps) # Store FPS data

        str_info = "Deg %0.1f'" % steering_angle

        x_int = int(camera.width * (x / 2.0 + 0.5))
        y_int = int(camera.height * 0.5)

        str1 = " FPS:%d" % int(fps)
        str_info += str1

        prediction = image.copy()
        prediction = cv2.circle(prediction, (x, y), 8, (255, 0, 0), 3)
        prediction = cv2.line(prediction,(x,y),(112,112),(0,0,255),3)
        prediction = cv2.circle(prediction, (112,112), 8, (255, 0, 0), 3)
        prediction = cv2.putText(prediction,str,(0,220),cv2.FONT_HERSHEY_PLAIN,1,(255,255,255))
        prediction_widget.value = bgr8_to_jpeg(prediction)

def start_live(change):
    if change['new'] == 'live':
        execute_thread = threading.Thread(target=live, args=(state_widget, model_trt, camera, prediction_widget))
        execute_thread.start()
    else :
        time.sleep(0.1)
        car.steering = 0
        car.throttle = 0
        car.manual = 0

state_widget.observe(start_live, names='value')

live_execution_widget = ipywidgets.VBox([
    prediction_widget,
    state_widget
])

display(live_execution_widget)

# Plot data
def plot_data():
    plt.figure(figsize=(10, 7))
    plt.subplot(2, 1, 1)
    plt.plot(steering_data)
    plt.title('Steering Angle Over Time')
    plt.xlabel('Time')
    plt.ylabel('Steering Angle (degrees)')

    plt.subplot(2, 1, 2)
    plt.plot(fps_data)
    plt.title('FPS Over Time')
    plt.xlabel('Time')
    plt.ylabel('FPS')
    plt.tight_layout()
    plt.show()
    print_statistics()

def print_statistics():
    if steering_data:
        print(f'Steering Angle: Max = {max(steering_data):.2f}, Min = {min(steering_data):.2f}, Avg = {np.mean(steering_data):.2f}')
    if fps_data:
        print(f'FPS: Max = {max(fps_data):.2f}, Min = {min(fps_data):.2f}, Avg = {np.mean(fps_data):.2f}')

def stop_live(change):
    if change['new'] == 'stop':
        plot_data()

state_widget.observe(stop_live, names='value')

VBox(children=(Image(value=b'', format='jpeg', height='224', width='224'), ToggleButtons(description='state', â€¦

In [None]:
import time

camera.unobserve_all()

time.sleep(0.1)
car.steering = 0
car.throttle = 0
car.manual = 0