# LSTM


### Imports

In [64]:
import yaml
import torch
import torch.nn as nn
import torch.nn.functional as F

In [65]:
import torch.optim as optim
import torch.utils.data as data
import math
import copy

### Data Loading

In [66]:
# cd Downloads

In [67]:
# cd b4c

In [68]:
# cd Brains4Cars

Straight from github

In [69]:
import os
from os.path import join
import copy

from multiprocessing import Pool
import tqdm

import numpy as np
import pickle
from torch.utils.data import Dataset, DataLoader

ACTION_TO_ID_MAP= {
    'end_action':   0,
    'lchange':      1,
    'lturn':        2,
    'rchange':      3,
    'rturn':        4
}

rng = np.random.default_rng(seed=42)
MAX_FRAMES=150


class B4CDataset(Dataset):
    """
    end_action - 0
    lchange - 1
    lturn - 2
    rchange - 3
    rturn - 4
    """
    def __init__(self, data_cfg, split="train", create_dataset=False):
        self.data_cfg   = data_cfg['DATALOADER_CONFIG']
        self.actions    = self.data_cfg['ACTIONS']
        self.cameras    = self.data_cfg['CAMERAS']
        self.data_dir   = self.data_cfg['DATA_DIR']
        self.split      = split

        videos_dict = self.read_videos_by_action()        

        if create_dataset:
            # self.create_gt_road_labels(videos_dict)
            self.generate_imagesets(videos_dict)
        
        self.image_sets = {}
        splits_process = [self.split] if self.split in ["train", "val", "test"] else ["train", "val", "test"]

        for camera in self.cameras:
            if camera not in self.image_sets:
                self.image_sets[camera] = []

            for curr_split in splits_process:
                imageset_path = join(self.data_dir, f'ImageSets_{camera}', f'{curr_split}.txt')
                self.image_sets[camera].extend([line.strip() for line in open(imageset_path, 'r')])

        # Drop nondivisible videos for fold splits from end
        if "fold" in self.split:
            num_folds = int(self.split.split("_")[0])
            num_to_drop = len(self) % num_folds
            self.image_sets = {k: v[:-num_to_drop] for k, v in self.image_sets.items()}

        print(f'Added {len(self)} videos to the dataset.')

    def read_videos_by_action(self):
        videos_dict = {}
        # Combine image sets for all cameras
        for camera in self.cameras:
            camera_dir = join(self.data_dir, camera+"_processed")

            for action in self.actions:
                action_dir = join(camera_dir, action)

                if action not in videos_dict:
                    videos_dict[action] = []

                # Take the set intersection between all the cameras sequentially
                video_action_set = set([f for f in os.listdir(action_dir) if os.path.isdir(join(action_dir, f))])

                if len(videos_dict[action])==0:
                    videos_dict[action] = video_action_set
                else:
                    videos_dict[action] = videos_dict[action].intersection(video_action_set)

        # Convert videos_dict back to list
        for action in self.actions:
            videos_dict[action] = list(videos_dict[action])

        # Check that files exist
        for camera in self.cameras:
            camera_dir = join(self.data_dir, camera+"_processed")

            for action in self.actions:
                action_dir = join(camera_dir, action)

                for video_dir in videos_dict[action]:
                    video_dir = join(action_dir, video_dir)
                    assert os.path.exists(video_dir), f"Video directory {video_dir} does not exist."

        print("Finished reading all valid videos by directory.")
        return videos_dict

    @staticmethod
    def write_list_to_file(file_path, data_list):
        """
        Create/overwrite a .txt file and write each line of the Python list to a new line in the file.

        Parameters:
            file_path : str
                The path to the .txt file.
            data_list : list
                The Python list containing data to write to the file.
        """
        with open(file_path, 'w') as file:
            file.writelines(f"{item}\n" for item in data_list)

    def check_data_quality(self, video_subdirs, camera):
        """
        Check that all videos have the same number of frames and that the number of frames is less than MAX_FRAMES.

        Parameters:
            videos_dict : dict
                Dictionary containing the list of videos for each action.
        """

        valid_videos_mask = np.zeros(len(video_subdirs), dtype=bool)
        for video_idx, video_subdir in enumerate(video_subdirs):
            video_path = join(self.data_dir, video_subdir)
            data_dict = {}

            # Check that all videos have the full frame set
            if camera=="face":
                data_dict['gt_gazepose'] = self.get_face_labels(video_path)
                valid_videos_mask[video_idx] = len(data_dict['gt_gazepose'])>=MAX_FRAMES-1 # gazepose has 149
            elif camera=="road":
                data_dict['gt_bbox'], data_dict['gt_lanes'] = self.get_road_labels(video_path)
                valid_videos_mask[video_idx] = len(data_dict['gt_bbox'])>=MAX_FRAMES and len(data_dict['gt_lanes'])>=MAX_FRAMES

            if valid_videos_mask[video_idx]==0:
                print(f'Video {video_subdir} does not have the full frame set, skipping...')

        return valid_videos_mask
            

    def generate_imagesets(self, videos_dict):    
        print("Generating imagesets...")
        facecam_imageset_dict = {'train': [], 'val': [], 'test': []}
        roadcam_imageset_dict = copy.deepcopy(facecam_imageset_dict)
        train_pct, val_pct, test_pct = 0.7, 0.15, 0.15
        for action, action_videos in videos_dict.items():
            road_cam_action_dir = join('road_camera_processed_combined', action)
            road_cam_video_labels = np.array([join(road_cam_action_dir, f) for f in action_videos])
            road_cam_video_mask = self.check_data_quality(road_cam_video_labels, "road")

            face_cam_action_dir = join('face_camera_processed', action)
            face_cam_video_labels = np.array([join(face_cam_action_dir, f) for f in action_videos])
            face_cam_video_mask = self.check_data_quality(face_cam_video_labels, "face")

            combined_cam_video_mask = np.logical_and(road_cam_video_mask, face_cam_video_mask)
            road_cam_video_labels = road_cam_video_labels[combined_cam_video_mask]
            face_cam_video_labels = face_cam_video_labels[combined_cam_video_mask]

            road_cam_video_labels_sort_idx = np.argsort(road_cam_video_labels)
            road_cam_video_labels = road_cam_video_labels[road_cam_video_labels_sort_idx]
            face_cam_video_labels = face_cam_video_labels[road_cam_video_labels_sort_idx]

            # Ensure file order matches for road and face camera
            for i in range(len(road_cam_video_labels)):
                assert os.path.basename(road_cam_video_labels[i])==os.path.basename(face_cam_video_labels[i]), 'Video labels do not match'
            num_videos = len(road_cam_video_labels)

            num_train       = int(num_videos * train_pct)
            num_val         = int(num_videos * val_pct)
            num_test        = int(num_videos * test_pct)

            indices = np.arange(0, num_videos, 1)
            rng.shuffle(indices)

            videos_indices = {"train": [], "val": [], "test": []}
            videos_indices['train'], videos_indices['val'], videos_indices['test'] = indices[:num_train], \
                indices[num_train:num_train+num_val],  indices[num_train+num_val:num_train+num_val+num_test]

            for split in facecam_imageset_dict.keys():
                roadcam_imageset_dict[split].extend(road_cam_video_labels[videos_indices[split]].tolist())
                facecam_imageset_dict[split].extend(face_cam_video_labels[videos_indices[split]].tolist())

        # Dump to imageset files for road and face camera
        roadcam_imagesets_dir = join(self.data_dir, "ImageSets_road_camera")
        facecam_imagesets_dir = join(self.data_dir, "ImageSets_face_camera")
        if not os.path.exists(roadcam_imagesets_dir):
            print(f'Video root directory {roadcam_imagesets_dir} does not exist. Creating...')
            os.mkdir(roadcam_imagesets_dir)
        if not os.path.exists(facecam_imagesets_dir):
            print(f'Video root directory {facecam_imagesets_dir} does not exist. Creating...')
            os.mkdir(facecam_imagesets_dir)

        for split_key in facecam_imageset_dict.keys():
            road_cam_split_path = join(roadcam_imagesets_dir, f'{split_key}.txt')
            face_cam_split_path = join(facecam_imagesets_dir, f'{split_key}.txt')
            print(f'Saving imageset file {split_key} for road {road_cam_split_path} and {face_cam_split_path}')
            self.write_list_to_file(road_cam_split_path, roadcam_imageset_dict[split_key])
            self.write_list_to_file(face_cam_split_path, facecam_imageset_dict[split_key])

    def __len__(self):
        num_files = 0
        for camera in self.cameras: 
            assert num_files==0 or num_files==len(self.image_sets[camera]), "Number files in dataset not correct"
            num_files=len(self.image_sets[camera])
        return num_files
    
    def collate_fn(self, data):
        # print(data[0][1])
        data_batch = [bi[0] for bi in data]
        action_batch = [bi[1] for bi in data]
        return data_batch, action_batch

    def get_face_labels(self, label_dir):  
        gazepose_path = join(label_dir, 'gazepose.npy')
        assert os.path.exists(gazepose_path), f'Label file {gazepose_path} does not exist'
        gt_gazepose = np.load(gazepose_path)

        if gt_gazepose.shape[0] < MAX_FRAMES:
            # print(f'Gaze pose {gazepose_path} has less than 150 frames, padding with zeros...')
            # Pad with zeros
            gt_gazepose = np.pad(gt_gazepose, ((0, MAX_FRAMES-gt_gazepose.shape[0]), (0, 0)), mode='constant')
        # Assume gt_gazepose is not smaller than MAX_FRAMES
        num_frames = min(MAX_FRAMES, gt_gazepose.shape[0])
        gt_gazepose = gt_gazepose[:num_frames, :]

        return gt_gazepose

    def get_road_labels(self, label_dir):
        
        bbox_file = join(label_dir, 'bbox_labels.pkl')
        lane_file = join(label_dir, 'lane_labels.pkl')

        assert os.path.exists(bbox_file), f'Label directory {bbox_file} does not exist'
        assert os.path.exists(lane_file), f'Label directory {lane_file} does not exist'

        # Load bbox detections
        with open(bbox_file, 'rb') as f:
            gt_bbox = pickle.load(f)
        # Load road labels
        with open(lane_file, 'rb') as f:
            gt_lanes = pickle.load(f)

        gt_bbox = gt_bbox[:MAX_FRAMES]
        gt_lanes = gt_lanes[:MAX_FRAMES]

        return gt_bbox, gt_lanes
    
    def get_action_label(self, video_dir):
        action_label = video_dir.split('/')[-2]
        assert action_label in ACTION_TO_ID_MAP.keys(), f'Action {action_label} not in action map'
        action_id = ACTION_TO_ID_MAP[action_label]
        return action_id

    @staticmethod
    def combine_img_labels(args):
        split_video_path, combined_video_path = args
        img_label_files = [f for f in os.listdir(split_video_path) if f.endswith('.pkl')]

        MAX_NUM_BBOXES = 5 
        full_img_label_np = np.ones((MAX_FRAMES+1, MAX_NUM_BBOXES*5)) * -1
        num_img_label_files = len(img_label_files)
        for img_label_idx, img_label_file in enumerate(img_label_files):
            img_label_path = join(split_video_path, img_label_file)
            with open(img_label_path, 'rb') as f:
                img_data = pickle.load(f)
                IMG_W, IMG_H = 720, 480

                #1 Convert to xc, yc, w, h
                x1, x2, y1, y2 = img_data['xyxy'][:, 0], img_data['xyxy'][:, 2], img_data['xyxy'][:, 1], img_data['xyxy'][:, 3]
                xc = (x1 + x2) / 2
                yc = (y1 + y2) / 2
                h = (x2 - x1)
                w = (y2 - y1)
            
                #2 Only keep bboxes with h and w < 0.33 (Ignore large bboxes of self)
                bbox_size_mask = np.logical_and(h < IMG_W*0.33, w < IMG_H*0.33)

                #3 Only keep labels with class_ids = 0, 1, 2, 3, 4 (Ignore 5 Date)
                class_ids = img_data['class_id'] 
                class_ids_mask = np.logical_and(class_ids >= 0, class_ids <= 4)

                bbox_mask = np.logical_and(bbox_size_mask, class_ids_mask)
                proc_img_label = np.ones((MAX_NUM_BBOXES, 5), dtype=int)*-1 # max of five bbox detections per image

                if np.sum(bbox_mask)>0:
                    xc = xc[bbox_mask].astype(int)
                    yc = yc[bbox_mask].astype(int)
                    h = h[bbox_mask].astype(int)
                    w = w[bbox_mask].astype(int)
                    class_ids = class_ids[bbox_mask]
                    gt_objs = np.stack((xc, yc, w, h, class_ids), axis=1)

                    #4 Select top 5 largest boxes
                    gt_obj_areas = gt_objs[:, 2] * gt_objs[:, 3]
                    gt_objs_sort_idx = np.argsort(-gt_obj_areas, kind='stable') # Sort high to low
                    num_objs = min(MAX_NUM_BBOXES, len(gt_objs_sort_idx))

                    proc_img_label[:num_objs, :] = gt_objs[gt_objs_sort_idx[:num_objs], :]

                proc_img_label = proc_img_label.flatten()
                proc_img_label = np.expand_dims(proc_img_label, axis=0)
                full_img_label_np[img_label_idx] = proc_img_label

        assert os.path.exists(combined_video_path), f'Video label directory {combined_video_path} does not exist'
        video_path = join(combined_video_path, 'bbox_labels.pkl')
        with open(video_path, 'wb') as f:
            pickle.dump(full_img_label_np, f)
        print("Saved combined bbox labels to: ", video_path)

        # Load, pad, save lane dets
        original_label_dir = split_video_path.replace('road_camera_processed', 'road_camera').replace('labels/', '')
        road_path = join(original_label_dir+".txt")
        gt_lanes = np.loadtxt(road_path, delimiter=',', dtype=int).reshape(1, -1)
        gt_lanes_padded = np.ones((MAX_FRAMES, 3)) * -1
        num_valid_gt_lanes = min(num_img_label_files, MAX_FRAMES)
        gt_lanes_padded[:num_valid_gt_lanes] = np.repeat(gt_lanes, [num_valid_gt_lanes], axis=0)

        lanes_path = join(combined_video_path, 'lane_labels.pkl')
        with open(lanes_path, 'wb') as f:
            pickle.dump(gt_lanes_padded, f)

    def create_gt_road_labels(self, videos_dict):
        split_video_path_list = []
        combined_video_path_list = []
        for action, action_videos in videos_dict.items():
            for video in action_videos:
                video_path = join(self.data_dir, 'road_camera_processed', 'labels', action, video)
                if not os.path.exists(video_path):
                    print(f'Video path {video_path} does not exist')
                    continue
                assert os.path.exists(video_path), f'Video path {video_path} does not exist'

                video_label_dir = join(self.data_dir, "road_camera_processed_combined", action, video)
                if not os.path.exists(video_label_dir):
                    print("Creating directory: ", video_label_dir)
                    os.makedirs(video_label_dir)
                
                split_video_path_list.append(video_path)
                combined_video_path_list.append(video_label_dir)
            # self.combine_img_labels((split_video_path_list[0], combined_video_path_list[0]))
        pool = Pool(processes=16)
        for _ in tqdm.tqdm(pool.imap_unordered(self.combine_img_labels, zip(split_video_path_list, 
            combined_video_path_list)), total=len(split_video_path_list)):
            pass

    def __getitem__(self, idx):
        data_dict = {}
        action_id=None

        for camera in self.cameras:
            video_subdir    = self.image_sets[camera][idx]
            video_fulldir   = join(self.data_dir, video_subdir)

            if camera == 'face_camera':
                data_dict['gt_gazepose'] = self.get_face_labels(video_fulldir)
            elif camera == 'road_camera':
                # Load all pickle files in the video directory
                data_dict['gt_bbox'], data_dict['gt_lanes'] = self.get_road_labels(video_fulldir)

            if action_id is None:
                action_id = self.get_action_label(video_fulldir)

        # Perform Data Augmentations
        

        # Sort dict by key so that it is is consistent
        sorted_data_dict = dict(sorted(data_dict.items(), key=lambda item: item[0]))

        sorted_gt_np = np.empty((MAX_FRAMES, 0))
        # Stack all values from data_dict into a single np array

        for _, items in sorted_data_dict.items():
            sorted_gt_np = np.hstack((sorted_gt_np, items))

        # TOOD: Extract action label from Imageset file
        #  150 x [ (5x5) (2) (2) (3) ] # Pad if not enough frames obj_detections, gazepose, lane_detections # 150 x 32
        return sorted_gt_np, action_id # processed_input, action_label one hot vector 

config file

In [70]:
cfg=None
with open("../config/lstm_all.yaml", 'r') as file:
    cfg = yaml.safe_load(file)

In [71]:
batch = 50

make dataset (do once)

In [72]:
dataset = B4CDataset(cfg, split="5_fold", create_dataset=False)

Finished reading all valid videos by directory.
Added 585 videos to the dataset.


In [73]:
#dataloader = DataLoader(dataset, collate_fn=dataset.collate_fn, batch_size=batch, shuffle=True)

In [74]:
#valdataset = B4CDataset(cfg, split="val", create_dataset=False)
#val_loader = DataLoader(valdataset, collate_fn=dataset.collate_fn, batch_size=1, shuffle=True)

In [75]:
#testdataset = B4CDataset(cfg, split="test", create_dataset=False)
#test_loader = DataLoader(testdataset, collate_fn=dataset.collate_fn, batch_size=1, shuffle=True)

#### Define your network, loss function and optimizer

In [76]:
#multiheaded attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        #d_model must be divisible by num_heads
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        nn.init.normal_(self.W_q.weight, mean=0., std=np.sqrt(2 / (d_model+d_model)))
        nn.init.normal_(self.W_k.weight, mean=0., std=np.sqrt(2 / (d_model+d_model)))
        nn.init.normal_(self.W_v.weight, mean=0., std=np.sqrt(2 / (d_model+d_model)))
        nn.init.normal_(self.W_o.weight, mean=0., std=np.sqrt(2 / (d_model+d_model)))
        # nn.init.xavier_uniform_(self.W_q.weight)
        # nn.init.xavier_uniform_(self.W_k.weight)
        # nn.init.xavier_uniform_(self.W_v.weight)
        # nn.init.xavier_uniform_(self.W_o.weight)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output
    
#position feedforward
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

        # nn.init.xavier_uniform_(self.fc1.weight)
        # nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))
    
#position encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length, dropout):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        x =  x + self.pe[:, :x.size(1)]
        return self.dropout(x)

#encoder no decoder
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout, d_hidden):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        # self.fc1 = nn.Linear(d_model*150,5)
        self.pfc1 = nn.Linear(4,32)
        self.pfc2 = nn.Linear(25,16)
        self.pfc3 = nn.Linear(3,16)
        self.sig = nn.Sigmoid()

        self.x1norm = nn.LayerNorm(32)
        self.x2norm = nn.LayerNorm(16)
        self.x3norm = nn.LayerNorm(16)

        self.pos1 = PositionalEncoding(32, 150, dropout)
        self.pos2 = PositionalEncoding(16, 150, dropout)
        self.pos3 = PositionalEncoding(16, 150, dropout)

        self.mlp_head = nn.Sequential(
            nn.Linear(d_model*150, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, 5)
        )

        # self.fc1 = nn.Linear(d_model*150,d_hidden)
        # self.relu = nn.ReLU()
        # self.fc2 = nn.Linear(d_hidden, 5)

        # nn.init.xavier_uniform_(self.fc1.weight)
        # nn.init.xavier_uniform_(self.pfc1.weight)
        # nn.init.xavier_uniform_(self.pfc2.weight)
        # nn.init.xavier_uniform_(self.pfc3.weight)
        
    def forward(self, x):
        x1 = self.pfc1(x[:,:,25:29])
        x1 = self.sig(x1)
        x2 = self.pfc2(x[:,:,0:25])
        x2 = self.sig(x2)
        x3 = self.pfc3(x[:,:,29:32])
        x3 = self.sig(x3)

        x1 = self.pos1(x1)
        x2 = self.pos2(x2)
        x3 = self.pos3(x3)
        # x1 = self.pos1(self.x1norm(x1))
        # x2 = self.pos2(self.x2norm(x2))
        # x3 = self.pos3(self.x3norm(x3))
        
        #x1 = torch.flatten(x1, start_dim=1)
        #x2 = torch.flatten(x2, start_dim=1)
        #x3 = torch.flatten(x3, start_dim=1)
        x = torch.cat((x1,x2,x3), 2)
        
        attn_output = self.self_attn(x, x, x, mask=None)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        x = torch.flatten(x, start_dim=1)
        x = self.mlp_head(x)

        return x

In [77]:
# Linear ramp up optimizer then inverse sqrt decay
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
   

In [78]:
# fine-violet-267
modeldim = 64
numheads = 8
hiddendim = 128 # modeldim*2
lr = 0.001
dropout = 0.10
num_epochs = 100
fchiddendim = 256
T_max = num_epochs / 2
final_lr = 0.1*lr

net1 = EncoderLayer(modeldim,numheads,hiddendim,dropout,fchiddendim)
loss_fn = nn.CrossEntropyLoss()
baseoptimizer = torch.optim.SGD(net1.parameters(), lr=lr, momentum=0.9, weight_decay=0.0001)
optimizer = NoamOpt(modeldim, 1, 100, torch.optim.Adam(net1.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
lr_scheduler=None
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, final_lr)
# lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

In [79]:
# modeldim = 64
# numheads = 8
# hiddendim = 128 # modeldim*2
# lr = 0.001
# dropout = 0.2
# num_epochs = 100
# T_max = num_epochs / 2
# final_lr = 0.1*lr

# net1 = EncoderLayer(modeldim,numheads,hiddendim,dropout)
# loss_fn = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(net1.parameters(), lr=lr, momentum=0.85, weight_decay=0.001)
# # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, final_lr)
# lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

#### Implement the training loop function

In [80]:
import random

random.seed(42)
# k-split validation
totalsize = len(dataset)
indices = list(range(0,totalsize))
random.shuffle(indices)
seg = int(totalsize/10) # number of splits



In [81]:
split = 0 #change for each k-validation split

trainlefti = indices[0:split*seg]
trainrighti = indices[min(split*seg + seg,totalsize):totalsize]
traini = trainlefti + trainrighti
vali = indices[split*seg:min(totalsize,seg*split+seg)]

train_set = torch.utils.data.dataset.Subset(dataset,traini)
val_set = torch.utils.data.dataset.Subset(dataset,vali)

trainloader = DataLoader(train_set, collate_fn=dataset.collate_fn, batch_size=batch, shuffle=True)
valloader = DataLoader(val_set, collate_fn=dataset.collate_fn, batch_size=1, shuffle=True)

In [82]:
def train(net, trainloader, optimizer, loss_fn, lr_scheduler, num_epochs, epoch_count=0, split=0, trial=0):

    import wandb
    # Put the network in training mode
    net.train()

    # Training loop
    for epoch in range(num_epochs):
        running_loss = 0

        running_correct = 0
        running_total = 0
        for batch_idx, (data, targets) in enumerate(trainloader):
            # TODO: zero the parameter gradients + forward pass + loss computation + backward pass + weight update
            data = np.array(data)
            tdata = torch.tensor(data)
            # optimizer.zero_grad()
            optimizer.optimizer.zero_grad()
            prediction = net(tdata.float())
            ttargets = torch.tensor(targets)

            with torch.no_grad():
                max_batch_preds = torch.argmax(prediction, dim=-1)
                preds_masked = max_batch_preds.cpu().numpy()
                gt_masked = ttargets.cpu().numpy()
                running_correct += np.sum(preds_masked == gt_masked)
                running_total += preds_masked.shape[0]

            loss = loss_fn(prediction, ttargets)
            loss.backward()
            optimizer.step()
            wandb.log({"loss": loss.item(), "lr": optimizer._rate})
            # optimizer.step()
            running_loss += loss.item()

        if lr_scheduler is not None:
            lr_scheduler.step()
        print(running_loss)

        ckpt_save_dir = f'/home/arthur/AMRL/Research/Driver-Intent-Prediction/outputs/default/ckpts_{split}_{trial}'
        if not os.path.exists(ckpt_save_dir):
            os.makedirs(ckpt_save_dir)
        torch.save(net1.state_dict(), os.path.join(ckpt_save_dir, f'epoch_{epoch_count}.pth'))

        if lr_scheduler is not None:
            wandb.log({"loss": running_loss / (len(trainloader) * batch), "trainacc": running_correct / running_total, "epoch": epoch, "lr": lr_scheduler.get_last_lr()[0]}) 
        else:
            wandb.log({"loss": running_loss / (len(trainloader) * batch), "trainacc": running_correct / running_total, "epoch": epoch}) 


Evaluate

In [83]:
import sklearn
from sklearn import metrics

In [84]:
def eval(net, loader):
    net.eval()
    se = 0 # sum of correct pred
    total = 0
    f1input = np.zeros(100)
    f1target = np.zeros(100)
    i1 = 0
    for batch_idx, (data, targets) in enumerate(loader):
        data = np.array(data)
        tdata = torch.tensor(data)
        y = net(tdata.float())

        # Do softmax afterwards since no softmax in model
        y = torch.nn.functional.softmax(y, dim=-1)
        f1input[i1] = int(torch.argmax(y))
        f1target[i1] = int(targets[0])
        i1 += 1
        if int(torch.argmax(y)) == targets[0]:
            se += 1
        total += 1
        
    f1score = sklearn.metrics.f1_score(f1target[0:i1], f1input[0:i1], average = 'macro')
    return float(se/total), f1score

In [85]:
import wandb

wandb.init(
    project="Driver-Intent-Prediction-transpos-models",
    config={
        "optimizer": "SGD",
        "lr": lr,
        "dropout": dropout,
        "num_epochs": num_epochs,
        "batch_size": batch,
        "modeldim": modeldim,
        "numheads": numheads,
        "hiddendim": hiddendim,
    }
)

for i in range(num_epochs): # for train/val stationary time before manuever
    print(i)
    train(net1,trainloader,optimizer,loss_fn, lr_scheduler, 1, epoch_count=i)
    evalval, f1val = eval(net1,valloader)
    print('eval: '+str(evalval) + '         f1score: ' + str(f1val))
    wandb.log({"val": evalval, "f1val": f1val})
    if evalval > 0.92:
        break

wandb.finish()

0
15.88800048828125
eval: 0.4827586206896552         f1score: 0.334025974025974
1
11.470966219902039
eval: 0.4827586206896552         f1score: 0.35525641025641025
2
10.569494009017944
eval: 0.6206896551724138         f1score: 0.6046846846846847
3
10.723693549633026
eval: 0.6379310344827587         f1score: 0.41750528541226223
4
9.339870929718018
eval: 0.5862068965517241         f1score: 0.5479537205081669
5
8.893221974372864
eval: 0.6379310344827587         f1score: 0.5363916083916084
6
6.299803227186203
eval: 0.7241379310344828         f1score: 0.6925752296484003
7
6.590177625417709
eval: 0.7241379310344828         f1score: 0.7285547785547786
8
8.483438402414322
eval: 0.603448275862069         f1score: 0.5705401945724526
9
7.6425401866436005
eval: 0.6724137931034483         f1score: 0.6089509572515079
10
9.903942167758942
eval: 0.6551724137931034         f1score: 0.562023562023562
11
7.841390877962112
eval: 0.6896551724137931         f1score: 0.6
12
7.501221835613251
eval: 0.724137931

KeyboardInterrupt: 

In [None]:
# torch.save(net1.state_dict(), "C:/Users/ykung/Downloads/b4c/LSTMFull91")

net1.load_state_dict(torch.load("/home/arthur/AMRL/Research/Driver-Intent-Prediction/outputs/default/ckpts/epoch_20.pth"))

Cross Validation

In [None]:
import wandb
# Run 5 Fold split
splits = list(range(0,10))
print(splits)
max_evalval_list = []
max_f1val_list = []
for split in splits:
    trainlefti = indices[0:split*seg]
    trainrighti = indices[min(split*seg + seg,totalsize):totalsize]
    traini = trainlefti + trainrighti
    vali = indices[split*seg:min(totalsize,seg*split+seg)]

    train_set = torch.utils.data.dataset.Subset(dataset,traini)
    val_set = torch.utils.data.dataset.Subset(dataset,vali)

    trainloader = DataLoader(train_set, collate_fn=dataset.collate_fn, batch_size=batch, shuffle=True)
    valloader = DataLoader(val_set, collate_fn=dataset.collate_fn, batch_size=1, shuffle=True)


    net1 = EncoderLayer(modeldim,numheads,hiddendim,dropout,fchiddendim)
    loss_fn = nn.CrossEntropyLoss()
    baseoptimizer = torch.optim.SGD(net1.parameters(), lr=lr, momentum=0.9, weight_decay=0.0001)
    optimizer = NoamOpt(modeldim, 1, 100, torch.optim.Adam(net1.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
    lr_scheduler=None
    
    maxevalval, maxf1val = 0, 0
    for trial in range(5): # Repeat each split 5 times and take the best 
        wandb.init(
        project="Driver-Intent-Prediction-Models-Cross-Validation",
            name=f'cross_validation_split{split}_{trial}',
            config={
                "optimizer": "SGD",
                "lr": lr,
                "dropout": dropout,
                "num_epochs": num_epochs,
                "batch_size": batch,
                "modeldim": modeldim,
                "numheads": numheads,
                "hiddendim": hiddendim,
                "validation_split": split
            }
        )

        for i in range(80): # for train/val stationary time before manuever
            print(i)
            train(net1,trainloader,optimizer,loss_fn, lr_scheduler, 1, epoch_count=i, split=split, trial=trial)
            evalval, f1val = eval(net1,valloader)
            print('eval: '+str(evalval) + '         f1score: ' + str(f1val))
            wandb.log({"val": evalval, "f1val": f1val})      
            maxevalval = max(maxevalval, evalval)
            maxf1val = max(maxf1val, f1val)

    max_evalval_list.append(maxevalval)
    max_f1val_list.append(maxf1val)

    wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


[34m[1mwandb[0m: Currently logged in as: [33martzha[0m. Use [1m`wandb login --relogin`[0m to force relogin


0
17.54335355758667
eval: 0.39655172413793105         f1score: 0.21339729921819478
1
13.501889824867249
eval: 0.5344827586206896         f1score: 0.45403050108932463
2
13.662025511264801
eval: 0.5172413793103449         f1score: 0.33367289190718735
3
12.86321645975113
eval: 0.603448275862069         f1score: 0.5492919254658386
4
11.158831298351288
eval: 0.603448275862069         f1score: 0.47913043478260875
5
8.794578284025192
eval: 0.6551724137931034         f1score: 0.6272289698605488
6
7.7462573647499084
eval: 0.7413793103448276         f1score: 0.7431503267973858
7
6.845777601003647
eval: 0.7068965517241379         f1score: 0.69
8
7.447414010763168
eval: 0.7413793103448276         f1score: 0.7594545454545456
9
9.342264115810394
eval: 0.5689655172413793         f1score: 0.5290293040293041
10
11.448580503463745
eval: 0.6379310344827587         f1score: 0.5818346732632447
11
12.613381385803223
eval: 0.603448275862069         f1score: 0.4053140096618358
12
10.520887315273285
eval: 0.72



0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▄▁▆██▇▇▇▇▆█▆▇▆▇▇▇█▇▇▇▆█▆▇█▆▆▇▇▇▇▆▇▇▇▆▇██
loss,█▇▅▁▂▂▁▂▂▁▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▆█▆▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainacc,▂▁▅▆▇▇▇▇▇▇██████████████████████████████
val,▁▁▄▇▇▆▆▆▆▃▇▅▆▅▇▇▅▇▇▅▆▆█▅▆▇▄▅▆▇▇▇▅▅▇▇▄▇██

0,1
epoch,0.0
f1val,0.73993
loss,0.00018
lr,0.00188
trainacc,0.9962
val,0.7931


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670187233345738, max=1.0…

0
17.841603755950928
eval: 0.5344827586206896         f1score: 0.2717863673087554
1
12.830268025398254
eval: 0.6551724137931034         f1score: 0.4063340168603326
2
11.262140393257141
eval: 0.6896551724137931         f1score: 0.559511666454163
3
8.72877722978592
eval: 0.7586206896551724         f1score: 0.5659003831417625
4
7.860845863819122
eval: 0.6379310344827587         f1score: 0.4384895104895105
5
9.329293608665466
eval: 0.7241379310344828         f1score: 0.5837162837162837
6
10.874084711074829
eval: 0.6551724137931034         f1score: 0.544862155388471
7
7.593812733888626
eval: 0.5689655172413793         f1score: 0.5234277491446779
8
8.312459737062454
eval: 0.7758620689655172         f1score: 0.7107855107855109
9
10.37014353275299
eval: 0.7241379310344828         f1score: 0.6789457303389193
10
7.668558120727539
eval: 0.6896551724137931         f1score: 0.5216862233811386
11
8.676576107740402
eval: 0.5862068965517241         f1score: 0.43801242236024845
12
13.540565609931946
ev



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▃▁▃▆▃█▅▅▆█▆▇▇█▅█▇▆▇▆▇▅▇▆▅▆▇█▆▇▇▇█▆▇▆▇██▆
loss,█▆▇▁▂▂▁▁▁▁▂▁▁▁▁▃▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▆█▆▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainacc,▁▂▁▄▅▆▇▇▇█▆████▇▇█▇█████████████████████
val,▆▁▄▇▄█▅▆▅█▆▇█▇▅██▇█▆▆▆▇▆▆▆▇█▇█▇▆█▆▇▆▇██▆

0,1
epoch,0.0
f1val,0.75344
loss,3e-05
lr,0.00188
trainacc,1.0
val,0.77586


0
18.239070534706116
eval: 0.41379310344827586         f1score: 0.11851851851851851
1
13.884239435195923
eval: 0.5517241379310345         f1score: 0.3795774647887324
2
12.983364999294281
eval: 0.6724137931034483         f1score: 0.5035151515151515
3
11.81034642457962
eval: 0.6379310344827587         f1score: 0.6038961038961038
4
9.577074408531189
eval: 0.7241379310344828         f1score: 0.5996090390179061
5
7.5956122279167175
eval: 0.8103448275862069         f1score: 0.7649696969696969
6
10.353747516870499
eval: 0.6379310344827587         f1score: 0.6222222222222222
7
7.345434099435806
eval: 0.7068965517241379         f1score: 0.6628638028638029
8
7.58065727353096
eval: 0.6896551724137931         f1score: 0.6404761904761904
9
8.978977501392365
eval: 0.603448275862069         f1score: 0.6157679479774172
10
10.390652596950531
eval: 0.5862068965517241         f1score: 0.44969954047366556
11
8.623873502016068
eval: 0.6896551724137931         f1score: 0.6401309145079029
12
6.41179379820823



0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▁▂▄▄▆▇▅▄▅▅▅▅▇▆▇▇▆▇▇▄▃▇▇▆▆▅▅▅██▇▇▇▇▇▇▄▅▆▄
loss,▇█▂▁▃▂▁▂▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁
lr,▆█▆▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainacc,▁▃▆▇▇▇▇█████████████████████████████████
val,▁▃▄▄▄▇▅▄▇▅▄▃█▇▆▇▇▇▇▅▄▇▇▇▆▆▅▅█▇▇▆▇▆▇▆▅▄▇▃

0,1
epoch,0.0
f1val,0.70642
loss,0.00144
lr,0.00188
trainacc,0.99241
val,0.68966


0
16.062702894210815
eval: 0.5344827586206896         f1score: 0.2649275362318841
1
12.34622836112976
eval: 0.5344827586206896         f1score: 0.458557801498978
2
8.793315529823303
eval: 0.603448275862069         f1score: 0.47846153846153844
3
8.026492685079575
eval: 0.6724137931034483         f1score: 0.5165837884840087
4
7.740876764059067
eval: 0.7413793103448276         f1score: 0.5493333333333332
5
9.843568503856659
eval: 0.603448275862069         f1score: 0.40905713481564876
6
11.19086492061615
eval: 0.7586206896551724         f1score: 0.6852018157281315
7
11.067300200462341
eval: 0.6896551724137931         f1score: 0.6404545454545454
8
10.478792548179626
eval: 0.6206896551724138         f1score: 0.563909423909424
9
10.57986468076706
eval: 0.8793103448275862         f1score: 0.8266576560694208
10
7.533821940422058
eval: 0.8275862068965517         f1score: 0.774129493694711
11
6.084637105464935
eval: 0.8448275862068966         f1score: 0.7894319131161237
12
6.443358272314072
eval:



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▁▇▅▅▆▃▆▆▆▆▆▆▇█▆▇▅▇▇▆▇▆████▇▇▇▇▇█▇▇█▇▇▇▇▅
loss,▅▇█▁▄▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▆█▆▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainacc,▁▃▄▄▆▇▇▇█▇████▇█████████████████████████
val,▁▆▄▄▅▄▆▅▆▆▆▆▆▇▅▇▅▇▇▆▇▅▇▇▇█▆▇▇▇▇█▇▇▇▇▇▇▇▅

0,1
epoch,0.0
f1val,0.72175
loss,7e-05
lr,0.00188
trainacc,0.9981
val,0.81034


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668503599794347, max=1.0…

0
18.750595927238464
eval: 0.5         f1score: 0.22999999999999998
1
13.033669471740723
eval: 0.5344827586206896         f1score: 0.5201909201909203
2
11.042026042938232
eval: 0.6551724137931034         f1score: 0.49018945760122234
3
9.066430509090424
eval: 0.7586206896551724         f1score: 0.6860846560846561
4
6.549089342355728
eval: 0.6379310344827587         f1score: 0.6497729618163054
5
8.39994215965271
eval: 0.6379310344827587         f1score: 0.48457663767261294
6
8.940994501113892
eval: 0.5344827586206896         f1score: 0.578443256090315
7
8.374999672174454
eval: 0.6379310344827587         f1score: 0.43452515226708777
8
10.396580696105957
eval: 0.7586206896551724         f1score: 0.7272723039672193
9
9.546269953250885
eval: 0.7758620689655172         f1score: 0.7662643239113827
10
8.630525052547455
eval: 0.7413793103448276         f1score: 0.6923076923076923
11
9.682791769504547
eval: 0.6206896551724138         f1score: 0.44572263993316624
12
8.829031825065613
eval: 0.68965



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▅▁▅▆▇▇▇▇▇▆▇▇▆▇▇▆▆▆▇▇▆▇▇▇▇▇▇▇▇▆█▇▇▇▇▇▆▇▇▆
loss,█▅▄▁▃▁▁▁▁▁▁▁▁▂▁▃▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▆█▆▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainacc,▁▁▅▆▆▇▇▇████████████████████████████████
val,▄▁▅▆▇▇▆▆▆▆▆▆▅▆▇▆▆▆▇▆▅▇▆▆▆▇▆▆▆▆█▇▇▇▇▇▅▇▆▆

0,1
epoch,0.0
f1val,0.78808
loss,0.00015
lr,0.00188
trainacc,0.9962
val,0.82759


0
19.771220564842224
eval: 0.4482758620689655         f1score: 0.2514007502593982
1
13.5880486369133
eval: 0.5344827586206896         f1score: 0.3186480186480186
2
11.112175107002258
eval: 0.5862068965517241         f1score: 0.4095873015873016
3
9.23244959115982
eval: 0.5517241379310345         f1score: 0.4975056689342405
4
11.953075468540192
eval: 0.5172413793103449         f1score: 0.4015604395604395
5
12.420594334602356
eval: 0.6379310344827587         f1score: 0.48128012338538656
6
9.403253257274628
eval: 0.603448275862069         f1score: 0.5681481481481481
7
7.858071029186249
eval: 0.5172413793103449         f1score: 0.47857142857142854
8
10.327807426452637
eval: 0.43103448275862066         f1score: 0.2949832775919733
9
10.241826236248016
eval: 0.6551724137931034         f1score: 0.4209800362976407
10
9.056163728237152
eval: 0.603448275862069         f1score: 0.4361009475560559
11
10.043043196201324
eval: 0.5         f1score: 0.37551956815114707
12
9.741655707359314
eval: 0.67241



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▃▁▆▆▇▅█▆▇▆▇▇▆▇▆▇▅▆▆▇▆▇▇▆▇▆▆▆█▆▅▆▇▇▇▇▇▆▇▆
loss,▆█▅▁▂▃▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▆█▆▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainacc,▂▁▄▅▇▇▇▇█████████████▇██████████████████
val,▂▁▆▆▇▅█▆▇▅▆▆▇▇▆▇▅▇▅▆▆▆▆▆▇▇▆▆█▇▆▆█▇▇▆▇▆▆▇

0,1
epoch,0.0
f1val,0.73243
loss,0.0
lr,0.00188
trainacc,1.0
val,0.7931


0
17.106112003326416
eval: 0.6379310344827587         f1score: 0.40166969147005444
1
13.183642208576202
eval: 0.5689655172413793         f1score: 0.4185338345864661
2
13.314109444618225
eval: 0.5862068965517241         f1score: 0.3411764705882353
3
11.29901260137558
eval: 0.7758620689655172         f1score: 0.49238189814158034
4
9.455598413944244
eval: 0.5517241379310345         f1score: 0.4007534966448379
5
11.539756178855896
eval: 0.7931034482758621         f1score: 0.6093656343656344
6
10.462510704994202
eval: 0.6896551724137931         f1score: 0.426994301994302
7
11.21316283941269
eval: 0.6551724137931034         f1score: 0.5688963585434174
8
9.590573847293854
eval: 0.7241379310344828         f1score: 0.5212947339398952
9
8.378383934497833
eval: 0.7758620689655172         f1score: 0.48358974358974366
10
9.526303440332413
eval: 0.7586206896551724         f1score: 0.5645421245421245
11
8.534927546977997
eval: 0.7068965517241379         f1score: 0.5074208754208754
12
8.90245074033737



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▁▁▄▆▅█▅▅▆▅▇▇▆▇█▆██▆▇▇▇▇▆▇▇▇▇▇▇▇█▇▇▇█▇▆▇▇
loss,██▄▁▃▂▁▂▂▁▂▁▂▂▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁
lr,▆█▆▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainacc,▁▃▆▆▆▇█▇█▇██████████████████████████████
val,▃▁▃▅▄█▄▄▅▃▆▆▄▆▇▅▇█▅▆▇▇▇▅▆▆▇▇▆▇▆▇▆▇▆▇▆▅▇▆

0,1
epoch,0.0
f1val,0.86087
loss,0.00017
lr,0.00188
trainacc,0.9962
val,0.87931


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668366533364558, max=1.0…

0
19.181662678718567
eval: 0.3448275862068966         f1score: 0.1262321144674086
1
12.66187447309494
eval: 0.5517241379310345         f1score: 0.41487719298245607
2
11.362105786800385
eval: 0.5862068965517241         f1score: 0.4057786357786358
3
11.20077532529831
eval: 0.5172413793103449         f1score: 0.41018858379472195
4
9.167570352554321
eval: 0.6551724137931034         f1score: 0.44226190476190474
5
8.66179370880127
eval: 0.6551724137931034         f1score: 0.5469730269730271
6
11.045786798000336
eval: 0.5862068965517241         f1score: 0.5033929162528964
7
11.329230546951294
eval: 0.603448275862069         f1score: 0.575993046501521
8
11.032628834247589
eval: 0.5862068965517241         f1score: 0.3897435897435898
9
12.523427844047546
eval: 0.7068965517241379         f1score: 0.5676600985221675
10
10.241974234580994
eval: 0.6724137931034483         f1score: 0.44523809523809527
11
9.361264765262604
eval: 0.7068965517241379         f1score: 0.5200840336134454
12
7.3726410865783



0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▁▃▆▇██▇▇▇▆▆█▆▆▇▆▆▇█▇▇▇▇▆▇▆▆▇▇▇▇▇█▇█▇▇▇▇▇
loss,▇█▆▁▄▁▁▄▂▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▆█▆▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainacc,▁▂▅▆▇▇▇█████████████████████████████████
val,▁▅▆▇██▇▇▇▇▆█▆▇▇▆▆▇█▇█▇▇▆▇▇▆▇█▇▇████▇▇▇▇▇

0,1
epoch,0.0
f1val,0.81238
loss,0.00012
lr,0.00188
trainacc,0.9981
val,0.84483


0
18.149694561958313
eval: 0.4482758620689655         f1score: 0.3048850574712644
1
11.998329043388367
eval: 0.5517241379310345         f1score: 0.37598978288633467
2
10.68610167503357
eval: 0.7413793103448276         f1score: 0.7156267806267806
3
8.243279695510864
eval: 0.7241379310344828         f1score: 0.6552855924978687
4
6.5665357410907745
eval: 0.5689655172413793         f1score: 0.40921166120407104
5
8.4600670337677
eval: 0.7413793103448276         f1score: 0.6823790945896209
6
7.27653169631958
eval: 0.7931034482758621         f1score: 0.8039026915113873
7
6.517304241657257
eval: 0.6379310344827587         f1score: 0.4747707131386427
8
18.298232555389404
eval: 0.3275862068965517         f1score: 0.11515151515151514
9
16.653787970542908
eval: 0.1724137931034483         f1score: 0.058823529411764705
10
16.960249304771423
eval: 0.39655172413793105         f1score: 0.2
11
16.872535824775696
eval: 0.3275862068965517         f1score: 0.0987012987012987
12
16.053465723991394
eval: 0.3



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,▄█▇▁▇▇▁▇▇▁▇▁▇▇▁█▇▁▇▇▇█▁▇▇▁▇▇▁▇▁▇█▁▇▇▁▇▇▁
lr,▆█▆▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainacc,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,0.0
f1val,0.0987
loss,0.02909
lr,0.00188
trainacc,0.40607
val,0.32759


0
18.98145830631256
eval: 0.3793103448275862         f1score: 0.2352520597681888
1
12.744089841842651
eval: 0.5862068965517241         f1score: 0.3614118677948466
2
8.719172894954681
eval: 0.7068965517241379         f1score: 0.5393056130596794
3
7.880806893110275
eval: 0.5862068965517241         f1score: 0.42908620362815153
4
7.388656675815582
eval: 0.5862068965517241         f1score: 0.39842093390480493
5
8.802099287509918
eval: 0.6896551724137931         f1score: 0.4445565862708721
6
8.353820770978928
eval: 0.7413793103448276         f1score: 0.5749235741537265
7
9.849240928888321
eval: 0.6896551724137931         f1score: 0.5959108818683287
8
13.575280219316483
eval: 0.4827586206896552         f1score: 0.23927272727272725
9
13.65108835697174
eval: 0.5         f1score: 0.32115550942743465
10
14.337719798088074
eval: 0.6379310344827587         f1score: 0.406221198156682
11
9.679144740104675
eval: 0.5         f1score: 0.4134025974025974
12
9.396648049354553
eval: 0.603448275862069      



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1val,▁▁▁▆▇▇█▆▅▆█▅▇▇▇▇▆█▇▇▆▆▆█▆█▆▆▅▆▅▇▆▆▆▆▇▇▆▅
loss,▅██▁▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁
lr,▆█▆▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainacc,▂▁▃▆▇▇▇▇▇█████████▇█████████████████████
val,▃▁▃▆▇▇█▇▇▇█▇█▇██▇█▇▇▆▆▆█▇█▆▆▆▆▅▇▆▇▆▆█▇▇▆

0,1
epoch,0.0
f1val,0.60021
loss,0.00014
lr,0.00188
trainacc,0.9981
val,0.68966


In [None]:
print(np.average(max_evalval_list))
print(np.std(max_evalval_list))
print(np.average(max_f1val_list))
print(np.std(max_f1val_list))

0.8758620689655172
0.04925123054167482
0.8623788128662836
0.045269280547075366


In [None]:
# the values recorded from the best models from each split
recacc = [0.89655,0.89655,0.8793,0.862,0.8448,0.8793,0.8448,0.8621,0.9138,0.8448]
recf1 = [0.9106,0.84262,0.870255,0.8374,0.8449,0.88,0.874,0.8202,0.8905,0.79139]

In [None]:
print(np.average(recacc))
print(np.std(recacc))
print(np.average(recf1))
print(np.std(recf1))

#### Varying time to manuever

In [None]:
def train2(net, trainloader, optimizer, loss_fn, lr_scheduler, num_epochs, curr_epoch=0, split=0, trial=0):
    # Put the network in training mode
    net.train()

    # Training loop
    for epoch in range(num_epochs):
        running_loss = 0
        for batch_idx, (data, targets) in enumerate(trainloader):
            # TODO: zero the parameter gradients + forward pass + loss computation + backward pass + weight update
            data = np.array(data)
            tdata = torch.tensor(data, dtype=torch.float32)
            ttargets = torch.tensor(targets)
            zten = torch.zeros((len(data),30,32))
            
            optimizer.optimizer.zero_grad()
            prediction = net(tdata)
            loss = loss_fn(prediction, ttargets)           
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
            optimizer.optimizer.zero_grad()
            prediction = net(torch.cat((zten,tdata[:,0:120,:]), 1))
            loss = loss_fn(prediction, ttargets)           
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
                             
            optimizer.optimizer.zero_grad()
            prediction = net(torch.cat((zten,zten,tdata[:,0:90,:]), 1))
            loss = loss_fn(prediction, ttargets)           
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
            optimizer.optimizer.zero_grad()
            prediction = net(torch.cat((zten,zten,zten,tdata[:,0:60,:]), 1))
            loss = loss_fn(prediction, ttargets)           
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
                             
            optimizer.optimizer.zero_grad()
            prediction = net(torch.cat((zten,zten,zten,zten,tdata[:,0:30,:]), 1))
            loss = loss_fn(prediction, ttargets)           
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        if lr_scheduler is not None:
            lr_scheduler.step()
                             
        print(running_loss)
        wandb.log({"loss": running_loss, "lr": optimizer._rate})

        ckpt_save_dir = f'/home/arthur/AMRL/Research/Driver-Intent-Prediction/outputs/default/ts2m_ckpts/split{split}_trial{trial}'
        if not os.path.exists(ckpt_save_dir):
            os.makedirs(ckpt_save_dir)
        torch.save(net1.state_dict(), os.path.join(ckpt_save_dir, f'epoch_{curr_epoch}.pth'))

In [None]:
import wandb

max_evalval_list = []
max_f1val_list = []
best_epoch_list = []
net_list = []
for split in range(10):
    maxevalval = 0
    maxf1val = 0
    bestepoch = 0
    for trial in range(3):

        trainlefti = indices[0:split*seg]
        trainrighti = indices[min(split*seg + seg,totalsize):totalsize]
        traini = trainlefti + trainrighti
        vali = indices[split*seg:min(totalsize,seg*split+seg)]

        train_set = torch.utils.data.dataset.Subset(dataset,traini)
        val_set = torch.utils.data.dataset.Subset(dataset,vali)

        trainloader = DataLoader(train_set, collate_fn=dataset.collate_fn, batch_size=batch, shuffle=True)
        valloader = DataLoader(val_set, collate_fn=dataset.collate_fn, batch_size=1, shuffle=True)
        
        net1 = EncoderLayer(modeldim,numheads,hiddendim,dropout,fchiddendim)
        loss_fn = nn.CrossEntropyLoss()
        baseoptimizer = torch.optim.SGD(net1.parameters(), lr=lr, momentum=0.9, weight_decay=0.0001)
        optimizer = NoamOpt(modeldim, 1, 500, torch.optim.Adam(net1.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
        lr_scheduler=None

        wandb.init(
            project="Driver-Intent-Prediction-Models-T2M",
            name=f'cross_validation_split{split}_{trial}',
            config={
                "optimizer": "SGD",
                "lr": lr,
                "dropout": dropout,
                "num_epochs": num_epochs,
                "batch_size": batch,
                "modeldim": modeldim,
                "numheads": numheads,
                "hiddendim": hiddendim,
                "validation_split": split,
            }
        )

        for epoch in range(70): # train/val for varying time to manuever
            print(epoch)
            train2(net1,trainloader,optimizer,loss_fn, lr_scheduler, 1, epoch, split, trial)
            evalval, f1val = eval(net1,valloader)
            print('eval: '+str(evalval) + '         f1score: ' + str(f1val))
            if evalval > 0.80:
                break
            
            wandb.log({"val": evalval, "f1val": f1val})  
            if maxevalval < evalval:
                bestepoch = epoch    
            maxevalval = max(maxevalval, evalval)
            maxf1val = max(maxf1val, f1val)

        wandb.finish()

    max_evalval_list.append(maxevalval)
    max_f1val_list.append(maxf1val)
    best_epoch_list.append(bestepoch)

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668438850077412, max=1.0…

0
79.31614553928375
eval: 0.43103448275862066         f1score: 0.26829346092503986
1
70.74294972419739
eval: 0.5172413793103449         f1score: 0.3348148148148148
2
60.83284765481949
eval: 0.5517241379310345         f1score: 0.3591836734693878
3
58.624989211559296
eval: 0.5689655172413793         f1score: 0.4674822695035461
4
51.41205805540085
eval: 0.5862068965517241         f1score: 0.4257241379310345
5
51.63296288251877
eval: 0.603448275862069         f1score: 0.617261335156072
6
45.14423358440399
eval: 0.6379310344827587         f1score: 0.5791617207269382
7
43.767731457948685
eval: 0.46551724137931033         f1score: 0.3738085255066387
8
53.48502451181412
eval: 0.7413793103448276         f1score: 0.7373626373626372
9
43.07899710536003
eval: 0.6896551724137931         f1score: 0.5396655518394649
10
44.14044263958931
eval: 0.7586206896551724         f1score: 0.7147262647262648
11
38.63112282752991
eval: 0.7586206896551724         f1score: 0.7335864135864136
12
40.32494083046913
ev



0,1
f1val,▁▂▂▄▃▆▅▂▇▅▇▇▅▅█▇▇▄▅▇█▇▇▇▆█▆▅▇▇█
loss,█▇▆▅▄▅▄▄▅▃▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁
lr,▁▂▃▄▄▅▆▇██▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▅▅▅▅▅▄▄
val,▁▃▃▄▄▄▅▂▇▆▇▇▆▅██▇▄▅█▇▇▇▇▆▇▇▆▇▇█

0,1
f1val,0.79555
loss,23.78324
lr,0.00298
val,0.7931


0
80.65652370452881
eval: 0.43103448275862066         f1score: 0.25518995929443694
1
69.38551652431488
eval: 0.5172413793103449         f1score: 0.41974691974691974
2
64.2894337773323
eval: 0.5517241379310345         f1score: 0.42439393939393943
3
60.36004877090454
eval: 0.603448275862069         f1score: 0.5790476190476189
4
52.59756636619568
eval: 0.6724137931034483         f1score: 0.592745185848634
5
46.084359139204025
eval: 0.5862068965517241         f1score: 0.5687753899518606
6
46.976432770490646
eval: 0.6379310344827587         f1score: 0.5987450980392157
7
50.715818643569946
eval: 0.4482758620689655         f1score: 0.390391363022942
8
52.33826997876167
eval: 0.603448275862069         f1score: 0.45474747474747473
9
49.76789799332619
eval: 0.5862068965517241         f1score: 0.6145977011494252
10
49.46023780107498
eval: 0.6379310344827587         f1score: 0.5182805429864252
11
44.47809165716171
eval: 0.7586206896551724         f1score: 0.7337474747474748
12
38.76064220070839
ev



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
f1val,▁▃▃▅▆▅▆▃▄▆▅▇▇██▇▃▇█▆▆█▆█▇█▇▇█████▆▇▇▇▇▇▇
loss,█▇▆▆▅▄▄▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
lr,▁▂▃▄▄▅▆▇██▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▅▅▅▅▅▄▄▄▄▄▄▄▄▄▄
val,▁▃▃▄▆▄▅▁▄▄▅▇▇██▆▄▇▇▆▆█▅▇▇█▆▇▇██▇█▅▇▇▇▇█▇

0,1
f1val,0.71187
loss,20.22185
lr,0.00263
val,0.75862


0
76.97384697198868
eval: 0.43103448275862066         f1score: 0.2360341151385928
1
64.50747120380402
eval: 0.5862068965517241         f1score: 0.457728654324399
2
55.41509699821472
eval: 0.6206896551724138         f1score: 0.5348546315077755
3
47.40600526332855
eval: 0.6896551724137931         f1score: 0.6348262548262549
4
44.48768177628517
eval: 0.7413793103448276         f1score: 0.6889786683904331
5
44.13673770427704
eval: 0.6206896551724138         f1score: 0.5891977807767281
6
47.149319499731064
eval: 0.6724137931034483         f1score: 0.6620197837589142
7
48.733472883701324
eval: 0.6551724137931034         f1score: 0.5335531135531135
8
40.412008225917816
eval: 0.7413793103448276         f1score: 0.6774776334776336
9
43.29548963904381
eval: 0.5689655172413793         f1score: 0.46452991452991454
10
46.738986641168594
eval: 0.4827586206896552         f1score: 0.28105351170568565
11
39.600195467472076
eval: 0.5172413793103449         f1score: 0.48730158730158724
12
41.910242915153



0,1
f1val,▁▄▅▆▇▅▆▅▆▄▂▄▇▇▆▆▇▇█▇▇
loss,█▆▅▄▃▃▄▄▃▃▄▃▃▂▂▂▂▂▁▁▁▁
lr,▁▂▃▄▄▅▆▇██▇▇▇▇▆▆▆▆▆▅▅▅
val,▁▄▅▆▇▅▆▅▇▄▂▃▇▆▆▆▇▇█▇▇

0,1
f1val,0.73252
loss,28.03641
lr,0.00359
val,0.75862


0
77.59100806713104
eval: 0.3448275862068966         f1score: 0.1935294117647059
1
68.39963799715042
eval: 0.5862068965517241         f1score: 0.38928571428571435
2
59.56242907047272
eval: 0.3448275862068966         f1score: 0.37940238623165456
3
52.86066049337387
eval: 0.6206896551724138         f1score: 0.5685795685795686
4
50.01531320810318
eval: 0.5517241379310345         f1score: 0.5064029304029305
5
52.76946473121643
eval: 0.6896551724137931         f1score: 0.674349698535745
6
53.536397099494934
eval: 0.5517241379310345         f1score: 0.45931313131313134
7
47.579118728637695
eval: 0.6206896551724138         f1score: 0.5133126133126134
8
49.1009963452816
eval: 0.6724137931034483         f1score: 0.5593846153846154
9
50.346918016672134
eval: 0.603448275862069         f1score: 0.47653111449666
10
44.94558838009834
eval: 0.6551724137931034         f1score: 0.5551282051282052
11
43.006695330142975
eval: 0.6379310344827587         f1score: 0.6019047619047619
12
39.46895608305931
eva



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
f1val,▁▃▃▆▅▇▄▅▆▅▅▆██▆▆▅▇
loss,█▇▅▄▄▄▄▃▃▄▃▂▂▁▂▁▁▁▁
lr,▁▂▃▄▄▅▆▇██▇▇▇▇▆▆▆▆▆
val,▁▅▁▆▅▇▅▆▇▆▆▆██▆▆▆█

0,1
f1val,0.68595
loss,34.09327
lr,0.00387
val,0.72414


0
73.86231237649918
eval: 0.5172413793103449         f1score: 0.34781838316722036
1
67.92077612876892
eval: 0.4827586206896552         f1score: 0.3355750487329434
2
61.762631833553314
eval: 0.603448275862069         f1score: 0.40101680266402295
3
64.5381350517273
eval: 0.5344827586206896         f1score: 0.503572701807996
4
51.008303463459015
eval: 0.6206896551724138         f1score: 0.5908836261777438
5
46.44215381145477
eval: 0.6551724137931034         f1score: 0.5957149758454107
6
48.208981573581696
eval: 0.603448275862069         f1score: 0.589017649017649
7
50.86739635467529
eval: 0.7068965517241379         f1score: 0.6232142857142857
8
42.951719015836716
eval: 0.6551724137931034         f1score: 0.5752307692307692
9
42.08325186371803
eval: 0.6724137931034483         f1score: 0.6577281414237935
10
39.36307176947594
eval: 0.7241379310344828         f1score: 0.6773161831782522
11
34.843561828136444
eval: 0.6896551724137931         f1score: 0.5763157894736842
12
33.393526554107666
ev



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
f1val,▁▁▂▄▅▅▅▆▅▆▆▅▇▆▅▇▇▆▆▅█▇███
loss,█▇▆▇▅▄▄▅▄▄▃▃▂▂▂▂▂▂▂▃▂▁▂▁▁▁
lr,▁▂▃▄▄▅▆▇██▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▅
val,▂▁▄▂▄▅▄▆▅▅▆▆▆▆▆▇▇▆▆▅▇████

0,1
f1val,0.76139
loss,23.93361
lr,0.00331
val,0.7931


In [None]:
print(max_evalval_list)
print(max_f1val_list)
print(best_epoch_list)

[0.7931034482758621]
[0.8016239316239316]
[14]


In [None]:
def eval2(net, loader):
    net.eval()
    t2m = np.zeros((100))
    totalcorrect = np.zeros((5,100))
    i1 = 0
    for batch_idx, (data, targets) in enumerate(loader):
        data = np.array(data)
        tdata = torch.tensor(data, dtype=torch.float32)
        zten = torch.zeros((len(data),30,32))
        
        y = net(torch.cat((zten,zten,zten,zten,tdata[:,0:30,:]), 1))
        if int(torch.argmax(y)) == targets[0]:
            t2m[i1] = 5
            totalcorrect[0,i1] = 1
            
        y = net(torch.cat((zten,zten,zten,tdata[:,0:60,:]), 1))
        if int(torch.argmax(y)) == targets[0]:
            if t2m[i1] == 0:
                t2m[i1] = 4
            totalcorrect[1,i1] = 1
            
        y = net(torch.cat((zten,zten,tdata[:,0:90,:]), 1))
        if int(torch.argmax(y)) == targets[0]:
            if t2m[i1] == 0:
                t2m[i1] = 3
            totalcorrect[2,i1] = 1
            
        y = net(torch.cat((zten,tdata[:,0:120,:]), 1))
        if int(torch.argmax(y)) == targets[0]:
            if t2m[i1] == 0:
                t2m[i1] = 2
            totalcorrect[3,i1] = 1
            
        y = net(tdata)
        if int(torch.argmax(y)) == targets[0]:
            if t2m[i1] == 0:
                t2m[i1] = 1
            totalcorrect[4,i1] = 1
            
        i1+=1
        
    return t2m[0:i1], totalcorrect[:,0:i1]

In [None]:
time_to_maneuver, correct_expansion = eval2(net1, valloader)
time_to_maneuver = time_to_maneuver.reshape(1, 58)
correct_expansion_changed = correct_expansion.reshape(1, 5, 58)

In [None]:
print(time_to_maneuver.shape)
print(correct_expansion.shape)
print(np.linalg.norm(correct_expansion_changed - correct_expansion))
correct_expansion=correct_expansion_changed
print(correct_expansion.shape)

(1, 58)
(5, 58)
0.0
(1, 5, 58)


In [None]:
# Model Sanity Check
ckpt_save_dir = "/home/arthur/AMRL/Research/Driver-Intent-Prediction/outputs/default/ts2m_ckpts_allsplits"

best_splits = {split: {"trial": 0, "epoch": 0, "evalval": 0} for split in range(10)}

for testsplit in range(10):
    for testtrial in range(3):
        split_save_path = os.path.join(ckpt_save_dir, f"split{testsplit}_trial{testtrial}")
                                       
        num_epochs = len(os.listdir(split_save_path))
        for testepoch in range(num_epochs):
            ckpt_save_path = os.path.join(split_save_path, f'epoch_{testepoch}.pth')

            test_net = EncoderLayer(modeldim,numheads,hiddendim,dropout,fchiddendim)
            test_net.load_state_dict(torch.load(ckpt_save_path))

            vali = indices[testsplit*seg:min(totalsize,seg*testsplit+seg)]
            val_set = torch.utils.data.dataset.Subset(dataset,vali)
            valloader = DataLoader(val_set, collate_fn=dataset.collate_fn, batch_size=1, shuffle=True)

            evalval, f1val = eval(test_net,valloader)
            print('eval: '+str(evalval) + '         f1score: ' + str(f1val))

            if evalval > best_splits[testsplit]["evalval"]:
                best_splits[testsplit]["evalval"] = evalval
                best_splits[testsplit]["trial"] = testtrial
                best_splits[testsplit]["epoch"] = testepoch

print(best_splits)

eval: 0.41379310344827586         f1score: 0.22325581395348837
eval: 0.3620689655172414         f1score: 0.21870401810979062
eval: 0.6379310344827587         f1score: 0.5712500749086116
eval: 0.6206896551724138         f1score: 0.6150628456510809
eval: 0.6206896551724138         f1score: 0.41726495726495727
eval: 0.6379310344827587         f1score: 0.6358418923124806
eval: 0.6379310344827587         f1score: 0.6045145794449714
eval: 0.5862068965517241         f1score: 0.407032967032967
eval: 0.6724137931034483         f1score: 0.5727195225916454
eval: 0.6896551724137931         f1score: 0.6774103944835652
eval: 0.6724137931034483         f1score: 0.5510559510559511
eval: 0.6724137931034483         f1score: 0.6148595848595848
eval: 0.7068965517241379         f1score: 0.6834439710079098
eval: 0.7758620689655172         f1score: 0.7668058968058968
eval: 0.6379310344827587         f1score: 0.6110284470532918
eval: 0.7068965517241379         f1score: 0.7058201058201058
eval: 0.7413793103448

In [None]:
# lt2m = len(time_to_maneuver)
# print('Total maneuvers: ' + str(lt2m))
# instancescorrect = 0
# for i in time_to_maneuver:
#     if i == 5:
#         instancescorrect += 1

# per5 = instancescorrect/lt2m
# avgtime = 5*instancescorrect
# print('Percentage correct 5 s in advance: ' + str(per5))

# numleft = lt2m - instancescorrect
# instancescorrect = 0
# for i in time_to_maneuver:
#     if i == 4:
#         instancescorrect += 1

# per4 = instancescorrect/numleft
# avgtime = avgtime + 4*instancescorrect
# print('Percentage correct 4 s in advance: ' + str(per4))

# numleft = numleft - instancescorrect
# instancescorrect = 0
# for i in time_to_maneuver:
#     if i == 3:
#         instancescorrect += 1

# per3 = instancescorrect/numleft
# avgtime = avgtime + 3*instancescorrect
# print('Percentage correct 3 s in advance: ' + str(per3))

# numleft = numleft - instancescorrect
# instancescorrect = 0
# for i in time_to_maneuver:
#     if i == 2:
#         instancescorrect += 1

# per2 = instancescorrect/numleft
# avgtime = avgtime + 2*instancescorrect
# print('Percentage correct 2 s in advance: ' + str(per2))

# numleft = numleft - instancescorrect
# instancescorrect = 0
# for i in time_to_maneuver:
#     if i == 1:
#         instancescorrect += 1

# per1 = instancescorrect/numleft
# avgtime = avgtime + 1*instancescorrect
# print('Percentage correct 1 s in advance: ' + str(per1))

# avgtime = avgtime/(lt2m-numleft)
# print('Avg time for correct prediction: ' + str(avgtime))

In [None]:
# time_to_maneuver.shape

In [None]:
# # 
# fillerarray = np.zeros((len(time_to_maneuver)))
# fillerarray[:] = time_to_maneuver
# tim2mset = np.stack((time_to_maneuver, fillerarray), axis=0) #np.concatenate((time_to_maneuver,fillerarray),axis=0)

# fillerarray = np.zeros((5,len(time_to_maneuver)))
# fillerarray[:,:] = correct_expansion
# expandedset = np.concatenate((tim2mset,fillerarray),axis=0)

In [None]:
# np.shape(tim2mset)

In [None]:
# np.shape(expandedset)
# np.shape(time_to_maneuver)

In [None]:
# np.save("C:/Users/ykung/Downloads/b4c/timevaryingtim2mset", tim2mset)

In [None]:
# np.save("C:/Users/ykung/Downloads/b4c/expandedset", expandedset)

In [None]:
# num5s = 0
# num4s = 0
# num3s = 0
# num2s = 0
# num1s = 0
# total = 580
# tolcorr = 0
# for i in range(2):
#     for j in range(58):
#         if tim2mset[i,j] == 5:
#             num5s += 1
#             tolcorr += 1
#         elif tim2mset[i,j] == 4:
#             num4s += 1
#             tolcorr += 1
#         elif tim2mset[i,j] == 3:
#             num3s += 1
#             tolcorr += 1
#         elif tim2mset[i,j] == 2:
#             num2s += 1
#             tolcorr += 1
#         elif tim2mset[i,j] == 1:
#             num1s += 1
#             tolcorr += 1
# print('5s pecentage: ' + str(num5s/total))
# print('4s pecentage: ' + str(num4s/(total-num5s)))
# print('3s pecentage: ' + str(num3s/(total-num5s-num4s)))
# print('2s pecentage: ' + str(num2s/(total-num5s-num4s-num3s)))
# print('1s pecentage: ' + str(num1s/(total-num5s-num4s-num3s-num2s)))
# print('total avg pred time: ' + str((num5s*5+num4s*4+num3s*3+num2s*2+num1s)/tolcorr))

In [None]:
# expandedset = np.load("C:/Users/ykung/Downloads/b4c/expandedset.npy")

In [None]:
# expandedset = correct_expansion

In [None]:
# np.shape(expandedset)

In [None]:
# expandedset = np.

In [98]:
num5s = 0
num4s = 0
num3s = 0
num2s = 0
num1s = 0
total = 580
for split, split_dict in best_splits.items():
    trial = split_dict["trial"]
    epoch = split_dict["epoch"]
    ckpt_save_path = os.path.join(ckpt_save_dir, f"split{split}_trial{trial}", f"epoch_{epoch}.pth")

    t2m_net = EncoderLayer(modeldim,numheads,hiddendim,dropout,fchiddendim)
    t2m_net.load_state_dict(torch.load(ckpt_save_path))

    vali = indices[split*seg:min(totalsize,seg*split+seg)]
    val_set = torch.utils.data.dataset.Subset(dataset,vali)
    valloader = DataLoader(val_set, collate_fn=dataset.collate_fn, batch_size=1, shuffle=True)
    
    time_to_maneuver, correct_expansion = eval2(t2m_net, valloader)
    time_to_maneuver = time_to_maneuver.reshape(1, 58)
    expandedset = correct_expansion.reshape(1, 5, 58)

    for i in range(1):
        for j in range(58): # number of either train/test for each split
            if expandedset[i,0,j] == 1:
                num5s += 1
                
            if expandedset[i,1,j] == 1:
                num4s += 1
                
            if expandedset[i,2,j] == 1:
                num3s += 1
                
            if expandedset[i,3,j] == 1:
                num2s += 1
                
            if expandedset[i,4,j] == 1:
                num1s += 1
    

                
print('5s pecentage: ' + str(num5s/total))
print('4s pecentage: ' + str(num4s/total))
print('3s pecentage: ' + str(num3s/total))
print('2s pecentage: ' + str(num2s/total))
print('1s pecentage: ' + str(num1s/total))

5s pecentage: 0.6224137931034482
4s pecentage: 0.6931034482758621
3s pecentage: 0.7724137931034483
2s pecentage: 0.8086206896551724
1s pecentage: 0.843103448275862


In [None]:
np.save("/home/arthur/AMRL/Research/Driver-Intent-Prediction/outputs_t2m/split0.npy", expandedset)