# Save prediction as jet over image

This is for video but applicable to images

In [None]:
import cv2 as cv
from UNet.unet import UNet
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np

In [None]:
VIDEO_PATH = 'input/jetson2_video_20180628-234507.mp4'      # scale
# VIDEO_PATH = 'input/jetson4_video_20180628-170007.mp4'    # vicolo
MODEL_PATH = 'models/700.pth'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
transform = A.Compose([
    # A.Resize(752, 423),
    ToTensorV2()
])

cap = cv.VideoCapture(VIDEO_PATH)

# Load model
model = UNet()
checkpoint = torch.load(MODEL_PATH, weights_only=True)
model.load_state_dict(checkpoint['model'])
model.eval()
model.to(DEVICE)
print('Model loaded')

In [None]:
while cap.isOpened():
    ret, orig_frame = cap.read()
    if not ret:
        break

    cv.imshow('orig frame', orig_frame)

    # Preprocess the frame

    frame = cv.resize(orig_frame, (752, 423))
    transformed = transform(image=frame)
    frame = transformed['image']
    # image_np = frame.permute(1, 2, 0).cpu().numpy()   # Optional: this step as image
    # cv.imwrite('test_middle_save', image_np)
    frame = frame.unsqueeze(0).float().to(DEVICE)


    # Get preddiction from the model

    with torch.no_grad():
        pred = model(frame)
        pred = torch.sigmoid(pred)
        # pred = (pred > 0.5).float()
        pred = pred.squeeze().cpu().numpy()
    
    # Visualize the prediction, by printing over the original frame (with alpha blending)

    pred_viz = cv.resize(pred, (orig_frame.shape[1], orig_frame.shape[0]))
    pred_viz = cv.cvtColor(pred_viz, cv.COLOR_GRAY2BGR)
    pred_viz = (pred_viz * 255).astype(np.uint8)  # Ensure the type is CV_8UC3
    pred_viz = cv.applyColorMap(pred_viz, cv.COLORMAP_JET)
    pred_viz = pred_viz.astype(orig_frame.dtype)
    pred_viz = cv.addWeighted(orig_frame.astype(pred_viz.dtype), 0.5, pred_viz, 0.5, 0)
    cv.imshow('pred', pred_viz)

    if cv.waitKey(0) & 0xFF == ord('q'):
        break

cap.release()
cv.destroyAllWindows()