In [172]:
from dataset.rehab_dataset import RehabHands
from spock_dataclasses import *
import albumentations as A
import pandas as pd
from tqdm import tqdm

In [173]:
cfg = RehabCfg(
    seed_num=42,
    batch_size=8,
    num_workers=32,
    max_epochs=500,
    lr=0.01,
    device=7,
    num_frames=32,
    use_scheduler=True,
    scheduler_milestones=[10, 15],
    img_size=[256, 256],
    model_type='SlowFast',
    load_checkpoint=False,
    checkpoint_name='flowing-aardvark-15',
    optimizer_type='AdamW',
    early_stopping=25,
    cross_subject=False,
    subject='Subject_8',
    task='repetition_counting',
    project_name='Exercise_Repetition_Counting',
    bacbone_freeze=True,
)
transform = A.ReplayCompose([
    A.Resize(
        cfg.img_size[0], cfg.img_size[1]),
    # A.Normalize(mean=MEAN, std=STD),
    # ToTensorV2(),
])
test_dataset = RehabHands(cfg, 'test', transform)

In [174]:
# Open test_paths_pick2.csv into dataframe

to_open_name_val = f'val_pick_predicted_{cfg.subject}.csv'
to_open_name_test = f'test_pick_predicted_{cfg.subject}.csv'

df_gt = pd.read_csv(to_open_name_val)
df_gt2 = pd.read_csv(to_open_name_test)
df_gt = pd.read_csv('test_pick_predicted.csv')
df_gt2 = pd.read_csv('val_pick_predicted.csv')
df_gt3 = pd.read_csv('train_pick_predicted.csv')

df_gt = pd.concat([df_gt, df_gt2, df_gt3], ignore_index=True)

In [175]:
# get list of paths from df_gt and list of labels
paths = df_gt['path'].tolist()
labels = df_gt['label'].tolist()

# Make dictionary where each path is a key and each label is a value
path_label_dict = dict(zip(paths, labels))

In [176]:
# make filter function which loops over the binary list
def filter_first(x):
    if x[0] == 0 and x[1] == 1:
        x[0] = 1
    return x


def filter_function(x):
    #     Loop over the binary list. If the value is equal to 0 but next value and previous value is 1 then change the value to 1
    for i in range(len(x)):
        if x[i] == 0 and i > 0 and i < len(x) - 1:
            if x[i+1] == 1 and x[i-1] == 1:
                x[i] = 1

    return x


def filter_function2(x):
    #     Loop over the binary list. If the value is equal to 0 but next value and previous value is 1 then change the value to 1
    for i in range(len(x)):
        if x[i] == 0 and i > 0 and i < len(x) - 1:
            if x[i+1] == 0 and x[i-1] == 1 and x[i+2] == 1:
                x[i] = 1
                x[i+1] = 1

    return x


def filter_function_0_n(x, N):
    # Loop over the binary list
    for i in range(len(x)):
        # Only proceed if the current value is 0 and there's enough room to check N elements ahead
        if x[i] == 0 and i > 0 and i + N < len(x):
            # Check if the sequence of up to N 0s is surrounded by 1s
            if all(x[i+j] == 0 for j in range(1, N)) and x[i-1] == 1 and x[i+N] == 1:
                # Replace the sequence of 0s with 1s
                for j in range(N):
                    x[i+j] = 1
    return x


def filter_function_1_n(x, N):
    # Loop over the binary list
    for i in range(len(x)):
        # Only proceed if the current value is 0 and there's enough room to check N elements ahead
        if x[i] == 1 and i > 0 and i + N < len(x):
            # Check if the sequence of up to N 0s is surrounded by 1s
            if all(x[i+j] == 1 for j in range(1, N)) and x[i-1] == 0 and x[i+N] == 0:
                # Replace the sequence of 0s with 1s
                for j in range(N):
                    x[i+j] = 0
    return x

In [180]:
# Make function which take a list of binary values and count the number of changes from 0 to 1, if it starts with 1 it will count that as 1 change
def count_changes(binary_list):
    count = 0

   
    binary_list = filter_first(binary_list)
    binary_list = filter_function_1_n(binary_list, 1)
    binary_list = filter_function_1_n(binary_list, 2)
    binary_list = filter_function_1_n(binary_list, 3)
    binary_list = filter_function_1_n(binary_list, 4)
    binary_list = filter_function_1_n(binary_list, 5)

    binary_list = filter_function_0_n(binary_list, 1)
    binary_list = filter_function_0_n(binary_list, 2)
    binary_list = filter_function_0_n(binary_list, 3)
    binary_list = filter_function_0_n(binary_list, 4)
    binary_list = filter_function_0_n(binary_list, 5)
   
    for i in range(len(binary_list)-1):
        if binary_list[i] != binary_list[i-1] and binary_list[i] < binary_list[i-1]:
           
            count += 1
    return count

In [182]:
total = 0
correct = 0
diff_dict = {}
for data in tqdm(test_dataset):
  
    reps = data['repetition_count']
    all_paths = data['all_paths']
    action_class = data['action_label']

  
    if True:
        

        picks = []
        for path in all_paths:
            picks.append(path_label_dict[path])
       
        idxs = []
        for i in range(len(picks)):
            idxs.append(i)
      
        counted = count_changes(picks)
        if counted == reps:
            correct += 1
      
        total += 1

        diff = abs(reps - counted)

        if diff in diff_dict:
            diff_dict[diff] += 1
        else:
            diff_dict[diff] = 1



print('Total: ', total)
print('Correct: ', correct)
print('Accuracy: ', correct/total)


total_diff = sum(diff_dict.values())
e0 = (diff_dict[0] if 0 in diff_dict else 0) / total_diff * 100
e1 = (diff_dict[1] if 1 in diff_dict else 0) / total_diff * 100
e2 = (diff_dict[2] if 2 in diff_dict else 0) / total_diff * 100

# Sum up all values in self.diff_dict for keys larger than 2
sum_diff = 0
for key in diff_dict.keys():
    if key > 2:
        sum_diff += diff_dict[key]
e_gt_2 = sum_diff / total_diff * 100

mae = 0

for key in diff_dict.keys():
    mae += key * diff_dict[key]

mae = round((mae / total_diff), 2)

# round e0 to float numbers
e0 = round(e0, 2)
e1 = round(e1, 2)
e2 = round(e2, 2)
e_gt_2 = round(e_gt_2, 2)
print(f'e=0: {e0}, e=1: {e1}, e=2: {
    e2}, e >= 2: {e_gt_2}, MAE: {mae}')

100%|██████████| 180/180 [03:20<00:00,  1.11s/it]

Total:  180
Correct:  116
Accuracy:  0.6444444444444445
e=0: 64.44, e=1: 13.89, e=2: 6.67, e >= 2: 15.0, MAE: 1.43



