In [1]:
from judo_footage_analysis.utils import get_spark

root = "/cs-share/pradalier/tmp/judo"

spark = get_spark(cores=4, mem="1g")
spark

24/04/15 14:31:26 WARN Utils: Your hostname, gtlpc108.georgiatech-metz.fr resolves to a loopback address: 127.0.1.1; using 192.93.8.108 instead (on interface enp0s31f6)
24/04/15 14:31:26 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/15 14:31:27 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/04/15 14:31:27 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [2]:
from pyspark.sql import functions as F, Window

df = (
    spark.read.json(
        f"{root}/models/evaluation_embeddings_logistic_binary/v4/*/*/perf/*.json"
    )
    .withColumn("filename", F.input_file_name())
    .withColumn("feature", F.split(F.col("filename"), "/").getItem(10))
)
df.show()

                                                                                

+--------------------+-----------+-----------+--------------------+-------------------+-------------+---------+-------------------+--------------+----------+------------------+--------------------+--------------------+
|         avg_metrics|      label|metric_name|         std_metrics|        test_metric|test_positive|test_size|          test_time|train_positive|train_size|        train_time|            filename|             feature|
+--------------------+-----------+-----------+--------------------+-------------------+-------------+---------+-------------------+--------------+----------+------------------+--------------------+--------------------+
|[0.7891663451795221]|  is_active|         f1|[0.7891663451795221]|0.37500000000000006|           45|       90|0.47534132993314415|           238|       480|12.309672495000996|file:///cs-share/...|emb_entity_detect...|
|[0.8351040324625366]|is_standing|         f1|[0.8351040324625366]| 0.6602469135802469|            3|       90|0.46770033007

In [13]:
clean = df.select(
    F.regexp_replace("label", "_", " ").alias("label"),
    F.regexp_replace(
        F.regexp_replace(
            F.regexp_replace("feature", "emb_entity_detection", "fine_tune"),
            "emb_vanilla",
            "vanilla",
        ),
        "_",
        " ",
    ).alias("feature"),
    F.round(F.col("avg_metrics")[0], 3).alias("avg_train_f1"),
    F.round("test_metric", 3).alias("test_f1"),
    F.col("train_time").cast("integer").alias("train_time"),
).orderBy("label", F.desc("test_f1"))

clean.where(~F.col("feature").contains("lag")).show(truncate=False, n=100)

+-----------+--------------------+------------+-------+----------+
|label      |feature             |avg_train_f1|test_f1|train_time|
+-----------+--------------------+------------+-------+----------+
|is active  |fine tune v3 dct d8 |0.611       |0.783  |7         |
|is active  |fine tune v2 dct d32|0.81        |0.576  |19        |
|is active  |fine tune v2        |0.843       |0.575  |100       |
|is active  |vanilla yolov8n     |0.843       |0.575  |93        |
|is active  |fine tune v2 dctn   |0.844       |0.533  |44        |
|is active  |fine tune v3        |0.817       |0.516  |23        |
|is active  |fine tune v3 dct d32|0.705       |0.504  |20        |
|is active  |fine tune v3 dct d16|0.672       |0.483  |14        |
|is active  |fine tune v2 dct d64|0.817       |0.479  |19        |
|is active  |fine tune v3 dctn   |0.755       |0.475  |53        |
|is active  |fine tune v2 dct d8 |0.743       |0.432  |8         |
|is active  |fine tune v1        |0.851       |0.416  |20     

In [14]:
import pandas as pd

res = (
    clean.where(~F.col("feature").contains("lag"))
    .withColumn(
        "rank",
        F.row_number().over(Window.partitionBy("label").orderBy(F.desc("test_f1"))),
    )
    .where(F.col("rank") <= 5)
    .drop("rank", "train_time")
)

res.show(truncate=False)
pd.set_option("display.precision", 2)
print(res.toPandas().round(2).to_latex(index=False))

+-----------+--------------------+------------+-------+
|label      |feature             |avg_train_f1|test_f1|
+-----------+--------------------+------------+-------+
|is active  |fine tune v3 dct d8 |0.611       |0.783  |
|is active  |fine tune v2 dct d32|0.81        |0.576  |
|is active  |fine tune v2        |0.843       |0.575  |
|is active  |vanilla yolov8n     |0.843       |0.575  |
|is active  |fine tune v2 dctn   |0.844       |0.533  |
|is match   |fine tune v3 dctn   |0.959       |0.658  |
|is match   |fine tune v2 dct d64|0.908       |0.626  |
|is match   |fine tune v1        |0.964       |0.613  |
|is match   |fine tune v3 dct d16|0.884       |0.611  |
|is match   |fine tune v2        |0.954       |0.582  |
|is standing|fine tune v3 dctn   |0.853       |0.866  |
|is standing|fine tune v2 dct d8 |0.761       |0.853  |
|is standing|fine tune v3 dct d8 |0.749       |0.818  |
|is standing|fine tune v3 dct d32|0.83        |0.758  |
|is standing|fine tune v2 dctn   |0.851       |0

In [19]:
res = (
    clean.where(F.col("feature").contains("dct d16"))
    .drop("train_time")
    .withColumn(
        "rank",
        F.row_number().over(Window.partitionBy("label").orderBy(F.desc("test_f1"))),
    )
    .where(F.col("rank") <= 5)
    .drop("rank")
)
res.show(truncate=False)
print(res.toPandas().round(2).to_latex(index=False))

+-----------+-------------------------+------------+-------+
|label      |feature                  |avg_train_f1|test_f1|
+-----------+-------------------------+------------+-------+
|is active  |fine tune v3 dct d16 lag1|0.686       |0.515  |
|is active  |fine tune v3 dct d16     |0.672       |0.483  |
|is active  |fine tune v2 dct d16 lag1|0.768       |0.408  |
|is active  |fine tune v2 dct d16 lag2|0.767       |0.383  |
|is active  |fine tune v3 dct d16 lag2|0.694       |0.381  |
|is match   |fine tune v3 dct d16 lag1|0.904       |0.621  |
|is match   |fine tune v3 dct d16     |0.884       |0.611  |
|is match   |fine tune v2 dct d16     |0.915       |0.558  |
|is match   |fine tune v3 dct d16 lag2|0.915       |0.557  |
|is match   |fine tune v2 dct d16 lag3|0.939       |0.533  |
|is standing|fine tune v3 dct d16 lag2|0.747       |0.745  |
|is standing|fine tune v3 dct d16 lag3|0.79        |0.737  |
|is standing|fine tune v3 dct d16 lag1|0.731       |0.731  |
|is standing|fine tune v