In [None]:
import os, time, random, logging , datetime , cv2 , csv , subprocess
import numpy as np

import tensorflow as tf
print("tf",tf.version.VERSION)
from tensorflow import keras

from utils import globo ,  xdv , tfh5


''' GPU CONFIGURATION '''

tfh5.set_tf_loglevel(logging.ERROR)
tfh5.tf.debugging.set_log_device_placement(False) #Enabling device placement logging causes any Tensor allocations or operations to be printed.
tfh5.set_memory_growth()
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
os.environ["CUDA_VISIBLE_DEVICES"]="1,2"

In [None]:
''' TRAIN & VALDT '''
#train_fn, train_labels, train_tot_frames, valdt_fn, valdt_labels , valdt_tot_frames = xdv.train_valdt_files(tframes=True)
train_fn, train_labels, valdt_fn, valdt_labels = xdv.train_valdt_files()

update_index_train = range(0, len(train_fn))
update_index_valdt = range(0, len(valdt_fn))

In [None]:
''' CONFIGS '''

train_config = {
    "frame_step":2, #24 fps -> 12
    
    "in_height":120,
    "in_width":160,
    
    "batch_size":1,
    "augment":True,
    "shuffle":False,
    
    "ativa" : 'leakyrelu',
    "optima" : 'sgd',
    "batch_type" : 0,   # =0 all batch have frame_max or video length // =1 last batch has frame_max frames // =2 last batch has no repetead frames
    "frame_max" : 8000,
    "ckpt_start" : f"{0:0>8}",  #used in train_model: if 00000000 start from scratch, else start from ckpt with config stated
    
    "epochs" : 1
}


In [None]:
class PrintDataCallback(tf.keras.callbacks.Callback):
    def on_batch_end(self, batch, logs=None):
        # Access the input data shape of the current batch
        input_batch_shape = self.model.get_layer('print_input_shape').output

        # Run the model on a dummy input to get the actual input shape
        input_shape = self.model.predict(np.zeros((1, *input_batch_shape.shape)))

        # Print the input shape
        print("Input shape:\n", input_shape)

In [None]:
class DataGen(keras.utils.Sequence):
    def __init__(self, vpath_list, label_list, config , valdt=False):
        
        self.valdt = valdt
        self.vpath_list = vpath_list
        self.label_list = label_list
        
        print(len(vpath_list),(len(label_list)))
        
        self.batch_size = config["batch_size"]
        self.frame_max = config["frame_max"]
        
        self.in_height = config["in_height"]
        self.in_width = config["in_width"]
        
        self.augment = config["augment"]
        self.shuffle = config["shuffle"]
        
        self.len_vpath_list = len(self.vpath_list)
        #self.indices = np.arange(self.len_vpath_list)

        self.frame_step = config["frame_step"]
    
    
    def skip_ms(self,cap):
        start_frame = cap.get(cv2.CAP_PROP_POS_FRAMES)
        #print("skip_start",start_frame)
        while True:
            success = cap.grab()
            curr_frame = cap.get(cv2.CAP_PROP_POS_FRAMES)

            if not success or curr_frame - start_frame >= self.frame_step:break
        
        if not success:return success, None, start_frame + self.frame_step

        success, image = cap.retrieve()
        return success, image, curr_frame        
        
        
    def __len__(self):
        #if self.augment:
        #    print("\n\n__len__ = n batchs = ",int(np.ceil(self.len_vpath_list * 2 / float(self.batch_size ))) ," w/ '2' vid each")
        #    return int(np.ceil(self.len_vpath_list * 2 / float(self.batch_size )))
        #else:
        #    print("\n\n__len__ = n batchs = ",int(np.ceil(self.len_vpath_list / float(self.batch_size)))," w/ 1 vid each")
        #    return int(np.ceil(self.len_vpath_list / float(self.batch_size)))
        return self.len_vpath_list
           
    def __getitem__(self, idx):
        #batch_indices = self.indices[idx * self.batch_size : (idx+1) * self.batch_size]
        print("\n\nbatch_indx",idx)
        
        batch_frames , batch_frames_flip , batch_labels = [] , [] , []
        
        #for i, index in enumerate(batch_indices):
        #vpath = self.vpath_list[index]
        #label = self.label_list[index] 
        vpath = self.vpath_list[idx]
        label = self.label_list[idx]
    
        video = cv2.VideoCapture(vpath)
        tframes = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        print("\n*************",vpath , label , tframes)
        
        ## if normal > frame_max picks random frame_max W
        if label == 0 and tframes > self.frame_max :
            start_index = random.randint(0, tframes - self.frame_max)
            end_index = start_index + self.frame_max
            video.set(cv2.CAP_PROP_POS_FRAMES, start_index)
        ## else ingests full video
        else: 
            start_index = 0
            end_index = tframes
        
        print("sstart_index,end_index",start_index , end_index)
        
        frames = []
        curr_frame = 0
        success, frame = video.read()
        for j in range(end_index - start_index):
            
            if not success or curr_frame > end_index: break
            
            frame = cv2.resize(frame, (self.in_width, self.in_height))
            frame_arr = np.array(frame)/255.0
            frames.append(frame_arr)
            
            ## jumps the next frame wo decoding
            success, frame, curr_frame = self.skip_ms(video)
            #print("skip_end",curr_frame)
                
        
        frames_arr = np.array(frames)
        frames_arr_flip = np.flip(frames_arr, axis=2)
        print("frames",frames_arr.shape,frames_arr_flip.shape)

        batch_frames.append(frames_arr)
        batch_frames_flip.append(frames_arr_flip)
        batch_labels.append(label)
        
        XN = np.array(batch_frames).astype(np.float32)
        XF = np.array(batch_frames_flip).astype(np.float32)
        y = np.array(batch_labels).astype(np.float32)
        

        if self.valdt or not self.augment:
            print("valdt")
            print("XN ",XN.dtype,XN.shape )
            print("y",y.shape)
            return XN , y
        elif self.augment:
            print("augment , train")
            X = np.concatenate([XN, XF], axis=0)
            Y = np.concatenate([y, y], axis=0)
            print("XN ",XN.dtype,XN.shape )
            print("XF ",XF.dtype,XF.shape )
            print("X ",X.dtype,X.shape )
            print("Y ",Y.dtype,Y.shape)
            return X , Y

    #def on_epoch_end(self):
    #    if self.shuffle:
    #        np.random.shuffle(self.indexes)
    


If there are only 8 videos being fed to the training phase, and batch_size is set to 1 with augment enabled, then the generator will yield 16 batches for each epoch of training, as each video will be flipped horizontally to create a second batch. This means that each video will be processed twice per epoch, once in its original orientation and once flipped horizontally.

After all of the training batches have been processed, the fit method will move onto the validation data, which is processed separately using a different generator (valdt_generator).

If augment is set to False, then each video will only yield one batch, regardless of the batch_size. So in this case, with a batch_size of 1, the generator would yield 8 batches for training before moving onto the validation data.

In [None]:
train_generator = DataGen(train_fn[], train_labels[], train_config)

valdt_generator = DataGen(valdt_fn[], valdt_labels[], train_config , True)

## len(train_fn) / batch_size = number of video per batch = __len__
## if batch_size 1 , each batch contains a video
## if augmt =True & batch_size 1 , each batch contains "2" videos

model,model_name = tfh5.form_model(train_config)

history = model.fit(train_generator, 
                    epochs = train_config["epochs"] ,
                    steps_per_epoch = len(train_fn),
                    
                    verbose=2,
                    
                    validation_data = valdt_generator ,
                    validation_steps = len(valdt_fn),
                    
                    use_multiprocessing = True , 
                    workers = 32 #,
                    #callbacks=[print_data_callback]
                  )

# Save the history to a CSV file
hist_csv_file = globo.HIST_PATH + model_name + '_history.csv'
with open(hist_csv_file, 'w', newline='') as file:writer = csv.writer(file);writer.writerow(history.history.keys());writer.writerows(zip(*history.history.values()))
    

NO CLASS

In [None]:
def video_generator(video_paths, labels, fps, frame_max, in_height, in_width):
    
    
    def skip_ms(cap,frame_step):
        start_frame = cap.get(cv2.CAP_PROP_POS_FRAMES)
        
        while True:
            success = cap.grab()

            curr_frame = cap.get(cv2.CAP_PROP_POS_FRAMES)

            if not success or curr_frame - start_frame >= frame_step:
                break
        
        if not success:
            return success, None, start_frame + frame_step

        success, image = cap.retrieve()
        
        return success, image, curr_frame    
    
    
    #while True:
    # Shuffle the video paths and labels
    zipped = list(zip(video_paths, labels))
    random.shuffle(zipped)
    video_paths, labels = zip(*zipped)
    
    # Loop over the video paths and labels
    #for i in range(len(video_paths)):
    video_path = video_paths[0]
    label = labels[0]
    
    video = cv2.VideoCapture(video_path)
    tframes = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    
    print(video_path , label , tframes)
    
    if label == 0 and tframes > frame_max :
        start_index = random.randint(0, tframes - frame_max)
        end_index = start_index + frame_max
        video.set(cv2.CAP_PROP_POS_FRAMES, start_index)
    else : 
        start_index = 0
        end_index = tframes
    
    print(start_index , end_index)
    
    frames = []
    curr_frame = 0
    
    success, frame = video.read()
    while 1:
        
        if not success or curr_frame > end_index: break
        
        frame = cv2.resize(frame, (in_width, in_height))
        frame = np.array(frame)/255.0
        frames.append(frame)
        
        '''cv2.imshow('ff', frame)
        key = cv2.waitKey(int(1000/12))
        if key == ord('q'): break  # quit
        if key == ord(' '):  # pause
            while True:
                key = cv2.waitKey(1)
                if key == ord(' '):break'''
                
        success, frame, curr_frame = skip_ms(video,2)
        
    #batch_frames.append(frames)
    #batch_labels.append(label)
    #
    #batch_frames = np.array(batch_frames) / 255.0
    #batch_labels = np.array(batch_labels)
    
    X = np.array(frames)
    
        
    return np.expand_dims(X,0) , np.array([label])
        

In [None]:
#x , y = video_generator(train_fn,train_labels,12,8000,120,160)

#print( np.shape(x) , np.shape(y))

ORIGINAL

In [None]:
""" INPUT DATA"""

in_height = 120; in_width = 160

def input_train_video_data(file_name):
    print("\n\ninput_train_video_data\n")
    
    #file_name = 'C:\\Bosch\\Anomaly\\training\\videos\\13_007.avi'
    
    video = cv2.VideoCapture(file_name)
    total_frame = video.get(cv2.CAP_PROP_FRAME_COUNT)
    
    divid_no = 1
    
    frame_max = train_config["frame_max"]
    
    # define the nmbers of batchs to divid atual video (divid_no)
    if total_frame > int(frame_max):
        total_frame_int = int(total_frame)
        if total_frame_int % int(frame_max) == 0:
            divid_no = int(total_frame / int(frame_max))
        else:
            divid_no = int(total_frame / int(frame_max)) + 1
        
    batch_no = 0
    batch_frames = []
    batch_frames_flip = []
    counter = 0
    
    # gets random batch w\ frame max lenght 
    if 'Normal' in file_name:
        print("\n\nNORMAL\n\n")
        if divid_no != 1:
            slice_no = int(random.random()*divid_no)
            passby = 0
            if slice_no != divid_no - 1:
                while video.isOpened and passby < int(frame_max) * slice_no:
                    passby += 1
                    success, image = video.read()
                    if success == False:
                        break
            else:
                while video.isOpened and passby < total_frame - int(frame_max):
                    passby += 1
                    success, image = video.read()
                    if success == False:
                        break

    while video.isOpened:               
        success, image = video.read()
        if success == False:
            break

        image = cv2.resize(image, (in_width, in_height))
        image_flip = cv2.flip(image, 1)
        
        image_array = np.array(image)/255.0
        image_array_flip = np.array(image_flip)/255.0
        
        batch_frames.append(image_array)
        batch_frames_flip.append(image_array_flip)
        
        counter += 1
        if counter > int(frame_max):
            break
            
    video.release()
    batch_frames = np.array(batch_frames)
    print(batch_frames.shape)
        
    return np.expand_dims(batch_frames,0), np.expand_dims(batch_frames_flip, 0), total_frame



def generate_input(data,update_index,validation):
    
    #data_var_name = [k for k, v in globals().items() if v is data][0]
    #print("\n\nGENERATE_INPUT FOR",data_var_name,\
    #    '\n\tupdate_index len = ',len(update_index),\
    #    '\n\tdata len = ',len(data))
    
    loop_no = 0
    while 1:
        index = update_index[loop_no]
        loop_no += 1
        #print("\n",data_var_name," index",index," loop_no",loop_no)
        if loop_no == len(data):loop_no= 0
        
        
        batch_frames, batch_frames_flip, total_frames = input_train_video_data(data[index])
        print("\n\tdata[",index,"]=",data[index],"\n\ttotal_frames=",total_frames,"\n\tbatch_frames.shape=",batch_frames.shape,"\n")
        #if batch_frames.ndim != 5:
        #   break

        
        if not validation:
            #batch_frames
            if 'label_A' in data[index]: return batch_frames, np.array([0])   #normal
            else: return batch_frames, np.array([1])   #abnormal
            
            #batch_frames_flip
            if 'label_A' in data[index]: return batch_frames_flip, np.array([0])  #normal
            else: return batch_frames_flip, np.array([1])  #abnormal
        else:
            #batch_frames
            if 'label_A' in data[index]: return batch_frames, np.array([0])   #normal
            else: return batch_frames, np.array([1])   #abnormal
                
    print("\nloop_no=",loop_no)
    
    
#generate_input(train_fn[:4] , update_index_train[:4] , False) 