In [1]:
%%time
from functools import partial
from collections import defaultdict
import numpy as np
import pandas as pd
import scipy
from sklearn.model_selection import KFold
from tqdm.notebook import tqdm
from itertools import product
import time

from video699.screen.semantic_segmentation.fastai_detector import *
from video699.screen.semantic_segmentation.common import *
from video699.screen.semantic_segmentation.postprocessing import *
from video699.screen.semantic_segmentation.evaluation import *

CPU times: user 5.9 s, sys: 1.14 s, total: 7.03 s
Wall time: 19.9 s


In [2]:
detector = FastAIScreenDetector()
method_params = list(detector.methods.keys())
train_params = list(detector.train_params.keys())
all_params = train_params + method_params

base = [True]
base_lower_bounds = [7, 10]
base_upper_bounds = [30, 40, 50, 60, 70]
base_factors = [[0.1, 0.01]]

erode_dilate = [True]
erode_dilate_lower_bounds = [5]
erode_dilate_upper_bounds = [40]
erode_dilate_factors = [[0.1, 0.01]]
erode_dilate_iterations = [40, 100]

ratio_split = [True]
ratio_split_lower_bounds = [0.7, 0.9]
ratio_split_upper_bounds = [1.5]

methods_values = [base] + [erode_dilate] + [ratio_split] + [base_lower_bounds] + [base_upper_bounds] \
        + [base_factors] + [erode_dilate_lower_bounds] + [erode_dilate_upper_bounds] + [erode_dilate_factors] \
        + [erode_dilate_iterations] + [ratio_split_lower_bounds] + [ratio_split_upper_bounds]

batch_size = [8]
resize_factor = [2]
frozen_epochs = [2, 6, 9]
unfrozen_epochs = [3, 7, 10]
frozen_lr = [1e-3]
unfrozen_lr = [slice(1e-4, 2e-4)]

train_params_values = [batch_size] + [resize_factor] + [frozen_epochs] + [unfrozen_epochs] + [frozen_lr] + [unfrozen_lr]

In [3]:
method_settings = list(product(*methods_values))
train_settings = list(product(*train_params_values))
all_lectures = [video.filename for video in ALL_VIDEOS]
all_frames = [frame for video in ALL_VIDEOS for frame in video]
all_frames_grouped_by_videos = {video.filename: [frame for frame in video] for video in ALL_VIDEOS}
test_lectures = ['PB069-D2-20140305.mp4']
test_frames = [frame for lecture in test_lectures for frame in all_frames_grouped_by_videos[lecture]]
actual_detector = AnnotatedSampledVideoScreenDetector()

In [4]:
%%time
kf = KFold(n_splits=5, shuffle=True, random_state=123)
df_best_models = pd.DataFrame(columns=all_params + ['iou', 'wrong_count'])
for i, split in tqdm(enumerate(kf.split(all_lectures))):
    other_lectures = [all_lectures[index] for index in split[0]]
    test_lectures = [all_lectures[index] for index in split[1]]
    
    # Model selection
    df_all = pd.DataFrame(columns=all_params + ['iou', 'wrong_count', 'kfold_split'])
    
    for train_setting in tqdm(train_settings):
        train_params_dict = dict(zip(train_params, train_setting))
        for j, split in enumerate(kf.split(other_lectures)):
            train_lectures = [other_lectures[index] for index in split[0]]
            valid_lectures = [other_lectures[index] for index in split[1]]
            valid_frames = [frame for lecture in valid_lectures for frame in all_frames_grouped_by_videos[lecture]]

            filtered_by = lambda name: any([lecture in str(name) for lecture in train_lectures + valid_lectures])  \
                            and 'frame' in str(name)
            split_by = lambda name: any([lecture in str(name) for lecture in valid_lectures])
            
            detector = FastAIScreenDetector(train_params=train_params_dict, methods=None, filtered_by=filtered_by,
                                        valid_func=split_by, progressbar=False, device='cuda')
        
            detector.train()
            
            actuals = [actual_detector.detect(frame) for frame in valid_frames]
            sem_preds = detector.semantic_segmentation_batch(valid_frames)
            
            for i, method_setting in enumerate(method_settings):    
                preds = detector.post_processing_batch(sem_preds, valid_frames, dict(zip(method_params, method_setting)))
                wrong_count, ious, _ = evaluate(actuals, preds)
                
                iou_score = np.nanmean(ious)
                wrong_count = len(wrong_count)
                df_all.loc[len(df_all)] = train_setting + method_setting + (iou_score, wrong_count, j)
    
    unhashable_columns = ['frozen_lr', 'unfrozen_lr', 'base_factors', 'erode_dilate_factors']
    df_all[unhashable_columns] = df_all[unhashable_columns].astype(str)
    df_all['wrong_count'] = df_all['wrong_count'].astype(int)
    
    best_params = df_all.groupby(train_params + method_params).mean().sort_values(by=['wrong_count', 'iou']).iloc[0].name
    converted_params = []
    for i, par in enumerate(best_params):
        if isinstance(par, np.int64) or isinstance(par, np.float64):
            converted_params.append(par.item())
        else:
            converted_params.append(par)
    best_params = tuple(converted_params)
    
    best_methods = dict(zip(method_params, best_params[-len(method_params):]))
    best_train_params_dict = dict(zip(train_params, best_params[:len(train_params)]))
    best_train_params_dict['frozen_lr'] = eval(best_train_params_dict['frozen_lr'])
    best_train_params_dict['unfrozen_lr'] = eval(best_train_params_dict['unfrozen_lr'])
    best_methods['base_factors'] = eval(best_methods['base_factors'])
    best_methods['erode_dilate_factors'] = eval(best_methods['erode_dilate_factors'])
    
    
    filtered_by = lambda name: 'frame' in str(name)
    split_by = lambda name: any([lecture in str(name) for lecture in test_lectures])
    
    best_detector = FastAIScreenDetector(train_params=best_train_params_dict, methods=best_methods, 
                                         filtered_by=filtered_by, valid_func=split_by, progressbar=True, device='cuda')
    best_detector.train()

    actuals = [actual_detector.detect(frame) for frame in valid_frames]
    preds = [best_detector.detect(frame) for frame in valid_frames]
    wrong_count, iou, _ = evaluate(actuals, preds)
    iou_score = np.nanmean(ious)
    wrong_count = len(wrong_count)
    df_best_models.loc[len(df_best_models)] = train_setting + method_setting + (iou_score, wrong_count)
    df_best_models.to_csv('cross_validation_results.csv', index=False)
df_best_models.to_csv('cross_validation_results.csv')

0it [00:00, ?it/s]
  0%|          | 0/9 [00:00<?, ?it/s][A

█


 11%|█         | 1/9 [07:40<1:01:24, 460.55s/it][A

█


 22%|██▏       | 2/9 [19:33<1:02:34, 536.30s/it][A

█


 33%|███▎      | 3/9 [34:46<1:04:55, 649.25s/it][A

█


 44%|████▍     | 4/9 [46:45<55:51, 670.21s/it]  [A

█


 56%|█████▌    | 5/9 [1:03:21<51:12, 768.01s/it][A

█


 67%|██████▋   | 6/9 [1:23:22<44:53, 897.76s/it][A

█


 78%|███████▊  | 7/9 [1:38:44<30:09, 904.98s/it][A

█


 89%|████████▉ | 8/9 [1:58:34<16:30, 990.68s/it][A

█


100%|██████████| 9/9 [2:21:31<00:00, 943.48s/it] [A


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.276048,0.18829,0.9517,0.883129,0.815287,00:19
1,0.213969,0.185705,0.936252,0.806059,0.734579,00:16


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.134757,0.145559,0.947237,0.849309,0.77971,00:17
1,0.082689,0.110258,0.957836,0.89805,0.834278,00:17
2,0.058024,0.059513,0.976537,0.94846,0.90854,00:16


1it [2:23:03, 8583.88s/it]
  0%|          | 0/9 [00:00<?, ?it/s][A

█


 11%|█         | 1/9 [07:18<58:24, 438.04s/it][A

█


 22%|██▏       | 2/9 [19:08<1:00:37, 519.71s/it][A

█


 33%|███▎      | 3/9 [34:17<1:03:38, 636.49s/it][A

█


 44%|████▍     | 4/9 [45:58<54:39, 655.97s/it]  [A

█


 56%|█████▌    | 5/9 [1:02:11<50:04, 751.07s/it][A

█


 67%|██████▋   | 6/9 [1:21:47<43:55, 878.39s/it][A

█


 78%|███████▊  | 7/9 [1:36:40<29:26, 883.02s/it][A

█


 89%|████████▉ | 8/9 [1:56:09<16:08, 968.55s/it][A

█


100%|██████████| 9/9 [2:19:07<00:00, 927.47s/it] [A


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.288883,0.176154,0.928801,0.886271,0.798361,00:19
1,0.148883,0.121931,0.951135,0.916517,0.848964,00:16
2,0.125821,0.046052,0.982866,0.969636,0.941763,00:16
3,0.076314,0.023623,0.98943,0.980982,0.962797,00:16
4,0.0507,0.01896,0.991736,0.985081,0.970684,00:16
5,0.033555,0.018326,0.992159,0.985309,0.971161,00:16


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.0203,0.022165,0.990598,0.981637,0.964339,00:16
1,0.02692,0.021192,0.991865,0.985455,0.97145,00:16
2,0.02391,0.017667,0.992418,0.986198,0.972853,00:16
3,0.027716,0.048512,0.986876,0.974463,0.950902,00:16
4,0.021645,0.018773,0.991926,0.98551,0.97151,00:16
5,0.025183,0.019446,0.992341,0.98532,0.971183,00:16
6,0.024716,0.014681,0.993899,0.988805,0.977904,00:16
7,0.016468,0.016065,0.993312,0.987918,0.976162,00:16
8,0.013915,0.016104,0.993485,0.988132,0.976583,00:16
9,0.010914,0.01776,0.992915,0.987258,0.974879,00:16


2it [4:46:48, 8596.10s/it]
  0%|          | 0/9 [00:00<?, ?it/s][A

█


 11%|█         | 1/9 [07:37<1:01:03, 457.90s/it][A

█


 22%|██▏       | 2/9 [20:02<1:03:27, 543.87s/it][A

█


 33%|███▎      | 3/9 [35:54<1:06:37, 666.27s/it][A

█


 44%|████▍     | 4/9 [48:12<57:19, 687.82s/it]  [A

█


 56%|█████▌    | 5/9 [1:05:23<52:43, 790.83s/it][A

█


 67%|██████▋   | 6/9 [1:26:10<46:22, 927.66s/it][A

█


 78%|███████▊  | 7/9 [1:42:10<31:14, 937.40s/it][A

█


 89%|████████▉ | 8/9 [2:02:54<17:09, 1029.44s/it][A

█


100%|██████████| 9/9 [2:27:19<00:00, 982.19s/it] [A


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.310782,0.121355,0.948854,0.895898,0.817388,00:17
1,0.161234,0.043402,0.983315,0.967219,0.938411,00:17
2,0.093186,0.038647,0.98738,0.975257,0.95551,00:17
3,0.05686,0.029298,0.988633,0.977736,0.958718,00:17
4,0.039098,0.018202,0.992794,0.985662,0.971914,00:17
5,0.028582,0.018146,0.99291,0.985873,0.972282,00:17


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.024208,0.02833,0.989141,0.978985,0.961118,00:17
1,0.018554,0.015262,0.993736,0.987878,0.976115,00:17
2,0.017293,0.026035,0.991016,0.982874,0.967192,00:17


3it [7:16:48, 8717.29s/it]
  0%|          | 0/9 [00:00<?, ?it/s][A

█


 11%|█         | 1/9 [07:49<1:02:37, 469.74s/it][A

█


 22%|██▏       | 2/9 [20:23<1:04:44, 554.86s/it][A

█


 33%|███▎      | 3/9 [36:30<1:07:51, 678.63s/it][A

█


 44%|████▍     | 4/9 [49:01<58:21, 700.39s/it]  [A

█


 56%|█████▌    | 5/9 [1:06:19<53:25, 801.45s/it][A

█


 67%|██████▋   | 6/9 [1:27:11<46:49, 936.61s/it][A

█


 78%|███████▊  | 7/9 [1:43:12<31:28, 944.03s/it][A

█


 89%|████████▉ | 8/9 [2:03:57<17:14, 1034.48s/it][A

█


100%|██████████| 9/9 [2:28:18<00:00, 988.70s/it] [A


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.267262,0.066469,0.973211,0.950454,0.908248,00:17
1,0.13413,0.049202,0.981159,0.941849,0.90531,00:17
2,0.288057,0.05149,0.985224,0.957237,0.932019,00:17
3,0.219998,0.033496,0.988991,0.977147,0.955663,00:17
4,0.130161,0.021372,0.991697,0.982628,0.967303,00:17
5,0.084979,0.021068,0.991715,0.982574,0.967209,00:17


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.032645,0.01513,0.994044,0.987929,0.976215,00:17
1,0.036641,0.012872,0.994928,0.989748,0.979716,00:17
2,0.029573,0.010858,0.996092,0.991987,0.984122,00:17
3,0.02815,0.011825,0.99507,0.989575,0.979448,00:17
4,0.020419,0.008886,0.996153,0.992134,0.984419,00:17
5,0.015815,0.008707,0.996552,0.993022,0.986162,00:17
6,0.013224,0.009043,0.995949,0.991798,0.983749,00:17
7,0.013315,0.009153,0.996109,0.991875,0.984061,00:17
8,0.010423,0.010462,0.99563,0.990688,0.98209,00:17
9,0.010114,0.014815,0.994332,0.98825,0.978793,00:17


4it [9:49:53, 8857.61s/it]
  0%|          | 0/9 [00:00<?, ?it/s][A

█


 11%|█         | 1/9 [07:44<1:01:55, 464.42s/it][A

█


 22%|██▏       | 2/9 [20:17<1:04:16, 550.97s/it][A

█


 33%|███▎      | 3/9 [36:22<1:07:31, 675.17s/it][A

█


 44%|████▍     | 4/9 [48:46<57:58, 695.78s/it]  [A

█


 67%|██████▋   | 6/9 [1:26:42<46:33, 931.28s/it][A

█


 78%|███████▊  | 7/9 [1:42:31<31:13, 936.63s/it][A

█


 89%|████████▉ | 8/9 [2:03:06<17:06, 1026.35s/it][A

█


100%|██████████| 9/9 [2:27:14<00:00, 981.64s/it] [A


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.337639,0.16616,0.904206,0.834766,0.759246,00:17
1,0.191064,0.079999,0.979892,0.951992,0.91981,00:17


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.084899,0.053519,0.983677,0.958577,0.932177,00:17
1,0.060548,0.032533,0.989308,0.96668,0.94776,00:17
2,0.047564,0.064651,0.980075,0.955695,0.926819,00:17
3,0.042504,0.02239,0.990919,0.967391,0.94946,00:17
4,0.034587,0.086912,0.977897,0.95031,0.917816,00:17
5,0.025196,0.023109,0.992688,0.972955,0.959688,00:17
6,0.02122,0.017067,0.99443,0.977163,0.96789,00:17
7,0.016796,0.011445,0.995355,0.978281,0.970058,00:17
8,0.014509,0.010906,0.995517,0.978336,0.970179,00:17
9,0.013038,0.010151,0.995869,0.97912,0.971711,00:17


5it [12:20:44, 8888.85s/it]

CPU times: user 8h 44min 5s, sys: 2h 14min 28s, total: 10h 58min 33s
Wall time: 12h 20min 44s



