## Import Libraries

In [1]:
from __future__ import print_function

import glob
from itertools import chain
import os
import random
import zipfile
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

from vit_pytorch.vit import ViT

In [2]:
print(f"Torch: {torch.__version__}")

Torch: 1.11.0


In [3]:
# Training settings
batch_size = 16
epochs = 1
lr = 3e-5
gamma = 0.7
seed = 42

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [5]:
device = 'cuda'

## Load Data

In [6]:
# Setting of preprocessed data path
PATH_none_crash_prep = '../../dataset/data_preprocessed/None-crash'
PATH_vulner_prep = '../../dataset/data_preprocessed/Vulner'

none_crash_drive_list = glob.glob(PATH_none_crash_prep + "/*")
vulner_drive_list = glob.glob(PATH_vulner_prep + "/*")

In [7]:
none_crash_data_list = list()
vulner_data_list = list()

for PATH_drive in none_crash_drive_list:
    none_crash_data_list += glob.glob(PATH_drive + "/*")

for PATH_drive in vulner_drive_list:
    vulner_data_list += glob.glob(PATH_drive + "/*")

In [9]:
print("Number of Drives")
print("None-crash:", len(none_crash_drive_list))
print("Vulner:", len(vulner_drive_list))

print(f"None-crash: {len(none_crash_data_list)}")
print(f"Vulner: {len(vulner_data_list)}")

Number of Drives
None-crash: 51
Vulner: 91
None-crash: 30498
Vulner: 3788


### Split data into Train & Test

In [10]:
none_crash_labels = ['None-crash' for _ in none_crash_data_list]
vulner_labels = ['Vulner' for _ in vulner_data_list]

In [11]:
none_crash_train_list, none_crash_test_list = train_test_split(none_crash_data_list, 
                                                                test_size=0.1,
                                                                stratify=none_crash_labels,
                                                                random_state=seed)

vulner_train_list, vulner_test_list = train_test_split(vulner_data_list, 
                                                        test_size=0.1,
                                                        stratify=vulner_labels,
                                                        random_state=seed)

In [12]:
print("Train None-crash:", len(none_crash_train_list))
print("Test None-crash:", len(none_crash_test_list))
print("Train Vulner:", len(vulner_train_list))
print("Test Vulner:", len(vulner_test_list))

Train None-crash: 27448
Test None-crash: 3050
Train Vulner: 3409
Test Vulner: 379


In [None]:
# ex) train_list[i]: "../../dataset/data_preprocessed/None-crash\21-12-01-11-07-44_end_extract_drive26\00449.pickle"
train_list = none_crash_train_list + vulner_train_list
test_list = none_crash_test_list + vulner_test_list

random.shuffle(train_list)
random.shuffle(test_list)

### Split Train data into Train & Validation

In [None]:
train_labels = [path.split('/')[4].split('\\')[0] for path in train_list]

In [None]:
train_list, valid_list = train_test_split(train_list, 
                                          test_size=0.2,
                                          stratify=train_labels,
                                          random_state=seed)

In [None]:
print(f"Train Data: {len(train_list)}")
print(f"Validation Data: {len(valid_list)}")
print(f"Test Data: {len(test_list)}")

## Load Dataset

In [None]:
class lidar_dataset(Dataset):
    def __init__(self, file_list):
        self.file_list = file_list # file name

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        lidar_path = self.file_list[idx]
        
        with open(lidar_path,"rb") as fr:
            data = pickle.load(fr)

        tensor = data['tensor']
        # 0 is None-crash
        # 1 is Vulnerable
        if data['label'] == 'None-crash':  
            label = 0
        else:
            label = 1
            
        return tensor, label

In [None]:
train_data = lidar_dataset(train_list)
valid_data = lidar_dataset(valid_list)
test_data = lidar_dataset(test_list)

In [None]:
train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset = test_data, batch_size=batch_size, shuffle=True)

## Effecient Attention

### Linformer

In [None]:
efficient_transformer = Linformer(
    dim=128,
    seq_len=196+1, #patches + 1
    depth=12,
    heads=8,
    k=64
)

### ViT

In [None]:
model = ViT(
    dim=128,
    image_size=28,
    patch_size=2,
    num_classes=2,
    channels=14,
    depth=12,
    heads=8,
    mlp_dim = 2048
).to(device)

### Training

In [None]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [None]:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    # Save a model each epoch
    torch.save(model, "model" + str(epoch) + ".pt")
    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

In [None]:
# Last trained model
model.eval()

## Test(Confusion matrix)

In [None]:
TP, FP, FN, TN = 0, 0, 0, 0
TP_list = list()
FP_list = list()
FN_list = list()
TN_list = list()

print("#Test-data:", len(test_list))

for i in range(len(test_list)):
    with open(test_list[i], "rb") as fr:
        tmp_data = pickle.load(fr)

    test_label = tmp_data['label']
    if test_label == "None-crash":
        test_label = 0
    else:
        test_label = 1

    # Make tensor as input of model
    test_tensor = np.asarray(tmp_data['tensor'])
    test_tensor = np.asarray([test_tensor])
    test_tensor = torch.from_numpy(test_tensor)
    test_tensor = test_tensor.to(torch.float32).cuda()

    pred = model(test_tensor)
    if pred[0][0] > pred[0][1]:
        pred_idx = 0
    elif pred[0][0] < pred[0][1]:
        pred_idx = 1
    else:
        assert pred[0][0] == pred[0][1], 'Same class score'

    if pred_idx == 1 and test_label == 1:
        TP += 1
        TP_list.append(test_list[i])
    elif pred_idx == 1 and test_label == 0:
        FP += 1
        FP_list.append(test_list[i])
    elif pred_idx == 0 and test_label == 1:
        FN += 1
        FN_list.append(test_list[i])
    elif pred_idx == 0 and test_label == 0:
        TN += 1
        TN_list.append(test_list[i])

In [None]:
acc = (TP+TN)/(TP+TN+FP+FN)
spec = (TN/(FP+TN))
prec = (TP/(TP+FP))
recall = (TP/(TP+FN))

print("Accuracy:", acc)
print("Specificity:", spec)
print("Precision:", prec)
print("Recall:", recall)
print("F1 Score:", (2*prec*recall/(prec+recall)))