In [13]:
import torch
from torchvision import models

# Define AlexNet architecture for 2 classes (road follow / obstacle)
model = models.alexnet(pretrained=False)
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2)


In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [15]:
model.load_state_dict(torch.load("best_model.pth", map_location=device))


<All keys matched successfully>

In [16]:
model = model.to(device)

In [17]:
model.eval()

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [18]:
from jetbot import Camera
from PIL import Image
from torchvision import transforms

# Start JetBot camera
camera = Camera.instance(width=224, height=224)

# Define transform to match model training
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


In [19]:
import traitlets
import ipywidgets.widgets as widgets
from IPython.display import display
from jetbot import Camera, bgr8_to_jpeg

In [20]:
image = widgets.Image(format='jpeg', width=224, height=224)  # this width and height doesn't necessarily have to match the camera

camera_link = traitlets.dlink((camera, 'value'), (image, 'value'), transform=bgr8_to_jpeg)

display(image)

Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x02\x01\x0…

In [21]:
import time

def predict_jetbot_frame():
    frame = camera.value  # numpy image array (H x W x C)
    image = transform(frame).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(image)
        prediction = output.argmax(1).item()
    return prediction


In [22]:
from jetbot import Robot

robot = Robot()

In [25]:



try:
    while True:
        pred = predict_jetbot_frame()
        print(pred)
        if pred == 0:
            # Class 0 = road → go forward
            robot.forward(0.2)
        else:
            # Class 1 = obstacle → stop/turn
            robot.left(0.3)
            time.sleep(0.7)  # Tweak this value if it over- or under-rotates
            robot.stop()
            time.sleep(0.2)

        time.sleep(0.1)
except KeyboardInterrupt:
    robot.stop()
    camera.stop()


0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0


In [None]:
robot.stop()

In [26]:
camera.stop()