In [None]:
# !git clone https://github.com/vietawake/VietNet
# !mv -v VietNet/* .

In [2]:
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as T
from torchvision import transforms
from models.vietnet import VietNet
import cv2
import torchvision
from inference import preprocess, segmentation, visualize
import os

In [3]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
np.random.seed(50)
torch.manual_seed(50)

if torch.cuda.is_available():
    torch.cuda.manual_seed(50)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
net = VietNet(num_classes=4)

In [6]:
checkpoint = torch.load('./pretrained_models/VietNet_epoch_33_acc_0.8496.pt',map_location=device)
multigpus = True
for key in checkpoint:  # check if the model was trained in multiple gpus
    if 'module' in key:
        multigpus = multigpus and True
    else:
        multigpus = False
if multigpus:
    net = torch.nn.DataParallel(net)
net.load_state_dict(checkpoint)
net.to(device)
net.eval()

VietNet(
  (feature_extaction): Sequential(
    (0): ConvBNReLU(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affi

In [7]:
list_images = os.listdir('img_test/')
for path_img in list_images:
    img = Image.open('img_test/'+ path_img)
    origin_height = img.size[1]
    origin_width = img.size[0]
    preprocess_img = preprocess(img, height=384, width=640)

    seg_map = segmentation(preprocess_img, origin_height, origin_width, net, device)
    overlaid_img = visualize(seg_map, np.asarray(img))

    combined_img = np.concatenate((np.asarray(img), overlaid_img), axis=1)

    cv2.imwrite('out_test/' + path_img, overlaid_img)


In [None]:
video = 'cam60_trifocal.mp4'
in_video  = cv2.VideoCapture(video)
fourcc    = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
seg_video = cv2.VideoWriter(video.replace('.mp4', '_out.mp4'),
                            fourcc, 24, (640, 384))

while True:
    # 3.1.  Read a single frame from the video
    result, frame = in_video.read()

    # Quit if end of video is reached
    if not result:
        break

    # Convert to the RGB format
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # 3.2. Convert the frame to a PIL image
    frame = Image.fromarray(frame)

    # 3.3. Apply transforms and convert the image to a Pytorch tensor
    frame = transforms.Resize((384, 640), 
                                interpolation=Image.NEAREST)(frame)

    # Note that the class "ToTensor()" already includes the normalization
    # [0,1] and the channel-first conversion
    frame_tensor = T.ToTensor()(frame).unsqueeze(dim=0).to(device)

    # 3.4. Perform a forward pass    
    logits = net(frame_tensor)

    # 3.5.  Produce a segmentation map from the logits   
    # Remove the first dimension (i.e. batch size of 1)
    logits  = logits.squeeze(0) 
    # Detach from graph and convert to a Numpy array
    logits  = logits.cpu().detach().numpy()
    # Get the segmentation map
    seg_map = np.argmax(logits, axis=0)

    # 3.6. Visualize the segmentation map
    overlaid_img = visualize(seg_map, np.asarray(frame))


    # 3.7. Combine the input image with the overlaid image
    combined_img = np.concatenate((np.asarray(frame), overlaid_img), 
                                    axis=1)


    # 3.8. Save the output frame
    seg_video.write(overlaid_img)
    
    # 3.9.(Optional) Early break if ESC is pressed
    if cv2.waitKey(1) & 0xff == 27:
        seg_video.release()
        in_video.release()
        cv2.destroyAllWindows()
        break

# 4. Close input and output video files
seg_video.release()
in_video.release()
cv2.destroyAllWindows()