In [None]:
import os
from glob import glob

import cv2

import matplotlib.pyplot as plt

def read_image(path):
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image


In [None]:
data_dir = '../input/state-farm-distracted-driver-detection/'
train_path = data_dir + 'imgs/train/c0/'
filename = 'img_100026.jpg'

image = read_image(train_path + filename)
plt.imshow(image)

In [None]:
labels = ['c0','c1','c2','c3','c4','c5','c6','c7','c8','c9']
col_to_jp = {
    'c0':'安全運転',
    'c1':'右手で携帯操作',
    'c2':'右手で電話',
    'c3':'左手で携帯操作',
    'c4':'左手で電話',
    'c5':'ラジオ操作',
    'c6':'飲み物摂取',
    'c7':'後部座席に手を伸ばす',
    'c8':'顔、髪に触れる',
    'c9':'助手席と対話'
}

for label in labels:
    f, ax = plt.subplots(figsize=(12, 10))
    files = glob(f'{data_dir}/imgs/train/{label}/*.jpg')
    
    if len(files)>9:
        n = 9
    else:
        n = len(files)
        
    for x in range(n):
        plt.subplot(3, 3, x+1)
        image = read_image(files[x])
        plt.imshow(image)
        plt.axis('off')
    
    print(f'\t\t\t\t# {label} : {col_to_jp[label]}')
    plt.show()
    print('#'*100)

In [None]:
f, ax = plt.subplots(figsize=(24, 10))
files = glob(f'{data_dir}/imgs/test/*.jpg')

for x in range(18):
    plt.subplot(3, 6, x+1)
    image = read_image(files[x])
    plt.imshow(image)
    plt.axis('off')

In [None]:
import pandas as pd
driver_list = pd.read_csv(data_dir + 'driver_imgs_list.csv')

In [None]:
driver_list.head()

In [None]:
import numpy as np
len(np.unique(driver_list['subject']).tolist())

In [None]:
driver_to_img = {}
for i, row in driver_list.iterrows():
    driver = row['subject']
    label = row['classname']
    image_path = row['img']
    if not driver_to_img.get(driver, False):
        driver_to_img[driver] = [image_path]
    else:
        driver_to_img.get(driver).append(image_path)
        
for driver in np.unique(driver_list['subject']).tolist():
    for label in labels:
        f, ax = plt.subplots(figsize=(12, 10))
        files = glob(f'{data_dir}/imgs/train/{label}/*.jpg')
        print_files = []
        for fl in files:
            if (driver_list[driver_list['img'] == os.path.basename(fl)]['subject'] == driver).values[0]:
                print_files.append(fl)
                
        if len(print_files)>9:
            n = 9
        else:
            n = len(print_files)
            
        for x in range(n):
            plt.subplot(3, 3, x+1)
            image = read_image(print_files[x])
            plt.imshow(image)
            plt.axis('off')
        
        print(f'\t\t\t\t# ドライバー：{driver}|クラス：{label}（{col_to_jp[label]}）')
        plt.show()
        print('#'*100)


# 例外的なデータ

In [None]:
label = "c0"
imgs = [21155, 31121]

print("安全運転の例外")
f, ax = plt.subplots(figsize=(12,10))
for x in range(len(imgs)):
    plt.subplot(1, 2, x+1)
    image = read_image(f"{data_dir}/imgs/train/{label}/img_{imgs[x]}.jpg")
    
    plt.imshow(image)
    plt.axis("off")
plt.show()
    

In [None]:
label = "c3"
imgs = [38563,45874,49269,62784]

print(f"{col_to_jp[label]}の例外")
f, ax = plt.subplots(figsize=(12,10))
for x in range(len(imgs)):
    plt.subplot(2, 2, x+1)
    image = read_image(f"{data_dir}/imgs/train/{label}/img_{imgs[x]}.jpg")
    
    plt.imshow(image)
    plt.axis("off")
plt.show()
    

In [None]:
label = "c4"
imgs = [92769,38427,41743,69998,77347,16077]

print(f"{col_to_jp[label]}の例外")
f, ax = plt.subplots(figsize=(18,10))
for x in range(len(imgs)):
    plt.subplot(2, 3, x+1)
    image = read_image(f"{data_dir}/imgs/train/{label}/img_{imgs[x]}.jpg")
    
    plt.imshow(image)
    plt.axis("off")
plt.show()

In [None]:
label = "c9"
imgs = [28068,37708,73663]

print(f"{col_to_jp[label]}の例外")
f, ax = plt.subplots(figsize=(18,10))
for x in range(len(imgs)):
    plt.subplot(1, 3, x+1)
    image = read_image(f"{data_dir}/imgs/train/{label}/img_{imgs[x]}.jpg")
    
    plt.imshow(image)
    plt.axis("off")
plt.show()

# 間違ったラベル

In [None]:
label = "c0"
imgs = [('c5',30288),('c7',46617),('c8',3835)]

print(f"本来は{label}：{col_to_jp[label]}のはずが、それ以外にラベルされたもの ")

f, ax = plt.subplots(figsize=(18,10))
for x in range(len(imgs)):
    plt.subplot(1, 3, x+1)
    image = read_image(f"{data_dir}/imgs/train/{imgs[x][0]}/img_{imgs[x][1]}.jpg")
    
    plt.imshow(image)
    plt.axis("off")
plt.show()

In [None]:
label = "c1"
imgs = [('c0',29923),('c0',79819),('c2',32934)]

print(f"本来は{label}：{col_to_jp[label]}のはずが、それ以外にラベルされたもの ")

f, ax = plt.subplots(figsize=(18,10))
for x in range(len(imgs)):
    plt.subplot(1, 3, x+1)
    image = read_image(f"{data_dir}/imgs/train/{imgs[x][0]}/img_{imgs[x][1]}.jpg")
    
    plt.imshow(image)
    plt.axis("off")
plt.show()

In [None]:
label = "c8"
imgs = [('c0',34380),('c3',423),('c5',78504)]

print(f"本来は{label}：{col_to_jp[label]}のはずが、それ以外にラベルされたもの ")

f, ax = plt.subplots(figsize=(18,10))
for x in range(len(imgs)):
    plt.subplot(1, 3, x+1)
    image = read_image(f"{data_dir}/imgs/train/{imgs[x][0]}/img_{imgs[x][1]}.jpg")
    
    plt.imshow(image)
    plt.axis("off")
plt.show()