In [23]:
%load_ext autoreload
%autoreload 2

In [64]:
from src.dataset.Loader.main_loader import define_path
from src.dataset.trans.data import TransDataset, extract_pred_sequence
from src.dataset.trans.jaad_trans import JaadTransDataset
from src.dataset.intention.jaad_dataset import build_pedb_dataset_jaad, subsample_and_balance
import cv2
import numpy as np
import glob
from IPython.display import display
from PIL import Image
import matplotlib.pyplot as plt

from src.dataset.loader import PaddedSequenceDataset, IntentionSequenceDataset
from src.transform.preprocess import ImageTransform, Compose, CropBox
import torchvision
import torch

from sklearn.metrics import average_precision_score, f1_score, classification_report
from train_hybrid import unpack_batch
from tqdm import tqdm

In [65]:
anns_paths, image_dir = define_path(use_jaad=True, use_pie=False, use_titan=False)
anns_paths_val, image_dir_val = define_path(use_jaad=True, use_pie=False, use_titan=False)

In [66]:
from dataclasses import dataclass

@dataclass
class Args:
    encoder_type: str = 'CC'
    encoder_pretrained: bool = False
    epochs: int = 1
    lr: float = 0.001
    wd: float = 0.0
    batch_size: int = 4
    max_frames: int = 10
    pred: int = 10
    output: str = None
    fps: int = 5
    seed: int = 99
    jitter_ratio: float = -1.0
    mobilenetsmall: bool = False
    mobilenetbig: bool = False
    num_workers: int = 4

args = Args(num_workers=0)
max_frames = 10

In [67]:
image_set = "test"


intent_sequences = build_pedb_dataset_jaad(anns_paths["JAAD"]["anns"], anns_paths["JAAD"]["split"], image_set=image_set, fps=args.fps, prediction_frames=args.pred, verbose=True)
balance = False if image_set == "test" else True
intent_sequences_cropped = subsample_and_balance(intent_sequences, max_frames=args.max_frames, seed=args.seed, balance=balance)

jitter_ratio = None if args.jitter_ratio < 0 else args.jitter_ratio
crop_preprocess = CropBox(size=224, padding_mode='pad_resize', jitter_ratio=jitter_ratio)
if image_set == 'train':
    TRANSFORM = Compose([
        crop_preprocess,
        ImageTransform(torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1))
        ])
else:
    TRANSFORM = crop_preprocess

ds = IntentionSequenceDataset(intent_sequences_cropped[:500], image_dir=image_dir, hflip_p = 0.5, preprocess=TRANSFORM)

----------------------------------------------------------------
JAAD:
Total number of crosses: 156
Total number of non-crosses: 35
Filtered samples: 1
Total number of samples before and after balancing: 131, 131


In [68]:
from src.model.models import build_encoder_res18, DecoderRNN_IMBS
from src.early_stopping import load_from_checkpoint

CP_PATH = '/Users/arinaruck/Desktop/courses/CIVIL459-PedestrianIntensionDetection/checkpoints/silver-sweep-5/Decoder_IMBS_lr0.001_wd0.0001_JAAD_mf10_pred10_bs16_202305241656.pt'

In [69]:
device = torch.device('cpu')


encoder_res18 = build_encoder_res18(args)
# freeze CNN-encoder during training
encoder_res18.freeze_backbone()

decoder_lstm = DecoderRNN_IMBS(CNN_embeded_size=256, h_RNN_0=256, h_RNN_1=64, h_RNN_2=16,
                                h_FC0_dim=128, h_FC1_dim=64, h_FC2_dim=86, drop_p=0.2).to(device)

model = {'encoder': encoder_res18, 'decoder': decoder_lstm}
load_from_checkpoint(model, CP_PATH)

Using resnet18 cnn encoder!!




In [70]:
@torch.no_grad()
def abl_eval_model(loader, model, device, abl_type='none'):
    # swith to evaluate mode
    encoder_CNN, decoder_RNN = model['encoder'], model['decoder']
    encoder_CNN.eval()
    decoder_RNN.eval()

    batch_size = loader.batch_size
    n_steps = len(loader)

    preds = np.zeros(n_steps * batch_size)
    tgts = np.zeros(n_steps * batch_size)

    for step, inputs in enumerate(tqdm(loader)):
        images, seq_len, pv, scene, behavior, targets = unpack_batch(inputs, device)
        outputs_CNN = encoder_CNN(images, seq_len)
        if abl_type == 'CNN':
            outputs_CNN = torch.zeros_like(outputs_CNN)
        elif abl_type == 'scene':
            scene = torch.zeros_like(scene)
        elif abl_type == 'behavior':
            behavior = torch.zeros_like(behavior)
        elif abl_type == 'pv':
            pv = torch.zeros_like(pv)
        outputs_RNN = decoder_RNN(xc_3d=outputs_CNN, xp_3d=pv, 
                                    xb_3d=behavior, xs_2d=scene, x_lengths=seq_len)
        
        preds[step * batch_size: (step + 1) * batch_size] = outputs_RNN.detach().cpu().squeeze()
        tgts[step * batch_size: (step + 1) * batch_size] = targets.detach().cpu().squeeze()

    ap_score = average_precision_score(tgts, preds)
    best_thr = decoder_RNN.threshold
    f1 = f1_score(tgts, preds > best_thr)
    preds = preds > best_thr
    print(classification_report(tgts, preds), flush=True)
    print(f"AP: {ap_score}, F1: {f1}", flush=True)

In [71]:
test_loader = torch.utils.data.DataLoader(ds, batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True)


In [72]:
abl_eval_model(test_loader, model, device)

100%|██████████| 131/131 [01:21<00:00,  1.61it/s]

              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00         6
         1.0       0.95      1.00      0.98       125

    accuracy                           0.95       131
   macro avg       0.48      0.50      0.49       131
weighted avg       0.91      0.95      0.93       131

AP: 0.9640141143283294, F1: 0.9765625



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [73]:
abl_eval_model(test_loader, model, device, abl_type='CNN')

100%|██████████| 131/131 [01:22<00:00,  1.59it/s]

              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00         6
         1.0       0.95      1.00      0.98       125

    accuracy                           0.95       131
   macro avg       0.48      0.50      0.49       131
weighted avg       0.91      0.95      0.93       131

AP: 0.9572672649633381, F1: 0.9765625



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [74]:
abl_eval_model(test_loader, model, device, abl_type='pv')

100%|██████████| 131/131 [01:23<00:00,  1.57it/s]

              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00         6
         1.0       0.95      1.00      0.98       125

    accuracy                           0.95       131
   macro avg       0.48      0.50      0.49       131
weighted avg       0.91      0.95      0.93       131

AP: 0.9306057546581918, F1: 0.9765625



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [75]:
abl_eval_model(test_loader, model, device, abl_type='behavior')

100%|██████████| 131/131 [01:31<00:00,  1.42it/s]

              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00         6
         1.0       0.95      1.00      0.98       125

    accuracy                           0.95       131
   macro avg       0.48      0.50      0.49       131
weighted avg       0.91      0.95      0.93       131

AP: 0.9556074968299084, F1: 0.9765625



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [76]:
abl_eval_model(test_loader, model, device, abl_type='scene')

100%|██████████| 131/131 [01:32<00:00,  1.42it/s]

              precision    recall  f1-score   support

         0.0       0.05      1.00      0.09         6
         1.0       0.00      0.00      0.00       125

    accuracy                           0.05       131
   macro avg       0.02      0.50      0.04       131
weighted avg       0.00      0.05      0.00       131

AP: 0.99430689692157, F1: 0.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
