In [None]:
from src.rnn.architecture import RnnArch
from src.rnn.model import Rnn, RnnConstructorArgs, RnnModelInitializeArgs, RnnTrainArgs,\
    RnnMultiRunTrainArgs, RnnTestArgs
from src.common.helpers import read_dataframe
from src.rnn.data import WindowGenerator
from src.rnn.kfold import ExtendedStratifiedGroupKFold, iterate_group_splits

input_width = 30
spacing = 1
name = "arch1-test"
arch = RnnArch.ARCH6
balanced = True
augmented = False


In [None]:
df = read_dataframe("data/df/rnn/cvs_features.pkl")
splitter = ExtendedStratifiedGroupKFold()

splits = list(iterate_group_splits(df, splitter))
train, val, test = splits[0]

wg = WindowGenerator(df, train, val, test, input_width, spacing)
print(wg)

In [None]:
model = Rnn(
    args=RnnConstructorArgs(
        name=name,
        model_initialize_args=RnnModelInitializeArgs(
            model_arch=arch,
            input_width=input_width,
            spacing=spacing
        )
    )
)
#model.initialize_model()
model.execute_train_runs(args=RnnMultiRunTrainArgs(
    train_args=RnnTrainArgs(
        window_generator=wg,
        epochs=5,
        balanced=balanced,
        augmented=augmented
    ),
    runs=2
))

model.test_model(args=RnnTestArgs(
    window_generator=wg,
    write_to_wandb=True
))

In [None]:
model.model.summary()

In [None]:
from src.hpe.common.landmarks import YoloLabels, MyLandmark

def count_landmarks(labels: YoloLabels) -> int:
    
    def count_landmark(landmark: MyLandmark) -> int:
        label = labels.get_keypoint(landmark)
        return 0 if label.is_missing() else 1
    
    return sum(list(map(count_landmark, MyLandmark)))

In [None]:
from os import listdir
from os.path import join

from src.hpe.common.landmarks import build_yolo_labels, get_most_central

root_label_dir = "data/hpe/img/test/labels"
totals = 0
    
for label_name in listdir(root_label_dir):
    label_path = join(root_label_dir, label_name)
    df = build_yolo_labels(label_path)
    
    if len(df) == 0:
        continue
    elif len(df) == 1:
        totals += count_landmarks(df[0])
    else:
        totals += count_landmarks(get_most_central(df))



In [None]:
totals / 41