In [None]:
from ultralytics import YOLO

In [None]:
yolo = YOLO('../models/yolov8n-seg.pt', 'segment')

In [None]:
yolo.train(data='./yolo-config.yaml', epochs=5, batch=4)
valid_results = yolo.val()

# Inference

In [None]:
import cv2, PIL
from PIL import Image
from torchvision.transforms import functional as F
import matplotlib.pyplot as plt
import torch
import numpy as np

In [None]:
image = cv2.imread('../data/images/clodding_train_005.jpg')

In [None]:
yolo = YOLO('../runs/segment/train6/weights/best.pt', 'segment')

In [None]:
res = yolo(image)


0: 544x640 13 ROCKs, 125.3ms
Speed: 5.5ms preprocess, 125.3ms inference, 26.1ms postprocess per image at shape (1, 3, 544, 640)


In [None]:
def get_mask(result):
    '''get segmentation mask from yolo model'''
    def extract_points(mask): 
        return mask.xy[0].astype(np.int32)[None]
    
    res = np.zeros(result.orig_shape, dtype=np.uint8)
    for mask in result.masks:
        res = cv2.fillPoly(res, extract_points(mask), 255)
    return res

def merge_with_mask(image, mask, p=0.2, gamma=0):
    mask_color = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)*np.array([0,1,0], np.uint8)
    return cv2.addWeighted(image, 1-p, mask_color, p, gamma)

In [None]:
import cv2
import threading
from IPython.display import display, Image
from queue import Queue

def display_frames(queue):
    display_handle = display(None, display_id=True)
    while True:
        print('q read')
        frame = queue.get()
        if frame is None:
            break
        display_handle.update(Image(data=frame.tobytes(), width=500, height=400))

def process_and_display(video, func=None):
    if func is None: 
        func = lambda x: x

    video = cv2.VideoCapture('../data/clods.mp4')
    queue = Queue()

    display_thread = threading.Thread(target=display_frames, args=(queue,))
    display_thread.start()

    try:
        while True:
            _, frame = video.read()
            if frame is None:
                break
            _, frame = cv2.imencode('.jpeg', func(frame))
            print('q write')
            queue.put(frame)
    except KeyboardInterrupt:
        print('Get keyboard interrupt')
    finally:
        video.release()
        queue.put(None)
        display_thread.join()

In [None]:
video = cv2.VideoCapture('../data/clods.mp4')
_, frame = video.read()

In [None]:
def segment_frame(frame): return merge_with_mask(frame, get_mask(yolo(frame)[0]))

process_and_display('../data/clods.mp4', segment_frame)