In [32]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
from sklearn.mixture import GaussianMixture
from inference.compute_test_features import load_scenario_features, get_array_features, get_ego_features

os.environ['PLANTF'] = '/home/sgwang/planTF'
class EncoderFeatureAnalyzer:
    def __init__(self):
        dim =128
        self.norm = nn.LayerNorm(dim)
        
    def load_encoder_features(self, folder_path):
        encoder_features = []
        
        for file_name in os.listdir(folder_path):
            if file_name.endswith('.pt'):
                file_path = os.path.join(folder_path, file_name)
                try:
                    features = torch.load(file_path)
                    encoder_features.append(features)
                except Exception as e:
                    print(f"Failed to load {file_path}: {e}")

        return encoder_features
    
    def get_ego_features(self, features):
        ego_features = []
        for feature in features:
            feature = feature[:,0]
            ego_features.append(feature)
        return ego_features
    
    def split_and_concat_features(self, features):
        # Convert to tensor
        # Split the tensor into a list of tensors each of size [128]
        all_features = []
        for feature in features:
            split_features = feature.split(1, dim=0)
            split_features = [f.squeeze(0) for f in split_features]
            all_features.extend(split_features)
        # Concatenate all tensors together as [:, 128]
        concatenated_features = torch.stack(all_features)
        return concatenated_features
    
    def get_other_features(self, features):
        other_features = []
        for feature in features:
            feature = feature[:,1:]
            other_features.append(feature)
        return other_features
    
    def compute_mean(self, features):
        means = []
        for feature in features:
            mean = torch.mean(feature, dim=0)
            means.append(mean)
        return means
    
    def compute_std(self, features):
        stds = []
        for feature in features:
            std = torch.std(feature, dim=0)
            stds.append(std)
        return stds
    
    def compute_min(self, features):
        mins = []
        for feature in features:
            min_val = torch.min(feature, dim=0).values
            mins.append(min_val)
        return mins
    
    def compute_max(self, features):
        maxs = []
        for feature in features:
            max_val = torch.max(feature, dim=0).values
            maxs.append(max_val)
        return maxs
    
    def concat_features(self, features):
        concat_features = []
        for feature in features:
            feature = feature.view(-1)
            concat_features.append(feature)
        return concat_features

    def compute_norm(self, features):
        norms = []
        for feature in features:
            encoder_feature = feature.clone().detach().cpu()
            x = self.norm(encoder_feature)
            norms.append(x)
        return norms
    
    def compute_gmm(self, features):
        gmm = GaussianMixture(n_components=2)
        gmm.fit(features)
        return gmm
    
    def colculate_matrix(self, features):
        mean = np.mean(features, axis=0)
        cov_matrix = np.cov(features, rowvar=False)
        inv_cov_matrix = np.linalg.inv(cov_matrix)
        return mean, cov_matrix, inv_cov_matrix
    
    
    def calculate_mahalanobis_distance(self,new_sample, mean, inv_cov_matrix):
        mahalanobis_dist = distance.mahalanobis(new_sample, mean, inv_cov_matrix)
        return mahalanobis_dist

In [14]:
from scipy.spatial import distance

In [22]:
analyzer = EncoderFeatureAnalyzer()
plantf_path = os.getenv('PLANTF')
norm_path = os.path.join(plantf_path, 'inference_x')
norm_features = analyzer.load_encoder_features(norm_path)
norm_features = analyzer.get_ego_features(norm_features)
concatenated_features = analyzer.split_and_concat_features(norm_features)

In [26]:
scenario_path = os.path.join(plantf_path, 'encoder_features')
scenarios =load_scenario_features(scenario_path)
first_scenario = scenarios[0]
scenario = list(first_scenario.values())[0]
array_features = get_array_features(scenario)
ego_features = get_ego_features(array_features)

In [27]:
print("Ego features", ego_feature.shape)

Ego features torch.Size([32, 128])


In [29]:
dim=128
for ego_feature in ego_features:
    ego_feature_tensor = torch.tensor(ego_feature).squeeze()
    if ego_feature_tensor.ndim == 1 and ego_feature_tensor.shape[0] == dim:
        dist = analyzer.calculate_mahalanobis_distance(concatenated_features.numpy(), ego_feature_tensor.numpy())
        print(dist)
    else:
        print(f"Skipping ego_feature with shape {ego_feature_tensor.shape}")

(128,) (128, 128) (128, 128)
725282.3632961452
(128,) (128, 128) (128, 128)
696808.9751484675
(128,) (128, 128) (128, 128)
676735.583125948
(128,) (128, 128) (128, 128)
682646.4328845876
(128,) (128, 128) (128, 128)
687601.3775549483
(128,) (128, 128) (128, 128)
680217.761981475
(128,) (128, 128) (128, 128)
700311.5794941292
(128,) (128, 128) (128, 128)
701244.655915094
(128,) (128, 128) (128, 128)
748879.9454690809
(128,) (128, 128) (128, 128)
692615.0388931336
(128,) (128, 128) (128, 128)
819474.7893184191
(128,) (128, 128) (128, 128)
712718.3800084519
(128,) (128, 128) (128, 128)
716089.2288661709
(128,) (128, 128) (128, 128)
724276.6243667427
(128,) (128, 128) (128, 128)
732007.0982167782
(128,) (128, 128) (128, 128)
734760.6432620665
(128,) (128, 128) (128, 128)
715089.3908039466
(128,) (128, 128) (128, 128)
718295.5400701498
(128,) (128, 128) (128, 128)
732936.9926380501
(128,) (128, 128) (128, 128)
772844.4630086473
(128,) (128, 128) (128, 128)
758483.4916312696
(128,) (128, 128

  return np.sqrt(m)


(128,) (128, 128) (128, 128)
nan
(128,) (128, 128) (128, 128)
nan
(128,) (128, 128) (128, 128)
nan
(128,) (128, 128) (128, 128)
nan
(128,) (128, 128) (128, 128)
nan
(128,) (128, 128) (128, 128)
nan
(128,) (128, 128) (128, 128)
nan
(128,) (128, 128) (128, 128)
364254.7207662308
(128,) (128, 128) (128, 128)
nan
(128,) (128, 128) (128, 128)
nan
(128,) (128, 128) (128, 128)
365846.6805675579
(128,) (128, 128) (128, 128)
396325.24092332233
(128,) (128, 128) (128, 128)
412002.03797212883
(128,) (128, 128) (128, 128)
407398.72875309247
(128,) (128, 128) (128, 128)
nan
(128,) (128, 128) (128, 128)
438748.9153974204
(128,) (128, 128) (128, 128)
428832.1371311852
(128,) (128, 128) (128, 128)
428583.34313878155
(128,) (128, 128) (128, 128)
409240.32806824596
(128,) (128, 128) (128, 128)
442112.4987801989
(128,) (128, 128) (128, 128)
448021.5145996695
(128,) (128, 128) (128, 128)
448484.2110488034
(128,) (128, 128) (128, 128)
455373.40385423444
(128,) (128, 128) (128, 128)
443259.71825558227
(128,

In [33]:
dim=128
analyzer = EncoderFeatureAnalyzer()
mean, cov_matrix, inv_cov_matrix = analyzer.colculate_matrix(concatenated_features.numpy())
for ego_feature in ego_features:
    ego_feature_tensor = torch.tensor(ego_feature).squeeze()
    if ego_feature_tensor.ndim == 1 and ego_feature_tensor.shape[0] == dim:
        dist = analyzer.calculate_mahalanobis_distance(ego_feature_tensor.numpy(),mean, inv_cov_matrix)
        print(dist)
    else:
        print(f"Skipping ego_feature with shape {ego_feature_tensor.shape}")

725282.3632961452
696808.9751484675
676735.583125948
682646.4328845876
687601.3775549483
680217.761981475
700311.5794941292
701244.655915094
748879.9454690809
692615.0388931336
819474.7893184191
712718.3800084519
716089.2288661709
724276.6243667427
732007.0982167782
734760.6432620665
715089.3908039466
718295.5400701498
732936.9926380501
772844.4630086473
758483.4916312696
760878.1260428292
770793.2347129629
778202.8548702886
784561.6525143145
787579.1286539536
760124.7534432492
779821.2317154
779803.2015974339
747760.7636221213
740340.605991406
732696.739322277
725579.9806718619
708518.7542859985
732138.9233577993
734693.1124706749
743338.7386834027
750746.00857115
795020.5101983593
757724.8596569014
732898.3662910346
743812.7894521984
793729.6358585089
701088.8389103388
676671.3297894157
674688.1488109217
666118.8674462708
714646.4277489141
636498.8539597132
649013.4105025795
675057.4310800601
676804.9769574918
712805.2797622046
702039.9537887048
703728.1503949104
757071.1435103102
67

  return np.sqrt(m)
