In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import random
import sys

sys.path.append('../tools')

import os

import torch

import pandas as pd
from sklearn.model_selection import train_test_split
import pytorch_lightning as pl
import pytorch_lightning.loggers as pl_loggers
import pytorch_lightning.callbacks as pl_callbacks
import data_utility, annotation_utility
from models.rns_dataloader import *
from active_learning_utility import get_strategy
from active_learning_data import Data
from active_learning_net import Net
import interrater_annotations.tools
from copy import deepcopy
from models.SwaV import SwaV
from models.LSTMDownStream import SupervisedDownstream
import warnings
import pickle
warnings.filterwarnings("ignore")


In [3]:
random_seed = 42
random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed(random_seed)
    # True ensures the algorithm selected by CUFA is deterministic
    torch.backends.cudnn.deterministic = True
    # torch.set_deterministic(True)
    # False ensures CUDA select the same algorithm each time the application is run
    torch.backends.cudnn.benchmark = False

import pytorch_lightning

pytorch_lightning.utilities.seed.seed_everything(seed=random_seed, workers=True)

Global seed set to 42


42

In [4]:
data_dir = "../../../user_data/"
log_folder_root = '../../../user_data/logs/'
ckpt_folder_root = '../../../user_data/checkpoints/'

In [6]:
raw_annotations = pd.read_csv('interrater_annotations/full_updated_anns (1).csv')
ids = list(raw_annotations['HUP_ID'].unique())
ids[:-1]

['HUP096',
 'HUP137',
 'HUP153',
 'HUP108',
 'HUP136',
 'HUP127',
 'HUP128',
 'HUP143',
 'HUP059',
 'HUP047',
 'HUP159',
 'HUP121',
 'HUP147',
 'HUP131',
 'HUP084',
 'RNS021',
 'HUP156',
 'HUP109',
 'HUP101',
 'RNS022',
 'HUP182',
 'HUP205',
 'RNS026',
 'HUP197']

In [7]:
from interrater_annotations.tools import data_utility
data_import = data_utility.read_files(path=data_dir + 'rns_data', patientIDs=ids,
                                      verbose=True)  # Import data with annotation

100%|██████████| 25/25 [00:41<00:00,  1.67s/it]


In [8]:
from interrater_annotations.tools import annotation_utility
annotations = annotation_utility.read_annotation(annotation_path = 'interrater_annotations/all_updated_anns.csv', annotation_catalog_path='interrater_annotations/test_datasets.csv', data = data_import)

42
42
42
43
43
43
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
48
48
48
41
41
41
42
42
42
41
41
41
41
41
41
42
42
42
41
41
41


In [7]:
annotations.annotation_dict['RNS_Test_Dataset_ErinConrad']

Unnamed: 0,Dataset,Annotation_Catalog_Index,Patient_ID,Alias_ID,Episode_Start_Timestamp,Episode_End_Timestamp,Episode_Start_UTC_Time,Episode_End_UTC_Time,Episode_Index,Episode_Start_Index,Episode_End_Index,Annotation_Start_Timestamp,Annotation_End_Timestamp,Annotation_Start_UTC_Time,Annotation_End_UTC_Time,Annotation_Start_Index,Annotation_End_Index,Type_Description,Class_Code,Annotation_Channel,Channel_Code,Binary_Channel_Code
59,RNS_Test_Dataset_ErinConrad,59,HUP096,RNS_Example_1_EC,1427742903476000,1427742993628000,2015-03-30 19:15:03.476,2015-03-30 19:16:33.628,10,250781,273318,[],[],[],[],[],[],No,0,[],[],[]
60,RNS_Test_Dataset_ErinConrad,60,HUP096,RNS_Example_1_EC,1442176442532000,1442176532616000,2015-09-13 20:34:02.532,2015-09-13 20:35:32.616,508,11181530,11204050,[1442176484602143],[1442176503681280],[2015-09-13 20:34:44.602143],[2015-09-13 20:35:03.681280],[11192047],[11196817],Yes,1,"[1,2,3,4]",[1111],[1111]
61,RNS_Test_Dataset_ErinConrad,61,HUP096,RNS_Example_1_EC,1446212334040000,1446212424128000,2015-10-30 13:38:54.040,2015-10-30 13:40:24.128,629,13870413,13892934,[1446212375851151],[1446212394966259],[2015-10-30 13:39:35.851151],[2015-10-30 13:39:54.966259],[13880865],[13885644],Yes,1,"[1,2,3,4]",[1111],[1111]
62,RNS_Test_Dataset_ErinConrad,62,HUP096,RNS_Example_1_EC,1449366435524000,1449366525608000,2015-12-06 01:47:15.524,2015-12-06 01:48:45.608,721,15909503,15932023,[1449366477065366],[1449366499097740],[2015-12-06 01:47:57.065366],[2015-12-06 01:48:19.097740],[15919888],[15925396],Yes,1,"[1,2,3,4]",[1111],[1111]
63,RNS_Test_Dataset_ErinConrad,63,HUP096,RNS_Example_1_EC,1449934755536000,1449934845628000,2015-12-12 15:39:15.536,2015-12-12 15:40:45.628,743,16390505,16413027,[1449934797422690],[1449934815379524],[2015-12-12 15:39:57.422690],[2015-12-12 15:40:15.379524],[16400976],[16405465],Yes,1,"[1,2,3,4]",[1111],[1111]
64,RNS_Test_Dataset_ErinConrad,64,HUP096,RNS_Example_1_EC,1455641459024000,1455641549116000,2016-02-16 16:50:59.024,2016-02-16 16:52:29.116,916,20226279,20248801,[1455641500632209],[1455641524046388],[2016-02-16 16:51:40.632209],[2016-02-16 16:52:04.046388],[20236681],[20242534],Yes,1,"[1,2,3,4]",[1111],[1111]
65,RNS_Test_Dataset_ErinConrad,65,HUP096,RNS_Example_1_EC,1455731984528000,1455732074628000,2016-02-17 17:59:44.528,2016-02-17 18:01:14.628,921,20338919,20361443,[1455732026202876],[1455732043242286],[2016-02-17 18:00:26.202876],[2016-02-17 18:00:43.242286],[20349337],[20353597],Yes,1,"[1,2,3,4]",[1111],[1111]
66,RNS_Test_Dataset_ErinConrad,66,HUP096,RNS_Example_1_EC,1461799060532000,1461799150616000,2016-04-27 23:17:40.532,2016-04-27 23:19:10.616,1105,24450741,24473261,[1461799101866975],[1461799126694561],[2016-04-27 23:18:21.866975],[2016-04-27 23:18:46.694561],[24461074],[24467281],Yes,1,"[1,2,3,4]",[1111],[1111]
67,RNS_Test_Dataset_ErinConrad,67,HUP096,RNS_Example_1_EC,1467873560512000,1467873650608000,2016-07-07 06:39:20.512,2016-07-07 06:40:50.608,1290,28581909,28604432,[],[],[],[],[],[],No,0,[],[],[]
68,RNS_Test_Dataset_ErinConrad,68,HUP096,RNS_Example_1_EC,1469814899464000,1469814989632000,2016-07-29 17:54:59.464,2016-07-29 17:56:29.632,1350,29921013,29943554,[],[],[],[],[],[],No,0,[],[],[]


In [14]:
data_list = os.listdir(data_dir+'rns_pred_cache')
# print(data_list)
# data_list = ['HUP047.npy', 'HUP084.npy', 'HUP096.npy', 'HUP109.npy', 'HUP121.npy', 'HUP129.npy', 'HUP131.npy',
#              'HUP137.npy', 'HUP147.npy', 'HUP156.npy', 'HUP159.npy', 'HUP182.npy', 'HUP197.npy', 'HUP199.npy',
#              'RNS026.npy', 'RNS029.npy']
# data_list = os.listdir(data_dir+'rns_test_cache')[1:]
data_list = ['HUP096.npy', 'HUP137.npy','HUP101.npy']
# data_list = ['HUP182.npy',   'HUP129.npy',   'HUP109.npy', 'HUP156.npy', 'HUP096.npy', 'RNS026.npy',  'HUP159.npy']
# data_list = ['RNS026.npy', 'HUP159.npy', 'HUP129.npy', 'HUP096.npy', 'HUP182.npy']
train_data, train_label, test_data, test_label, train_index, test_index = get_data_by_episode(data_list, file_path = 'rns_pred_cache', split=1)
# train_data, train_label, test_data, test_label, train_index, test_index = get_data_by_episode(data_list, split=0.8)
# data, label,_,_ = get_data(data_list, split=1)
# train_data, test_data, train_label, test_label = sklearn.model_selection.train_test_split(data, label, test_size=0.8, random_state=42)

print(train_data.shape)
print(train_label.shape)
print(test_data.shape)
print(test_label.shape)

3it [00:00, 15.71it/s]

(39,)
(39,)
(0,)
(0,)





In [213]:
 with open(log_folder_root + 'rns_active_selected/' + 'EntropySampling' + '/' + 'selected_indices.pkl', 'rb') as f:
    # Load the content of the file into a Python object
    selected_inds = pickle.load(f)

In [32]:
selected_ind_list = []
for items in selected_inds.items():
    selected_ind_list.append(np.array(items[1]))

In [37]:
selected_ind_list

np.concatenate(selected_ind_list[:-1])

array([ 5164,  5165,  5166, ..., 82181, 82182, 82183], dtype=int64)

In [66]:
np.concatenate(train_index)[np.concatenate(selected_ind_list[:-1])][np.concatenate(train_index)[np.concatenate(selected_ind_list[:-1])]['patient_index'] == b'HUP101']

array([],
      dtype=[('patient_index', 'S10'), ('episode_index', '<i4'), ('slice_index', '<i4'), ('start_index', '<i4')])

In [77]:
np.concatenate(train_index)[np.concatenate(selected_ind_list[:-1])]

In [142]:
train_list = np.array([ti[0] for ti in train_index])

In [143]:
matching_array = [None]*38
i = 0
for index, row in annotations.annotation_dict['RNS_Test_Dataset_BrianLitt'].iterrows():

    pt = row['Patient_ID'].encode('utf-8')
    si = row['Episode_Start_Index']
    filtered_1 = train_list[train_list['patient_index'] == pt]
    filtered_2 = filtered_1[filtered_1['start_index'] == si]
    matching_array[i] = np.where(train_list == filtered_2)[0]
    i+= 1




ValueError: operands could not be broadcast together with shapes (1262,) (0,) 

In [277]:
annot = annotations.annotation_dict['RNS_Test_Dataset_ErinConrad']
annot_nonseizure = annot[annot['Class_Code'] == 0]
annot_seizure = annot[annot['Class_Code'] == 1]
patient_list = ['HUP047',
       'HUP059',
       'HUP084',
       'HUP096',
       'HUP101',
       'HUP108',
       'HUP109',
       'HUP121',
       'HUP127',
       'HUP128',
       'HUP129',
       'HUP131',
       'HUP136',
       'HUP137',
       'HUP143',
       'HUP147',
       'HUP153',
       'HUP156',
       'HUP159',
       'HUP182',
       'HUP192',
       'HUP197',
       'HUP199',
       'HUP205',
       'RNS021',
       'RNS022',
       'RNS026',
       'RNS029']

# patient_list = [ 'HUP137',
#        'HUP153',]

# patient_list = ['RNS026', 'HUP159', 'HUP129', 'HUP096', 'HUP182']
clip_dict = {}
# for p in patient_list:
for p in patient_list:
    # print(p)
    seizure_start_index = np.array([])
    seizure_end_index = np.array([])
    nonseizure_start_index = np.array([])
    nonseizure_end_index = np.array([])
    global_episode_index_seizure = np.array([])
    global_episode_index_nonseizure = np.array([])

    annotation_list = []

    start_index = annot[annot['Patient_ID'] == p]['Episode_Start_Index']
    end_index = annot[annot['Patient_ID'] == p]['Episode_End_Index']
    annot_start_list = annot[annot['Patient_ID'] == p]['Annotation_Start_Index']
    annot_end_list = annot[annot['Patient_ID'] == p]['Annotation_End_Index']
    j = 0
    for i in range(len(start_index)):
        if end_index.iloc[i] - start_index.iloc[i] > 0:
            initial_arr = np.zeros(end_index.iloc[i] - start_index.iloc[i])
            if len(annot_start_list.iloc[i]) > 0:
                sl_order = np.argsort(annot_start_list.iloc[i])
                sl = np.array(annot_start_list.iloc[i])[sl_order]
                el = np.array(annot_end_list.iloc[i])[sl_order]

                for si, ei in zip(sl, el):
                    initial_arr[si - start_index.iloc[i]:ei - start_index.iloc[i]] = 1
        else:
            # print(i)
            initial_arr = np.zeros(1)

        annotation_list.append(initial_arr)

    ind_arr = np.vstack(
        [start_index,
         end_index,
         start_index.index]).astype(int)

    # print(annotation_list)

    valid = np.where((ind_arr[1] - ind_arr[0]) > 500)
    combined_clip = ind_arr[:, valid].squeeze()
    annotation_list = np.array(annotation_list, dtype=object)[valid]
    try:
        combined_clip = np.vstack((combined_clip, annotation_list))
    except:
        print(annotation_list)
        print(p)

    if combined_clip.shape[1]>0:

        clip_dict[p] = combined_clip

In [278]:
from scipy.stats import mode

window_len = 1
stride = 1
concat_n = 4
for id in clip_dict.keys():
    data_import[id].set_window_parameter(window_length=window_len, window_displacement=stride)
    data_import[id].set_concatenation_parameter(concatenate_window_n=concat_n)
    window_indices, _ = data_import[id].get_windowed_data(clip_dict[id][0], clip_dict[id][1])
    import_indices = []
    import_label = []
    import_clip_indices = []
    import_start_indicies = []
    import_patient_ID = []
    for i, ind in enumerate(window_indices):
        indices = window_indices[i]+1-clip_dict[id][0][i]
        offsets = np.arange(249)
        full_indices = indices[:,0][:, np.newaxis] + offsets
        slices_no_loop = clip_dict[id][3][i][full_indices]
        mode_result = mode(slices_no_loop, axis=1)
        mode_values = mode_result.mode

        # print(mode_values)
        import_label.append(mode_values)
        import_indices.append(np.repeat(clip_dict[id][2][i], len(ind)))
        import_clip_indices.append(np.arange(len(ind)))
        import_start_indicies.append(np.repeat(clip_dict[id][0][i], len(ind)))
        import_patient_ID.append(np.repeat(id, len(ind)))

    import_label = np.hstack(import_label)
    import_indices = np.hstack(import_indices)
    import_clip_indices = np.hstack(import_clip_indices)
    import_start_indicies = np.hstack(import_start_indicies)
    import_patient_ID = np.hstack(import_patient_ID)

    data_import[id].normalize_windowed_data()
    _, concatenated_data = data_import[id].get_concatenated_data(data_import[id].windowed_data, arrange='channel_stack')

    assert np.hstack(import_label).shape[0] == concatenated_data.shape[0]

    np.save(data_dir+'rns_pred_cache/' + id + '.npy', {'data': concatenated_data, 'label': import_label, 'patientID': import_patient_ID, 'indices': np.vstack([import_indices,import_clip_indices,import_start_indicies]).T})

In [153]:
# selected_episode = np.array(matching_array[:25])

In [188]:
X_train = np.concatenate(train_data[selected_episode])
y_train = np.concatenate(train_label[selected_episode])
# X_test = np.concatenate(test_data)
# y_test = np.concatenate(test_label)
index_train = np.concatenate(train_index[selected_episode])
# index_test = np.concatenate(test_index)
seq_len_train = np.array([y.shape[0] for y in train_label[selected_episode]])
# seq_len_test = np.array([y.shape[0] for y in test_label])

IndexError: index 145 is out of bounds for axis 0 with size 38

In [15]:
X_train = np.concatenate(train_data)
y_train = np.concatenate(train_label)
# X_test = np.concatenate(test_data)
# y_test = np.concatenate(test_label)
index_train = np.concatenate(train_index)
# index_test = np.concatenate(test_index)
seq_len_train = np.array([y.shape[0] for y in train_label])

In [16]:
args_task = {'n_epoch': 40,
             'transform_train': True,
             'strategy_name': 'EntropySampling',
             'transform': False,
             'loader_tr_args': {'batch_size': 2, 'num_workers': 4, 'collate_fn': collate_fn,
                                'drop_last': True, 'persistent_workers': True},
             'loader_te_args': {'batch_size': 2, 'num_workers': 4, 'collate_fn': collate_fn,
                                'drop_last': True, 'persistent_workers': True}
             }

In [17]:
swav = SwaV().load_from_checkpoint(
    ckpt_folder_root + 'rns_swav_50_12/rns_swav-epoch=82-swav_loss=2.58204.ckpt')
model = SupervisedDownstream(swav.backbone)
# initialize model and save the model state
device = "cuda" if torch.cuda.is_available() else "cpu"

In [19]:
trainer = pl.Trainer( accelerator='gpu', devices=1,precision=16)
from models.rns_dataloader import RNS_Active_by_episode_LSTM, collate_fn
train_dataset = RNS_Active_by_episode_LSTM(train_data, train_label, transform=False, astensor=True)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=4,
    collate_fn=collate_fn,
    shuffle=False,
    drop_last=True,
)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
strategy_name = 'RandomSampling'
with open(log_folder_root + 'rns_active_selected/' + strategy_name + '/' + 'selected_indices.pkl', 'rb') as f:
    # Load the content of the file into a Python object
        selected_inds = pickle.load(f)

selected_inds[1]

array([55685, 55686, 55687, ..., 77472, 77473, 77474])

In [None]:
for rd in range(1, NUM_ROUND + 1):
    print('round ' + str(rd))
    log_file_name = log_folder_root + 'rns_active/active_logs_' + strategy_name + '/logger_round_' + str(
        rd - 1) + '/version_0/metrics.csv'
    logs = pd.read_csv(log_file_name)
    max_ind = logs['val_acc'].argmax()
    max_row = logs.iloc[max_ind]
    ckpt_directory = ckpt_folder_root + 'rns_active/active_checkpoints_' + strategy_name
    ckpt_files = os.listdir(ckpt_directory)
    load_file_name = strategy_name + '_round_' + str(rd - 1) + '-step=' + str(int(max_row['step']+1))
    print(load_file_name)

    ind = next((i for i, s in enumerate(ckpt_files) if load_file_name in s), None)
    print(ind, ckpt_files[ind])
    strategy.net.net.load_from_checkpoint(ckpt_directory + '/' + ckpt_files[ind], backbone=swav.backbone)

    q_idxs = strategy.query(NUM_QUERY * 90)

    with open(log_folder_root + 'rns_active_selected/' + strategy_name + '/' + 'selected_indices.pkl', 'rb') as f:
    # Load the content of the file into a Python object
        selected_inds = pickle.load(f)
    selected_inds[rd] = q_idxs
    with open(log_folder_root + 'rns_active_selected/' + strategy_name + '/' + 'selected_indices.pkl', 'wb') as f:
        pickle.dump(selected_inds, f)
# Now you can use the dictionary object as usual
    strategy.update(q_idxs)
    strategy.net.round = rd
    strategy.net.net.load_state_dict(modelstate)
    torch.cuda.empty_cache()
    strategy.train()
    torch.cuda.empty_cache()

In [190]:
import os
from pathlib import Path


strategy_name = 'MarginSamplingDropout'
ckpt_directory = ckpt_folder_root + 'rns_active/active_checkpoints_' + strategy_name
# ckpt_files = os.listdir(ckpt_directory)
ckpt_files = sorted(Path(ckpt_directory).iterdir(), key=os.path.getmtime)[50:]
high_f1 = 0
high_class_f1 = 0
file_name = []
file_name_class_f1 = []
for cf in tqdm(ckpt_files):
    model = model.load_from_checkpoint(ckpt_folder_root + 'rns_active/active_checkpoints_' + strategy_name + '/' + cf.name, backbone=swav.backbone)
    predictions = trainer.predict(model,train_dataloader)
    output_list = []
    target_list = []
    emb_list = []
    m = nn.Softmax(dim=1)
    seq_len_list = []
    for pred, y, emb, emb2, seq_len in predictions:
        output_list.append(pred)
        target_list.append(y)
        emb_list.append(emb)
        seq_len_list.append(seq_len)
    pred_raw = torch.vstack(output_list)
    target = torch.concat(target_list)
    emb = torch.concat(emb_list)
    out = torch.argmax(pred_raw, dim=1)
    seq_len_arr = torch.tensor([item for sublist in seq_len_list for item in sublist])
    pred_episode = combine_window_to_episode(torch.argmax(pred_raw, dim=1),seq_len_arr)
    class_f1 = sklearn.metrics.f1_score([np.sign(tl.sum()) for tl in pred_episode], [np.sign(tl.sum()) for tl in train_label])
    f1 = sklearn.metrics.f1_score(torch.argmax(pred_raw, dim=1), target)
    if f1>high_f1:
        high_f1 = f1
        file_name.append(cf)
        print('high_f1', f1,cf)
    if class_f1>=high_class_f1:
        high_class_f1 = class_f1
        file_name_class_f1.append((class_f1, cf))
        print('high_class_f1', class_f1,cf)


  0%|          | 0/176 [00:00<?, ?it/s]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

  1%|          | 1/176 [00:05<16:46,  5.75s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_f1 0.6977950713359273 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_2-step=1500-train_loss=0.01494.ckpt
high_class_f1 0.8 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_2-step=1500-train_loss=0.01494.ckpt


Predicting: 0it [00:00, ?it/s]

  1%|          | 2/176 [00:11<16:14,  5.60s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_f1 0.7001287001287002 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_2-step=1575-train_loss=0.01105.ckpt
high_class_f1 0.8 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_2-step=1575-train_loss=0.01105.ckpt


Predicting: 0it [00:00, ?it/s]

  2%|▏         | 3/176 [00:16<16:08,  5.60s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_f1 0.7272727272727274 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_2-step=1650-train_loss=0.01088.ckpt
high_class_f1 0.846153846153846 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_2-step=1650-train_loss=0.01088.ckpt


Predicting: 0it [00:00, ?it/s]

  2%|▏         | 4/176 [00:22<15:56,  5.56s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

  3%|▎         | 5/176 [00:27<15:46,  5.53s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

  3%|▎         | 6/176 [00:33<15:43,  5.55s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

  4%|▍         | 7/176 [00:38<15:34,  5.53s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_class_f1 0.846153846153846 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_3-step=225-train_loss=0.03723.ckpt


Predicting: 0it [00:00, ?it/s]

  5%|▍         | 8/176 [00:44<15:31,  5.55s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_class_f1 0.846153846153846 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_3-step=300-train_loss=0.01363.ckpt


Predicting: 0it [00:00, ?it/s]

  5%|▌         | 9/176 [00:49<15:23,  5.53s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

  6%|▌         | 10/176 [00:55<15:16,  5.52s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

  6%|▋         | 11/176 [01:01<15:11,  5.53s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_f1 0.7644110275689223 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_3-step=525-train_loss=0.01328.ckpt
high_class_f1 0.888888888888889 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_3-step=525-train_loss=0.01328.ckpt


Predicting: 0it [00:00, ?it/s]

  7%|▋         | 12/176 [01:06<15:07,  5.54s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

  7%|▋         | 13/176 [01:12<15:11,  5.59s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_f1 0.7735849056603773 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_3-step=675-train_loss=0.01668.ckpt
high_class_f1 0.888888888888889 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_3-step=675-train_loss=0.01668.ckpt


Predicting: 0it [00:00, ?it/s]

  8%|▊         | 14/176 [01:17<15:04,  5.58s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

  9%|▊         | 15/176 [01:23<14:58,  5.58s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_class_f1 0.888888888888889 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_3-step=825-train_loss=0.04724.ckpt


Predicting: 0it [00:00, ?it/s]

  9%|▉         | 16/176 [01:29<14:55,  5.60s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_f1 0.7773766546329722 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_3-step=900-train_loss=0.01129.ckpt
high_class_f1 0.888888888888889 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_3-step=900-train_loss=0.01129.ckpt


Predicting: 0it [00:00, ?it/s]

 10%|▉         | 17/176 [01:34<14:46,  5.58s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 10%|█         | 18/176 [01:40<14:45,  5.60s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_f1 0.780652418447694 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_3-step=1050-train_loss=0.01149.ckpt
high_class_f1 0.888888888888889 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_3-step=1050-train_loss=0.01149.ckpt


Predicting: 0it [00:00, ?it/s]

 11%|█         | 19/176 [01:45<14:40,  5.61s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_class_f1 0.888888888888889 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_3-step=1125-train_loss=0.02841.ckpt


Predicting: 0it [00:00, ?it/s]

 11%|█▏        | 20/176 [01:51<14:32,  5.59s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_class_f1 0.888888888888889 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_3-step=1200-train_loss=0.01092.ckpt


Predicting: 0it [00:00, ?it/s]

 12%|█▏        | 21/176 [01:57<14:28,  5.61s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 12%|█▎        | 22/176 [02:02<14:19,  5.58s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_class_f1 0.888888888888889 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_3-step=1350-train_loss=0.01721.ckpt


Predicting: 0it [00:00, ?it/s]

 13%|█▎        | 23/176 [02:08<14:16,  5.60s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 14%|█▎        | 24/176 [02:13<14:10,  5.60s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_class_f1 0.888888888888889 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_3-step=1500-train_loss=0.01141.ckpt


Predicting: 0it [00:00, ?it/s]

 14%|█▍        | 25/176 [02:19<14:03,  5.58s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 15%|█▍        | 26/176 [02:24<13:57,  5.58s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 15%|█▌        | 27/176 [02:30<13:48,  5.56s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 16%|█▌        | 28/176 [02:36<13:47,  5.59s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 16%|█▋        | 29/176 [02:41<13:45,  5.61s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 17%|█▋        | 30/176 [02:47<14:00,  5.76s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 18%|█▊        | 31/176 [02:53<13:53,  5.75s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 18%|█▊        | 32/176 [02:59<13:39,  5.69s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 19%|█▉        | 33/176 [03:04<13:26,  5.64s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 19%|█▉        | 34/176 [03:10<13:14,  5.60s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_class_f1 0.888888888888889 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_4-step=600-train_loss=0.01605.ckpt


Predicting: 0it [00:00, ?it/s]

 20%|█▉        | 35/176 [03:15<13:13,  5.63s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 20%|██        | 36/176 [03:21<13:05,  5.61s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 21%|██        | 37/176 [03:26<12:43,  5.50s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_class_f1 0.888888888888889 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_4-step=825-train_loss=0.02223.ckpt


Predicting: 0it [00:00, ?it/s]

 22%|██▏       | 38/176 [03:33<13:14,  5.76s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 22%|██▏       | 39/176 [03:38<12:58,  5.68s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_class_f1 0.896551724137931 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_4-step=975-train_loss=0.01749.ckpt


Predicting: 0it [00:00, ?it/s]

 23%|██▎       | 40/176 [03:44<12:52,  5.68s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 23%|██▎       | 41/176 [03:49<12:42,  5.65s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 24%|██▍       | 42/176 [03:55<12:40,  5.68s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 24%|██▍       | 43/176 [04:04<14:26,  6.52s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 25%|██▌       | 44/176 [04:09<13:44,  6.25s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 26%|██▌       | 45/176 [04:16<14:02,  6.43s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 26%|██▌       | 46/176 [04:23<14:19,  6.62s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 27%|██▋       | 47/176 [04:29<13:43,  6.39s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 27%|██▋       | 48/176 [04:36<13:53,  6.52s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 28%|██▊       | 49/176 [04:42<13:20,  6.31s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 28%|██▊       | 50/176 [04:48<13:08,  6.26s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 29%|██▉       | 51/176 [04:54<13:06,  6.29s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 30%|██▉       | 52/176 [05:00<12:38,  6.12s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 30%|███       | 53/176 [05:07<13:02,  6.36s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_f1 0.8406466512702079 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_5-step=750-train_loss=0.05351.ckpt
high_class_f1 0.9655172413793104 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_5-step=750-train_loss=0.05351.ckpt


Predicting: 0it [00:00, ?it/s]

 31%|███       | 54/176 [05:13<12:37,  6.21s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 31%|███▏      | 55/176 [05:18<12:10,  6.04s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 32%|███▏      | 56/176 [05:24<11:53,  5.95s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 32%|███▏      | 57/176 [05:30<11:58,  6.04s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 33%|███▎      | 58/176 [05:37<12:04,  6.14s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 34%|███▎      | 59/176 [05:43<12:09,  6.24s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 34%|███▍      | 60/176 [05:49<11:45,  6.08s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 35%|███▍      | 61/176 [05:55<11:42,  6.11s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 35%|███▌      | 62/176 [06:01<11:40,  6.15s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 36%|███▌      | 63/176 [06:07<11:17,  6.00s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 36%|███▋      | 64/176 [06:12<10:59,  5.89s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 37%|███▋      | 65/176 [06:18<10:52,  5.88s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 38%|███▊      | 66/176 [06:24<10:49,  5.90s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 38%|███▊      | 67/176 [06:30<10:35,  5.83s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 39%|███▊      | 68/176 [06:38<11:34,  6.43s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 39%|███▉      | 69/176 [06:43<11:02,  6.19s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 40%|███▉      | 70/176 [06:49<10:50,  6.14s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 40%|████      | 71/176 [06:55<10:30,  6.00s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 41%|████      | 72/176 [07:01<10:10,  5.87s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 41%|████▏     | 73/176 [07:06<10:00,  5.83s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 42%|████▏     | 74/176 [07:12<09:47,  5.76s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 43%|████▎     | 75/176 [07:18<09:39,  5.74s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 43%|████▎     | 76/176 [07:24<09:41,  5.82s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 44%|████▍     | 77/176 [07:29<09:34,  5.80s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 44%|████▍     | 78/176 [07:36<09:44,  5.97s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 45%|████▍     | 79/176 [07:43<10:07,  6.26s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 45%|████▌     | 80/176 [07:48<09:45,  6.10s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 46%|████▌     | 81/176 [07:54<09:29,  5.99s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 47%|████▋     | 82/176 [08:00<09:13,  5.89s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 47%|████▋     | 83/176 [08:06<09:12,  5.94s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 48%|████▊     | 84/176 [08:13<09:26,  6.16s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 48%|████▊     | 85/176 [08:18<09:13,  6.08s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 49%|████▉     | 86/176 [08:24<08:55,  5.95s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 49%|████▉     | 87/176 [08:30<08:55,  6.02s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 50%|█████     | 88/176 [08:36<08:49,  6.01s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 51%|█████     | 89/176 [08:42<08:35,  5.92s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 51%|█████     | 90/176 [08:48<08:23,  5.86s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 52%|█████▏    | 91/176 [08:53<08:12,  5.80s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 52%|█████▏    | 92/176 [08:59<08:10,  5.84s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 53%|█████▎    | 93/176 [09:05<08:01,  5.80s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 53%|█████▎    | 94/176 [09:12<08:17,  6.07s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 54%|█████▍    | 95/176 [09:18<08:12,  6.08s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 55%|█████▍    | 96/176 [09:24<08:06,  6.08s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 55%|█████▌    | 97/176 [09:30<07:50,  5.95s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 56%|█████▌    | 98/176 [09:35<07:37,  5.86s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 56%|█████▋    | 99/176 [09:42<07:59,  6.23s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 57%|█████▋    | 100/176 [09:49<08:02,  6.35s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 57%|█████▋    | 101/176 [09:55<07:46,  6.22s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 58%|█████▊    | 102/176 [10:01<07:28,  6.06s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 59%|█████▊    | 103/176 [10:07<07:28,  6.14s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 59%|█████▉    | 104/176 [10:13<07:23,  6.17s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 60%|█████▉    | 105/176 [10:19<07:08,  6.04s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 60%|██████    | 106/176 [10:25<06:56,  5.94s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 61%|██████    | 107/176 [10:31<06:55,  6.02s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 61%|██████▏   | 108/176 [10:37<06:54,  6.10s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 62%|██████▏   | 109/176 [10:43<06:41,  5.99s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 62%|██████▎   | 110/176 [10:49<06:30,  5.91s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 63%|██████▎   | 111/176 [10:55<06:36,  6.10s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 64%|██████▎   | 112/176 [11:01<06:29,  6.09s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 64%|██████▍   | 113/176 [11:07<06:21,  6.06s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 65%|██████▍   | 114/176 [11:13<06:18,  6.11s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 65%|██████▌   | 115/176 [11:19<06:13,  6.12s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 66%|██████▌   | 116/176 [11:26<06:10,  6.17s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 66%|██████▋   | 117/176 [11:32<06:07,  6.23s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 67%|██████▋   | 118/176 [11:38<05:51,  6.07s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 68%|██████▊   | 119/176 [11:44<05:44,  6.05s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 68%|██████▊   | 120/176 [11:49<05:31,  5.92s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 69%|██████▉   | 121/176 [11:55<05:20,  5.82s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 69%|██████▉   | 122/176 [12:01<05:10,  5.75s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 70%|██████▉   | 123/176 [12:06<05:06,  5.78s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 70%|███████   | 124/176 [12:12<04:56,  5.71s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 71%|███████   | 125/176 [12:18<04:52,  5.73s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 72%|███████▏  | 126/176 [12:23<04:44,  5.69s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 72%|███████▏  | 127/176 [12:29<04:38,  5.68s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 73%|███████▎  | 128/176 [12:35<04:38,  5.79s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 73%|███████▎  | 129/176 [12:41<04:35,  5.86s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 74%|███████▍  | 130/176 [12:47<04:26,  5.78s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 74%|███████▍  | 131/176 [12:52<04:19,  5.76s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 75%|███████▌  | 132/176 [12:58<04:16,  5.82s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 76%|███████▌  | 133/176 [13:04<04:10,  5.82s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 76%|███████▌  | 134/176 [13:11<04:17,  6.14s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 77%|███████▋  | 135/176 [13:18<04:19,  6.33s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 77%|███████▋  | 136/176 [13:24<04:13,  6.34s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 78%|███████▊  | 137/176 [13:31<04:12,  6.47s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 78%|███████▊  | 138/176 [13:38<04:06,  6.50s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 79%|███████▉  | 139/176 [13:44<03:55,  6.36s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 80%|███████▉  | 140/176 [13:50<03:47,  6.31s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 80%|████████  | 141/176 [13:56<03:44,  6.42s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 81%|████████  | 142/176 [14:03<03:37,  6.39s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 81%|████████▏ | 143/176 [14:09<03:33,  6.48s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 82%|████████▏ | 144/176 [14:16<03:30,  6.57s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 82%|████████▏ | 145/176 [14:22<03:19,  6.44s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 83%|████████▎ | 146/176 [14:28<03:08,  6.30s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_class_f1 0.9655172413793104 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_9-step=1125-train_loss=0.01908.ckpt


Predicting: 0it [00:00, ?it/s]

 84%|████████▎ | 147/176 [14:34<02:57,  6.12s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 84%|████████▍ | 148/176 [14:40<02:51,  6.14s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 85%|████████▍ | 149/176 [14:46<02:46,  6.16s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 85%|████████▌ | 150/176 [14:52<02:37,  6.06s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 86%|████████▌ | 151/176 [14:59<02:32,  6.12s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_class_f1 0.9655172413793104 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_9-step=1500-train_loss=0.01125.ckpt


Predicting: 0it [00:00, ?it/s]

 86%|████████▋ | 152/176 [15:04<02:24,  6.01s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_class_f1 0.9655172413793104 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_9-step=1575-train_loss=0.02680.ckpt


Predicting: 0it [00:00, ?it/s]

 87%|████████▋ | 153/176 [15:10<02:15,  5.89s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_class_f1 0.9655172413793104 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_9-step=1650-train_loss=0.01904.ckpt


Predicting: 0it [00:00, ?it/s]

 88%|████████▊ | 154/176 [15:16<02:09,  5.87s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 88%|████████▊ | 155/176 [15:22<02:06,  6.02s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


high_class_f1 0.9655172413793104 ..\..\..\user_data\checkpoints\rns_active\active_checkpoints_MarginSamplingDropout\MarginSamplingDropout_round_9-step=1800-train_loss=0.01269.ckpt


Predicting: 0it [00:00, ?it/s]

 89%|████████▊ | 156/176 [15:28<01:59,  5.96s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 89%|████████▉ | 157/176 [15:34<01:55,  6.05s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 90%|████████▉ | 158/176 [15:40<01:46,  5.94s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 90%|█████████ | 159/176 [15:47<01:46,  6.28s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 91%|█████████ | 160/176 [15:53<01:37,  6.10s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 91%|█████████▏| 161/176 [15:58<01:29,  5.99s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 92%|█████████▏| 162/176 [16:04<01:23,  5.96s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 93%|█████████▎| 163/176 [16:10<01:16,  5.91s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 93%|█████████▎| 164/176 [16:16<01:09,  5.81s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 94%|█████████▍| 165/176 [16:21<01:03,  5.76s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 94%|█████████▍| 166/176 [16:27<00:58,  5.85s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 95%|█████████▍| 167/176 [16:33<00:52,  5.84s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 95%|█████████▌| 168/176 [16:39<00:46,  5.78s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 96%|█████████▌| 169/176 [16:44<00:40,  5.73s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 97%|█████████▋| 170/176 [16:50<00:34,  5.76s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 97%|█████████▋| 171/176 [16:57<00:29,  5.93s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 98%|█████████▊| 172/176 [17:03<00:24,  6.07s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 98%|█████████▊| 173/176 [17:09<00:17,  5.97s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 99%|█████████▉| 174/176 [17:14<00:11,  5.90s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

 99%|█████████▉| 175/176 [17:20<00:05,  5.84s/it]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

100%|██████████| 176/176 [17:26<00:00,  5.95s/it]


In [191]:
 file_name_class_f1

[(0.8,
  WindowsPath('../../../user_data/checkpoints/rns_active/active_checkpoints_MarginSamplingDropout/MarginSamplingDropout_round_2-step=1500-train_loss=0.01494.ckpt')),
 (0.8,
  WindowsPath('../../../user_data/checkpoints/rns_active/active_checkpoints_MarginSamplingDropout/MarginSamplingDropout_round_2-step=1575-train_loss=0.01105.ckpt')),
 (0.846153846153846,
  WindowsPath('../../../user_data/checkpoints/rns_active/active_checkpoints_MarginSamplingDropout/MarginSamplingDropout_round_2-step=1650-train_loss=0.01088.ckpt')),
 (0.846153846153846,
  WindowsPath('../../../user_data/checkpoints/rns_active/active_checkpoints_MarginSamplingDropout/MarginSamplingDropout_round_3-step=225-train_loss=0.03723.ckpt')),
 (0.846153846153846,
  WindowsPath('../../../user_data/checkpoints/rns_active/active_checkpoints_MarginSamplingDropout/MarginSamplingDropout_round_3-step=300-train_loss=0.01363.ckpt')),
 (0.888888888888889,
  WindowsPath('../../../user_data/checkpoints/rns_active/active_checkpoint

In [192]:
file_name

[WindowsPath('../../../user_data/checkpoints/rns_active/active_checkpoints_MarginSamplingDropout/MarginSamplingDropout_round_2-step=1500-train_loss=0.01494.ckpt'),
 WindowsPath('../../../user_data/checkpoints/rns_active/active_checkpoints_MarginSamplingDropout/MarginSamplingDropout_round_2-step=1575-train_loss=0.01105.ckpt'),
 WindowsPath('../../../user_data/checkpoints/rns_active/active_checkpoints_MarginSamplingDropout/MarginSamplingDropout_round_2-step=1650-train_loss=0.01088.ckpt'),
 WindowsPath('../../../user_data/checkpoints/rns_active/active_checkpoints_MarginSamplingDropout/MarginSamplingDropout_round_3-step=525-train_loss=0.01328.ckpt'),
 WindowsPath('../../../user_data/checkpoints/rns_active/active_checkpoints_MarginSamplingDropout/MarginSamplingDropout_round_3-step=675-train_loss=0.01668.ckpt'),
 WindowsPath('../../../user_data/checkpoints/rns_active/active_checkpoints_MarginSamplingDropout/MarginSamplingDropout_round_3-step=900-train_loss=0.01129.ckpt'),
 WindowsPath('../..

In [208]:
model = model.load_from_checkpoint(ckpt_folder_root + 'rns_active/active_checkpoints_' + strategy_name + '/' + 'MarginSamplingDropout_round_5-step=750-train_loss=0.05351.ckpt', backbone=swav.backbone)

In [209]:
predictions = trainer.predict(model,train_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

In [210]:
import torch.nn as nn
output_list = []
target_list = []
emb_list = []
m = nn.Softmax(dim=1)
seq_len_list = []
for pred, y, emb, emb2, seq_len in predictions:
    output_list.append(pred)
    target_list.append(y)
    emb_list.append(emb)
    seq_len_list.append(seq_len)

In [211]:
pred_raw = torch.vstack(output_list)
target = torch.concat(target_list)
emb = torch.concat(emb_list)
out = torch.argmax(pred_raw, dim=1)
seq_len_arr = torch.tensor([item for sublist in seq_len_list for item in sublist])

In [212]:
import sklearn
clf_report = sklearn.metrics.classification_report(torch.argmax(pred_raw, dim=1), target, digits=6)

print(f"Classification Report : \n{clf_report}")

Classification Report : 
              precision    recall  f1-score   support

           0   0.995551  0.958800  0.976830      3034
           1   0.744376  0.965517  0.840647       377

    accuracy                       0.959543      3411
   macro avg   0.869964  0.962159  0.908738      3411
weighted avg   0.967790  0.959543  0.961778      3411


In [213]:
pred_episode = combine_window_to_episode(torch.argmax(pred_raw, dim=1),seq_len_arr)

In [214]:

import matplotlib.pyplot as plt
plt.figure(figsize=(10,6))
plt.plot(target)
plt.plot(torch.argmax(pred_raw, dim=1))
plt.show()

<IPython.core.display.Javascript object>

In [349]:
i = 31
plt.figure()
plt.plot(train_data[i][:,:,4].flatten()+2,color = 'k')
plt.plot(train_data[i][:,:,13].flatten()+1,color = 'k')
plt.plot(train_data[i][:,:,22].flatten(),color = 'k')
plt.plot(train_data[i][:,:,31].flatten()-1,color = 'k')
label_start, label_end, pred_start, pred_end = plot_high_light(train_label[i],pred_episode[i])
if len(pred_start)>0:
    for i in range(len(pred_start)):
        plt.axvspan(pred_start[i]*249, pred_end[i]*249, color="blue", alpha=0.3)
if len(label_start)>0:
    for i in range(len(label_start)):
        plt.axvspan(label_start[i]*249, label_end[i]*249, color="yellow", alpha=0.3)
plt.show()

<IPython.core.display.Javascript object>

In [108]:
pred_start

array([], dtype=int64)

In [337]:
def plot_high_light(train_label, pred_episode):
    label_start = np.where(np.diff(train_label) == 1)[0]
    label_end = np.where(np.diff(train_label) == -1)[0]
    pred_start = np.where(np.diff(pred_episode) == 1)[0]
    pred_end = np.where(np.diff(pred_episode) == -1)[0]
    label_start, label_end = check_consistent(label_start,label_end, len(train_label))
    pred_start, pred_end = check_consistent(pred_start,pred_end, len(pred_episode))
    # if len(label_start)>0:
    #     plt.axvspan(label_start[0]*249, label_end[0]*249, color="yellow", alpha=0.3)

    return label_start, label_end, pred_start, pred_end


def check_consistent(start, end, total_len):
    if len(start) != len(end):
        if len(start)>0:
            end = [total_len]
        elif len(end)>0:
            start = [0]
    return start, end


In [27]:
i = 31
plot_high_light(train_label[i],pred_episode[i])

NameError: name 'pred_episode' is not defined

In [28]:
train_data[i]

array([[[0.49853372, 0.49853372, 0.49853372, ..., 0.50928641,
         0.50439883, 0.44868035],
        [0.50439883, 0.50439883, 0.50439883, ..., 0.49853372,
         0.50439883, 0.46236559],
        [0.50146628, 0.50146628, 0.50146628, ..., 0.49364614,
         0.49364614, 0.46236559],
        ...,
        [0.50342131, 0.50342131, 0.50342131, ..., 0.49462366,
         0.47507331, 0.47214076],
        [0.49853372, 0.49853372, 0.49853372, ..., 0.51026393,
         0.4740958 , 0.48680352],
        [0.49853372, 0.49853372, 0.49853372, ..., 0.51026393,
         0.46138807, 0.50146628]],

       [[0.49853372, 0.49853372, 0.49853372, ..., 0.50439883,
         0.44868035, 0.50342131],
        [0.50439883, 0.50439883, 0.50439883, ..., 0.50439883,
         0.46236559, 0.50733138],
        [0.50146628, 0.50146628, 0.50146628, ..., 0.49364614,
         0.46236559, 0.49755621],
        ...,
        [0.50342131, 0.50342131, 0.50342131, ..., 0.47507331,
         0.47214076, 0.49853372],
        [0.4

In [115]:
torch.tensor([item for sublist in seq_len_list for item in sublist])

tensor([90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
        90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 70, 90, 90, 90,
        90, 35, 66])

In [116]:
def combine_window_to_episode(data, seq_len, index=None):
    cum_sum_index = np.cumsum(seq_len)
    cum_sum_index = np.insert(cum_sum_index, 0, 0)

    assert len(data) == cum_sum_index[-1]

    data_out = [None] * (len(cum_sum_index) - 1)

    for i in range(1, len(cum_sum_index)):
        start_index, end_index = cum_sum_index[i - 1], cum_sum_index[i]
        episode_data = data[start_index:end_index]

        if index is None:
            out = episode_data
        else:
            episode_labeled = index[start_index:end_index]
            out = episode_data[episode_labeled]

        if len(out) > 0:
            data_out[i - 1] = out

    data_out = [segment for segment in data_out if segment is not None]

    return np.array(data_out, dtype=object)

In [117]:
[np.sign(tl.sum()) for tl in train_label]

[0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0]

In [118]:
[np.sign(tl.sum()) for tl in pred_episode]

[tensor(0),
 tensor(1),
 tensor(1),
 tensor(1),
 tensor(1),
 tensor(1),
 tensor(1),
 tensor(1),
 tensor(0),
 tensor(0),
 tensor(0),
 tensor(0),
 tensor(0),
 tensor(0),
 tensor(0),
 tensor(1),
 tensor(1),
 tensor(1),
 tensor(1),
 tensor(0),
 tensor(0),
 tensor(0),
 tensor(0),
 tensor(0),
 tensor(0),
 tensor(0),
 tensor(0),
 tensor(0),
 tensor(0),
 tensor(0),
 tensor(0),
 tensor(1),
 tensor(1),
 tensor(0),
 tensor(0),
 tensor(1),
 tensor(0),
 tensor(0),
 tensor(0)]

In [172]:
clf_report = sklearn.metrics.classification_report([np.sign(tl.sum()) for tl in pred_episode], [np.sign(tl.sum()) for tl in train_label], digits=6)

print(f"Classification Report : \n{clf_report}")

Classification Report : 
              precision    recall  f1-score   support

           0   1.000000  0.923077  0.960000        26
           1   0.866667  1.000000  0.928571        13

    accuracy                       0.948718        39
   macro avg   0.933333  0.961538  0.944286        39
weighted avg   0.955556  0.948718  0.949524        39


In [121]:
sklearn.metrics.f1_score([np.sign(tl.sum()) for tl in pred_episode], [np.sign(tl.sum()) for tl in train_label])

0.9655172413793104

In [395]:
np.where(np.diff(train_label[i]) == -1)[0]

0

In [329]:
machine_annot = annotations.annotation_dict['RNS_Test_Dataset_ErinConrad'].copy()

In [330]:
def get_start_stop(pred_episode):
    pred_start = np.where(np.diff(pred_episode) == 1)[0]
    pred_end = np.where(np.diff(pred_episode) == -1)[0]
    pred_start, pred_end = check_consistent(pred_start,pred_end, len(pred_episode))
    pred_start *= 249
    pred_end *= 249
    return pred_start, pred_end


def check_consistent(start, end, total_len):
    if len(start) != len(end):
        if len(start)>0:
            end = np.array([total_len], dtype=np.int64)
        elif len(end)>0:
            start = np.array([0], dtype=np.int64)
    return start, end

In [331]:
from datetime import datetime
def interpolate_time(index, start_index, end_index, start_timestamp, end_timestamp):
    return int((index-start_index)/(end_index-start_index)*(end_timestamp-start_timestamp)+start_timestamp)

def timestamp_to_utctime(ts):
    """
    :param ts: int - datetime timestamp
    :return: string - utc time
    """
    return datetime.utcfromtimestamp(ts * 1e-6)


In [332]:
machine_annot = machine_annot.reset_index(drop=True)
machine_annot.Dataset = 'RNS_Test_Dataset_DeepLearning'
machine_annot.Alias_ID = 'RNS_Example_DL'
machine_annot = machine_annot.drop(['Type_Description', 'Annotation_Channel', 'Channel_Code', 'Binary_Channel_Code','Annotation_Catalog_Index'], axis=1)

In [333]:
machine_annot

Unnamed: 0,Dataset,Patient_ID,Alias_ID,Episode_Start_Timestamp,Episode_End_Timestamp,Episode_Start_UTC_Time,Episode_End_UTC_Time,Episode_Index,Episode_Start_Index,Episode_End_Index,Annotation_Start_Timestamp,Annotation_End_Timestamp,Annotation_Start_UTC_Time,Annotation_End_UTC_Time,Annotation_Start_Index,Annotation_End_Index,Class_Code
0,RNS_Test_Dataset_DeepLearning,HUP096,RNS_Example_DL,1427742903476000,1427742993628000,2015-03-30 19:15:03.476,2015-03-30 19:16:33.628,10,250781,273318,[],[],[],[],[],[],0
1,RNS_Test_Dataset_DeepLearning,HUP096,RNS_Example_DL,1442176442532000,1442176532616000,2015-09-13 20:34:02.532,2015-09-13 20:35:32.616,508,11181530,11204050,[1442176484602143],[1442176503681280],[2015-09-13 20:34:44.602143],[2015-09-13 20:35:03.681280],[11192047],[11196817],1
2,RNS_Test_Dataset_DeepLearning,HUP096,RNS_Example_DL,1446212334040000,1446212424128000,2015-10-30 13:38:54.040,2015-10-30 13:40:24.128,629,13870413,13892934,[1446212375851151],[1446212394966259],[2015-10-30 13:39:35.851151],[2015-10-30 13:39:54.966259],[13880865],[13885644],1
3,RNS_Test_Dataset_DeepLearning,HUP096,RNS_Example_DL,1449366435524000,1449366525608000,2015-12-06 01:47:15.524,2015-12-06 01:48:45.608,721,15909503,15932023,[1449366477065366],[1449366499097740],[2015-12-06 01:47:57.065366],[2015-12-06 01:48:19.097740],[15919888],[15925396],1
4,RNS_Test_Dataset_DeepLearning,HUP096,RNS_Example_DL,1449934755536000,1449934845628000,2015-12-12 15:39:15.536,2015-12-12 15:40:45.628,743,16390505,16413027,[1449934797422690],[1449934815379524],[2015-12-12 15:39:57.422690],[2015-12-12 15:40:15.379524],[16400976],[16405465],1
5,RNS_Test_Dataset_DeepLearning,HUP096,RNS_Example_DL,1455641459024000,1455641549116000,2016-02-16 16:50:59.024,2016-02-16 16:52:29.116,916,20226279,20248801,[1455641500632209],[1455641524046388],[2016-02-16 16:51:40.632209],[2016-02-16 16:52:04.046388],[20236681],[20242534],1
6,RNS_Test_Dataset_DeepLearning,HUP096,RNS_Example_DL,1455731984528000,1455732074628000,2016-02-17 17:59:44.528,2016-02-17 18:01:14.628,921,20338919,20361443,[1455732026202876],[1455732043242286],[2016-02-17 18:00:26.202876],[2016-02-17 18:00:43.242286],[20349337],[20353597],1
7,RNS_Test_Dataset_DeepLearning,HUP096,RNS_Example_DL,1461799060532000,1461799150616000,2016-04-27 23:17:40.532,2016-04-27 23:19:10.616,1105,24450741,24473261,[1461799101866975],[1461799126694561],[2016-04-27 23:18:21.866975],[2016-04-27 23:18:46.694561],[24461074],[24467281],1
8,RNS_Test_Dataset_DeepLearning,HUP096,RNS_Example_DL,1467873560512000,1467873650608000,2016-07-07 06:39:20.512,2016-07-07 06:40:50.608,1290,28581909,28604432,[],[],[],[],[],[],0
9,RNS_Test_Dataset_DeepLearning,HUP096,RNS_Example_DL,1469814899464000,1469814989632000,2016-07-29 17:54:59.464,2016-07-29 17:56:29.632,1350,29921013,29943554,[],[],[],[],[],[],0


In [334]:
for index, row in machine_annot.iterrows():
    start_ind, end_ind = get_start_stop(pred_episode[index])
    machine_annot.at[index, 'Annotation_Start_Index'] = (row.Episode_Start_Index + start_ind).tolist()
    machine_annot.at[index, 'Annotation_End_Index'] = (row.Episode_Start_Index + end_ind).tolist()
    machine_annot.at[index, 'Class_Code'] = np.sign(pred_episode[index].sum()).item()
    machine_annot.at[index, 'Annotation_Start_Timestamp'] = [
        interpolate_time(st_ind, row.Episode_Start_Index, row.Episode_End_Index, row.Episode_Start_Timestamp,
                         row.Episode_End_Timestamp) for st_ind in (row.Episode_Start_Index + start_ind).tolist()]
    machine_annot.at[index, 'Annotation_End_Timestamp'] = [
        interpolate_time(ed_ind, row.Episode_Start_Index, row.Episode_End_Index, row.Episode_Start_Timestamp,
                         row.Episode_End_Timestamp) for ed_ind in (row.Episode_Start_Index + end_ind).tolist()]
    machine_annot.at[index, 'Annotation_Start_UTC_Time'] = [
        timestamp_to_utctime(interpolate_time(st_ind, row.Episode_Start_Index, row.Episode_End_Index, row.Episode_Start_Timestamp,
                         row.Episode_End_Timestamp)) for st_ind in (row.Episode_Start_Index + start_ind).tolist()]
    machine_annot.at[index, 'Annotation_End_UTC_Time'] = [
        timestamp_to_utctime(interpolate_time(ed_ind, row.Episode_Start_Index, row.Episode_End_Index, row.Episode_Start_Timestamp,
                         row.Episode_End_Timestamp)) for ed_ind in (row.Episode_Start_Index + end_ind).tolist()]


In [285]:
1442176484602143-interpolate_time(11192047, 11181530, 11204050 ,1442176442532000, 1442176532616000)

275.0

In [283]:
annotations.annotation_dict['RNS_Test_Dataset_ErinConrad']

Unnamed: 0,Dataset,Annotation_Catalog_Index,Patient_ID,Alias_ID,Episode_Start_Timestamp,Episode_End_Timestamp,Episode_Start_UTC_Time,Episode_End_UTC_Time,Episode_Index,Episode_Start_Index,Episode_End_Index,Annotation_Start_Timestamp,Annotation_End_Timestamp,Annotation_Start_UTC_Time,Annotation_End_UTC_Time,Annotation_Start_Index,Annotation_End_Index,Type_Description,Class_Code,Annotation_Channel,Channel_Code,Binary_Channel_Code
59,RNS_Test_Dataset_ErinConrad,59,HUP096,RNS_Example_1_EC,1427742903476000,1427742993628000,2015-03-30 19:15:03.476,2015-03-30 19:16:33.628,10,250781,273318,[],[],[],[],[],[],No,0,[],[],[]
60,RNS_Test_Dataset_ErinConrad,60,HUP096,RNS_Example_1_EC,1442176442532000,1442176532616000,2015-09-13 20:34:02.532,2015-09-13 20:35:32.616,508,11181530,11204050,[1442176484602143],[1442176503681280],[2015-09-13 20:34:44.602143],[2015-09-13 20:35:03.681280],[11192047],[11196817],Yes,1,"[1,2,3,4]",[1111],[1111]
61,RNS_Test_Dataset_ErinConrad,61,HUP096,RNS_Example_1_EC,1446212334040000,1446212424128000,2015-10-30 13:38:54.040,2015-10-30 13:40:24.128,629,13870413,13892934,[1446212375851151],[1446212394966259],[2015-10-30 13:39:35.851151],[2015-10-30 13:39:54.966259],[13880865],[13885644],Yes,1,"[1,2,3,4]",[1111],[1111]
62,RNS_Test_Dataset_ErinConrad,62,HUP096,RNS_Example_1_EC,1449366435524000,1449366525608000,2015-12-06 01:47:15.524,2015-12-06 01:48:45.608,721,15909503,15932023,[1449366477065366],[1449366499097740],[2015-12-06 01:47:57.065366],[2015-12-06 01:48:19.097740],[15919888],[15925396],Yes,1,"[1,2,3,4]",[1111],[1111]
63,RNS_Test_Dataset_ErinConrad,63,HUP096,RNS_Example_1_EC,1449934755536000,1449934845628000,2015-12-12 15:39:15.536,2015-12-12 15:40:45.628,743,16390505,16413027,[1449934797422690],[1449934815379524],[2015-12-12 15:39:57.422690],[2015-12-12 15:40:15.379524],[16400976],[16405465],Yes,1,"[1,2,3,4]",[1111],[1111]
64,RNS_Test_Dataset_ErinConrad,64,HUP096,RNS_Example_1_EC,1455641459024000,1455641549116000,2016-02-16 16:50:59.024,2016-02-16 16:52:29.116,916,20226279,20248801,[1455641500632209],[1455641524046388],[2016-02-16 16:51:40.632209],[2016-02-16 16:52:04.046388],[20236681],[20242534],Yes,1,"[1,2,3,4]",[1111],[1111]
65,RNS_Test_Dataset_ErinConrad,65,HUP096,RNS_Example_1_EC,1455731984528000,1455732074628000,2016-02-17 17:59:44.528,2016-02-17 18:01:14.628,921,20338919,20361443,[1455732026202876],[1455732043242286],[2016-02-17 18:00:26.202876],[2016-02-17 18:00:43.242286],[20349337],[20353597],Yes,1,"[1,2,3,4]",[1111],[1111]
66,RNS_Test_Dataset_ErinConrad,66,HUP096,RNS_Example_1_EC,1461799060532000,1461799150616000,2016-04-27 23:17:40.532,2016-04-27 23:19:10.616,1105,24450741,24473261,[1461799101866975],[1461799126694561],[2016-04-27 23:18:21.866975],[2016-04-27 23:18:46.694561],[24461074],[24467281],Yes,1,"[1,2,3,4]",[1111],[1111]
67,RNS_Test_Dataset_ErinConrad,67,HUP096,RNS_Example_1_EC,1467873560512000,1467873650608000,2016-07-07 06:39:20.512,2016-07-07 06:40:50.608,1290,28581909,28604432,[],[],[],[],[],[],No,0,[],[],[]
68,RNS_Test_Dataset_ErinConrad,68,HUP096,RNS_Example_1_EC,1469814899464000,1469814989632000,2016-07-29 17:54:59.464,2016-07-29 17:56:29.632,1350,29921013,29943554,[],[],[],[],[],[],No,0,[],[],[]


In [None]:
ind-start_ind/end_ind - start_ind = ind_ts-st_ts/endst-st_ts

In [335]:
machine_annot.to_csv('machineprediction.csv', index=True)