In [1]:
%%time
import warnings
warnings.filterwarnings('ignore')
from functools import partial
from collections import defaultdict, namedtuple
import numpy as np
import pandas as pd
import scipy
from sklearn.model_selection import KFold
from tqdm.notebook import tqdm
from itertools import product

from video699.screen.semantic_segmentation.fastai_detector import *
from video699.screen.semantic_segmentation.common import *
from video699.screen.semantic_segmentation.postprocessing import *
from video699.screen.semantic_segmentation.evaluation import *

CPU times: user 8.02 s, sys: 1.13 s, total: 9.15 s
Wall time: 6.6 s


In [2]:
resize_factor = [8]
frozen_epochs = [4, 7]
unfrozen_epochs = [4, 7]
base_lower_bound = [5, 7, 10, 15]
erosion_dilation_kernel_size = [20, 50, 80, 150]
ratio_split_lower_bound = [0.3, 0.4, 0.5, 0.7, 0.8, 0.9]

In [3]:
# resize_factor = [8]
# frozen_epochs = [0]
# unfrozen_epochs = [0]
# base_lower_bound = [15]
# erosion_dilation_kernel_size = [20,25]
# ratio_split_lower_bound = [0.7]

In [4]:
train = list(product(resize_factor, frozen_epochs, unfrozen_epochs))
train_names = ['resize_factor', 'frozen_epochs', 'unfrozen_epochs']

post_processing = list(product(base_lower_bound, erosion_dilation_kernel_size, ratio_split_lower_bound))
post_processing_names = ['base_lower_bound', 'erosion_dilation_kernel_size', 'ratio_split_lower_bound']

all_frames = [frame for video in ALL_VIDEOS for frame in video]

detector = FastAIScreenDetector()
actual_detector = AnnotatedSampledVideoScreenDetector()

In [5]:
def filtered_by(name, used):
    return str(name) in [str(frame.pathname) for frame in used] and 'frame' in str(name)

def split_by(name, validation):
    return str(name) in [str(frame.pathname) for frame in validation]

In [6]:
def model_selection(frames, train_names, post_processing_names, default_filtered_by, default_split_by):
    def make_splits(frames):
        Split = namedtuple('Split', ['train', 'valid'])
        kf = KFold(n_splits=5, shuffle=True, random_state=123)
        splits = {}
        for j, split in enumerate(kf.split(frames)):    
            train_frames = [frames[index] for index in split[0]]
            valid_frames = [frames[index] for index in split[1]]
            splits[j] = Split(train=train_frames, valid=valid_frames)
        return splits

    splits = make_splits(frames)
    df_all = pd.DataFrame(columns=train_names + post_processing_names + ['iou', 'wrong_count', 'kfold_split'])

    for train_values in tqdm(train):
        resize_factor, frozen_epochs, unfrozen_epochs = train_values
        CONFIGURATION['resize_factor'] = str(resize_factor)
        CONFIGURATION['frozen_epochs'] = str(frozen_epochs)
        CONFIGURATION['unfrozen_epochs'] = str(unfrozen_epochs)

        for j in splits.keys():
            filtered_by = partial(default_filtered_by, used=splits[j].train + splits[j].valid)
            split_by = partial(default_split_by, validation=splits[j].valid)

            detector = FastAIScreenDetector(filtered_by=filtered_by, valid_func=split_by)
            detector.train()
            
            valid_frames = [frame for frame in all_frames if split_by(frame.pathname)]
            actuals = [actual_detector.detect(frame) for frame in valid_frames]
            sem_preds = detector.semantic_segmentation_batch(valid_frames)

            for post_processing_values in post_processing:    
                preds = detector.post_processing_batch(sem_preds, valid_frames, **dict(zip(post_processing_names, post_processing_values)))
                wrong_count, ious, _ = evaluate(actuals, preds)
                iou_score = np.nanmean(ious)
                wrong_count = len(wrong_count)
                df_all.loc[len(df_all)] = train_values + post_processing_values + (iou_score, wrong_count, j)
    return df_all

In [7]:
def convert_params(best_params):
    converted_params = []
    for i, par in enumerate(best_params):
        if par.is_integer():
            converted_params.append(int(par))
        elif isinstance(par, np.int64) or isinstance(par, np.float64):
            converted_params.append(par.item())
        else:
            converted_params.append(par)
    best_params = tuple(converted_params)
    
    return best_params

### Lecture-wise 5-fold cross validation

In [8]:
# from fastai.vision import *
# paths = get_image_files(detector.videos_path, recurse=True)
# filtered_by()

In [9]:
kf = KFold(n_splits=5, shuffle=True, random_state=123)
df_best_models = pd.DataFrame(columns=train_names + post_processing_names + ['iou', 'wrong_count'])

for i, split in tqdm(enumerate(kf.split(all_frames))):
    print(f"###################### Split No. {i}")
    other_frames = [all_frames[index] for index in split[0]]
    test_frames = [all_frames[index] for index in split[1]]
    
    # Model selection
    df_all = model_selection(other_frames, train_names, post_processing_names, filtered_by, split_by)    
    df_all['wrong_count'] = df_all['wrong_count'].astype(int)
    best_params = df_all.groupby(train_names + post_processing_names).mean().sort_values(by=['wrong_count', 'iou']).iloc[0].name
    best_params = convert_params(best_params)
    best_params = dict(zip(train_names + post_processing_names, best_params))
    test_filtered_by = partial(filtered_by, used=all_frames)
    test_split_by = partial(split_by, validation=test_frames)
    
    
    best_detector = FastAIScreenDetector(filtered_by=test_filtered_by, valid_func=test_split_by)
    best_detector.train(**best_params)
    
    test_frames = [frame for frame in all_frames if test_split_by(frame.pathname)]
    actuals = [actual_detector.detect(frame) for frame in test_frames]
    preds = [best_detector.detect(frame) for frame in test_frames]
    
    wrong_count, ious, _ = evaluate(actuals, preds)
    iou_score = np.nanmean(ious)
    wrong_count = len(wrong_count)
    df_best_models.loc[len(df_best_models)] = tuple(best_params.values()) + (iou_score, wrong_count)
    df_best_models.to_csv('cross_validation_results_image_wise.csv', index=False)
    
df_best_models.to_csv('cross_validation_results_image_wise.csv', index=False)

0it [00:00, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s][A

###################### Split No. 0


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.388863,0.152626,0.940612,0.874231,0.795262,00:45
1,0.201911,0.057415,0.979826,0.960039,0.928533,00:49
2,0.125444,0.033549,0.987479,0.974957,0.95256,00:51
3,0.091997,0.03896,0.984975,0.971454,0.948013,00:49


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.046139,0.03109,0.989721,0.979137,0.961982,00:51
1,0.040536,0.021564,0.991793,0.984543,0.970247,00:49
2,0.029451,0.018319,0.993226,0.987354,0.975723,00:49
3,0.024288,0.014892,0.994108,0.989054,0.978675,00:55


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.207827,0.074476,0.970398,0.933303,0.891105,00:48
1,0.16714,0.043669,0.984561,0.953854,0.927349,00:45
2,0.117859,0.040899,0.984196,0.952868,0.925816,00:45
3,0.07673,0.033948,0.986787,0.973103,0.950082,00:44


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.033695,0.035033,0.988777,0.977282,0.958119,00:49
1,0.035502,0.038239,0.989099,0.962851,0.94431,00:51
2,0.036344,0.039413,0.986277,0.965257,0.948285,00:57
3,0.032984,0.033628,0.987196,0.965805,0.949881,00:55


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.352898,0.084821,0.969243,0.934244,0.882312,00:49
1,0.221717,0.052774,0.981795,0.956359,0.923999,00:54
2,0.129403,0.02703,0.99019,0.980374,0.961897,00:49
3,0.085495,0.022374,0.990515,0.981478,0.963808,00:44


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.040265,0.018878,0.992077,0.984386,0.969389,00:46
1,0.040818,0.018013,0.99287,0.985981,0.972493,00:46
2,0.031393,0.01294,0.994653,0.989577,0.979463,00:46
3,0.027462,0.01222,0.995278,0.99089,0.982013,00:45


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.277143,0.15603,0.961344,0.917779,0.85884,00:44
1,0.169078,0.029188,0.988713,0.97711,0.955524,00:42
2,0.105603,0.020035,0.99254,0.985206,0.970939,00:44
3,0.070531,0.019143,0.991954,0.983687,0.968016,00:42


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.03172,0.016493,0.993409,0.987062,0.974531,00:48
1,0.034355,0.013426,0.99482,0.989914,0.980105,00:48
2,0.025166,0.010172,0.996322,0.992833,0.985811,00:51
3,0.023524,0.00975,0.996638,0.993473,0.987072,00:48


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.294647,0.084784,0.969855,0.93406,0.884023,00:42
1,0.150863,0.033892,0.986007,0.968151,0.941321,00:44
2,0.088311,0.0253,0.990064,0.979194,0.961441,00:42
3,0.057004,0.025496,0.990532,0.980304,0.963811,00:46


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.027073,0.036198,0.989271,0.976622,0.957129,00:46
1,0.030209,0.020559,0.992324,0.9842,0.969134,00:50
2,0.02614,0.011954,0.995247,0.990337,0.981028,01:04
3,0.021383,0.011729,0.995453,0.990426,0.981268,01:04



 25%|██▌       | 1/4 [35:49<1:47:29, 2149.72s/it][A

epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.254534,0.118333,0.954386,0.902973,0.837448,00:44
1,0.154229,0.057016,0.979548,0.959605,0.924521,00:44
2,0.09761,0.034317,0.987841,0.975428,0.956572,00:45
3,0.06567,0.038868,0.988127,0.9758,0.957389,00:43


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.029477,0.045828,0.98613,0.971974,0.950269,00:48
1,0.031501,0.0434,0.982461,0.966827,0.937099,00:48
2,0.038318,0.027962,0.987376,0.973605,0.953893,00:47
3,0.032261,0.025773,0.990914,0.982184,0.968065,00:49
4,0.027092,0.024614,0.991108,0.98224,0.968483,00:46
5,0.024295,0.013473,0.994683,0.990152,0.981069,00:49
6,0.020695,0.011442,0.995742,0.992152,0.984604,00:49


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.23428,0.07938,0.972491,0.930526,0.886103,00:43
1,0.124973,0.063665,0.976693,0.943445,0.908625,00:46
2,0.0835,0.028457,0.988674,0.962247,0.942829,00:43
3,0.054141,0.027427,0.988933,0.962475,0.943464,00:46


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.02969,0.029822,0.989646,0.964352,0.9473,00:48
1,0.023857,0.024094,0.992639,0.9704,0.958651,00:49
2,0.029234,0.02813,0.992059,0.969529,0.95705,00:49
3,0.030098,0.033546,0.991594,0.983295,0.969635,00:46
4,0.023265,0.030628,0.992782,0.98519,0.973772,00:50
5,0.019346,0.030993,0.992798,0.985181,0.973744,00:49
6,0.015757,0.030453,0.992934,0.98556,0.974324,00:47


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.265349,0.095145,0.963977,0.925189,0.870921,00:44
1,0.200588,0.129734,0.960824,0.91531,0.858323,00:46
2,0.139635,0.029354,0.988305,0.976644,0.954727,00:45
3,0.092892,0.029176,0.988611,0.977457,0.956226,00:47


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.049886,0.022425,0.991192,0.982942,0.966614,00:48
1,0.04961,0.02691,0.988751,0.978004,0.957375,00:50
2,0.040848,0.014108,0.994454,0.988925,0.978299,00:50
3,0.032104,0.011923,0.995465,0.991156,0.982576,00:52
4,0.029625,0.01113,0.99557,0.99127,0.982774,00:51
5,0.025041,0.00997,0.996211,0.992663,0.985492,00:51
6,0.021741,0.009447,0.99647,0.993135,0.98642,00:48


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.230592,0.068135,0.972384,0.938668,0.888582,00:45
1,0.17162,0.055475,0.983122,0.962369,0.929327,00:47
2,0.108283,0.023049,0.991161,0.98213,0.965055,00:45
3,0.072601,0.019722,0.992367,0.984865,0.970312,00:47


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.032984,0.016583,0.993445,0.987018,0.974466,00:52
1,0.033185,0.015526,0.99363,0.987248,0.97488,00:49
2,0.036069,0.015016,0.993868,0.987784,0.975944,00:52
3,0.029167,0.010796,0.99589,0.992009,0.984197,00:53
4,0.024309,0.008869,0.996569,0.993325,0.986775,01:04
5,0.020579,0.008194,0.996757,0.993705,0.987517,01:05
6,0.017865,0.007908,0.99699,0.994134,0.988365,00:58


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.282706,0.180396,0.92128,0.849866,0.75536,00:49
1,0.146113,0.037722,0.985126,0.96977,0.943216,00:47
2,0.093205,0.026802,0.989715,0.978337,0.95965,00:48
3,0.062941,0.023155,0.990897,0.9815,0.965064,00:49


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.030949,0.016013,0.993575,0.9869,0.974313,00:50
1,0.037076,0.017529,0.993029,0.986083,0.972984,00:51
2,0.040837,0.018187,0.992631,0.984992,0.970899,00:53
3,0.035193,0.014665,0.994155,0.987944,0.97652,00:52
4,0.026466,0.011617,0.995416,0.99063,0.98171,00:49
5,0.021049,0.01062,0.995909,0.991548,0.983508,00:52
6,0.018391,0.010056,0.996156,0.992024,0.984398,00:53



 50%|█████     | 2/4 [1:24:15<1:19:13, 2376.63s/it][A

epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.366039,0.168217,0.927991,0.865513,0.773124,00:50
1,0.209097,0.051722,0.981182,0.962821,0.932967,00:48
2,0.12469,0.037735,0.983796,0.969282,0.943064,00:51
3,0.086363,0.034184,0.985449,0.972826,0.947932,00:50
4,0.066887,0.01945,0.992841,0.986488,0.973481,00:48
5,0.046792,0.016788,0.993418,0.987625,0.975731,00:49
6,0.035198,0.016219,0.993194,0.987111,0.974706,00:50


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.021563,0.018911,0.992593,0.986047,0.972718,00:51
1,0.027772,0.016594,0.993512,0.98768,0.975775,00:52
2,0.023121,0.012044,0.995151,0.990883,0.982066,00:54
3,0.018044,0.011335,0.995847,0.992353,0.984958,00:54


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.31132,0.130084,0.941423,0.88312,0.807537,00:52
1,0.17425,0.046208,0.983086,0.951099,0.921896,00:47
2,0.118429,0.050997,0.983474,0.952504,0.925034,00:50
3,0.089279,0.039058,0.986001,0.956566,0.933078,00:51
4,0.060602,0.039509,0.987224,0.962804,0.944241,00:49
5,0.044007,0.033619,0.988924,0.966059,0.950614,00:49
6,0.032417,0.032736,0.989244,0.966915,0.952243,01:10


epoch,train_loss,valid_loss,acc,dice,iou_sem_seg,time
0,0.021317,0.036205,0.988868,0.965256,0.949431,01:06
1,0.024572,0.036029,0.990367,0.980086,0.964653,01:10
2,0.027724,0.030428,0.99166,0.983021,0.969447,00:55


Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/xbankov/anaconda3/envs/video699/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/home/xbankov/anaconda3/envs/video699/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/xbankov/anaconda3/envs/video699/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/home/xbankov/anaconda3/envs/video699/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
  File "/home/xbankov/anaconda3/envs/video699/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/home/xbankov/anaconda3/envs/video699/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/xbankov/

KeyboardInterrupt: 