In [1]:
import os, sys
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), 'src'))

In [2]:
from types import SimpleNamespace

from rps.augment import Augmenter
from rps.dataset import RPSDataset
from rps.utils.data import load_yaml, get_image_paths, make_class_map, train_val_split


cfg_dict = load_yaml(path="../configs/config.yaml")
cfg = SimpleNamespace(**cfg_dict)

img_paths, class_map = get_image_paths(cfg.data_dir), make_class_map(cfg.data_dir)
train_img_paths, val_img_paths = train_val_split(img_paths, val_ratio=cfg.val_ratio, seed=cfg.seed)

train_dataset = RPSDataset(
    img_paths=train_img_paths, 
    class_map=class_map, 
    transform=Augmenter(train=True, image_size=cfg.image_size), 
)
val_dataset = RPSDataset(
    img_paths=val_img_paths,
    class_map=class_map,
    transform=Augmenter(train=False, image_size=cfg.image_size), 
)

In [None]:
from rps.utils.visualize import show_image_grid

show_image_grid(train_dataset, class_map, title='Train dataset images')
show_image_grid(val_dataset, class_map, title='Validation dataset images')

In [None]:
from rps.inference import RPSInference
from rps.utils.capture import crop_square


ONNX_PATH = "../runs/run_1/best_model.onnx"
model = RPSInference(model_path=ONNX_PATH, class_map=class_map)

In [None]:
import random
from PIL import Image
import matplotlib.pyplot as plt
from rps.utils.data import read_image


img_path = random.choice(val_img_paths)
image = read_image(img_path)

plt.imshow(image)
plt.title("Random Validation Image")
plt.axis('off')
plt.show()

prediction = model.predict(crop_square(image))
print(f"Prediction: {prediction}")