In [None]:
from PIL import Image
import numpy as np
import os, sys
from pathlib import Path
root_path = Path.cwd().parents[1]
if str(root_path) not in sys.path:
  sys.path.append(str(root_path))
print(f"{root_path=}")

In [None]:
from katacv.yolov5.parser import get_args_and_writer
from katacv.yolov5.model import get_state
from katacv.utils.model_weights import load_weights

args = get_args_and_writer(no_writer=True, input_args="")
args.model_name = "YOLOv5_b32"
args.path_logs = root_path / "logs"
path_debug = args.path_logs / "debug"
# args.path_cp = args.path_logs / "YOLOv5-checkpoints"
args.path_cp = args.path_logs / "YOLOv5_b32-checkpoints"
args.load_id = 39
state_ok = load_weights(get_state(args, use_init=False), args)
args.load_id = 47
state_bad = load_weights(get_state(args, use_init=False), args)

In [None]:
print(state_ok)

In [None]:
%load_ext autoreload
%autoreload 2
from katacv.yolov5.predict_debug import Predictor
from katacv.utils.yolo.utils import show_box
predict_ok = Predictor(args, state_ok)
predict_bad = Predictor(args, state_bad)

In [None]:
print(state_ok)

In [None]:
# img = []
# for p in Path("/home/wty/Pictures/model_test/test_image/8examples").glob('*.jpg'):
#   img.append(np.array(Image.open(str(p)).resize((640,640)).convert("RGB")))
# img = np.array(img)
# print(f"{len(img)=}, {img.dtype=}")
# for i in range(len(img)):
#   x = img[i:i+1].astype(np.float32) / 255.
#   pbox = predict_bad.update(x, nms_iou=0.65, nms_conf=0.01)
#   show_box(x[0], pbox[0])

In [None]:
from katacv.utils.yolo.build_dataset import DatasetBuilder
# args.path_dataset = Path("/home/wty/Coding/datasets/coco")
args.path_dataset = Path("/home/yy/Coding/datasets/coco")
args.batch_size = 1
ds_builder = DatasetBuilder(args)
train_ds = ds_builder.get_dataset(subset='train', use_cache=False)
val_ds = ds_builder.get_dataset(subset='val', use_cache=False)

In [None]:
iter_ds = iter(train_ds)
x, tbox, tnum = next(iter_ds)
x, tbox, tnum = x.numpy().astype(np.float32) / 255., tbox.numpy(), tnum.numpy()
pbox = predict_ok.update(x, nms_conf=0.01)
show_box(x[0], pbox[0])
pbox = predict_bad.update(x, nms_conf=0.01)
show_box(x[0], pbox[0])

In [None]:
pbox = predict_bad.update(x, nms_conf=0.01)
show_box(x[0], pbox[0])

In [None]:
# Save current data
Image.fromarray((x[0]*255).astype(np.uint8)).save(path_debug / "origin.jpg")
with (path_debug / "sample_data.npy").open('wb') as file:
  np.save(
    file, {
      'x': x,
      'tbox': tbox,
      'tnum': tnum
    }, allow_pickle=True
  )

In [None]:
print("tbox number:", tnum[0])
show_box(x[0], tbox[0][:tnum[0]])

In [None]:
from katacv.yolov5.loss_debug import ComputeLoss
compute_loss = ComputeLoss(args)

In [None]:
_, metrics = compute_loss.step(state_ok, x, tbox, tnum)
for val, name in zip(metrics, ['loss', 'lbox', 'lobj', 'lcls', 'l2']):
  if name not in ['loss', 'l2']:
    val *= 16
  print(name+':', val)

In [None]:
_, metrics = compute_loss.step(state_bad, x, tbox, tnum)
for val, name in zip(metrics, ['loss', 'lbox', 'lobj', 'lcls', 'l2']):
  if name not in ['loss', 'l2']:
    val *= 16
  print(name+':', val)

In [None]:
print(compute_loss.weight_decay)

In [None]:
# from katacv.utils.related_pkgs.jax_flax_optax_orbax import *
# from katacv.yolov5.loss import cell2pixel
# @jax.jit
# def predict(state: train_state.TrainState, x: jnp.ndarray):
#   logits = state.apply_fn(
#     {'params': state.params, 'batch_stats': state.batch_stats},
#     x, train=False
#   )
#   y, batch_size = [], x.shape[0]
#   for i in range(3):
#     xy = (jax.nn.sigmoid(logits[i][...,:2]) - 0.5) * 2.0 + 0.5
#     xy = cell2pixel(xy, scale=2**(i+3))
#     wh = (jax.nn.sigmoid(logits[i][...,2:4]) * 2) ** 2 * args.anchors[i].reshape(1,3,1,1,2)
#     conf = jax.nn.sigmoid(logits[i][...,4:5])
#     cls = jax.nn.sigmoid(logits[i][...,5:])
#     conf = conf * jnp.max(cls, axis=-1, keepdims=True)
#     cls = jnp.argmax(cls, axis=-1, keepdims=True)
#     y.append(jnp.concatenate([xy,wh,conf,cls], -1).reshape(batch_size,-1,6))
#   y = jnp.concatenate(y, 1)  # shape=(batch_size,all_pbox_num,6)
#   return y
# p_ok = jax.device_get(predict(state_ok, x))
# p_bad = jax.device_get(predict(state_bad, x))

In [None]:
from tqdm import tqdm
# Test
from katacv.utils.yolo.build_dataset import DatasetBuilder
# args.path_dataset = Path("/home/wty/Coding/datasets/coco")
args.path_dataset = Path("/home/yy/Coding/datasets/coco")
args.batch_size = 32
ds_builder = DatasetBuilder(args)
train_ds = ds_builder.get_dataset(subset='train', use_cache=False)
val_ds = ds_builder.get_dataset(subset='val', use_cache=False)
predict_bad.reset()
for x, tbox, tnum in tqdm(val_ds):
  x, tbox, tnum = x.numpy().astype(np.float32) / 255.0, tbox.numpy(), tnum.numpy()
  predict_bad.update(x, tbox, tnum)
print(predict_bad.p_r_ap50_ap75_map())