In [2]:
from judo_footage_analysis.utils import get_spark

root = "/cs-share/pradalier/tmp/judo"

spark = get_spark(cores=4, mem="1g")
spark

24/04/05 15:38:38 WARN Utils: Your hostname, gtlpc108.georgiatech-metz.fr resolves to a loopback address: 127.0.1.1; using 192.93.8.108 instead (on interface enp0s31f6)
24/04/05 15:38:38 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/05 15:38:39 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/04/05 15:38:39 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [3]:
from pyspark.sql import functions as F

df = (
    spark.read.json(
        f"{root}/models/evaluation_embeddings_logistic_binary/v3/*/*/perf/*.json"
    )
    .withColumn("filename", F.input_file_name())
    .withColumn("feature", F.split(F.col("filename"), "/").getItem(10))
)
df.show()

+--------------------+-----------+-----------+--------------------+-------------------+-------------+---------+------------------+--------------+----------+------------------+--------------------+--------------------+
|         avg_metrics|      label|metric_name|         std_metrics|        test_metric|test_positive|test_size|         test_time|train_positive|train_size|        train_time|            filename|             feature|
+--------------------+-----------+-----------+--------------------+-------------------+-------------+---------+------------------+--------------+----------+------------------+--------------------+--------------------+
|[0.7217889872996653]|  is_active|         f1|[0.7217889872996653]|0.28322548531607766|           45|       90|1.2950853689981159|           238|       480|27.456793420002214|file:///cs-share/...|emb_entity_detect...|
|[0.7118405659227064]|is_standing|         f1|[0.7118405659227064]| 0.9037475345167653|            3|       90|1.213938665998284

In [19]:
df.select(
    "label",
    # get lag{n} from feature
    F.coalesce(
        F.regexp_extract(F.col("feature"), r"lag(\d+)", 1).cast("integer"), F.lit(0)
    ).alias("lag"),
    "train_size",
    "train_positive",
    "test_size",
    "test_positive",
    F.round(F.expr("train_positive / train_size"), 2).alias("train_rate"),
    F.round(F.expr("test_positive / test_size"), 2).alias("test_rate"),
).distinct().orderBy("label", "lag").show()

+-----------+---+----------+--------------+---------+-------------+----------+---------+
|      label|lag|train_size|train_positive|test_size|test_positive|train_rate|test_rate|
+-----------+---+----------+--------------+---------+-------------+----------+---------+
|  is_active|  0|       480|           238|       90|           45|       0.5|      0.5|
|  is_active|  1|       464|           232|       87|           43|       0.5|     0.49|
|  is_active|  2|       432|           219|       81|           39|      0.51|     0.48|
|  is_active|  3|       384|           200|       72|           34|      0.52|     0.47|
|  is_active|  5|       240|           130|       45|           23|      0.54|     0.51|
|   is_match|  0|       480|           378|       90|           60|      0.79|     0.67|
|   is_match|  1|       464|           367|       87|           58|      0.79|     0.67|
|   is_match|  2|       432|           341|       81|           54|      0.79|     0.67|
|   is_match|  3|    

In [20]:
clean = df.select(
    "label",
    "feature",
    F.round(F.col("avg_metrics")[0], 3).alias("avg_train_f1"),
    F.round("test_metric", 3).alias("test_f1"),
    F.col("train_time").cast("integer").alias("train_time"),
).orderBy("label", F.desc("test_f1"))

clean.show(truncate=False, n=100)

+-----------+------------------------------------+------------+-------+----------+
|label      |feature                             |avg_train_f1|test_f1|train_time|
+-----------+------------------------------------+------------+-------+----------+
|is_active  |emb_entity_detection_v2_lag1        |0.884       |0.592  |137       |
|is_active  |emb_entity_detection_v2             |0.843       |0.575  |88        |
|is_active  |emb_vanilla_yolov8n                 |0.843       |0.575  |82        |
|is_active  |emb_entity_detection_v2_dct_d64     |0.796       |0.486  |95        |
|is_active  |emb_entity_detection_v2_dct_d32     |0.781       |0.418  |85        |
|is_active  |emb_entity_detection_v2_dct_d32_lag2|0.731       |0.391  |25        |
|is_active  |emb_entity_detection_v2_dct_d16_lag1|0.721       |0.359  |23        |
|is_active  |emb_entity_detection_v2_dct_d16     |0.757       |0.357  |66        |
|is_active  |emb_entity_detection_v2_dct_d32_lag1|0.785       |0.346  |22        |
|is_

In [24]:
clean.where(F.col("feature").contains("dct_d16")).show(truncate=False)

+-----------+------------------------------------+------------+-------+----------+
|label      |feature                             |avg_train_f1|test_f1|train_time|
+-----------+------------------------------------+------------+-------+----------+
|is_active  |emb_entity_detection_v2_dct_d16_lag1|0.721       |0.359  |23        |
|is_active  |emb_entity_detection_v2_dct_d16     |0.757       |0.357  |66        |
|is_active  |emb_entity_detection_v2_dct_d16_lag5|0.795       |0.319  |18        |
|is_active  |emb_entity_detection_v2_dct_d16_lag3|0.713       |0.293  |26        |
|is_active  |emb_entity_detection_v2_dct_d16_lag2|0.722       |0.27   |26        |
|is_match   |emb_entity_detection_v2_dct_d16_lag2|0.877       |0.542  |39        |
|is_match   |emb_entity_detection_v2_dct_d16_lag1|0.905       |0.511  |37        |
|is_match   |emb_entity_detection_v2_dct_d16     |0.895       |0.506  |78        |
|is_match   |emb_entity_detection_v2_dct_d16_lag3|0.861       |0.484  |20        |
|is_