In [9]:
%%time

from functools import partial
from collections import defaultdict, namedtuple
import numpy as np
import pandas as pd
import scipy
from sklearn.model_selection import KFold
from tqdm.notebook import tqdm
from itertools import product
import time

from video699.screen.semantic_segmentation.fastai_detector import *
from video699.screen.semantic_segmentation.common import *
from video699.screen.semantic_segmentation.postprocessing import *
from video699.screen.semantic_segmentation.evaluation import *

CPU times: user 4 ms, sys: 0 ns, total: 4 ms
Wall time: 203 µs


In [10]:
detector = FastAIScreenDetector()
method_params = list(detector.methods.keys())
train_params = list(detector.train_params.keys())
all_params = train_params + method_params

base = [True]
base_lower_bounds = [7]
base_upper_bounds = [70]
base_factors = [[0.1, 0.01]]

erode_dilate = [True]
erode_dilate_lower_bounds = [7]
erode_dilate_upper_bounds = [70]
erode_dilate_factors = [[0.1, 0.01]]
erode_dilate_iterations = [40, 100]

ratio_split = [True]
ratio_split_lower_bounds = [0.7, 0.9]
ratio_split_upper_bounds = [1.5]

methods_values = [base] + [erode_dilate] + [ratio_split] + [base_lower_bounds] + [base_upper_bounds] \
        + [base_factors] + [erode_dilate_lower_bounds] + [erode_dilate_upper_bounds] + [erode_dilate_factors] \
        + [erode_dilate_iterations] + [ratio_split_lower_bounds] + [ratio_split_upper_bounds]

batch_size = [8]
resize_factor = [2]
frozen_epochs = [2, 6, 9]
unfrozen_epochs = [3, 7, 10]
frozen_lr = [1e-3, 1e-4]
unfrozen_lr = [slice(1e-4, 2e-4)]

train_params_values = [batch_size] + [resize_factor] + [frozen_epochs] + [unfrozen_epochs] + [frozen_lr] + [unfrozen_lr]

In [11]:
method_settings = list(product(*methods_values))
train_settings = list(product(*train_params_values))
all_lectures = [video.filename for video in ALL_VIDEOS]
all_frames = [frame for video in ALL_VIDEOS for frame in video]
all_frames_grouped_by_videos = {video.filename: [frame for frame in video] for video in ALL_VIDEOS}
test_lectures = ['PB069-D2-20140305.mp4']
test_frames = [frame for lecture in test_lectures for frame in all_frames_grouped_by_videos[lecture]]
actual_detector = AnnotatedSampledVideoScreenDetector()

In [12]:
Split = namedtuple('Split', ['train', 'valid'])
kf = KFold(n_splits=5, shuffle=True, random_state=123)
splits = {}
for j, split in enumerate(kf.split(all_lectures)):    
    train_lectures = [all_lectures[index] for index in split[0]]
    valid_lectures = [all_lectures[index] for index in split[1]]
    valid_frames = [frame for lecture in valid_lectures for frame in all_frames_grouped_by_videos[lecture]]
    splits[j] = Split(train=train_lectures, valid=valid_lectures)

In [None]:
%%time
# Model selection
method_settings = list(product(*methods_values))
train_settings = list(product(*train_params_values))
df_all = pd.DataFrame(columns=all_params + ['iou', 'wrong_count', 'kfold_split'])

for train_setting in tqdm(train_settings):
    train_params_dict = dict(zip(train_params, train_setting))
    for i in splits.keys():
        train_lectures = splits[i].train
        valid_lectures = splits[i].valid
        valid_frames = [frame for lecture in valid_lectures for frame in all_frames_grouped_by_videos[lecture]]

        filtered_by = lambda name: any([lecture in str(name) for lecture in train_lectures + valid_lectures])  \
                        and 'frame' in str(name)
        split_by = lambda name: any([lecture in str(name) for lecture in valid_lectures])
        
        detector = FastAIScreenDetector(train_params=train_params_dict, methods=None, filtered_by=filtered_by,
                                    valid_func=split_by, device='cuda')
    
        detector.train()
        
        actuals = [actual_detector.detect(frame) for frame in valid_frames]
        sem_preds = detector.semantic_segmentation_batch(valid_frames)
        
        print(f"Iterating through {len(method_settings)} methods in split {i}.")    
        for i, method_setting in enumerate(method_settings):    
            preds = detector.post_processing_batch(sem_preds, valid_frames, dict(zip(method_params, method_setting)))
            wrong_count, ious, _ = evaluate(actuals, preds)
            
            iou_score = np.nanmean(ious)
            wrong_count = len(wrong_count)
            df_all.loc[len(df_all)] = train_setting + method_setting + (iou_score, wrong_count, i)

  0%|          | 0/18 [00:00<?, ?it/s]

epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.211407,0.086621,0.983923,0.961771,0.936068,00:31
1,0.148264,0.080905,0.979986,0.955173,0.924167,00:34


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.0866,0.036479,0.990924,0.974885,0.960532,00:34
1,0.068892,0.024189,0.992091,0.977909,0.966236,00:34
2,0.05125,0.020661,0.992514,0.977682,0.965821,00:34


Iterating through 4 methods in split 0.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.36537,0.119078,0.966666,0.949328,0.910215,00:37
1,0.194564,0.101981,0.970766,0.95544,0.920599,00:35


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.085014,0.106929,0.974942,0.962017,0.932093,00:34
1,0.056184,0.120329,0.97709,0.964966,0.936775,00:29
2,0.036207,0.102159,0.979539,0.96915,0.943131,00:31


Iterating through 4 methods in split 1.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.234823,0.208547,0.899154,0.854123,0.763105,00:35
1,0.161659,0.176045,0.911114,0.86655,0.781339,00:35


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.098864,0.111806,0.951535,0.921805,0.862089,00:24
1,0.068176,0.06529,0.977381,0.962092,0.927939,00:35
2,0.052098,0.050593,0.982813,0.971357,0.945505,00:36


Iterating through 4 methods in split 2.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.224163,0.21522,0.910513,0.784442,0.675071,00:34
1,0.145221,0.157777,0.931271,0.814783,0.721708,00:24


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.084081,0.072765,0.972372,0.926584,0.87994,00:34
1,0.056449,0.040046,0.98338,0.967537,0.938843,00:36
2,0.043531,0.037235,0.98577,0.969463,0.944124,00:36


Iterating through 4 methods in split 3.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.279712,0.141419,0.950328,0.912075,0.858801,00:25
1,0.174126,0.143575,0.956999,0.92393,0.877563,00:33


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.092077,0.084875,0.977943,0.963991,0.938523,00:36
1,0.064768,0.083404,0.977731,0.963789,0.937051,00:37
2,0.048396,0.074836,0.97947,0.966167,0.941873,00:36


Iterating through 4 methods in split 4.


  6%|▌         | 1/18 [16:20<4:37:51, 980.70s/it]

epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.367051,0.111482,0.966836,0.931437,0.88545,00:35
1,0.183671,0.059455,0.982982,0.960431,0.933319,00:36


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.099978,0.03955,0.985123,0.963252,0.938841,00:35
1,0.069814,0.081718,0.96422,0.923132,0.879791,00:34
2,0.049313,0.065888,0.970071,0.9323,0.892387,00:23


Iterating through 4 methods in split 0.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.330892,0.080468,0.971694,0.957261,0.921172,00:34
1,0.158082,0.073792,0.9751,0.962954,0.930407,00:36


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.1066,0.103649,0.965769,0.948187,0.90921,00:36
1,0.065341,0.072424,0.978791,0.968642,0.941129,00:35
2,0.043673,0.069538,0.980019,0.970452,0.943617,00:35


Iterating through 4 methods in split 1.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.217588,0.050906,0.982754,0.97073,0.946416,00:32
1,0.133282,0.057945,0.979197,0.965152,0.933717,00:35


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.08128,0.029799,0.987939,0.980126,0.961359,00:36
1,0.070545,0.031768,0.988454,0.981101,0.963164,00:36
2,0.055424,0.043914,0.98847,0.980756,0.962392,00:36


Iterating through 4 methods in split 2.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.196091,0.082716,0.972919,0.929939,0.883134,00:37
1,0.10913,0.054096,0.97947,0.943049,0.905737,00:29


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.107513,0.062877,0.976974,0.937694,0.896561,00:29
1,0.081625,0.051251,0.983236,0.964803,0.935188,00:35
2,0.056299,0.040632,0.986073,0.971087,0.946357,00:35


Iterating through 4 methods in split 3.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.272145,0.148448,0.957446,0.928688,0.879301,00:34
1,0.13208,0.090647,0.97293,0.956966,0.925,00:34


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.086217,0.156031,0.960982,0.93197,0.888801,00:36
1,0.055968,0.112545,0.980678,0.968451,0.945878,00:30
2,0.041755,0.111891,0.979927,0.967058,0.943318,00:25


Iterating through 4 methods in split 4.


 11%|█         | 2/18 [32:44<4:21:47, 981.70s/it]

epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.309176,0.147955,0.968817,0.935812,0.889996,00:33
1,0.192883,0.15069,0.967709,0.93255,0.88385,00:34


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.103959,0.066952,0.969969,0.931801,0.889624,00:34
1,0.083892,0.040943,0.991575,0.975556,0.961886,00:33
2,0.078729,0.02849,0.988557,0.973137,0.957478,00:35
3,0.062176,0.036805,0.984327,0.958811,0.931975,00:33
4,0.046684,0.027801,0.987037,0.964308,0.941477,00:22
5,0.034862,0.020509,0.99077,0.971846,0.955047,00:34
6,0.029226,0.017959,0.992374,0.975785,0.962272,00:34


Iterating through 4 methods in split 0.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.307615,0.123104,0.968666,0.952548,0.913787,00:34
1,0.189219,0.093392,0.971121,0.956658,0.921692,00:34


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.109034,0.10256,0.972752,0.958822,0.926147,00:34
1,0.072675,0.122512,0.972895,0.959842,0.926798,00:34
2,0.053946,0.093711,0.979443,0.969622,0.943185,00:25
3,0.039252,0.108704,0.970069,0.953817,0.91612,00:34
4,0.035584,0.076555,0.980589,0.970506,0.945206,00:33
5,0.026839,0.10318,0.980634,0.970212,0.946005,00:34
6,0.022269,0.104741,0.980496,0.970079,0.945463,00:34


Iterating through 4 methods in split 1.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.22655,0.184355,0.910243,0.870259,0.788928,00:34
1,0.134162,0.075143,0.975703,0.958374,0.922096,00:27


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.061929,0.035252,0.988124,0.980816,0.963656,00:28
1,0.055666,0.052689,0.978365,0.963699,0.931276,00:35
2,0.046439,0.042693,0.985655,0.976197,0.954224,00:35
3,0.038279,0.071887,0.974408,0.957829,0.923268,00:36
4,0.029486,0.021349,0.991399,0.98603,0.972587,00:36
5,0.024997,0.036536,0.98714,0.977859,0.957053,00:35
6,0.023195,0.025564,0.991364,0.985715,0.971886,00:35


Iterating through 4 methods in split 2.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.214178,0.204847,0.907058,0.772723,0.661454,00:28
1,0.143133,0.198358,0.91486,0.780981,0.674412,00:35


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.101841,0.062237,0.976157,0.952499,0.911823,00:34
1,0.080065,0.066112,0.971913,0.940701,0.893316,00:36
2,0.061724,0.123014,0.938239,0.842056,0.754394,00:36
3,0.048131,0.111035,0.956895,0.888918,0.821734,00:35
4,0.037914,0.093912,0.966843,0.914427,0.860688,00:35
5,0.030551,0.048524,0.979866,0.952703,0.917491,00:30
6,0.028555,0.052106,0.978174,0.949522,0.911606,00:24


Iterating through 4 methods in split 3.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.701357,0.113829,0.970752,0.947714,0.917628,00:35
1,0.314019,0.140215,0.952198,0.915394,0.863836,00:34


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.119349,0.101196,0.9695,0.948063,0.914837,00:35
1,0.08378,0.13681,0.979875,0.967431,0.943654,00:35
2,0.075194,0.082264,0.97988,0.967425,0.943976,00:35
3,0.051828,0.138914,0.984055,0.973713,0.955094,00:34
4,0.035777,0.155154,0.98337,0.972664,0.952946,00:24
5,0.03121,0.133538,0.983958,0.973446,0.954571,00:35
6,0.025359,0.117522,0.98479,0.974717,0.956955,00:35


Iterating through 4 methods in split 4.


 17%|█▋        | 3/18 [1:00:03<4:54:42, 1178.83s/it]

epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.408809,0.124927,0.969737,0.941303,0.899319,00:33
1,0.207885,0.058381,0.984721,0.963736,0.939684,00:33


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.084345,0.042564,0.986851,0.965988,0.943844,00:33
1,0.07848,0.04128,0.986103,0.96425,0.941811,00:33
2,0.068596,0.098506,0.960461,0.916878,0.871098,00:22
3,0.053127,0.045239,0.979695,0.948866,0.916349,00:33
4,0.040211,0.018113,0.993691,0.980023,0.970332,00:33
5,0.030022,0.015225,0.994315,0.981264,0.972738,00:33
6,0.028,0.014156,0.994658,0.981699,0.973568,00:33


Iterating through 4 methods in split 0.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.414199,0.081822,0.973564,0.960574,0.924706,00:34
1,0.198976,0.085514,0.972026,0.957494,0.922231,00:34


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.079831,0.077482,0.975205,0.963924,0.933397,00:34
1,0.066038,0.104854,0.970463,0.955716,0.920669,00:35
2,0.053621,0.115971,0.967001,0.946041,0.902867,00:23
3,0.050652,0.129413,0.968429,0.952101,0.914899,00:30
4,0.03691,0.084609,0.978515,0.967733,0.940459,00:34
5,0.027802,0.098555,0.978069,0.966961,0.939148,00:34
6,0.02277,0.096723,0.978536,0.967657,0.940207,00:34


Iterating through 4 methods in split 1.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.374784,0.107381,0.966532,0.937739,0.895567,00:34
1,0.182787,0.079553,0.970766,0.950553,0.90697,00:35


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.070252,0.046652,0.985159,0.975891,0.954754,00:34
1,0.073216,0.053888,0.980595,0.967487,0.937392,00:36
2,0.061185,0.157929,0.924661,0.885608,0.8062,00:29
3,0.052368,0.036168,0.987779,0.980018,0.960901,00:24
4,0.037734,0.023808,0.992136,0.987262,0.974924,00:35
5,0.027114,0.022912,0.992305,0.987451,0.975258,00:35
6,0.023867,0.020736,0.99258,0.987896,0.97613,00:35


Iterating through 4 methods in split 2.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.234586,0.081501,0.976496,0.936812,0.894609,00:35
1,0.130349,0.04976,0.982999,0.950376,0.918481,00:35


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.073461,0.043488,0.984531,0.954027,0.925277,00:35
1,0.069348,0.134814,0.96934,0.923396,0.874922,00:35
2,0.061942,0.069223,0.969841,0.9394,0.893145,00:34
3,0.047111,0.049407,0.982557,0.962369,0.931677,00:24
4,0.03452,0.03406,0.986726,0.972279,0.948449,00:32
5,0.025894,0.03621,0.986274,0.972061,0.947824,00:35
6,0.021599,0.036028,0.986247,0.972209,0.947999,00:36


Iterating through 4 methods in split 3.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.599163,0.107141,0.966564,0.945203,0.906206,00:35
1,0.26193,0.112564,0.964079,0.941152,0.898811,00:36


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.085882,0.071495,0.978245,0.963972,0.938832,00:36
1,0.073359,0.188449,0.951846,0.914971,0.862763,00:36
2,0.061958,0.193305,0.971747,0.952854,0.919909,00:35
3,0.044997,0.071026,0.983847,0.973557,0.954874,00:33
4,0.038309,0.086603,0.977218,0.961598,0.935331,00:30
5,0.029626,0.105346,0.982507,0.971111,0.95063,00:36
6,0.023348,0.084467,0.982907,0.971848,0.951777,00:35


Iterating through 4 methods in split 4.


 22%|██▏       | 4/18 [1:27:35<5:08:12, 1320.92s/it]

epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.415668,0.122708,0.967281,0.928659,0.885526,00:34
1,0.217403,0.056162,0.982755,0.957098,0.928063,00:34


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.089771,0.049531,0.981053,0.953541,0.921951,00:35
1,0.068567,0.08222,0.971825,0.934459,0.894902,00:35
2,0.065036,0.08981,0.961539,0.928027,0.881383,00:35
3,0.050668,0.018385,0.993495,0.989765,0.979798,00:35
4,0.039685,0.028705,0.988512,0.966389,0.945499,00:23
5,0.033758,0.214409,0.958276,0.924319,0.879159,00:35
6,0.027733,0.046361,0.979146,0.95743,0.925354,00:35
7,0.024338,0.042359,0.978682,0.946798,0.915129,00:35
8,0.022866,0.018279,0.992176,0.984362,0.969938,00:34
9,0.020612,0.014872,0.993794,0.988101,0.976761,00:34


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.077471,0.095997,0.974299,0.961932,0.930617,00:34
1,0.063517,0.113826,0.973747,0.961255,0.930147,00:24
2,0.050023,0.085432,0.977092,0.96572,0.937373,00:32
3,0.051585,0.081299,0.977416,0.965434,0.936733,00:35
4,0.038435,0.079536,0.977522,0.966225,0.939179,00:35
5,0.028787,0.083515,0.980231,0.970508,0.9454,00:36
6,0.02297,0.097134,0.97933,0.968525,0.942117,00:33
7,0.020163,0.093103,0.98013,0.970354,0.94485,00:34
8,0.017968,0.092762,0.980714,0.971541,0.947269,00:34
9,0.016196,0.093268,0.980849,0.971511,0.947561,00:34


Iterating through 4 methods in split 1.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.231398,0.194647,0.906323,0.85872,0.764922,00:29
1,0.14729,0.128983,0.938628,0.901864,0.832601,00:26


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.093406,0.131684,0.944651,0.91239,0.847863,00:36
1,0.080159,0.084837,0.957876,0.93093,0.876588,00:35
2,0.057643,0.076911,0.962484,0.937567,0.887359,00:35
3,0.055348,0.09003,0.960157,0.934535,0.881524,00:36
4,0.045682,0.024406,0.991485,0.986636,0.973813,00:35
5,0.035274,0.051104,0.981914,0.969765,0.943337,00:34
6,0.036191,0.061745,0.973909,0.955378,0.9172,00:35
7,0.027315,0.054866,0.97607,0.95939,0.924765,00:36
8,0.021993,0.036666,0.987201,0.978202,0.957783,00:36
9,0.018926,0.045467,0.981334,0.967989,0.939489,00:25


Iterating through 4 methods in split 2.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.309795,0.175546,0.918101,0.805996,0.699412,00:35
1,0.18531,0.178997,0.920754,0.810073,0.706429,00:36


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.10904,0.197811,0.914828,0.791057,0.683657,00:36
1,0.090498,0.188375,0.924364,0.79,0.694048,00:35
2,0.072123,0.079672,0.980278,0.94486,0.9107,00:35
3,0.061268,0.039854,0.98386,0.953057,0.925347,00:36
4,0.049749,0.063792,0.976903,0.95442,0.91556,00:35
5,0.03628,0.038014,0.98522,0.971119,0.945753,00:37
6,0.029959,0.036123,0.986422,0.973164,0.949022,00:35
7,0.023726,0.033772,0.987848,0.977189,0.956902,00:25
8,0.020329,0.030711,0.98955,0.980008,0.962331,00:37
9,0.017597,0.032051,0.989509,0.980174,0.96268,00:36


Iterating through 4 methods in split 3.


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.291019,0.120003,0.968584,0.943725,0.911143,00:36
1,0.178459,0.129014,0.959528,0.928713,0.885144,00:35


epoch,train_loss,valid_loss,acc,dice,iou,time


In [None]:
df_all.to_csv("model_selection.csv", index=False)

In [None]:
df_all = pd.read_csv("model_selection.csv")

In [None]:
unhashable_columns = ['frozen_lr', 'unfrozen_lr', 'base_factors', 'erode_dilate_factors']
df_all[unhashable_columns] = df_all[unhashable_columns].astype(str)
df_all['wrong_count'] = df_all['wrong_count'].astype(int)

best_params = df_all.groupby(train_params + method_params).mean().sort_values(by=['wrong_count', 'iou']).iloc[0].name
converted_params = []
for i, par in enumerate(best_params):
    if isinstance(par, np.int64) or isinstance(par, np.float64):
        converted_params.append(par.item())
    else:
        converted_params.append(par)
best_params = tuple(converted_params)

best_methods = dict(zip(method_params, best_params[-len(method_params):]))
best_train_params_dict = dict(zip(train_params, best_params[:len(train_params)]))
best_train_params_dict['frozen_lr'] = eval(best_train_params_dict['frozen_lr'])
best_train_params_dict['unfrozen_lr'] = eval(best_train_params_dict['unfrozen_lr'])
best_methods['base_factors'] = eval(best_methods['base_factors'])
best_methods['erode_dilate_factors'] = eval(best_methods['erode_dilate_factors'])

