# LSTM


### Imports

In [1]:
import yaml
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import torch.optim as optim
import torch.utils.data as data
import math
import copy

### Data Loading

In [3]:
# cd Downloads

In [None]:
# cd b4c

In [None]:
# cd Brains4Cars

Straight from github

In [4]:
import os
from os.path import join
import copy

from multiprocessing import Pool
import tqdm

import numpy as np
import pickle
from torch.utils.data import Dataset, DataLoader

ACTION_TO_ID_MAP= {
    'end_action':   0,
    'lchange':      1,
    'lturn':        2,
    'rchange':      3,
    'rturn':        4
}

rng = np.random.default_rng(seed=42)
MAX_FRAMES=150


class B4CDataset(Dataset):
    """
    end_action - 0
    lchange - 1
    lturn - 2
    rchange - 3
    rturn - 4
    """
    def __init__(self, data_cfg, split="train", create_dataset=False):
        self.data_cfg   = data_cfg['DATALOADER_CONFIG']
        self.actions    = self.data_cfg['ACTIONS']
        self.cameras    = self.data_cfg['CAMERAS']
        self.data_dir   = self.data_cfg['DATA_DIR']
        self.split      = split

        videos_dict = self.read_videos_by_action()        

        if create_dataset:
            # self.create_gt_road_labels(videos_dict)
            self.generate_imagesets(videos_dict)
        
        self.image_sets = {}
        splits_process = [self.split] if self.split in ["train", "val", "test"] else ["train", "val", "test"]

        for camera in self.cameras:
            if camera not in self.image_sets:
                self.image_sets[camera] = []

            for curr_split in splits_process:
                imageset_path = join(self.data_dir, f'ImageSets_{camera}', f'{curr_split}.txt')
                self.image_sets[camera].extend([line.strip() for line in open(imageset_path, 'r')])

        # Drop nondivisible videos for fold splits from end
        if "fold" in self.split:
            num_folds = int(self.split.split("_")[0])
            num_to_drop = len(self) % num_folds
            self.image_sets = {k: v[:-num_to_drop] for k, v in self.image_sets.items()}

        print(f'Added {len(self)} videos to the dataset.')

    def read_videos_by_action(self):
        videos_dict = {}
        # Combine image sets for all cameras
        for camera in self.cameras:
            camera_dir = join(self.data_dir, camera+"_processed")

            for action in self.actions:
                action_dir = join(camera_dir, action)

                if action not in videos_dict:
                    videos_dict[action] = []

                # Take the set intersection between all the cameras sequentially
                video_action_set = set([f for f in os.listdir(action_dir) if os.path.isdir(join(action_dir, f))])

                if len(videos_dict[action])==0:
                    videos_dict[action] = video_action_set
                else:
                    videos_dict[action] = videos_dict[action].intersection(video_action_set)

        # Convert videos_dict back to list
        for action in self.actions:
            videos_dict[action] = list(videos_dict[action])

        # Check that files exist
        for camera in self.cameras:
            camera_dir = join(self.data_dir, camera+"_processed")

            for action in self.actions:
                action_dir = join(camera_dir, action)

                for video_dir in videos_dict[action]:
                    video_dir = join(action_dir, video_dir)
                    assert os.path.exists(video_dir), f"Video directory {video_dir} does not exist."

        print("Finished reading all valid videos by directory.")
        return videos_dict

    @staticmethod
    def write_list_to_file(file_path, data_list):
        """
        Create/overwrite a .txt file and write each line of the Python list to a new line in the file.

        Parameters:
            file_path : str
                The path to the .txt file.
            data_list : list
                The Python list containing data to write to the file.
        """
        with open(file_path, 'w') as file:
            file.writelines(f"{item}\n" for item in data_list)

    def check_data_quality(self, video_subdirs, camera):
        """
        Check that all videos have the same number of frames and that the number of frames is less than MAX_FRAMES.

        Parameters:
            videos_dict : dict
                Dictionary containing the list of videos for each action.
        """

        valid_videos_mask = np.zeros(len(video_subdirs), dtype=bool)
        for video_idx, video_subdir in enumerate(video_subdirs):
            video_path = join(self.data_dir, video_subdir)
            data_dict = {}

            # Check that all videos have the full frame set
            if camera=="face":
                data_dict['gt_gazepose'] = self.get_face_labels(video_path)
                valid_videos_mask[video_idx] = len(data_dict['gt_gazepose'])>=MAX_FRAMES-1 # gazepose has 149
            elif camera=="road":
                data_dict['gt_bbox'], data_dict['gt_lanes'] = self.get_road_labels(video_path)
                valid_videos_mask[video_idx] = len(data_dict['gt_bbox'])>=MAX_FRAMES and len(data_dict['gt_lanes'])>=MAX_FRAMES

            if valid_videos_mask[video_idx]==0:
                print(f'Video {video_subdir} does not have the full frame set, skipping...')

        return valid_videos_mask
            

    def generate_imagesets(self, videos_dict):    
        print("Generating imagesets...")
        facecam_imageset_dict = {'train': [], 'val': [], 'test': []}
        roadcam_imageset_dict = copy.deepcopy(facecam_imageset_dict)
        train_pct, val_pct, test_pct = 0.7, 0.15, 0.15
        for action, action_videos in videos_dict.items():
            road_cam_action_dir = join('road_camera_processed_combined', action)
            road_cam_video_labels = np.array([join(road_cam_action_dir, f) for f in action_videos])
            road_cam_video_mask = self.check_data_quality(road_cam_video_labels, "road")

            face_cam_action_dir = join('face_camera_processed', action)
            face_cam_video_labels = np.array([join(face_cam_action_dir, f) for f in action_videos])
            face_cam_video_mask = self.check_data_quality(face_cam_video_labels, "face")

            combined_cam_video_mask = np.logical_and(road_cam_video_mask, face_cam_video_mask)
            road_cam_video_labels = road_cam_video_labels[combined_cam_video_mask]
            face_cam_video_labels = face_cam_video_labels[combined_cam_video_mask]

            road_cam_video_labels_sort_idx = np.argsort(road_cam_video_labels)
            road_cam_video_labels = road_cam_video_labels[road_cam_video_labels_sort_idx]
            face_cam_video_labels = face_cam_video_labels[road_cam_video_labels_sort_idx]

            # Ensure file order matches for road and face camera
            for i in range(len(road_cam_video_labels)):
                assert os.path.basename(road_cam_video_labels[i])==os.path.basename(face_cam_video_labels[i]), 'Video labels do not match'
            num_videos = len(road_cam_video_labels)

            num_train       = int(num_videos * train_pct)
            num_val         = int(num_videos * val_pct)
            num_test        = int(num_videos * test_pct)

            indices = np.arange(0, num_videos, 1)
            rng.shuffle(indices)

            videos_indices = {"train": [], "val": [], "test": []}
            videos_indices['train'], videos_indices['val'], videos_indices['test'] = indices[:num_train], \
                indices[num_train:num_train+num_val],  indices[num_train+num_val:num_train+num_val+num_test]

            for split in facecam_imageset_dict.keys():
                roadcam_imageset_dict[split].extend(road_cam_video_labels[videos_indices[split]].tolist())
                facecam_imageset_dict[split].extend(face_cam_video_labels[videos_indices[split]].tolist())

        # Dump to imageset files for road and face camera
        roadcam_imagesets_dir = join(self.data_dir, "ImageSets_road_camera")
        facecam_imagesets_dir = join(self.data_dir, "ImageSets_face_camera")
        if not os.path.exists(roadcam_imagesets_dir):
            print(f'Video root directory {roadcam_imagesets_dir} does not exist. Creating...')
            os.mkdir(roadcam_imagesets_dir)
        if not os.path.exists(facecam_imagesets_dir):
            print(f'Video root directory {facecam_imagesets_dir} does not exist. Creating...')
            os.mkdir(facecam_imagesets_dir)

        for split_key in facecam_imageset_dict.keys():
            road_cam_split_path = join(roadcam_imagesets_dir, f'{split_key}.txt')
            face_cam_split_path = join(facecam_imagesets_dir, f'{split_key}.txt')
            print(f'Saving imageset file {split_key} for road {road_cam_split_path} and {face_cam_split_path}')
            self.write_list_to_file(road_cam_split_path, roadcam_imageset_dict[split_key])
            self.write_list_to_file(face_cam_split_path, facecam_imageset_dict[split_key])

    def __len__(self):
        num_files = 0
        for camera in self.cameras: 
            assert num_files==0 or num_files==len(self.image_sets[camera]), "Number files in dataset not correct"
            num_files=len(self.image_sets[camera])
        return num_files
    
    def collate_fn(self, data):
        # print(data[0][1])
        data_batch = [bi[0] for bi in data]
        action_batch = [bi[1] for bi in data]
        return data_batch, action_batch

    def get_face_labels(self, label_dir):  
        gazepose_path = join(label_dir, 'gazepose.npy')
        assert os.path.exists(gazepose_path), f'Label file {gazepose_path} does not exist'
        gt_gazepose = np.load(gazepose_path)

        if gt_gazepose.shape[0] < MAX_FRAMES:
            # print(f'Gaze pose {gazepose_path} has less than 150 frames, padding with zeros...')
            # Pad with zeros
            gt_gazepose = np.pad(gt_gazepose, ((0, MAX_FRAMES-gt_gazepose.shape[0]), (0, 0)), mode='constant')
        # Assume gt_gazepose is not smaller than MAX_FRAMES
        num_frames = min(MAX_FRAMES, gt_gazepose.shape[0])
        gt_gazepose = gt_gazepose[:num_frames, :]

        return gt_gazepose

    def get_road_labels(self, label_dir):
        
        bbox_file = join(label_dir, 'bbox_labels.pkl')
        lane_file = join(label_dir, 'lane_labels.pkl')

        assert os.path.exists(bbox_file), f'Label directory {bbox_file} does not exist'
        assert os.path.exists(lane_file), f'Label directory {lane_file} does not exist'

        # Load bbox detections
        with open(bbox_file, 'rb') as f:
            gt_bbox = pickle.load(f)
        # Load road labels
        with open(lane_file, 'rb') as f:
            gt_lanes = pickle.load(f)

        gt_bbox = gt_bbox[:MAX_FRAMES]
        gt_lanes = gt_lanes[:MAX_FRAMES]

        return gt_bbox, gt_lanes
    
    def get_action_label(self, video_dir):
        action_label = video_dir.split('/')[-2]
        assert action_label in ACTION_TO_ID_MAP.keys(), f'Action {action_label} not in action map'
        action_id = ACTION_TO_ID_MAP[action_label]
        return action_id

    @staticmethod
    def combine_img_labels(args):
        split_video_path, combined_video_path = args
        img_label_files = [f for f in os.listdir(split_video_path) if f.endswith('.pkl')]

        MAX_NUM_BBOXES = 5 
        full_img_label_np = np.ones((MAX_FRAMES+1, MAX_NUM_BBOXES*5)) * -1
        num_img_label_files = len(img_label_files)
        for img_label_idx, img_label_file in enumerate(img_label_files):
            img_label_path = join(split_video_path, img_label_file)
            with open(img_label_path, 'rb') as f:
                img_data = pickle.load(f)
                IMG_W, IMG_H = 720, 480

                #1 Convert to xc, yc, w, h
                x1, x2, y1, y2 = img_data['xyxy'][:, 0], img_data['xyxy'][:, 2], img_data['xyxy'][:, 1], img_data['xyxy'][:, 3]
                xc = (x1 + x2) / 2
                yc = (y1 + y2) / 2
                h = (x2 - x1)
                w = (y2 - y1)
            
                #2 Only keep bboxes with h and w < 0.33 (Ignore large bboxes of self)
                bbox_size_mask = np.logical_and(h < IMG_W*0.33, w < IMG_H*0.33)

                #3 Only keep labels with class_ids = 0, 1, 2, 3, 4 (Ignore 5 Date)
                class_ids = img_data['class_id'] 
                class_ids_mask = np.logical_and(class_ids >= 0, class_ids <= 4)

                bbox_mask = np.logical_and(bbox_size_mask, class_ids_mask)
                proc_img_label = np.ones((MAX_NUM_BBOXES, 5), dtype=int)*-1 # max of five bbox detections per image

                if np.sum(bbox_mask)>0:
                    xc = xc[bbox_mask].astype(int)
                    yc = yc[bbox_mask].astype(int)
                    h = h[bbox_mask].astype(int)
                    w = w[bbox_mask].astype(int)
                    class_ids = class_ids[bbox_mask]
                    gt_objs = np.stack((xc, yc, w, h, class_ids), axis=1)

                    #4 Select top 5 largest boxes
                    gt_obj_areas = gt_objs[:, 2] * gt_objs[:, 3]
                    gt_objs_sort_idx = np.argsort(-gt_obj_areas, kind='stable') # Sort high to low
                    num_objs = min(MAX_NUM_BBOXES, len(gt_objs_sort_idx))

                    proc_img_label[:num_objs, :] = gt_objs[gt_objs_sort_idx[:num_objs], :]

                proc_img_label = proc_img_label.flatten()
                proc_img_label = np.expand_dims(proc_img_label, axis=0)
                full_img_label_np[img_label_idx] = proc_img_label

        assert os.path.exists(combined_video_path), f'Video label directory {combined_video_path} does not exist'
        video_path = join(combined_video_path, 'bbox_labels.pkl')
        with open(video_path, 'wb') as f:
            pickle.dump(full_img_label_np, f)
        print("Saved combined bbox labels to: ", video_path)

        # Load, pad, save lane dets
        original_label_dir = split_video_path.replace('road_camera_processed', 'road_camera').replace('labels/', '')
        road_path = join(original_label_dir+".txt")
        gt_lanes = np.loadtxt(road_path, delimiter=',', dtype=int).reshape(1, -1)
        gt_lanes_padded = np.ones((MAX_FRAMES, 3)) * -1
        num_valid_gt_lanes = min(num_img_label_files, MAX_FRAMES)
        gt_lanes_padded[:num_valid_gt_lanes] = np.repeat(gt_lanes, [num_valid_gt_lanes], axis=0)

        lanes_path = join(combined_video_path, 'lane_labels.pkl')
        with open(lanes_path, 'wb') as f:
            pickle.dump(gt_lanes_padded, f)

    def create_gt_road_labels(self, videos_dict):
        split_video_path_list = []
        combined_video_path_list = []
        for action, action_videos in videos_dict.items():
            for video in action_videos:
                video_path = join(self.data_dir, 'road_camera_processed', 'labels', action, video)
                if not os.path.exists(video_path):
                    print(f'Video path {video_path} does not exist')
                    continue
                assert os.path.exists(video_path), f'Video path {video_path} does not exist'

                video_label_dir = join(self.data_dir, "road_camera_processed_combined", action, video)
                if not os.path.exists(video_label_dir):
                    print("Creating directory: ", video_label_dir)
                    os.makedirs(video_label_dir)
                
                split_video_path_list.append(video_path)
                combined_video_path_list.append(video_label_dir)
            # self.combine_img_labels((split_video_path_list[0], combined_video_path_list[0]))
        pool = Pool(processes=16)
        for _ in tqdm.tqdm(pool.imap_unordered(self.combine_img_labels, zip(split_video_path_list, 
            combined_video_path_list)), total=len(split_video_path_list)):
            pass

    def __getitem__(self, idx):
        data_dict = {}
        action_id=None

        for camera in self.cameras:
            video_subdir    = self.image_sets[camera][idx]
            video_fulldir   = join(self.data_dir, video_subdir)

            if camera == 'face_camera':
                data_dict['gt_gazepose'] = self.get_face_labels(video_fulldir)
            elif camera == 'road_camera':
                # Load all pickle files in the video directory
                data_dict['gt_bbox'], data_dict['gt_lanes'] = self.get_road_labels(video_fulldir)

            if action_id is None:
                action_id = self.get_action_label(video_fulldir)

        # Sort dict by key so that it is is consistent
        sorted_data_dict = dict(sorted(data_dict.items(), key=lambda item: item[0]))

        sorted_gt_np = np.empty((MAX_FRAMES, 0))
        # Stack all values from data_dict into a single np array

        for _, items in sorted_data_dict.items():
            sorted_gt_np = np.hstack((sorted_gt_np, items))

        # TOOD: Extract action label from Imageset file
        #  150 x [ (5x5) (2) (2) (3) ] # Pad if not enough frames obj_detections, gazepose, lane_detections # 150 x 32
        return sorted_gt_np, action_id # processed_input, action_label one hot vector 

config file

In [5]:
cfg=None
with open("../config/lstm_all.yaml", 'r') as file:
    cfg = yaml.safe_load(file)

In [6]:
batch = 50

make dataset (do once)

In [7]:
dataset = B4CDataset(cfg, split="5_fold", create_dataset=False)

Finished reading all valid videos by directory.
Added 585 videos to the dataset.


In [None]:
#dataloader = DataLoader(dataset, collate_fn=dataset.collate_fn, batch_size=batch, shuffle=True)

In [None]:
#valdataset = B4CDataset(cfg, split="val", create_dataset=False)
#val_loader = DataLoader(valdataset, collate_fn=dataset.collate_fn, batch_size=1, shuffle=True)

In [None]:
#testdataset = B4CDataset(cfg, split="test", create_dataset=False)
#test_loader = DataLoader(testdataset, collate_fn=dataset.collate_fn, batch_size=1, shuffle=True)

#### Define your network, loss function and optimizer

In [8]:
#multiheaded attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        #d_model must be divisible by num_heads
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        nn.init.normal_(self.W_q.weight, mean=0., std=np.sqrt(2 / (d_model+d_model)))
        nn.init.normal_(self.W_k.weight, mean=0., std=np.sqrt(2 / (d_model+d_model)))
        nn.init.normal_(self.W_v.weight, mean=0., std=np.sqrt(2 / (d_model+d_model)))
        nn.init.normal_(self.W_o.weight, mean=0., std=np.sqrt(2 / (d_model+d_model)))
        # nn.init.xavier_uniform_(self.W_q.weight)
        # nn.init.xavier_uniform_(self.W_k.weight)
        # nn.init.xavier_uniform_(self.W_v.weight)
        # nn.init.xavier_uniform_(self.W_o.weight)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output
    
#position feedforward
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

        # nn.init.xavier_uniform_(self.fc1.weight)
        # nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))
    
#position encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

#encoder no decoder
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout, d_hidden):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        # self.fc1 = nn.Linear(d_model*150,5)
        self.pfc1 = nn.Linear(4,32)
        self.pfc2 = nn.Linear(25,16)
        self.pfc3 = nn.Linear(3,16)
        self.sig = nn.Sigmoid()

        self.x1norm = nn.LayerNorm(32)
        self.x2norm = nn.LayerNorm(16)
        self.x3norm = nn.LayerNorm(16)

        self.fc1 = nn.Linear(d_model*150,d_hidden)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(d_hidden, 5)

        # nn.init.xavier_uniform_(self.fc1.weight)
        # nn.init.xavier_uniform_(self.pfc1.weight)
        # nn.init.xavier_uniform_(self.pfc2.weight)
        # nn.init.xavier_uniform_(self.pfc3.weight)
        
    def forward(self, x):
        x1 = self.pfc1(x[:,:,25:29])
        x1 = self.sig(x1)
        x2 = self.pfc2(x[:,:,0:25])
        x2 = self.sig(x2)
        x3 = self.pfc3(x[:,:,29:32])
        x3 = self.sig(x3)

        x1 = self.x1norm(x1)
        x2 = self.x2norm(x2)
        x3 = self.x3norm(x3)
        #x1 = torch.flatten(x1, start_dim=1)
        #x2 = torch.flatten(x2, start_dim=1)
        #x3 = torch.flatten(x3, start_dim=1)
        x = torch.cat((x1,x2,x3), 2)
        attn_output = self.self_attn(x, x, x, mask=None)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        x = torch.flatten(x, start_dim=1)
        # x = self.fc1(x)
        x = self.fc2(self.relu(self.fc1(x)))

        return x

In [9]:
# fine-violet-267
modeldim = 64
numheads = 8
hiddendim = 128 # modeldim*2
lr = 0.001
dropout = 0.20
num_epochs = 100
fchiddendim = 64
T_max = num_epochs / 2
final_lr = 0.1*lr

net1 = EncoderLayer(modeldim,numheads,hiddendim,dropout,fchiddendim)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net1.parameters(), lr=lr, momentum=0.9, weight_decay=0.001)
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, final_lr)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

In [None]:
# modeldim = 64
# numheads = 8
# hiddendim = 128 # modeldim*2
# lr = 0.001
# dropout = 0.2
# num_epochs = 100
# T_max = num_epochs / 2
# final_lr = 0.1*lr

# net1 = EncoderLayer(modeldim,numheads,hiddendim,dropout)
# loss_fn = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(net1.parameters(), lr=lr, momentum=0.85, weight_decay=0.001)
# # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, final_lr)
# lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

#### Implement the training loop function

In [10]:
import random

random.seed(42)
# k-split validation
totalsize = len(dataset)
indices = list(range(0,totalsize))
random.shuffle(indices)
seg = int(totalsize/10) # number of splits



In [None]:
split = 0 #change for each k-validation split

trainlefti = indices[0:split*seg]
trainrighti = indices[min(split*seg + seg,totalsize):totalsize]
traini = trainlefti + trainrighti
vali = indices[split*seg:min(totalsize,seg*split+seg)]

train_set = torch.utils.data.dataset.Subset(dataset,traini)
val_set = torch.utils.data.dataset.Subset(dataset,vali)

trainloader = DataLoader(train_set, collate_fn=dataset.collate_fn, batch_size=batch, shuffle=True)
valloader = DataLoader(val_set, collate_fn=dataset.collate_fn, batch_size=1, shuffle=True)

In [11]:
def train(net, trainloader, optimizer, loss_fn, lr_scheduler, num_epochs, epoch_count=0, split=0):
    import wandb
    # Put the network in training mode
    net.train()

    # Training loop
    for epoch in range(num_epochs):
        running_loss = 0
        for batch_idx, (data, targets) in enumerate(trainloader):
            # TODO: zero the parameter gradients + forward pass + loss computation + backward pass + weight update
            data = np.array(data)
            tdata = torch.tensor(data)
            optimizer.zero_grad()
            prediction = net(tdata.float())
            ttargets = torch.tensor(targets)

            loss = loss_fn(prediction, ttargets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        lr_scheduler.step()
        print(running_loss)

        ckpt_save_dir = f'/home/arthur/AMRL/Research/Driver-Intent-Prediction/outputs/default/ckpts_{split}'
        if not os.path.exists(ckpt_save_dir):
            os.makedirs(ckpt_save_dir)
        torch.save(net1.state_dict(), os.path.join(ckpt_save_dir, f'epoch_{epoch_count}.pth'))

        wandb.log({"loss": running_loss / (len(trainloader) * batch), "epoch": epoch, "lr": lr_scheduler.get_last_lr()[0]}) 


Evaluate

In [12]:
import sklearn
from sklearn import metrics

In [13]:
def eval(net, loader):
    net.eval()
    se = 0 # sum of correct pred
    total = 0
    f1input = np.zeros(100)
    f1target = np.zeros(100)
    i1 = 0
    for batch_idx, (data, targets) in enumerate(loader):
        data = np.array(data)
        tdata = torch.tensor(data)
        y = net(tdata.float())
        f1input[i1] = int(torch.argmax(y))
        f1target[i1] = int(targets[0])

        i1 += 1
        if int(torch.argmax(y)) == targets[0]:
            se += 1
        total += 1
        
    f1score = sklearn.metrics.f1_score(f1target[0:i1], f1input[0:i1], average = 'macro')
    return float(se/total), f1score

In [None]:
import wandb

wandb.init(
    project="Driver-Intent-Prediction-models",
    config={
        "optimizer": "SGD",
        "lr": lr,
        "dropout": dropout,
        "num_epochs": num_epochs,
        "batch_size": batch,
        "modeldim": modeldim,
        "numheads": numheads,
        "hiddendim": hiddendim,
    }
)

for i in range(num_epochs): # for train/val stationary time before manuever
    print(i)
    train(net1,trainloader,optimizer,loss_fn, lr_scheduler, 1, epoch_count=i)
    evalval, f1val = eval(net1,valloader)
    print('eval: '+str(evalval) + '         f1score: ' + str(f1val))
    wandb.log({"val": evalval, "f1val": f1val})
    if evalval > 0.92:
        break

wandb.finish()

In [None]:
# torch.save(net1.state_dict(), "C:/Users/ykung/Downloads/b4c/LSTMFull91")

net1.load_state_dict(torch.load("/home/arthur/AMRL/Research/Driver-Intent-Prediction/outputs/default/ckpts/epoch_20.pth"))

Cross Validation

In [14]:
import wandb
# Run 5 Fold split
splits = list(range(0,10))
print(splits)
max_evalval_list = []
max_f1val_list = []
for split in splits:
    trainlefti = indices[0:split*seg]
    trainrighti = indices[min(split*seg + seg,totalsize):totalsize]
    traini = trainlefti + trainrighti
    vali = indices[split*seg:min(totalsize,seg*split+seg)]

    train_set = torch.utils.data.dataset.Subset(dataset,traini)
    val_set = torch.utils.data.dataset.Subset(dataset,vali)

    trainloader = DataLoader(train_set, collate_fn=dataset.collate_fn, batch_size=batch, shuffle=True)
    valloader = DataLoader(val_set, collate_fn=dataset.collate_fn, batch_size=1, shuffle=True)

    wandb.init(
        project="Driver-Intent-Prediction-Models-Cross-Validation",
        name=f'cross_validation_split{split}',
        config={
            "optimizer": "SGD",
            "lr": lr,
            "dropout": dropout,
            "num_epochs": num_epochs,
            "batch_size": batch,
            "modeldim": modeldim,
            "numheads": numheads,
            "hiddendim": hiddendim,
            "validation_split": split
        }
    )


    net1 = EncoderLayer(modeldim,numheads,hiddendim,dropout,fchiddendim)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net1.parameters(), lr=lr, momentum=0.9, weight_decay=0.001)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    
    maxevalval, maxf1val = 0, 0
    for i in range(70): # for train/val stationary time before manuever
        print(i)
        train(net1,trainloader,optimizer,loss_fn, lr_scheduler, 1, epoch_count=i, split=split)
        evalval, f1val = eval(net1,valloader)
        print('eval: '+str(evalval) + '         f1score: ' + str(f1val))
        wandb.log({"val": evalval, "f1val": f1val})      
        maxevalval = max(maxevalval, evalval)
        maxf1val = max(maxf1val, f1val)

    max_evalval_list.append(maxevalval)
    max_f1val_list.append(maxf1val)

    wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


[34m[1mwandb[0m: Currently logged in as: [33martzha[0m. Use [1m`wandb login --relogin`[0m to force relogin


0
15.814381718635559
eval: 0.3793103448275862         f1score: 0.17759892689470153
1
12.981415569782257
eval: 0.4827586206896552         f1score: 0.3212088191493226
2
10.56434714794159
eval: 0.6551724137931034         f1score: 0.5480863649807749
3
8.77808392047882
eval: 0.6379310344827587         f1score: 0.6176073926073926
4
7.5812618136405945
eval: 0.7241379310344828         f1score: 0.6956190476190477
5
6.575533837080002
eval: 0.7586206896551724         f1score: 0.7323976488627653
6
6.038686901330948
eval: 0.6896551724137931         f1score: 0.6571658615136876
7
5.357866287231445
eval: 0.7758620689655172         f1score: 0.7457875457875458
8
4.951022207736969
eval: 0.7068965517241379         f1score: 0.6687232574189095
9
4.605223625898361
eval: 0.8103448275862069         f1score: 0.769180245042314
10
4.300177216529846
eval: 0.7931034482758621         f1score: 0.7586209634990123
11
3.9513242840766907
eval: 0.7413793103448276         f1score: 0.6975991140642304
12
3.7379772067070007
e



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▁▃▆██▇██▇██▇██▇█▇██████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
loss,█▇▅▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,██▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val,▁▃▅▇█▇██▇██▇██▇█▇██████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇

0,1
epoch,0.0
f1val,0.68857
loss,0.00225
lr,3e-05
val,0.75862


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668636366618254, max=1.0…

0
16.705782175064087
eval: 0.5689655172413793         f1score: 0.3136173767752715
1
13.243318319320679
eval: 0.6551724137931034         f1score: 0.4087912087912088
2
11.48274677991867
eval: 0.6206896551724138         f1score: 0.37281262858639136
3
9.909279346466064
eval: 0.6551724137931034         f1score: 0.5055574229691876
4
8.664392113685608
eval: 0.6724137931034483         f1score: 0.5016304428454863
5
7.915194272994995
eval: 0.6896551724137931         f1score: 0.555287356321839
6
7.10442653298378
eval: 0.6724137931034483         f1score: 0.5377192982456139
7
6.336791843175888
eval: 0.7068965517241379         f1score: 0.6001990269792128
8
5.900714635848999
eval: 0.6896551724137931         f1score: 0.5911764705882353
9
5.448321968317032
eval: 0.6896551724137931         f1score: 0.5766666666666665
10
5.319511443376541
eval: 0.6724137931034483         f1score: 0.5368522637216555
11
5.060990691184998
eval: 0.6724137931034483         f1score: 0.5368522637216555
12
4.77200710773468
eval:



VBox(children=(Label(value='0.003 MB of 0.014 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.222379…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▁▃▅▅▆▆▅▆▆▆▇▆▇▆█▆███▇▇█▇███▇█████████████
loss,█▆▅▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,██▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val,▁▄▄▅▆▅▅▆▆▆▇▆▇▆█▆▇▇▇▇▇▇▇▇▇▇▇▇▇██▇█▇██▇███

0,1
epoch,0.0
f1val,0.69026
loss,0.00296
lr,3e-05
val,0.75862


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669646783338977, max=1.0…

0
15.91739571094513
eval: 0.4482758620689655         f1score: 0.20054794520547947
1
13.021990299224854
eval: 0.5         f1score: 0.32775318578135487
2
10.724472165107727
eval: 0.5517241379310345         f1score: 0.4529192546583851
3
8.720082402229309
eval: 0.603448275862069         f1score: 0.5570075757575758
4
7.643564820289612
eval: 0.6724137931034483         f1score: 0.702048174048174
5
6.328321695327759
eval: 0.6379310344827587         f1score: 0.682718339861197
6
5.406238794326782
eval: 0.6896551724137931         f1score: 0.7184003152088259
7
4.711895316839218
eval: 0.6724137931034483         f1score: 0.7078132045088567
8
4.156149759888649
eval: 0.6724137931034483         f1score: 0.7009712509712509
9
3.9647527933120728
eval: 0.7068965517241379         f1score: 0.7324444444444443
10
3.550213932991028
eval: 0.7241379310344828         f1score: 0.7408565209807445
11
3.280630901455879
eval: 0.7241379310344828         f1score: 0.7412045088566828
12
2.9676313996315002
eval: 0.724137931



0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▁▂▅▇▇▇▇▇█▇█▇████████████████████████████
loss,█▇▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,██▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val,▁▂▄▅▆▆▇▇▇▇▇▇████████████████████████████

0,1
epoch,0.0
f1val,0.78407
loss,0.00161
lr,3e-05
val,0.77586


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669814866691012, max=1.0…

0
15.013159155845642
eval: 0.4482758620689655         f1score: 0.28857808857808853
1
11.492665231227875
eval: 0.7068965517241379         f1score: 0.5614814814814815
2
8.993632912635803
eval: 0.7586206896551724         f1score: 0.7277106908685855
3
7.44922000169754
eval: 0.7586206896551724         f1score: 0.6164834168822205
4
6.800191700458527
eval: 0.8103448275862069         f1score: 0.7381988304093567
5
6.1469354927539825
eval: 0.7931034482758621         f1score: 0.7288244139761169
6
5.58998915553093
eval: 0.8103448275862069         f1score: 0.7083170163170164
7
5.104052573442459
eval: 0.8103448275862069         f1score: 0.7669010989010989
8
4.477040112018585
eval: 0.8103448275862069         f1score: 0.746953046953047
9
4.195429772138596
eval: 0.8620689655172413         f1score: 0.8131320754716983
10
3.7981453090906143
eval: 0.8620689655172413         f1score: 0.7887365967365968
11
3.525368183851242
eval: 0.8275862068965517         f1score: 0.7734887334887334
12
3.412018433213234
eva



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▁▅▆▇█▇██████████████████████████████████
loss,█▆▄▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,██▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val,▁▅▆▇▇▇██▇▇▇▇▇▇▇▇▇▇▇███▇█████████████████

0,1
epoch,0.0
f1val,0.77382
loss,0.00222
lr,3e-05
val,0.84483


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669028850083124, max=1.0…

0
15.804187774658203
eval: 0.2413793103448276         f1score: 0.1585581601061477
1
13.170600533485413
eval: 0.6551724137931034         f1score: 0.5607536725569513
2
10.492628931999207
eval: 0.7068965517241379         f1score: 0.6062049062049062
3
8.419887483119965
eval: 0.6896551724137931         f1score: 0.6030614169833186
4
7.197699189186096
eval: 0.7413793103448276         f1score: 0.6882018254969074
5
6.477713048458099
eval: 0.7586206896551724         f1score: 0.7142857142857142
6
6.091778457164764
eval: 0.7758620689655172         f1score: 0.733500417710944
7
5.621086448431015
eval: 0.7586206896551724         f1score: 0.7222356739305892
8
5.251305013895035
eval: 0.7413793103448276         f1score: 0.7025974025974024
9
4.787072077393532
eval: 0.7241379310344828         f1score: 0.6989872872225813
10
4.726896047592163
eval: 0.8103448275862069         f1score: 0.7762810369706921
11
4.409168839454651
eval: 0.7586206896551724         f1score: 0.7086728573570679
12
4.108072578907013
eva



0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▁▅▆▇▇▇▇▇▇▇█▇▇██▇████████████████████████
loss,█▇▄▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,██▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val,▁▆▆▇▇▇█▇▇▇██▇███████████████████████████

0,1
epoch,0.0
f1val,0.82441
loss,0.00268
lr,3e-05
val,0.84483


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016671020750072783, max=1.0…

0
15.979933619499207
eval: 0.3793103448275862         f1score: 0.13405572755417958
1
13.293981790542603
eval: 0.46551724137931033         f1score: 0.26095238095238094
2
10.626657664775848
eval: 0.5517241379310345         f1score: 0.46781874039938554
3
9.010565638542175
eval: 0.6551724137931034         f1score: 0.6592250712250712
4
7.5033769607543945
eval: 0.7241379310344828         f1score: 0.7038453500522467
5
6.430320620536804
eval: 0.7413793103448276         f1score: 0.7155038759689923
6
5.666096031665802
eval: 0.7413793103448276         f1score: 0.7416450216450217
7
5.030765861272812
eval: 0.7241379310344828         f1score: 0.6968647591098068
8
4.60918802022934
eval: 0.7586206896551724         f1score: 0.7560292580982236
9
4.259765192866325
eval: 0.7758620689655172         f1score: 0.744510582010582
10
3.987866222858429
eval: 0.7586206896551724         f1score: 0.7347813106433796
11
3.706470012664795
eval: 0.7413793103448276         f1score: 0.7108554061592322
12
3.48807530105114




0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▁▂▆▇▇▇▇▇▇▇▇█▇▇█▇▇███████████████████████
loss,█▇▅▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,██▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val,▁▂▅▇▇▇▇▇▇▇▇█▇▇█▇▇███████████████████████

0,1
epoch,0.0
f1val,0.80409
loss,0.0024
lr,3e-05
val,0.81034


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669812583252983, max=1.0…

0
16.04127788543701
eval: 0.46551724137931033         f1score: 0.14905149051490513
1
13.982216596603394
eval: 0.603448275862069         f1score: 0.3728042328042328
2
11.757802248001099
eval: 0.6551724137931034         f1score: 0.4073346839125239
3
10.274856388568878
eval: 0.6724137931034483         f1score: 0.5639057239057239
4
9.268387258052826
eval: 0.8275862068965517         f1score: 0.7735294117647058
5
7.951789140701294
eval: 0.8275862068965517         f1score: 0.7609090909090909
6
7.314817368984222
eval: 0.8275862068965517         f1score: 0.8026610644257703
7
6.698518246412277
eval: 0.7758620689655172         f1score: 0.7711390759777856
8
6.0312526524066925
eval: 0.8448275862068966         f1score: 0.7931890331890332
9
5.687610059976578
eval: 0.8275862068965517         f1score: 0.7827597402597403
10
5.070540487766266
eval: 0.8448275862068966         f1score: 0.8245331465919701
11
4.840139180421829
eval: 0.8448275862068966         f1score: 0.8248832866479925
12
4.478961080312729




0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▁▃▅▇▇▇▇▇▇█▇█████████████████████████████
loss,█▇▅▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,██▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val,▁▃▄▆▆▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████

0,1
epoch,0.0
f1val,0.88268
loss,0.00285
lr,3e-05
val,0.93103


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01667033028328054, max=1.0)…

0
15.838720083236694
eval: 0.3793103448275862         f1score: 0.15859649122807018
1
12.849116921424866
eval: 0.5517241379310345         f1score: 0.4050739821756816
2
10.450303792953491
eval: 0.6379310344827587         f1score: 0.48278388278388285
3
8.792458951473236
eval: 0.6724137931034483         f1score: 0.5521654815772463
4
7.4720481634140015
eval: 0.7758620689655172         f1score: 0.7142995169082126
5
6.589536160230637
eval: 0.7931034482758621         f1score: 0.7501587301587301
6
5.784304440021515
eval: 0.8103448275862069         f1score: 0.7783047870004391
7
5.447883576154709
eval: 0.8103448275862069         f1score: 0.7802020202020201
8
4.874263256788254
eval: 0.7931034482758621         f1score: 0.7685513452955313
9
4.31824404001236
eval: 0.8448275862068966         f1score: 0.8429090909090909
10
3.949472576379776
eval: 0.7931034482758621         f1score: 0.7685513452955313
11
3.7104280292987823
eval: 0.8275862068965517         f1score: 0.8137062937062938
12
3.409832239151001



0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▁▃▅▇▇▇▇████████████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
loss,█▇▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,██▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val,▁▄▅▇▇▇▇█████████████████████████████████

0,1
epoch,0.0
f1val,0.8023
loss,0.00208
lr,3e-05
val,0.82759


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666942541669414, max=1.0)…

0
15.861315369606018
eval: 0.41379310344827586         f1score: 0.2704761904761905
1
12.663205981254578
eval: 0.5172413793103449         f1score: 0.376025641025641
2
10.026461124420166
eval: 0.5344827586206896         f1score: 0.45579550842708727
3
8.577489733695984
eval: 0.5517241379310345         f1score: 0.5213544625309332
4
7.466750204563141
eval: 0.6551724137931034         f1score: 0.5963333333333333
5
6.798234939575195
eval: 0.6551724137931034         f1score: 0.6311885404568331
6
6.268428921699524
eval: 0.7068965517241379         f1score: 0.6400136798905608
7
5.512543946504593
eval: 0.7068965517241379         f1score: 0.6676767676767676
8
5.142245322465897
eval: 0.7413793103448276         f1score: 0.672319783197832
9
4.835874229669571
eval: 0.7241379310344828         f1score: 0.6877260981912144
10
4.4754011034965515
eval: 0.6724137931034483         f1score: 0.6080698287220025
11
4.12048202753067
eval: 0.7586206896551724         f1score: 0.7465865699161364
12
4.048694223165512
ev



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▁▂▄▆▆▆▅▇█▇█████▇████████████████████████
loss,█▆▄▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,██▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val,▁▃▃▅▆▇▆▆▇▆▇████▇████████████████████████

0,1
epoch,0.0
f1val,0.8137
loss,0.00272
lr,3e-05
val,0.81034


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668562399960743, max=1.0…

0
16.093876719474792
eval: 0.4482758620689655         f1score: 0.22960358056265986
1
13.215928435325623
eval: 0.5517241379310345         f1score: 0.49828239490085985
2
11.024395942687988
eval: 0.5689655172413793         f1score: 0.4377257525083612
3
9.117416977882385
eval: 0.6896551724137931         f1score: 0.6703529411764706
4
7.903863370418549
eval: 0.6551724137931034         f1score: 0.6387878787878789
5
6.895468175411224
eval: 0.7241379310344828         f1score: 0.6934969994544463
6
6.202450394630432
eval: 0.7068965517241379         f1score: 0.6891943521594686
7
5.89722803235054
eval: 0.7068965517241379         f1score: 0.688223109289028
8
5.4230508506298065
eval: 0.7758620689655172         f1score: 0.7686080586080586
9
5.007880210876465
eval: 0.7413793103448276         f1score: 0.7212867274569403
10
4.5161310732364655
eval: 0.7241379310344828         f1score: 0.7274444314231024
11
4.214025169610977
eval: 0.7413793103448276         f1score: 0.7212867274569403
12
3.9864901304244995



0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▁▄▇▇▇█▇█▇██▇██▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
loss,█▇▅▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,██▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val,▁▃▆▇▆█▇█▇██▇██▇▇▇▇▇█▇▇██████████████████

0,1
epoch,0.0
f1val,0.72966
loss,0.00283
lr,3e-05
val,0.77586


In [15]:
print(np.average(max_evalval_list))
print(np.std(max_evalval_list))
print(np.average(max_f1val_list))
print(np.std(max_f1val_list))

0.8258620689655171
0.047811809047800156
0.8035354639736083
0.0508981653410157


In [None]:
# the values recorded from the best models from each split
recacc = [0.89655,0.89655,0.8793,0.862,0.8448,0.8793,0.8448,0.8621,0.9138,0.8448]
recf1 = [0.9106,0.84262,0.870255,0.8374,0.8449,0.88,0.874,0.8202,0.8905,0.79139]

In [None]:
print(np.average(recacc))
print(np.std(recacc))
print(np.average(recf1))
print(np.std(recf1))

#### Varying time to manuever

In [23]:
def train2(net, trainloader, optimizer, loss_fn, num_epochs):
    # Put the network in training mode
    net.train()

    # Training loop
    for epoch in range(num_epochs):
        running_loss = 0
        for batch_idx, (data, targets) in enumerate(trainloader):
            # TODO: zero the parameter gradients + forward pass + loss computation + backward pass + weight update
            data = np.array(data)
            tdata = torch.tensor(data, dtype=torch.float32)
            ttargets = torch.tensor(targets)
            zten = torch.zeros((len(data),30,32))
            
            optimizer.zero_grad()
            prediction = net(tdata)
            loss = loss_fn(prediction, ttargets)           
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
            optimizer.zero_grad()
            prediction = net(torch.cat((zten,tdata[:,0:120,:]), 1))
            loss = loss_fn(prediction, ttargets)           
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
                             
            optimizer.zero_grad()
            prediction = net(torch.cat((zten,zten,tdata[:,0:90,:]), 1))
            loss = loss_fn(prediction, ttargets)           
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
            optimizer.zero_grad()
            prediction = net(torch.cat((zten,zten,zten,tdata[:,0:60,:]), 1))
            loss = loss_fn(prediction, ttargets)           
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
                             
            optimizer.zero_grad()
            prediction = net(torch.cat((zten,zten,zten,zten,tdata[:,0:30,:]), 1))
            loss = loss_fn(prediction, ttargets)           
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
                             
        print(running_loss)

In [24]:
net1 = EncoderLayer(modeldim,numheads,hiddendim,dropout,fchiddendim)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net1.parameters(), lr=lr, momentum=0.9, weight_decay=0.001)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

In [25]:
for i in range(50): # train/val for varying time to manuever
    print(i)
    train2(net1,trainloader,optimizer,loss_fn,1)
    evalval, f1val = eval(net1,valloader)
    print('eval: '+str(evalval) + '         f1score: ' + str(f1val))
    if evalval > 0.80:
        break

0
69.44520598649979
eval: 0.6724137931034483         f1score: 0.6376245210727969
1
52.68117117881775
eval: 0.5517241379310345         f1score: 0.564350400557297
2
43.11290442943573
eval: 0.6896551724137931         f1score: 0.6618929765886288
3
36.13350674510002
eval: 0.7758620689655172         f1score: 0.7739944602013568
4
35.403600201010704
eval: 0.7241379310344828         f1score: 0.7096498262496659
5
31.96283385157585
eval: 0.7586206896551724         f1score: 0.7068055833834236
6
31.050378516316414
eval: 0.6379310344827587         f1score: 0.5905820105820105
7
26.887646317481995
eval: 0.8448275862068966         f1score: 0.8193342888995062


In [39]:
def eval2(net, loader):
    net.eval()
    t2m = np.zeros((100))
    totalcorrect = np.zeros((5,100))
    i1 = 0
    for batch_idx, (data, targets) in enumerate(loader):
        data = np.array(data)
        tdata = torch.tensor(data, dtype=torch.float32)
        zten = torch.zeros((len(data),30,32))
        
        y = net(torch.cat((zten,zten,zten,zten,tdata[:,0:30,:]), 1))
        if int(torch.argmax(y)) == targets[0]:
            t2m[i1] = 5
            totalcorrect[0,i1] = 1
            
        y = net(torch.cat((zten,zten,zten,tdata[:,0:60,:]), 1))
        if int(torch.argmax(y)) == targets[0]:
            if t2m[i1] == 0:
                t2m[i1] = 4
            totalcorrect[1,i1] = 1
            
        y = net(torch.cat((zten,zten,tdata[:,0:90,:]), 1))
        if int(torch.argmax(y)) == targets[0]:
            if t2m[i1] == 0:
                t2m[i1] = 3
            totalcorrect[2,i1] = 1
            
        y = net(torch.cat((zten,tdata[:,0:120,:]), 1))
        if int(torch.argmax(y)) == targets[0]:
            if t2m[i1] == 0:
                t2m[i1] = 2
            totalcorrect[3,i1] = 1
            
        y = net(tdata)
        if int(torch.argmax(y)) == targets[0]:
            if t2m[i1] == 0:
                t2m[i1] = 1
            totalcorrect[4,i1] = 1
            
        i1+=1
        
    return t2m[0:i1], totalcorrect[:,0:i1]

In [40]:
time_to_maneuver, correct_expansion = eval2(net1, valloader)

In [42]:
lt2m = len(time_to_maneuver)
print('Total maneuvers: ' + str(lt2m))
instancescorrect = 0
for i in time_to_maneuver:
    if i == 5:
        instancescorrect += 1

per5 = instancescorrect/lt2m
avgtime = 5*instancescorrect
print('Percentage correct 5 s in advance: ' + str(per5))

numleft = lt2m - instancescorrect
instancescorrect = 0
for i in time_to_maneuver:
    if i == 4:
        instancescorrect += 1

per4 = instancescorrect/numleft
avgtime = avgtime + 4*instancescorrect
print('Percentage correct 4 s in advance: ' + str(per4))

numleft = numleft - instancescorrect
instancescorrect = 0
for i in time_to_maneuver:
    if i == 3:
        instancescorrect += 1

per3 = instancescorrect/numleft
avgtime = avgtime + 3*instancescorrect
print('Percentage correct 3 s in advance: ' + str(per3))

numleft = numleft - instancescorrect
instancescorrect = 0
for i in time_to_maneuver:
    if i == 2:
        instancescorrect += 1

per2 = instancescorrect/numleft
avgtime = avgtime + 2*instancescorrect
print('Percentage correct 2 s in advance: ' + str(per2))

numleft = numleft - instancescorrect
instancescorrect = 0
for i in time_to_maneuver:
    if i == 1:
        instancescorrect += 1

per1 = instancescorrect/numleft
avgtime = avgtime + 1*instancescorrect
print('Percentage correct 1 s in advance: ' + str(per1))

avgtime = avgtime/(lt2m-numleft)
print('Avg time for correct prediction: ' + str(avgtime))

Total maneuvers: 58
Percentage correct 5 s in advance: 0.7068965517241379
Percentage correct 4 s in advance: 0.35294117647058826
Percentage correct 3 s in advance: 0.45454545454545453
Percentage correct 2 s in advance: 0.5
Percentage correct 1 s in advance: 0.0
Avg time for correct prediction: 4.545454545454546


In [47]:
time_to_maneuver.shape

(58,)

In [60]:
fillerarray = np.zeros((len(time_to_maneuver)))
fillerarray[:] = time_to_maneuver
tim2mset = np.stack((time_to_maneuver, fillerarray), axis=0) #np.concatenate((time_to_maneuver,fillerarray),axis=0)

fillerarray = np.zeros((5,len(time_to_maneuver)))
fillerarray[:,:] = correct_expansion
expandedset = np.concatenate((tim2mset,fillerarray),axis=0)

In [61]:
np.shape(tim2mset)

(2, 58)

In [72]:
np.shape(expandedset)
# np.shape(time_to_maneuver)

(7, 58)

In [None]:
np.save("C:/Users/ykung/Downloads/b4c/timevaryingtim2mset", tim2mset)

In [None]:
np.save("C:/Users/ykung/Downloads/b4c/expandedset", expandedset)

In [73]:
num5s = 0
num4s = 0
num3s = 0
num2s = 0
num1s = 0
total = 580
tolcorr = 0
for i in range(2):
    for j in range(58):
        if tim2mset[i,j] == 5:
            num5s += 1
            tolcorr += 1
        elif tim2mset[i,j] == 4:
            num4s += 1
            tolcorr += 1
        elif tim2mset[i,j] == 3:
            num3s += 1
            tolcorr += 1
        elif tim2mset[i,j] == 2:
            num2s += 1
            tolcorr += 1
        elif tim2mset[i,j] == 1:
            num1s += 1
            tolcorr += 1
print('5s pecentage: ' + str(num5s/total))
print('4s pecentage: ' + str(num4s/(total-num5s)))
print('3s pecentage: ' + str(num3s/(total-num5s-num4s)))
print('2s pecentage: ' + str(num2s/(total-num5s-num4s-num3s)))
print('1s pecentage: ' + str(num1s/(total-num5s-num4s-num3s-num2s)))
print('total avg pred time: ' + str((num5s*5+num4s*4+num3s*3+num2s*2+num1s)/tolcorr))

5s pecentage: 0.1413793103448276
4s pecentage: 0.024096385542168676
3s pecentage: 0.0205761316872428
2s pecentage: 0.012605042016806723
1s pecentage: 0.0
total avg pred time: 4.545454545454546


In [None]:
expandedset = np.load("C:/Users/ykung/Downloads/b4c/expandedset.npy")

In [75]:
np.shape(expandedset)

(7, 58)

In [74]:
num5s = 0
num4s = 0
num3s = 0
num2s = 0
num1s = 0
total = 580
for i in range(10):
    for j in range(58):
        if expandedset[i,0,j] == 1:
            num5s += 1
            
        if expandedset[i,1,j] == 1:
            num4s += 1
            
        if expandedset[i,2,j] == 1:
            num3s += 1
            
        if expandedset[i,3,j] == 1:
            num2s += 1
            
        if expandedset[i,4,j] == 1:
            num1s += 1
            
print('5s pecentage: ' + str(num5s/total))
print('4s pecentage: ' + str(num4s/total))
print('3s pecentage: ' + str(num3s/total))
print('2s pecentage: ' + str(num2s/total))
print('1s pecentage: ' + str(num1s/total))

IndexError: too many indices for array: array is 2-dimensional, but 3 were indexed