In [1]:
%%time
from functools import partial
from collections import defaultdict
import numpy as np
import pandas as pd
import scipy
from sklearn.model_selection import KFold
from tqdm.notebook import tqdm
from itertools import product
import time

from video699.screen.semantic_segmentation.fastai_detector import *
from video699.screen.semantic_segmentation.common import *
from video699.screen.semantic_segmentation.postprocessing import *
from video699.screen.semantic_segmentation.evaluation import *

CPU times: user 5.76 s, sys: 860 ms, total: 6.62 s
Wall time: 6.09 s


In [2]:
detector = FastAIScreenDetector()
method_params = list(detector.methods.keys())
train_params = list(detector.train_params.keys())
all_params = train_params + method_params

base = [True]
base_lower_bounds = [7]
base_upper_bounds = [70]
base_factors = [[0.1, 0.01]]

erode_dilate = [True]
erode_dilate_lower_bounds = [7]
erode_dilate_upper_bounds = [70]
erode_dilate_factors = [[0.1, 0.01]]
erode_dilate_iterations = [40, 100]

ratio_split = [True]
ratio_split_lower_bounds = [0.7, 0.9]
ratio_split_upper_bounds = [1.5]

methods_values = [base] + [erode_dilate] + [ratio_split] + [base_lower_bounds] + [base_upper_bounds] \
        + [base_factors] + [erode_dilate_lower_bounds] + [erode_dilate_upper_bounds] + [erode_dilate_factors] \
        + [erode_dilate_iterations] + [ratio_split_lower_bounds] + [ratio_split_upper_bounds]

batch_size = [8]
resize_factor = [2]
frozen_epochs = [2, 6, 9]
unfrozen_epochs = [3, 7, 10]
frozen_lr = [1e-3, 1e-4]
unfrozen_lr = [slice(1e-4, 2e-4)]

train_params_values = [batch_size] + [resize_factor] + [frozen_epochs] + [unfrozen_epochs] + [frozen_lr] + [unfrozen_lr]

In [3]:
method_settings = list(product(*methods_values))
train_settings = list(product(*train_params_values))
all_lectures = [video.filename for video in ALL_VIDEOS]
all_frames = [frame for video in ALL_VIDEOS for frame in video]
all_frames_grouped_by_videos = {video.filename: [frame for frame in video] for video in ALL_VIDEOS}
test_lectures = ['PB069-D2-20140305.mp4']
test_frames = [frame for lecture in test_lectures for frame in all_frames_grouped_by_videos[lecture]]
actual_detector = AnnotatedSampledVideoScreenDetector()

In [4]:
%%time
kf = KFold(n_splits=5, shuffle=True, random_state=123)
df_best_models = pd.DataFrame(columns=all_params + ['iou', 'wrong_count'])
for i, split in tqdm(enumerate(kf.split(all_lectures))):
    other_lectures = [all_lectures[index] for index in split[0]]
    test_lectures = [all_lectures[index] for index in split[1]]
    
    # Model selection
    df_all = pd.DataFrame(columns=all_params + ['iou', 'wrong_count', 'kfold_split'])
    
    for train_setting in tqdm(train_settings):
        train_params_dict = dict(zip(train_params, train_setting))
        for j, split in enumerate(kf.split(other_lectures)):
            train_lectures = [other_lectures[index] for index in split[0]]
            valid_lectures = [other_lectures[index] for index in split[1]]
            valid_frames = [frame for lecture in valid_lectures for frame in all_frames_grouped_by_videos[lecture]]

            filtered_by = lambda name: any([lecture in str(name) for lecture in train_lectures + valid_lectures])  \
                            and 'frame' in str(name)
            split_by = lambda name: any([lecture in str(name) for lecture in valid_lectures])
            
            detector = FastAIScreenDetector(train_params=train_params_dict, methods=None, filtered_by=filtered_by,
                                        valid_func=split_by, progressbar=False, device='cuda')
        
            detector.train()
            
            actuals = [actual_detector.detect(frame) for frame in valid_frames]
            sem_preds = detector.semantic_segmentation_batch(valid_frames)
            
            for i, method_setting in enumerate(method_settings):    
                preds = detector.post_processing_batch(sem_preds, valid_frames, dict(zip(method_params, method_setting)))
                wrong_count, ious, _ = evaluate(actuals, preds)
                
                iou_score = np.nanmean(ious)
                wrong_count = len(wrong_count)
                df_all.loc[len(df_all)] = train_setting + method_setting + (iou_score, wrong_count, j)
    
    unhashable_columns = ['frozen_lr', 'unfrozen_lr', 'base_factors', 'erode_dilate_factors']
    df_all[unhashable_columns] = df_all[unhashable_columns].astype(str)
    df_all['wrong_count'] = df_all['wrong_count'].astype(int)
    
    best_params = df_all.groupby(train_params + method_params).mean().sort_values(by=['wrong_count', 'iou']).iloc[0].name
    converted_params = []
    for i, par in enumerate(best_params):
        if isinstance(par, np.int64) or isinstance(par, np.float64):
            converted_params.append(par.item())
        else:
            converted_params.append(par)
    best_params = tuple(converted_params)
    
    best_methods = dict(zip(method_params, best_params[-len(method_params):]))
    best_train_params_dict = dict(zip(train_params, best_params[:len(train_params)]))
    best_train_params_dict['frozen_lr'] = eval(best_train_params_dict['frozen_lr'])
    best_train_params_dict['unfrozen_lr'] = eval(best_train_params_dict['unfrozen_lr'])
    best_methods['base_factors'] = eval(best_methods['base_factors'])
    best_methods['erode_dilate_factors'] = eval(best_methods['erode_dilate_factors'])
    
    
    filtered_by = lambda name: 'frame' in str(name)
    split_by = lambda name: any([lecture in str(name) for lecture in test_lectures])
    
    best_detector = FastAIScreenDetector(train_params=best_train_params_dict, methods=best_methods, 
                                         filtered_by=filtered_by, valid_func=split_by, progressbar=True, device='cuda')
    best_detector.train()

    actuals = [actual_detector.detect(frame) for frame in valid_frames]
    preds = [best_detector.detect(frame) for frame in valid_frames]
    wrong_count, iou, _ = evaluate(actuals, preds)
    iou_score = np.nanmean(ious)
    wrong_count = len(wrong_count)
    df_best_models.loc[len(df_best_models)] = train_setting + method_setting + (iou_score, wrong_count)
    df_best_models.to_csv('cross_validation_results.csv', index=False)
df_best_models.to_csv('cross_validation_results.csv')

0it [00:00, ?it/s]
  0%|          | 0/18 [00:00<?, ?it/s][A

█


  6%|▌         | 1/18 [06:37<1:52:41, 397.76s/it][A

█


 11%|█         | 2/18 [12:53<1:44:20, 391.30s/it][A

█


 28%|██▊       | 5/18 [48:43<2:14:38, 621.40s/it][A

█


 33%|███▎      | 6/18 [1:02:58<2:18:18, 691.56s/it][A

█


 39%|███▉      | 7/18 [1:13:41<2:04:07, 677.07s/it][A

█


 44%|████▍     | 8/18 [1:24:25<1:51:09, 666.95s/it][A

█


 50%|█████     | 9/18 [1:39:40<1:51:12, 741.41s/it][A

█


 56%|█████▌    | 10/18 [1:54:56<1:45:50, 793.87s/it][A

█


 61%|██████    | 11/18 [2:13:39<1:44:07, 892.56s/it][A

█


 67%|██████▋   | 12/18 [2:32:25<1:36:16, 962.75s/it][A

█


 72%|███████▏  | 13/18 [2:46:32<1:17:19, 927.94s/it][A

█


 78%|███████▊  | 14/18 [3:00:39<1:00:15, 903.76s/it][A

█


 83%|████████▎ | 15/18 [3:19:18<48:24, 968.26s/it]  [A

█


 89%|████████▉ | 16/18 [3:37:55<33:45, 1012.83s/it][A

█


 94%|█████████▍| 17/18 [3:59:59<18:26, 1106.04s/it][A

█


100%|██████████| 18/18 [4:22:13<00:00, 874.06s/it] [A


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.339358,0.137502,0.950038,0.861033,0.796631,00:17
1,0.199227,0.064911,0.977297,0.941446,0.90337,00:17
2,0.132205,0.042738,0.987287,0.966501,0.945082,00:17
3,0.10123,0.034727,0.987904,0.966345,0.945406,00:17
4,0.088352,0.03824,0.989803,0.971224,0.953785,00:17
5,0.070047,0.030055,0.99033,0.971995,0.955211,00:17
6,0.053798,0.02986,0.988796,0.967685,0.947864,00:17
7,0.044685,0.029722,0.989419,0.969491,0.950762,00:17
8,0.040183,0.02902,0.989774,0.970418,0.952354,00:17


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.057786,0.042605,0.98913,0.968152,0.950916,00:17
1,0.068453,0.066378,0.980555,0.95695,0.927205,00:17
2,0.052656,0.035328,0.987416,0.965015,0.943267,00:17


1it [4:25:47, 15947.99s/it]
  0%|          | 0/18 [00:00<?, ?it/s][A

█


  6%|▌         | 1/18 [06:29<1:50:22, 389.56s/it][A

█


 11%|█         | 2/18 [12:59<1:43:56, 389.81s/it][A

█


 17%|█▋        | 3/18 [24:08<1:58:22, 473.47s/it][A

█


 22%|██▏       | 4/18 [35:12<2:03:49, 530.67s/it][A

█


 28%|██▊       | 5/18 [49:42<2:17:02, 632.54s/it][A

█


 33%|███▎      | 6/18 [1:04:16<2:20:56, 704.69s/it][A

█


 39%|███▉      | 7/18 [1:15:20<2:06:58, 692.56s/it][A

█


 44%|████▍     | 8/18 [1:26:30<1:54:17, 685.77s/it][A

█


 50%|█████     | 9/18 [1:42:17<1:54:39, 764.37s/it][A

█


 56%|█████▌    | 10/18 [1:58:01<1:49:05, 818.22s/it][A

█


 61%|██████    | 11/18 [2:17:12<1:47:06, 918.05s/it][A

█


 67%|██████▋   | 12/18 [2:36:18<1:38:38, 986.46s/it][A

█


 72%|███████▏  | 13/18 [2:50:36<1:18:58, 947.77s/it][A

█


 78%|███████▊  | 14/18 [3:04:52<1:01:20, 920.19s/it][A

█


 83%|████████▎ | 15/18 [3:23:50<49:16, 985.66s/it]  [A

█


 89%|████████▉ | 16/18 [3:42:47<34:21, 1030.94s/it][A

█


 94%|█████████▍| 17/18 [4:05:11<18:44, 1124.81s/it][A

█


100%|██████████| 18/18 [4:27:37<00:00, 892.07s/it] [A


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.169093,0.079754,0.970421,0.954109,0.914919,00:20
1,0.146467,0.141869,0.935264,0.902468,0.834138,00:17
2,0.131756,0.161775,0.929699,0.883426,0.803574,00:17
3,0.114882,0.18291,0.91984,0.895189,0.836094,00:17
4,0.095437,0.152145,0.933923,0.908522,0.852732,00:17
5,0.084418,0.122772,0.944264,0.912572,0.846358,00:17
6,0.060537,0.068189,0.969096,0.950149,0.907134,00:17
7,0.043727,0.064354,0.969312,0.950676,0.909452,00:17
8,0.03812,0.061064,0.971764,0.954564,0.916245,00:17


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.025072,0.10817,0.961391,0.939132,0.892083,00:17
1,0.035551,0.056559,0.977005,0.962433,0.929998,00:17
2,0.039263,0.039236,0.985789,0.977735,0.957058,00:17
3,0.02895,0.079233,0.970914,0.952856,0.91343,00:17
4,0.022265,0.031251,0.988307,0.982094,0.965449,00:17
5,0.02234,0.025484,0.991263,0.987104,0.974845,00:17
6,0.019071,0.028257,0.989955,0.985143,0.971016,00:17


2it [8:58:14, 16067.61s/it]
  0%|          | 0/18 [00:00<?, ?it/s][A

█


  6%|▌         | 1/18 [06:52<1:56:44, 412.04s/it][A

█


 11%|█         | 2/18 [13:44<1:49:56, 412.25s/it][A

█


 17%|█▋        | 3/18 [25:38<2:05:40, 502.67s/it][A

█


 22%|██▏       | 4/18 [37:34<2:12:14, 566.75s/it][A

█


 28%|██▊       | 5/18 [53:24<2:27:42, 681.70s/it][A

█


 33%|███▎      | 6/18 [1:09:18<2:32:41, 763.46s/it][A

█


 39%|███▉      | 7/18 [1:21:21<2:17:41, 751.09s/it][A

█


 44%|████▍     | 8/18 [1:33:21<2:03:37, 741.78s/it][A

█


 50%|█████     | 9/18 [1:50:29<2:04:10, 827.80s/it][A

█


 56%|█████▌    | 10/18 [2:07:42<1:58:34, 889.31s/it][A

█


 61%|██████    | 11/18 [2:28:40<1:56:40, 1000.03s/it][A

█


 67%|██████▋   | 12/18 [2:49:38<1:47:43, 1077.25s/it][A

█


 72%|███████▏  | 13/18 [3:05:25<1:26:31, 1038.35s/it][A

█


 78%|███████▊  | 14/18 [3:21:13<1:07:23, 1010.98s/it][A

█


 83%|████████▎ | 15/18 [3:42:03<54:08, 1082.73s/it]  [A

█


 89%|████████▉ | 16/18 [4:03:04<37:52, 1136.23s/it][A

█


100%|██████████| 18/18 [4:52:29<00:00, 975.00s/it] [A


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.279195,0.12218,0.969078,0.92567,0.890459,00:20
1,0.159708,0.090428,0.974153,0.939431,0.908416,00:18


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.065353,0.099961,0.975049,0.956458,0.92577,00:18
1,0.051377,0.111108,0.973217,0.954959,0.921679,00:18
2,0.046192,0.097865,0.976712,0.960689,0.932324,00:18
3,0.038218,0.107283,0.980153,0.965331,0.940639,00:18
4,0.033161,0.111029,0.977238,0.960648,0.933313,00:18
5,0.027052,0.106592,0.978978,0.9482,0.922166,00:18
6,0.022411,0.116105,0.980262,0.965605,0.941277,00:18
7,0.018189,0.118714,0.981329,0.967623,0.944699,00:18
8,0.015753,0.145748,0.981535,0.967931,0.945203,00:18
9,0.013692,0.133455,0.981784,0.968256,0.945863,00:18


3it [13:54:33, 16580.88s/it]
  0%|          | 0/18 [00:00<?, ?it/s][A

█


  6%|▌         | 1/18 [06:56<1:57:56, 416.27s/it][A

█


 11%|█         | 2/18 [13:48<1:50:42, 415.13s/it][A

█


 17%|█▋        | 3/18 [25:38<2:05:52, 503.52s/it][A

█


 22%|██▏       | 4/18 [37:26<2:11:48, 564.88s/it][A

█


 28%|██▊       | 5/18 [53:01<2:26:27, 675.98s/it][A

█


 33%|███▎      | 6/18 [1:08:37<2:30:46, 753.86s/it][A

█


 39%|███▉      | 7/18 [1:20:20<2:15:26, 738.78s/it][A

█


 44%|████▍     | 8/18 [1:32:05<2:01:26, 728.60s/it][A

█


 50%|█████     | 9/18 [1:48:47<2:01:34, 810.55s/it][A

█


 56%|█████▌    | 10/18 [2:05:33<1:55:54, 869.30s/it][A

█


 61%|██████    | 11/18 [2:26:03<1:54:01, 977.40s/it][A

█


 67%|██████▋   | 12/18 [2:46:28<1:45:10, 1051.73s/it][A

█


 72%|███████▏  | 13/18 [3:01:49<1:24:22, 1012.52s/it][A

█


 78%|███████▊  | 14/18 [3:17:06<1:05:35, 983.92s/it] [A

█


 83%|████████▎ | 15/18 [3:37:21<52:39, 1053.06s/it] [A

█


 89%|████████▉ | 16/18 [3:57:37<36:43, 1101.89s/it][A

█


 94%|█████████▍| 17/18 [4:21:40<20:04, 1204.46s/it][A

█


100%|██████████| 18/18 [4:45:44<00:00, 952.49s/it] [A


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.255325,0.127602,0.954977,0.93494,0.885137,00:17
1,0.166541,0.124251,0.9529,0.931847,0.879311,00:17


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.092396,0.091333,0.971807,0.962533,0.931492,00:18
1,0.058454,0.08482,0.978133,0.970878,0.946812,00:18
2,0.043485,0.101154,0.976331,0.968286,0.94256,00:18


4it [18:41:54, 16778.98s/it]
  0%|          | 0/18 [00:00<?, ?it/s][A

█


  6%|▌         | 1/18 [06:54<1:57:21, 414.18s/it][A

█


 11%|█         | 2/18 [13:47<1:50:20, 413.80s/it][A

█


 17%|█▋        | 3/18 [25:34<2:05:28, 501.93s/it][A

█


 22%|██▏       | 4/18 [37:21<2:11:25, 563.25s/it][A

█


 28%|██▊       | 5/18 [52:49<2:25:48, 672.93s/it][A

█


 33%|███▎      | 6/18 [1:08:17<2:29:51, 749.29s/it][A

█


 39%|███▉      | 7/18 [1:19:56<2:14:37, 734.30s/it][A

█


 44%|████▍     | 8/18 [1:31:33<2:00:29, 722.95s/it][A

█


 50%|█████     | 9/18 [1:48:06<2:00:37, 804.18s/it][A

█


 56%|█████▌    | 10/18 [2:04:41<1:54:50, 861.30s/it][A

█


 61%|██████    | 11/18 [2:24:56<1:52:51, 967.40s/it][A

█


 67%|██████▋   | 12/18 [2:45:12<1:44:12, 1042.02s/it][A

█


 72%|███████▏  | 13/18 [3:00:31<1:23:44, 1004.96s/it][A

█


 78%|███████▊  | 14/18 [3:15:47<1:05:14, 978.52s/it] [A

█


 83%|████████▎ | 15/18 [3:35:47<52:14, 1044.99s/it] [A

█


 89%|████████▉ | 16/18 [3:55:39<36:18, 1089.11s/it][A

█


 94%|█████████▍| 17/18 [4:19:16<19:47, 1187.21s/it][A

█


100%|██████████| 18/18 [4:43:16<00:00, 944.24s/it] [A


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.232944,0.109097,0.946056,0.89264,0.819051,00:18
1,0.131544,0.060144,0.974024,0.942271,0.895718,00:17


epoch,train_loss,valid_loss,acc,dice,iou,time
0,0.082832,0.156206,0.951223,0.903428,0.8394,00:18
1,0.068399,0.124386,0.954879,0.90886,0.845913,00:18
2,0.048719,0.053915,0.979603,0.953867,0.914846,00:18


5it [23:26:47, 16881.43s/it]

CPU times: user 12h 19min 6s, sys: 3h 15min 10s, total: 15h 34min 17s
Wall time: 23h 26min 47s



