In [1]:
from model_functions import SocialGNN, get_inputs_outputs
from collections import namedtuple
import pickle
import os 
from tqdm import tqdm

In [2]:
trained_model_dir = 'TrainedModels/'
encoding_output_dir = 'Data/encodings/'
data_input_dir = 'Data/preprocess/graphs'

In [3]:
def load_pickle(path):
    with open(path, 'rb') as f:
        pickled = pickle.load(f)
    return pickled

def save_pickle(obj, path):
    with open(path, 'wb') as f:
        pickle.dump(obj, f)

In [4]:
videos = load_pickle(data_input_dir)
model_config = namedtuple('model_config', 'NUM_NODES NUM_AGENTS V_SPATIAL_SIZE E_SPATIAL_SIZE V_TEMPORAL_SIZE V_OUTPUT_SIZE BATCH_SIZE CLASS_WEIGHTS LEARNING_RATE LAMBDA')

In [5]:
len(videos)

250

In [6]:
videos[0]

{'label': [],
 'graph_dicts': [{'nodes': [[-7.850303649902344,
     -1.6825066804885864,
     -4.208608150482178,
     -11.062053680419922,
     2.976902961730957,
     -21.33794403076172,
     -3.873410224914551,
     -4.370782375335693,
     4.302788734436035,
     3.327807664871216,
     -8.291187286376953,
     -6.290378570556641,
     -2.905923366546631,
     -6.54016637802124,
     -0.16700157523155212,
     5.006124496459961,
     4.564458847045898,
     4.927541255950928,
     4.128571510314941,
     1.09581458568573,
     0.0,
     101.0,
     94.0,
     228.0,
     0.0],
    [-14.308711051940918,
     0.29627636075019836,
     4.855468273162842,
     -9.988350868225098,
     -1.8918441534042358,
     -18.185495376586914,
     -6.089308261871338,
     9.820735931396484,
     -3.754924774169922,
     -1.39609694480896,
     -3.3341917991638184,
     -5.8742241859436035,
     -8.983794212341309,
     -0.9540156126022339,
     2.080803155899048,
     -3.7811877727508545,
     4.8

In [8]:
restart = True

In [9]:
old_output_size = None
for model_name in tqdm(sorted(os.listdir(trained_model_dir))):
    model_dir = os.path.join(trained_model_dir, model_name)
    if os.path.isdir(model_dir):
        output_dir = f'Data/encodings/{model_name}.pkl'
        if not os.path.exists(output_dir) or restart:
            parameters = model_name.split('_')[7:]
            config = model_config(
                NUM_NODES = 5,              # always 5,
                NUM_AGENTS = 5,             # always 5
                V_SPATIAL_SIZE = 12,        # always 12
                E_SPATIAL_SIZE = 12,        # always 12
                V_TEMPORAL_SIZE = 6,        # always 6
                V_OUTPUT_SIZE = int(parameters[5]),      # 2 or 5
                BATCH_SIZE = 10,            # 20, but forced to be 10 so that I don't need to worry about the padded data
                CLASS_WEIGHTS = [float(parameters[7+i]) for i in range(int(parameters[5]))],
                LEARNING_RATE = 0.001,      # always 0.001
                LAMBDA = 0.01               # always 0.01
            )
            if config.V_OUTPUT_SIZE != old_output_size:
                model = SocialGNN(videos, config, get_inputs_outputs(videos)[0])
                old_output_size = config.V_OUTPUT_SIZE
            model.load_model(model_dir)
            rnn_activations = model.get_activations()
            save_pickle(rnn_activations, output_dir)
        
        

  0%|          | 0/41 [00:00<?, ?it/s]

.............DEFINING INPUT PLACEHOLDERS..............

.............BUILDING GRAPH..............

.............INITIALIZATION SESSION..............

.............TESTING..............


  2%|▏         | 1/41 [04:08<2:45:28, 248.21s/it]


.............TESTING..............


  5%|▍         | 2/41 [04:45<1:20:38, 124.08s/it]


.............TESTING..............


  7%|▋         | 3/41 [05:03<47:58, 75.74s/it]   


.............TESTING..............


 10%|▉         | 4/41 [05:28<34:14, 55.53s/it]


.............TESTING..............


 12%|█▏        | 5/41 [05:44<24:47, 41.32s/it]


.............TESTING..............


 15%|█▍        | 6/41 [06:06<20:22, 34.92s/it]


.............TESTING..............


 17%|█▋        | 7/41 [06:42<19:58, 35.26s/it]


.............TESTING..............


 20%|█▉        | 8/41 [06:59<16:08, 29.36s/it]


.............TESTING..............


 22%|██▏       | 9/41 [07:23<14:44, 27.63s/it]


.............TESTING..............


 24%|██▍       | 10/41 [07:41<12:42, 24.59s/it]

.............DEFINING INPUT PLACEHOLDERS..............

.............BUILDING GRAPH..............

.............INITIALIZATION SESSION..............

.............TESTING..............


 27%|██▋       | 11/41 [12:20<51:19, 102.66s/it]


.............TESTING..............


 29%|██▉       | 12/41 [12:52<39:09, 81.02s/it] 


.............TESTING..............


 32%|███▏      | 13/41 [13:07<28:33, 61.19s/it]


.............TESTING..............


 34%|███▍      | 14/41 [13:29<22:10, 49.26s/it]


.............TESTING..............


 37%|███▋      | 15/41 [13:43<16:40, 38.49s/it]


.............TESTING..............


 39%|███▉      | 16/41 [14:08<14:27, 34.70s/it]


.............TESTING..............


 41%|████▏     | 17/41 [14:58<15:39, 39.13s/it]


.............TESTING..............


 44%|████▍     | 18/41 [15:20<12:59, 33.90s/it]


.............TESTING..............


 46%|████▋     | 19/41 [15:46<11:38, 31.74s/it]


.............TESTING..............


 49%|████▉     | 20/41 [16:01<09:20, 26.67s/it]

.............DEFINING INPUT PLACEHOLDERS..............

.............BUILDING GRAPH..............

.............INITIALIZATION SESSION..............

.............TESTING..............


 51%|█████     | 21/41 [20:19<32:00, 96.00s/it]


.............TESTING..............


 54%|█████▎    | 22/41 [21:27<27:46, 87.70s/it]


.............TESTING..............


 56%|█████▌    | 23/41 [21:38<19:22, 64.58s/it]


.............TESTING..............


 59%|█████▊    | 24/41 [21:55<14:16, 50.37s/it]


.............TESTING..............


 61%|██████    | 25/41 [22:19<11:19, 42.47s/it]


.............TESTING..............


 63%|██████▎   | 26/41 [22:44<09:18, 37.24s/it]


.............TESTING..............


 66%|██████▌   | 27/41 [23:07<07:42, 33.01s/it]


.............TESTING..............


 68%|██████▊   | 28/41 [23:45<07:27, 34.41s/it]


.............TESTING..............


 71%|███████   | 29/41 [24:26<07:15, 36.26s/it]


.............TESTING..............


 73%|███████▎  | 30/41 [24:44<05:38, 30.79s/it]

.............DEFINING INPUT PLACEHOLDERS..............

.............BUILDING GRAPH..............

.............INITIALIZATION SESSION..............

.............TESTING..............


 76%|███████▌  | 31/41 [30:16<20:12, 121.23s/it]


.............TESTING..............


 78%|███████▊  | 32/41 [30:39<13:45, 91.74s/it] 


.............TESTING..............


 80%|████████  | 33/41 [30:53<09:07, 68.42s/it]


.............TESTING..............


 83%|████████▎ | 34/41 [32:34<09:07, 78.19s/it]


.............TESTING..............


 85%|████████▌ | 35/41 [32:56<06:08, 61.46s/it]


.............TESTING..............


 88%|████████▊ | 36/41 [33:30<04:25, 53.15s/it]


.............TESTING..............


 90%|█████████ | 37/41 [33:50<02:52, 43.12s/it]


.............TESTING..............


 93%|█████████▎| 38/41 [34:17<01:54, 38.30s/it]


.............TESTING..............


 95%|█████████▌| 39/41 [34:59<01:18, 39.49s/it]


.............TESTING..............


100%|██████████| 41/41 [35:21<00:00, 51.74s/it]
