<a href="https://colab.research.google.com/github/softmurata/colab_notebooks/blob/main/sports/soccerXAI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!wget https://github.com/yt-dlp/yt-dlp/releases/download/2023.03.04/yt-dlp
!chmod +x yt-dlp

In [None]:
!./yt-dlp https://www.youtube.com/watch?v=mqdt2FEE22o -o test.mp4 -f bestvideo[ext=mp4]+bestaudio[ext=m4a] -S vcodec:h264

In [None]:
import cv2
from PIL import Image
cap = cv2.VideoCapture("test.mp4")
cap.set(cv2.CAP_PROP_POS_FRAMES, 2200)
ret, frame = cap.read()
cv2.imwrite("test.jpg", frame)
display(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))

Groudig DINO installation

In [None]:
%cd /content
!git clone https://github.com/IDEA-Research/GroundingDINO.git
%cd /content/GroundingDINO
!pip install -q -e .
!pip install -q roboflow

Import libraries

In [1]:
import argparse
from functools import partial
import cv2
import requests

from io import BytesIO
from PIL import Image
import numpy as np
from pathlib import Path
import random


import warnings
warnings.filterwarnings("ignore")


import torch
from torchvision.ops import box_convert

from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from groundingdino.util.inference import annotate, load_image, predict
import groundingdino.datasets.transforms as T

from huggingface_hub import hf_hub_download

In [2]:
# utils function
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)

    args = SLConfig.fromfile(cache_config_file) 
    model = build_model(args)
    args.device = device

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location='cpu')
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    print("Model loaded from {} \n => {}".format(cache_file, log))
    _ = model.eval()
    return model  

In [3]:
# load detection model
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filenmae = "groundingdino_swint_ogc.pth"
ckpt_config_filename = "GroundingDINO_SwinT_OGC.cfg.py"
dino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)

final text_encoder_type: bert-base-uncased


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model loaded from /root/.cache/huggingface/hub/models--ShilongLiu--GroundingDINO/snapshots/a94c9b567a2a374598f05c584e96798a170c56fb/groundingdino_swint_ogc.pth 
 => _IncompatibleKeys(missing_keys=[], unexpected_keys=['label_enc.weight'])


Inference

In [None]:
import os
import supervision as sv

local_image_path = "/content/test.jpg"

TEXT_PROMPT = "player"
BOX_TRESHOLD = 0.2
TEXT_TRESHOLD = 0.2

image_source, image = load_image(local_image_path)

boxes, logits, phrases = predict(
    model=dino_model, 
    image=image, 
    caption=TEXT_PROMPT, 
    box_threshold=BOX_TRESHOLD, 
    text_threshold=TEXT_TRESHOLD
)

annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)

%matplotlib inline  
sv.plot_image(annotated_frame, (16, 16))

In [5]:
import os
import cv2
imgh, imgw = image_source.shape[:2]

player_dir = "/content/playerimages"
os.makedirs(player_dir, exist_ok=True)

for idx, bbox in enumerate(boxes.numpy()):
  xc, yc, w, h = bbox
  xmin = int((xc - w * 0.5) * imgw)
  ymin = int((yc - h * 0.5) * imgh)
  xmax = int((xc + w * 0.5) * imgw)
  ymax = int((yc + h * 0.5) * imgh)
  crop_img = image_source[ymin:ymax, xmin:xmax, :]
  cv2.imwrite(player_dir + "/{}.jpg".format(str(idx).zfill(3)), cv2.cvtColor(crop_img, cv2.COLOR_RGB2BGR))


Player classification

Installation transformers

In [None]:
!pip install -q git+https://github.com/huggingface/transformers.git

Utility functions

In [6]:
import matplotlib.pyplot as plt
import numpy as np

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))  

def show_boxes_on_image(raw_image, boxes):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    for box in boxes:
      show_box(box, plt.gca())
    plt.axis('on')
    plt.show()

def show_points_on_image(raw_image, input_points, input_labels=None):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    input_points = np.array(input_points)
    if input_labels is None:
      labels = np.ones_like(input_points[:, 0])
    else:
      labels = np.array(input_labels)
    show_points(input_points, labels, plt.gca())
    plt.axis('on')
    plt.show()

def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    input_points = np.array(input_points)
    if input_labels is None:
      labels = np.ones_like(input_points[:, 0])
    else:
      labels = np.array(input_labels)
    show_points(input_points, labels, plt.gca())
    for box in boxes:
      show_box(box, plt.gca())
    plt.axis('on')
    plt.show()


def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    input_points = np.array(input_points)
    if input_labels is None:
      labels = np.ones_like(input_points[:, 0])
    else:
      labels = np.array(input_labels)
    show_points(input_points, labels, plt.gca())
    for box in boxes:
      show_box(box, plt.gca())
    plt.axis('on')
    plt.show()


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)



def show_masks_on_image(raw_image, masks, scores):
    if len(masks.shape) == 4:
      masks = masks.squeeze()
    if scores.shape[0] == 1:
      scores = scores.squeeze()

    nb_predictions = scores.shape[-1]
    fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))

    for i, (mask, score) in enumerate(zip(masks, scores)):
      mask = mask.cpu().detach()
      axes[i].imshow(np.array(raw_image))
      show_mask(mask, axes[i])
      axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
      axes[i].axis("off")
    plt.show()

Segment Anything

In [7]:
import torch
from PIL import Image
import requests
from transformers import SamModel, SamProcessor

Load model

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

In [None]:
# input player image
img_path = "/content/playerimages/011.jpg"
img = Image.open(img_path).convert("RGB")
display(img)

In [48]:
w, h = img.size
input_points = [[[int(w * 0.5), int(h * 0.5)]]]  # 2D location of a window in the image

inputs = processor(img, input_points=input_points, return_tensors="pt").to(device)
with torch.no_grad():
  outputs = model(**inputs)

masks = processor.image_processor.post_process_masks(
    outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores

In [None]:
show_masks_on_image(img, masks[0], scores)

In [50]:
scores_list = scores.cpu().numpy()[0][0]
for idx in range(len(scores_list)):
  sum = np.sum(masks[0].cpu().numpy().squeeze(0)[idx])
  prod = np.prod(masks[0].cpu().numpy().squeeze(0)[idx].shape)
  if sum > prod * 0.8:
    scores_list[idx] = 0
  if sum < prod * 0.05:
    scores_list[idx] = 0

print(scores_list)

[0.93344164 0.94545054 0.        ]


In [None]:
import cv2
midx = np.argmax(scores_list)
pred_mask = masks[0].cpu().numpy().squeeze(0)[midx]
cv2.imwrite("mask.jpg", pred_mask.astype(np.uint8) * 255)
rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
mask = cv2.imread("mask.jpg")
rgb_and = cv2.bitwise_and(rgb, mask)
rh, rw = rgb_and.shape[:2]
# center crop
center_rgb_and = rgb_and[int(0.1 * rh):int(0.5 * rh), int(0.3 * rw):int(0.7 * rw), :]
from PIL import Image
Image.fromarray(center_rgb_and)

In [52]:
hsv = cv2.cvtColor(center_rgb_and, cv2.COLOR_RGB2HSV)
red = cv2.inRange(hsv, np.array([145, 70, 0]), np.array([180, 255, 255]))
yellow = cv2.inRange(hsv, np.array([10, 80, 0]), np.array([50, 255, 255]))
blue = cv2.inRange(hsv, np.array([108, 121, 0]), np.array([120, 255, 255]))

In [None]:
bin_imgs = {'red': red, 'yellow': yellow, 'blue': blue}

# 2値化結果を可視化する。
fig, axes_list = plt.subplots(3, 1, figsize=(10, 18))

player_color = ""
max_sum = 0
for ax, (label, bin_img) in zip(axes_list.ravel(), bin_imgs.items()):
    ax.axis('off')
    ax.set_title(label)
    ax.imshow(bin_img, cmap=plt.cm.gray)
    sum = np.sum(bin_img)
    if sum > max_sum:
      player_color = label
      max_sum = sum
    
print(player_color)
    
plt.show()

Tips

Load model

In [None]:
# player classification
from transformers import pipeline
generator = pipeline("mask-generation", model="facebook/sam-vit-huge", device=0)

In [None]:
from PIL import Image
img_path = "/content/playerimages/000.jpg"
img = Image.open(img_path).convert("RGB")

Inference

In [None]:
outputs = generator(img, points_per_batch=64)

In [None]:
# display image
masks = outputs["masks"]
show_masks_on_image(img, masks)

In [None]:
print(masks[1].astype(np.uint8))

In [None]:
import cv2
import numpy as np
cv2_img = cv2.imread(img_path)
h, w = cv2_img.shape[:2]
base_img = np.zeros((h, w))
for mask in masks:
  h, w = mask.shape[-2:]
  mask_image = mask.reshape(h, w).astype(np.uint8)
  base_img += mask_image
print(base_img)

動画切り出し

In [None]:
# ２分から10秒だけ切り出す。
!ffmpeg -ss 00:02:00 -i test.mp4 -ss 0 -t 10 -c:v copy -c:a copy -async 1 -strict -2 output.mp4