In [18]:
import csv
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import random
from random import shuffle
from sklearn.model_selection import train_test_split

In [3]:
import tensorflow as tf

In [15]:
DATA_PATH = "data/driving_log.csv"
DATA_IMGS = "data/"
BATCH_SIZE = 64

In [12]:
class ProcessData() :
    """Classe pour traiter la dataset"""
    def __init__(self, data_path, data_imgs):
        """Constructeur de la classe qui permet d'initialiser les différents variables"""
        self.data_path = data_path
        self.data_imgs = data_imgs
        
    def index_data(self):
        """Fonction qui scinde les index de la dataset : set de train et set de validation"""
        with open(self.data_path, "r") as file :
            content = [line for line in csv.reader(file)]
        random_index = np.array(range(len(content)))
        train_index, valid_index = train_test_split(random_index, test_size=0.15)
        return content, train_index, valid_index
    
    def get_data(self, log_content, list_index, batch_size) :
        """Fonction qui fournit la batch à au model d'entraitement"""
        images, rotations = [], []
        for index in list_index :
            # Correction des angles
            angle_correction = [0., 0.25, -.25]
            # Selection randomiser d'une image
            i = random.choice([0,1,2])
            img = cv2.imread(os.path.join(self.data_imgs, log_content[index][i]).replace(" ", ""))
            if img is None: continue
            # Conversion de l'image
            img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
            # Récupération de la rotation associé à l'image
            rotation = float(log_content[index][3])
            # Correction de l'angle
            rotation = rotation + angle_correction[i]
            if random.random() > 0.5 :
                img = cv2.flip(img,1)
                rotation = rotation * -1
            # Ajout dans les tableaux images et rotations
            images.append(img)
            rotations.append(rotation)
            if len(images) >= batch_size :
                yield np.array(images), np.array(rotations)
                images, rotations = [], [] 

In [13]:
data_process = ProcessData(DATA_PATH,DATA_IMGS)
content, train_index, valid_index = data_process.index_data()

In [None]:
images, rotation = next(data_process.get_data(content, train_index, BATCH_SIZE))
plt.imshow(images[20])
print(rotation[20])

[1935 6948 1795 ... 4395 5653 4131]
