In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

In [4]:
class data():
    def __init__(self, path, batch_size):
        self.path = path
        self.acc_num = 334
        self.no_acc_num = 392
        self.train_index = 0
        self.valid_index = 0
        self.seq_length = 50
        self.split = 0.9
        self.batch_size = batch_size
        self.read_annotation()
        self.shuffle_data()
        
        
    def shuffle_data(self):
        img_range = np.arange(0,300-self.seq_length-30,30)
        acc_list = np.arange(1,self.acc_num+1)
        no_acc_list = np.arange(1,self.no_acc_num+1)
        list1 = np.array(np.meshgrid(1,acc_list,img_range)).T.reshape(-1,3)
        list2 = np.array(np.meshgrid(0,no_acc_list,img_range)).T.reshape(-1,3)
        shuffle_list = np.concatenate([list1, list2], axis=0)
        np.random.shuffle(shuffle_list)
        self.train = shuffle_list[:int(shuffle_list.shape[0]*self.split)]
        self.valid = shuffle_list[int(shuffle_list.shape[0]*self.split):]
        
    def read_annotation(self):
        annotation_file = '/media/user/Hard_Disk/Dataset/child_accident_2/annotation/accident_frame.txt'
        w = open(annotation_file, "r")
        ann = w.read()
        annotation_data = []
        for i in ann.split("\n"):
            b = i.split(" ")
            if (len(b) > 1):
                annotation_data.append(b[1])
        self.annotation = np.array(annotation_data).astype("int32")
        
    def read_data(self, is_accident, dir_index, image_range):
        data = []
        label = []
        if (is_accident):
            acc_dir = "accident/"
        else:
            acc_dir = "no_accident/"
        dir_name = "%04d"%dir_index
        for j in range(image_range, image_range+self.seq_length):
#             print(self.path+ acc_dir + dir_name +"/"+ str(j)+".jpg")
            img = cv2.imread(self.path+ acc_dir + dir_name +"/"+ str(j)+".jpg")
#             print(img.shape)
            imgRGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            imgRGB = cv2.resize(imgRGB, (150, 150))
            data.append(imgRGB)
        
        if (is_accident and ((image_range + self.seq_length + 30) > self.annotation[dir_index-1])):
            label.append([0,1])
        else:
            label.append([1,0])
        return np.array(data),np.array(label)
    
    def has_train_next(self):
        if (self.train_index + self.batch_size <= self.train.shape[0]):
            return True
        return False
    
    def has_valid_next(self):
        if (self.valid_index + self.batch_size <= self.valid.shape[0]):
            return True
        return False
        
    def train_reset(self):
        np.random.shuffle(self.train)
        self.train_index = 0
        
    def valid_reset(self):
        np.random.shuffle(self.valid)
        self.valid_index = 0
        
    def next_batch_train(self):     
        batch_x = []
        batch_y = []
        for i in range(self.batch_size):
            data, label = self.read_data(self.train[self.train_index+i][0], self.train[self.train_index+i][1], self.train[self.train_index+i][2])
            batch_x.append(data)
            batch_y.append(label)
        self.train_index += self.batch_size
        return np.array(batch_x), np.squeeze(np.array(batch_y))
    
    def get_shape(self):
        return self.train.shape, self.valid.shape
    
    def next_batch_valid(self):     
        batch_x = []
        batch_y = []
        for i in range(self.batch_size):
            data, label = self.read_data(self.valid[self.valid_index+i][0], self.valid[self.valid_index+i][1], self.valid[self.valid_index+i][2])
            batch_x.append(data)
            batch_y.append(label)
        self.valid_index += self.batch_size
        return np.array(batch_x), np.squeeze(np.array(batch_y))   
                

In [None]:
class net(nn.Module)