In [1]:
# Importing necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models

import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from PIL import Image
import os
from tqdm import tqdm

# torch.backends.cuda.matmul.allow_tf32 = False
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'

import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
from ultralytics import YOLO
import torch.nn.functional as F

In [2]:
class CFG:
    ######### importtant #########
    video_name = 'GX050090_resize.mp4'
    binary_model_addr = '../model/model56.pth'
    video_root =  'D:/data/cow_teat_segmentation/2022_02_08_Aim2-selected'
    yolo_model_path = '../model/yolov8m.pt'
    ######### importtant #########
    
    
    
    device = 'cuda'
    
    
    video_list = os.listdir(video_root)
    video_name_list = [i.split('_')[0] for i in video_list]

    video_path = video_root + '/' + video_name
    transform = A.Compose([
        A.Resize(224, 224),
        A.Normalize(),
        ToTensorV2(),
    ])
cfg = CFG()

In [3]:
import sys
sys.path.append('../utils/')
from generateMaskArea import generate_mask_area
from bboxAnalyse import get_score, record, get_GT_label,convert_scores,cal_correct_count,show_class_item

from getFrameNumber import get_total_frame_number
from model.unet import UNet

In [4]:
import cv2
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
import re
Atransform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(),
    ToTensorV2()
])

In [6]:
# input image and yolo model, return bbox, class_id, image
def yolov8_detection(model, image):
    # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    results = model(image)  # generator of Results objects

    boxes_list = []
    classes_list = []
    boxes = results[0].boxes  # Boxes object for bbox outputs
    class_id = results[0].boxes.cls.long().tolist()
    boxes_list.append(boxes.xyxy.tolist())

    bbox = [[int(i) for i in box] for boxes in boxes_list for box in boxes]

    return bbox, class_id, image,results

# input binary mask and yolo's bbox, class id, return labeled mask
def label_mask(SAM_mask,bboxes, classes):
    
    labeled_mask = SAM_mask.copy().astype(np.uint8)
    for bbox, cls in zip(bboxes, classes):
        x1, y1, x2, y2 = bbox
        labeled_mask[y1:y2, x1:x2] = np.where(SAM_mask[y1:y2, x1:x2] == True, cls, labeled_mask[y1:y2, x1:x2])
        labeled_mask[labeled_mask==1]=0
    return labeled_mask


class_colors = {
    range(1, 61): (255, 0, 255),  # Purple
    61: (0, 0, 255),  # Red
    62: (0, 255, 0),  # Green
    63: (255, 0, 0),  # Blue
    64: (0, 255, 255)  # Yellow
}

# plot bbox of stall
def plot_bbox(image, yolov8_boxex, yolov8_class_id, if_teat=False):
    # Conve
    for bbox, class_id in zip(yolov8_boxex, yolov8_class_id):
        if class_id <61:
            color = (255, 0, 255)
        elif if_teat:
            # Determine the color for this class
            color = class_colors[class_id]  # Default to white
        else: 
            continue
        # Draw the bounding box
        cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 1)
        
        # Write the class id near the top-left corner of the box
        cv2.putText(image, str(class_id), (bbox[0], bbox[1]-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
    return image

In [7]:
# draw cow teat color

enhancement = 50
def draw_1_red(image, mask, target):
    image[mask == target[0], 2] = np.clip(image[mask == target[0], 2].astype(int) + enhancement, 0, 255)
    return image

def draw_2_green(image, mask, target):
    image[mask == target[1], 1] = np.clip(image[mask == target[1], 1].astype(int) + enhancement, 0, 255)
    return image

def draw_3_blue(image, mask, target):
    image[mask == target[2], 0] = np.clip(image[mask == target[2], 0].astype(int) + enhancement, 0, 255)
    return image

def draw_4_yellow(image, mask, target):
    image[mask == target[3], 2] = np.clip(image[mask == target[3], 2].astype(int) + enhancement, 0, 255)  # Red
    image[mask == target[3], 1] = np.clip(image[mask == target[3], 1].astype(int) + enhancement, 0, 255)  # Green
    image[mask == target[3], 0] = np.clip(image[mask == target[3], 0].astype(int) - enhancement, 0, 255)  # Blue
    return image



def draw_image_color(image, mask, only_teat = False):
    if only_teat:
        target = [1, 2, 3, 4]
    else:
        target = [61, 62, 63, 64]

    image = draw_1_red(image, mask, target)
    image = draw_2_green(image, mask, target)
    image = draw_3_blue(image, mask, target)
    image = draw_4_yellow(image, mask, target)

    return image

In [8]:
root_dir = 'D:/data/cow_teat_segmentation/2022_02_08_Aim2-selected' # root 

In [9]:
model = UNet(2)
model.load_state_dict(torch.load(cfg.binary_model_addr))
model.to(cfg.device)

UNet(
  (inc): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (mpconv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_

In [10]:

# model_path = 'D:/data/cow_teat_segmentation/model/yolo/best.pt'
yolo_model = YOLO(cfg.yolo_model_path)

In [13]:
# put predict result on image
def put_predict(yolov8_class_id,yolov8_boxex,image):
    score = get_score(yolov8_class_id,yolov8_boxex)
    if len(score) == 1:
        score = score[0]
        text = str(score['stall'])+': LH: '+str(score['left-hind'][1])+' RH: '+str(score['right-hind'][1])+' LF: '+str(score['left-front'][1])+' RF: '+str(score['right-front'][1])
        image = cv2.putText(image, text , (70, 250), cv2.FONT_HERSHEY_PLAIN, 1, (255,255, 0), 1)
    return image

In [14]:
# plot sample color in image
def color_GT_sample_score(image):
    enhancement = 50
    image[200:215,10:25,2] = np.clip(image[200:215,10:25,2].astype(int) + enhancement, 0, 255)
    image = cv2.putText(image, '1' , (35, 212), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 255), 1)
    image[220:235,10:25,1] = np.clip(image[220:235,10:25,1].astype(int) + enhancement, 0, 255)
    image = cv2.putText(image, '2' , (35, 232), cv2.FONT_HERSHEY_PLAIN, 1, (0, 255, 0), 1)
    image[240:255,10:25,0] = np.clip(image[240:255,10:25,0].astype(int) + enhancement, 0, 255)
    image = cv2.putText(image, '3' , (35, 252), cv2.FONT_HERSHEY_PLAIN, 1, (255, 0, 0), 1)
    image[260:275,10:25,2] = np.clip(image[260:275,10:25,2].astype(int) + enhancement, 0, 255)
    image[260:275,10:25,1] = np.clip(image[260:275,10:25,1].astype(int) + enhancement, 0, 255)
    image = cv2.putText(image, '4' , (35, 272), cv2.FONT_HERSHEY_PLAIN, 1, (0, 255, 255), 1)
    image = cv2.putText(image, 'score' , (10, 195), cv2.FONT_HERSHEY_PLAIN, 1, (255,255, 255), 1)
    # image = cv2.putText(image, 'Ground Truth' , (140, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0, 255), 2)
    return image

In [15]:
# input original image, return the extracted image
def get_extracted_image(image,Atransform,model):
    original_image = image.copy()
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    transformed = Atransform(image=image)
    image = transformed['image']
    image = image.unsqueeze(0).to(cfg.device)
    with torch.no_grad():
        outputs = model(image)

    predicted = torch.argmax(outputs, dim=1)
    mask = predicted.squeeze(0).cpu().numpy().astype(np.float32)
    mask = cv2.resize(mask, (480, 270), interpolation = cv2.INTER_LINEAR)
    mask[mask>0] = 255
    binary_mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR).astype(np.uint8)
    extracted_image = cv2.bitwise_and(original_image, binary_mask)
    return extracted_image

In [17]:
threshold = [0, 0, 0]
def play_video():
    video = cv2.VideoCapture(cfg.video_path)
    if (video.isOpened()== False):
        print("Error opening video file")
    
    # Read until video is completed
    count=0
    frame_width = int(video.get(3))
    frame_height = int(video.get(4))
    
    size = (frame_width, frame_height)
    result = cv2.VideoWriter('filename.avi', 
                         cv2.VideoWriter_fourcc(*'MJPG'),
                         10, size)
    while(video.isOpened()):
    # Capture frame-by-frame
        ret, original_image = video.read()
        # print(ret)
        if ret == True:
            original_image_copy = original_image.copy()
            image = get_extracted_image(original_image_copy,Atransform, model)
            # cv2.imshow("a",image)
            count += 1
            dh, dw, _ = image.shape
            mask = np.all(image > threshold, axis=2)
            yolov8_boxex,yolov8_class_id, image,results = yolov8_detection(yolo_model, image)
            if len(yolov8_boxex) > 0:
                new_mask = label_mask(mask, yolov8_boxex, yolov8_class_id)
                image = plot_bbox(original_image, yolov8_boxex, yolov8_class_id)
                image = draw_image_color(image,new_mask)
                # Display the annotated frame
                image = put_predict(yolov8_class_id,yolov8_boxex,image)
                image = cv2.putText(image, 'Prediction' , (170, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,0, 0), 2)
                image = color_GT_sample_score(image)
                combined_image = image
                # combined_image = cv2.cvtColor(combined_image, cv2.COLOR_BGR2RGB)
                # Break the loop if 'q' is pressed
            else:
                image = color_GT_sample_score(original_image)
                combined_image = original_image
                # combined_image = cv2.cvtColor(combined_image, cv2.COLOR_BGR2RGB)
            result.write(combined_image)
            cv2.imshow("stall number", combined_image)
            if cv2.waitKey(1) & 0xFF == ord("q"):
                break
        else:
            break
    video.release()
    cv2.destroyAllWindows()

In [18]:
play_video()


0: 384x640 1 54, 83.9ms
Speed: 2.5ms preprocess, 83.9ms inference, 1.5ms postprocess per image at shape (1, 3, 640, 640)

0: 384x640 1 54, 14.9ms
Speed: 1.0ms preprocess, 14.9ms inference, 1.0ms postprocess per image at shape (1, 3, 640, 640)

0: 384x640 1 54, 12.8ms
Speed: 1.0ms preprocess, 12.8ms inference, 2.5ms postprocess per image at shape (1, 3, 640, 640)

0: 384x640 1 51, 1 54, 12.4ms
Speed: 1.0ms preprocess, 12.4ms inference, 2.5ms postprocess per image at shape (1, 3, 640, 640)

0: 384x640 1 54, 12.3ms
Speed: 1.0ms preprocess, 12.3ms inference, 2.0ms postprocess per image at shape (1, 3, 640, 640)

0: 384x640 1 54, 1 s2, 11.4ms
Speed: 1.0ms preprocess, 11.4ms inference, 2.0ms postprocess per image at shape (1, 3, 640, 640)

0: 384x640 1 54, 1 s2, 12.9ms
Speed: 1.0ms preprocess, 12.9ms inference, 1.0ms postprocess per image at shape (1, 3, 640, 640)

0: 384x640 1 54, 1 s2, 12.0ms
Speed: 0.5ms preprocess, 12.0ms inference, 1.5ms postprocess per image at shape (1, 3, 640, 640)
