In [1]:
%load_ext autoreload
%autoreload 2

import shutil

from torch.utils.data import DataLoader

from diveslowlearnfast.datasets import Diving48Dataset
from diveslowlearnfast.train import StatsDB

In [2]:
stats_db = StatsDB('../results/run11/stats.db')

In [11]:
stats_db.add_loss(0.1, 100, 'loss', 'abc', 'train')
stats_db.add_loss(0.025, 100, 'slow', 'abc', 'train')
stats_db.add_loss(0.025, 100, 'fast', 'abc', 'train')
stats_db.add_loss(0.05, 100, 'ce', 'abc', 'train')

In [12]:
stats_db.execute_query("""SELECT  value, tag, created_at FROM losses ORDER BY created_at""")

[[0.1, 'loss', 1740489007060],
 [0.05, 'slow', 1740489107161],
 [0.05, 'fast', 1740489107162],
 [0.05, 'ce', 1740489107163],
 [0.1, 'loss', 1740489123371],
 [0.025, 'slow', 1740489124273],
 [0.025, 'fast', 1740489124273],
 [0.05, 'ce', 1740489124274]]

In [8]:
query = """
WITH acc AS
(
    SELECT gt, (correct_n / n) as acc FROM
        (
        SELECT
            gt,
            epoch,
            CAST(SUM(CASE WHEN pred = gt THEN 1 ELSE 0 END) as REAL) as correct_n,
            CAST(COUNT(*) as REAL) as n
        FROM stats
        WHERE epoch > 0 AND epoch <= 100
        GROUP BY gt
        )
) SELECT * FROM acc
"""
stats_db.execute_query(query)


[[0, 0.0],
 [1, 0.0],
 [2, 0.0],
 [3, 0.0],
 [4, 0.125],
 [5, 0.0],
 [6, 0.0625],
 [7, 0.0],
 [8, 0.0],
 [9, 0.0625],
 [10, 0.0625],
 [11, 0.0],
 [12, 0.0],
 [13, 0.0],
 [14, 0.0],
 [15, 0.0625],
 [16, 0.0],
 [17, 0.0],
 [18, 0.0],
 [19, 0.0],
 [20, 0.0],
 [21, 0.0],
 [22, 0.0],
 [23, 0.0],
 [24, 0.0],
 [25, 0.0],
 [26, 0.0],
 [27, 0.0],
 [28, 0.0625],
 [29, 0.0],
 [31, 0.0],
 [32, 0.0],
 [33, 0.0],
 [34, 0.0],
 [35, 0.0],
 [36, 0.0],
 [37, 0.0],
 [38, 0.0625],
 [39, 0.0625],
 [40, 0.0],
 [41, 0.0],
 [42, 0.0],
 [43, 0.0],
 [44, 0.0625],
 [45, 0.0625],
 [46, 0.0],
 [47, 0.0625]]

In [4]:
difficult_classes_query = """
WITH
    acc AS (
        SELECT gt, (correct_n / n) as acc FROM(
            SELECT
                gt,
                epoch,
                CAST(SUM(CASE WHEN pred = gt THEN 1 ELSE 0 END) as REAL) as correct_n,
                CAST(COUNT(*) as REAL) as n
            FROM stats
            WHERE epoch == 99
            GROUP BY gt
        )
    ),
    median AS (
        SELECT AVG(acc) as median FROM(
            SELECT * FROM acc
            ORDER BY acc
            LIMIT 2 - (SELECT COUNT(*) FROM acc) % 2
            OFFSET (SELECT (COUNT(*) - 1) / 2 FROM acc)
        )
    )
SELECT *, (SELECT median FROM median) as median FROM acc
WHERE acc < median
ORDER BY acc
"""

stats_db.execute_query(difficult_classes_query)

[[21, 0.64, 0.73],
 [17, 0.6466666666666666, 0.73],
 [24, 0.65, 0.73],
 [44, 0.6666666666666666, 0.73],
 [26, 0.67, 0.73],
 [7, 0.6933333333333334, 0.73],
 [8, 0.6933333333333334, 0.73],
 [34, 0.71, 0.73],
 [12, 0.7133333333333334, 0.73],
 [19, 0.7266666666666667, 0.73]]

In [5]:
max_epoch = 50
data = [40]
if max_epoch == 'max_epoch':
    max_epoch = ''
else:
    data.append(max_epoch)
    max_epoch = 'AND epoch <= ?'
difficult_videos_query = f"""
WITH
    acc AS (
        SELECT video_id, gt, (correct_n / n) as acc FROM(
            SELECT
                video_id,
                gt,
                epoch,
                CAST(SUM(CASE WHEN pred = gt THEN 1 ELSE 0 END) as REAL) as correct_n,
                CAST(COUNT(*) as REAL) as n
            FROM stats
            WHERE epoch > ? {max_epoch}
            GROUP BY video_id, gt
        )
    ),
    median AS (
        SELECT AVG(acc) as median FROM(
            SELECT * FROM acc
            ORDER BY acc
            LIMIT 2 - (SELECT COUNT(*) FROM acc) % 2
            OFFSET (SELECT (COUNT(*) - 1) / 2 FROM acc)
        )
    )
SELECT *, (SELECT median FROM median) as median FROM acc
WHERE acc < median
ORDER BY acc
"""

data = tuple(data)
stats_db.execute_query(difficult_videos_query, data=data)

[['-mmq0PT-u8k_00000', 31, 0.0, 0.1],
 ['-mmq0PT-u8k_00007', 31, 0.0, 0.1],
 ['-mmq0PT-u8k_00027', 31, 0.0, 0.1],
 ['-mmq0PT-u8k_00039', 24, 0.0, 0.1],
 ['-mmq0PT-u8k_00061', 24, 0.0, 0.1],
 ['-mmq0PT-u8k_00132', 17, 0.0, 0.1],
 ['-mmq0PT-u8k_00133', 17, 0.0, 0.1],
 ['-mmq0PT-u8k_00134', 17, 0.0, 0.1],
 ['-mmq0PT-u8k_00138', 17, 0.0, 0.1],
 ['-mmq0PT-u8k_00139', 17, 0.0, 0.1],
 ['-mmq0PT-u8k_00143', 36, 0.0, 0.1],
 ['2x00lRzlTVQ_00000', 15, 0.0, 0.1],
 ['2x00lRzlTVQ_00001', 15, 0.0, 0.1],
 ['2x00lRzlTVQ_00002', 15, 0.0, 0.1],
 ['2x00lRzlTVQ_00011', 31, 0.0, 0.1],
 ['2x00lRzlTVQ_00013', 15, 0.0, 0.1],
 ['2x00lRzlTVQ_00018', 15, 0.0, 0.1],
 ['2x00lRzlTVQ_00019', 15, 0.0, 0.1],
 ['2x00lRzlTVQ_00021', 36, 0.0, 0.1],
 ['2x00lRzlTVQ_00032', 46, 0.0, 0.1],
 ['2x00lRzlTVQ_00033', 46, 0.0, 0.1],
 ['2x00lRzlTVQ_00037', 36, 0.0, 0.1],
 ['2x00lRzlTVQ_00040', 46, 0.0, 0.1],
 ['2x00lRzlTVQ_00041', 46, 0.0, 0.1],
 ['2x00lRzlTVQ_00042', 46, 0.0, 0.1],
 ['2x00lRzlTVQ_00046', 46, 0.0, 0.1],
 ['2x00lRzlT

In [6]:
stats_db.execute_query('SELECT run_id FROM stats LIMIT 10')

[['/home/s2871513/Projects/diveslowlearnfast/results/run18'],
 ['/home/s2871513/Projects/diveslowlearnfast/results/run18'],
 ['/home/s2871513/Projects/diveslowlearnfast/results/run18'],
 ['/home/s2871513/Projects/diveslowlearnfast/results/run18'],
 ['/home/s2871513/Projects/diveslowlearnfast/results/run18'],
 ['/home/s2871513/Projects/diveslowlearnfast/results/run18'],
 ['/home/s2871513/Projects/diveslowlearnfast/results/run18'],
 ['/home/s2871513/Projects/diveslowlearnfast/results/run18'],
 ['/home/s2871513/Projects/diveslowlearnfast/results/run18'],
 ['/home/s2871513/Projects/diveslowlearnfast/results/run18']]

In [4]:
a = stats_db.get_below_median_samples(
    epoch_start=0,
    epoch_end='max_epoch',
    run_id='results/run11',
    split='train'
)
len(a)

176

In [10]:
difficult_samples = list(map(lambda x: x[0], a))

In [7]:
query = """
WITH
    acc AS (
        SELECT video_id, gt, (correct_n / n) as acc FROM(
            SELECT
                video_id,
                gt,
                epoch,
                CAST(SUM(CASE WHEN pred = gt THEN 1 ELSE 0 END) as REAL) as correct_n,
                CAST(COUNT(*) as REAL) as n
            FROM stats
            WHERE epoch > ?
            AND run_id = ?
            AND split = ?
            GROUP BY video_id, gt
        )
    ),
    median AS (
        SELECT AVG(acc) as median FROM(
            SELECT * FROM acc
            ORDER BY acc
            LIMIT 2 - (SELECT COUNT(*) FROM acc) % 2
            OFFSET (SELECT (COUNT(*) - 1) / 2 FROM acc)
        )
    )
SELECT *, (SELECT median FROM median) as median FROM acc
WHERE acc <= median
ORDER BY acc
"""
stats_db.execute_query(query, (0, 'results/run11', 'train'))

[['-mmq0PT-u8k_00043', 24, 0.0, 0.0],
 ['3N1kUtqJ25A_00073', 44, 0.0, 0.0],
 ['3N1kUtqJ25A_00131', 34, 0.0, 0.0],
 ['3PLiUG_DuC8_00051', 21, 0.0, 0.0],
 ['3PLiUG_DuC8_00202', 22, 0.0, 0.0],
 ['5V-dKBtmKLI_00030', 5, 0.0, 0.0],
 ['5V-dKBtmKLI_00132', 28, 0.0, 0.0],
 ['5i1begTTucc_00072', 2, 0.0, 0.0],
 ['9BC6ssCjyfg_00134', 45, 0.0, 0.0],
 ['9BC6ssCjyfg_00181', 36, 0.0, 0.0],
 ['9BC6ssCjyfg_00213', 7, 0.0, 0.0],
 ['9jZYYtzYqwE_00046', 24, 0.0, 0.0],
 ['Bb0ZiYVNtDs_00094', 23, 0.0, 0.0],
 ['Bb0ZiYVNtDs_00095', 23, 0.0, 0.0],
 ['D6zILEKIJbk_00056', 12, 0.0, 0.0],
 ['D6zILEKIJbk_00105', 35, 0.0, 0.0],
 ['D8YKHC5hmUs_00096', 0, 0.0, 0.0],
 ['D8YKHC5hmUs_00164', 32, 0.0, 0.0],
 ['D8YKHC5hmUs_00173', 24, 0.0, 0.0],
 ['D8YKHC5hmUs_00218', 4, 0.0, 0.0],
 ['D8YKHC5hmUs_00224', 25, 0.0, 0.0],
 ['D8YKHC5hmUs_00254', 32, 0.0, 0.0],
 ['DB4lpBDPnTY_00043', 1, 0.0, 0.0],
 ['DB4lpBDPnTY_00046', 21, 0.0, 0.0],
 ['JzOshOJgofw_00042', 47, 0.0, 0.0],
 ['JzOshOJgofw_00163', 18, 0.0, 0.0],
 ['LNMISxO35S0_000