In [1]:
# !git clone https://github.com/vietawake/ERFModel
# !mv -v ERFModel/* .

In [1]:
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as T
from torchvision import transforms
from models.roadseg import RoadSeg
import cv2
from inference import preprocess, segmentation, visualize

In [2]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
np.random.seed(50)
torch.manual_seed(50)

if torch.cuda.is_available():
    torch.cuda.manual_seed(50)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
net = RoadSeg(num_labels=4)

In [5]:
checkpoint = torch.load('./pretrained_models/RoadSeg_epoch_44_acc_0.7665.pt',map_location=device)
multigpus = True
for key in checkpoint:  # check if the model was trained in multiple gpus
    if 'module' in key:
        multigpus = multigpus and True
    else:
        multigpus = False
if multigpus:
    net = torch.nn.DataParallel(net)
net.load_state_dict(checkpoint)
net.to(device)
net.eval()

RoadSeg(
  (encoder_rgb_conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (encoder_rgb_bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (encoder_rgb_relu): ReLU(inplace=True)
  (encoder_rgb_maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (encoder_rgb_layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, 

In [10]:
input_path = 'img_test/autopilot_test_0001_20201020_073306_000064.png'
img = Image.open(input_path)
origin_height = img.size[1]
origin_width = img.size[0]
preprocess_img = preprocess(img, height=384, width=640)

seg_map = segmentation(preprocess_img, origin_height, origin_width, net, device)
print(seg_map.shape)
overlaid_img = visualize(seg_map, np.asarray(img))

combined_img = np.concatenate((np.asarray(img), overlaid_img), axis=1)

cv2.imwrite('ouput.jpg', overlaid_img)


(1080, 1920)


True

In [None]:
video = 'cam60_trifocal.mp4'
in_video  = cv2.VideoCapture(video)
fourcc    = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
seg_video = cv2.VideoWriter(video.replace('.mp4', '_out.mp4'),
                            fourcc, 24, (640, 384))

while True:
    # 3.1.  Read a single frame from the video
    result, frame = in_video.read()

    # Quit if end of video is reached
    if not result:
        break

    # Convert to the RGB format
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # 3.2. Convert the frame to a PIL image
    frame = Image.fromarray(frame)

    # 3.3. Apply transforms and convert the image to a Pytorch tensor
    frame = transforms.Resize((384, 640), 
                                interpolation=Image.NEAREST)(frame)

    # Note that the class "ToTensor()" already includes the normalization
    # [0,1] and the channel-first conversion
    frame_tensor = T.ToTensor()(frame).unsqueeze(dim=0).to(device)

    # 3.4. Perform a forward pass    
    logits = net(frame_tensor)

    # 3.5.  Produce a segmentation map from the logits   
    # Remove the first dimension (i.e. batch size of 1)
    logits  = logits.squeeze(0) 
    # Detach from graph and convert to a Numpy array
    logits  = logits.cpu().detach().numpy()
    # Get the segmentation map
    seg_map = np.argmax(logits, axis=0)

    # 3.6. Visualize the segmentation map
    overlaid_img = visualize(seg_map, np.asarray(frame))


    # 3.7. Combine the input image with the overlaid image
    combined_img = np.concatenate((np.asarray(frame), overlaid_img), 
                                    axis=1)


    # 3.8. Save the output frame
    seg_video.write(overlaid_img)
    
    # 3.9.(Optional) Early break if ESC is pressed
    if cv2.waitKey(1) & 0xff == 27:
        seg_video.release()
        in_video.release()
        cv2.destroyAllWindows()
        break

# 4. Close input and output video files
seg_video.release()
in_video.release()
cv2.destroyAllWindows()