In [1]:
import numpy as np
import os
import pickle
import torch
import pandas as pd
from model import VSE

In [2]:
with open('../../data/msvd_video_caps.pkl', 'rb') as f:
    data = pickle.load(f)

In [11]:
train = np.load('../../data/cpv20_rand/train.npy')
val = np.load('../../data/cpv20_rand/val.npy')
test = np.load('../../data/cpv20_rand/test.npy')

In [16]:
with open('vocab/vocab.pkl', 'rb') as f:
    data = pickle.load(f)

In [18]:
data.idx

11755

In [15]:
class Vocabulary:
    pass

## Search the best model

In [2]:
OBJ_DIR = 'E:/gs_obj_median'
ACT_DIR = 'E:/gs_act'
ACT2_DIR = 'E:/gs_act2'
ACTX_DIR = 'E:/gs_actx'
FLOW_DIR = 'E:/gs_flow'
FLOW_DTVL1_DIR = 'E:/gs_flow_dtvl1'

In [3]:
def collect_results(base_dir, result_name = 'val.stat'):
    model_names = os.listdir(base_dir)
    
    df = pd.DataFrame(columns=['name', 'epoch', 'rsum', 
                      'r1', 'r5', 'r10', 'medr', 'meanr', 
                      'r1i', 'r5i', 'r10i', 'medri', 'meanri'], dtype='float')
    df['name'] = df['name'].astype('str')
    df['epoch'] = df['epoch'].astype('int')
    
    rows = []
    for model_name in model_names:
        row = {'name': model_name}
        with open(os.path.join(base_dir, model_name, result_name), 'r') as f:
            for l in f:
                line = l.split('=')
                row[line[0]] = float(line[1])

        row['epoch'] = int(row['epoch'])
        rows.append(row)
    df = df.append(rows, ignore_index=True)
    return df

## Object

In [4]:
obj_val = collect_results(OBJ_DIR, result_name='val.stat')
obj_val.sort_values(by='rsum', ascending=False).head(5)

Unnamed: 0,name,epoch,rsum,r1,r5,r10,medr,meanr,r1i,r5i,r10i,medri,meanri
495,objcpv20m1.0d0.5wd100es2048lr0.0001,36,332.959184,37.755102,65.306122,74.489796,2.0,21.020408,26.428571,56.581633,72.397959,4.0,10.282653
460,objcpv20m1.0d0.25wd100es1536lr0.0001,40,326.632653,37.755102,65.306122,72.44898,3.0,21.479592,25.204082,55.255102,70.663265,4.0,10.657143
145,objcpv20m0.4d0.1wd100es2048lr0.0001,39,322.857143,37.755102,64.285714,69.387755,2.0,18.397959,25.816327,55.102041,70.510204,5.0,10.491327
385,objcpv20m0.8d0.5wd100es1536lr0.0001,40,321.428571,31.632653,67.346939,73.469388,3.0,21.306122,25.102041,54.489796,69.387755,4.0,10.175
85,objcpv20m0.2d0.5wd100es1536lr0.0001,38,321.326531,38.77551,61.22449,67.346939,3.0,23.459184,26.836735,56.071429,71.071429,4.0,10.645918


In [5]:
obj_test = collect_results(OBJ_DIR, result_name='test.stat')
obj_test.sort_values(by='rsum', ascending=False).head(5)

Unnamed: 0,name,epoch,rsum,r1,r5,r10,medr,meanr,r1i,r5i,r10i,medri,meanri
145,objcpv20m0.4d0.1wd100es2048lr0.0001,39,137.769347,6.980273,25.341426,39.453718,19.0,105.867982,7.283763,23.520486,35.189681,22.0,60.609105
390,objcpv20m0.8d0.5wd100es1792lr0.0001,39,136.714719,5.007587,26.403642,39.453718,20.0,104.308042,7.283763,23.573596,34.992413,23.0,62.439909
265,objcpv20m0.6d0.25wd100es1792lr0.0001,39,135.834598,7.738998,26.403642,37.329287,22.0,110.239757,6.904401,23.308042,34.150228,25.0,64.117071
495,objcpv20m1.0d0.5wd100es2048lr0.0001,36,135.341426,7.587253,25.493171,36.874052,22.0,99.411229,7.496206,23.095599,34.795144,22.0,60.147269
295,objcpv20m0.6d0.5wd100es2048lr0.0001,36,135.068285,7.283763,25.948407,37.632777,21.0,100.028832,6.919575,22.966616,34.317147,23.0,62.990288


## Activity 1

In [5]:
act_val = collect_results(ACT_DIR, result_name='val.stat')
act_val.sort_values(by='rsum', ascending=False).head(5)

Unnamed: 0,name,epoch,rsum,r1,r5,r10,medr,meanr,r1i,r5i,r10i,medri,meanri
550,actcpv20m1.0wc1e-08wd100es1536lr0.0001,37,259.234694,29.591837,57.142857,69.387755,4.0,22.653061,14.642857,37.397959,51.071429,10.0,17.770918
666,actcpv20m0.6wc1e-08wd100es1024lr0.0001,39,257.653061,29.591837,56.122449,66.326531,4.0,20.561224,13.826531,36.785714,55.0,9.0,16.991837
671,actcpv20m0.6wc1e-08wd100es1280lr0.0001,35,257.602041,31.632653,57.142857,69.387755,4.0,18.336735,12.704082,36.071429,50.663265,10.0,19.511224
422,actcpv20m0.4wc1e-08wd100es1536lr0.0001,39,256.989796,31.632653,57.142857,65.306122,4.0,17.183673,15.05102,37.091837,50.765306,10.0,19.285204
173,actcpv20m0.6wc1e-16wd100es1536lr0.0001,32,256.683673,30.612245,60.204082,68.367347,4.0,21.081633,13.265306,35.102041,49.132653,11.0,19.057653


In [6]:
act_test = collect_results(ACT_DIR, result_name='test.stat')
act_test.sort_values(by='rsum', ascending=False).head(5)

Unnamed: 0,name,epoch,rsum,r1,r5,r10,medr,meanr,r1i,r5i,r10i,medri,meanri
72,actcpv20m1.0wc1e-16wd100es1280lr0.0001,40,90.455235,5.159332,19.271624,28.376328,32.0,133.83915,3.846737,13.376328,20.424886,58.0,113.718134
152,actcpv20m0.8wc1e-16wd100es1280lr0.0001,40,89.764795,6.069803,19.119879,30.045524,35.0,121.323217,3.550835,11.98786,18.990895,56.0,111.731942
496,actcpv20m0.4wc1e-12wd100es1536lr0.0001,40,88.945372,5.614568,20.637329,33.23217,26.0,107.681335,2.587253,10.394537,16.479514,68.0,123.174583
554,actcpv20m0.2wc0wd100es1536lr0.0001,38,87.890744,4.248862,19.423369,30.045524,29.0,126.735964,3.338392,12.025797,18.808801,59.0,111.51176
422,actcpv20m0.4wc1e-08wd100es1536lr0.0001,39,87.253414,5.462822,18.816388,28.376328,32.0,113.400607,3.277693,12.207891,19.112291,65.0,116.898027


## ACTIVITY 2

In [4]:
act2_val = collect_results(ACT2_DIR, result_name='val.stat')
act2_val.sort_values(by='rsum', ascending=False).head(5)

Unnamed: 0,name,epoch,rsum,r1,r5,r10,medr,meanr,r1i,r5i,r10i,medri,meanri
445,actcpv20m1.0d0.1wd100es2048lr0.0001,37,320.357143,33.673469,60.204082,74.489796,3.0,17.367347,24.744898,56.071429,71.173469,4.0,10.556633
235,actcpv20m0.6d0.1wd100es1536lr0.0001,40,315.510204,33.673469,63.265306,72.44898,3.0,12.806122,23.826531,53.214286,69.081633,5.0,11.281633
165,actcpv20m0.4d0.25wd100es1792lr0.0001,34,313.979592,36.734694,65.306122,72.44898,3.0,14.765306,24.030612,49.387755,66.071429,6.0,12.116837
380,actcpv20m0.8d0.5wd100es1280lr0.0001,34,312.346939,36.734694,61.22449,71.428571,3.0,17.306122,21.22449,52.346939,69.387755,5.0,10.966327
90,actcpv20m0.2d0.5wd100es1792lr0.0001,39,311.122449,38.77551,59.183673,68.367347,3.0,13.489796,23.214286,52.091837,69.489796,5.0,11.442857


In [8]:
act2_test = collect_results(ACT2_DIR, result_name='test.stat')
act2_test.sort_values(by='rsum', ascending=False).head(5)

Unnamed: 0,name,epoch,rsum,r1,r5,r10,medr,meanr,r1i,r5i,r10i,medri,meanri
90,actcpv20m0.2d0.5wd100es1792lr0.0001,39,123.027314,7.435508,23.672231,34.294385,26.0,90.247344,5.940819,20.569044,31.115326,28.0,69.855994
190,actcpv20m0.4d0.5wd100es1792lr0.0001,40,122.124431,6.221548,22.76176,34.749621,23.0,99.174507,6.388467,20.637329,31.365706,28.0,72.1
45,actcpv20m0.2d0.1wd100es2048lr0.0001,37,120.751138,7.435508,22.913505,33.383915,24.0,113.763278,6.312595,20.500759,30.204856,28.0,69.532246
235,actcpv20m0.6d0.1wd100es1536lr0.0001,40,119.597876,6.525038,23.823976,33.383915,25.0,111.614568,6.153263,19.734446,29.977238,28.0,70.69264
470,actcpv20m1.0d0.25wd100es2048lr0.0001,40,119.294385,7.435508,22.45827,32.92868,24.0,103.578149,5.849772,20.060698,30.561457,28.0,70.416995


## ACTIVITY WITH FLOW

In [4]:
actx_val = collect_results(ACTX_DIR, result_name='val.stat')
actx_val.sort_values(by='rsum', ascending=False).head(5)

Unnamed: 0,name,epoch,rsum,r1,r5,r10,medr,meanr,r1i,r5i,r10i,medri,meanri
56,actcpv20m0.6d0.5wd100es1280lr0.0001,40,315.459184,34.693878,67.346939,76.530612,3.0,16.561224,22.397959,50.102041,64.387755,5.0,12.315816
49,actcpv20m0.6d0.1wd100es2048lr0.0001,36,307.908163,36.734694,59.183673,69.387755,3.0,20.459184,23.877551,52.5,66.22449,5.0,12.116327
93,actcpv20m1.0d0.25wd100es1792lr0.0001,39,305.204082,32.653061,63.265306,74.489796,4.0,19.132653,19.94898,49.846939,65.0,6.0,12.185714
34,actcpv20m0.4d0.25wd100es2048lr0.0001,36,303.979592,36.734694,57.142857,66.326531,3.0,16.816327,23.826531,52.704082,67.244898,5.0,11.971429
58,actcpv20m0.6d0.5wd100es1792lr0.0001,36,303.520408,32.653061,59.183673,69.387755,3.0,17.857143,22.142857,52.142857,68.010204,5.0,11.389286


In [5]:
actx_test = collect_results(ACTX_DIR, result_name='test.stat')
actx_test.sort_values(by='rsum', ascending=False).head(5)

Unnamed: 0,name,epoch,rsum,r1,r5,r10,medr,meanr,r1i,r5i,r10i,medri,meanri
34,actcpv20m0.4d0.25wd100es2048lr0.0001,36,118.254932,6.221548,23.823976,33.990895,24.0,109.23824,5.698027,19.301973,29.218513,31.0,73.982473
14,actcpv20m0.2d0.25wd100es2048lr0.0001,38,118.110774,7.132018,22.45827,33.080425,26.0,104.453718,5.493171,19.855842,30.091047,29.0,73.909181
37,actcpv20m0.4d0.5wd100es1536lr0.0001,37,117.488619,7.132018,23.216995,32.776935,27.0,107.986343,6.092564,19.339909,28.930197,32.0,76.051366
96,actcpv20m1.0d0.5wd100es1280lr0.0001,33,116.85129,6.525038,23.520486,32.92868,28.0,110.311077,5.576631,19.165402,29.135053,31.0,75.767071
59,actcpv20m0.6d0.5wd100es2048lr0.0001,34,116.46434,6.373293,22.15478,33.687405,25.0,107.262519,5.705615,19.021244,29.522003,30.0,72.179287


## FLOW

In [9]:
flow_val = collect_results(FLOW_DIR, result_name='val.stat')
flow_val.sort_values(by='rsum', ascending=False).head(5)

Unnamed: 0,name,epoch,rsum,r1,r5,r10,medr,meanr,r1i,r5i,r10i,medri,meanri
215,objcpv20m0.6d0.5wd100es1792lr0.0001,35,200.204082,18.367347,39.795918,53.061224,7.0,42.581633,9.846939,32.346939,46.785714,12.0,19.035204
195,objcpv20m0.6d0.25wd100es2048lr0.0001,40,194.387755,15.306122,36.734694,52.040816,10.0,49.642857,10.969388,32.040816,47.295918,12.0,18.743367
10,objcpv20m0.2d0.1wd100es1536lr0.0001,26,193.826531,19.387755,34.693878,53.061224,9.0,39.867347,10.153061,30.561224,45.969388,12.0,18.663776
135,objcpv20m0.4d0.5wd100es1536lr0.0001,35,191.785714,19.387755,39.795918,47.959184,14.0,39.214286,8.520408,29.693878,46.428571,12.0,19.419388
370,objcpv20m1.0d0.5wd100es2048lr0.0001,20,190.306122,13.265306,38.77551,48.979592,11.0,45.795918,8.928571,31.989796,48.367347,11.0,18.587755


In [11]:
flow_test = collect_results(FLOW_DIR, result_name='test.stat')
flow_test.sort_values(by='rsum', ascending=False).head(5)

Unnamed: 0,name,epoch,rsum,r1,r5,r10,medr,meanr,r1i,r5i,r10i,medri,meanri
295,objcpv20m0.8d0.5wd100es2048lr0.0001,25,55.121396,1.820941,9.104704,16.084977,92.0,318.256449,2.655539,9.772382,15.682853,72.0,126.742109
365,objcpv20m1.0d0.5wd100es1792lr0.0001,27,54.301973,1.972686,9.559939,15.174507,96.0,355.682853,2.435508,9.430956,15.728376,71.0,126.675873
120,objcpv20m0.4d0.25wd100es2048lr0.0001,18,54.112291,1.972686,8.952959,14.871017,95.0,313.497724,2.655539,9.742033,15.918058,69.0,118.340819
145,objcpv20m0.4d0.5wd100es2048lr0.0001,39,53.793627,1.365706,9.408194,15.933232,84.0,295.036419,2.54173,9.203338,15.341426,71.0,124.314719
115,objcpv20m0.4d0.25wd100es1792lr0.0001,27,53.657056,2.427921,8.194234,15.326252,85.0,307.599393,2.359636,9.650986,15.698027,71.0,124.582398


## FLOW Dual-TVL1

In [4]:
flow_dtvl1_val = collect_results(FLOW_DTVL1_DIR, result_name='val.stat')
flow_dtvl1_val.sort_values(by='rsum', ascending=False).head(5)

Unnamed: 0,name,epoch,rsum,r1,r5,r10,medr,meanr,r1i,r5i,r10i,medri,meanri
93,objcpv20m1.0d0.25wd100es1792lr0.0001,40,264.132653,22.44898,53.061224,65.306122,5.0,27.612245,16.326531,45.459184,61.530612,7.0,14.681122
32,objcpv20m0.4d0.25wd100es1536lr0.0001,28,261.479592,22.44898,53.061224,61.22449,5.0,30.622449,16.785714,44.693878,63.265306,7.0,13.484694
48,objcpv20m0.6d0.1wd100es1792lr0.0001,33,259.744898,23.469388,47.959184,64.285714,6.0,21.591837,17.857143,46.122449,60.05102,7.0,14.157143
72,objcpv20m0.8d0.25wd100es1536lr0.0001,30,258.367347,23.469388,54.081633,65.306122,5.0,23.061224,14.642857,42.142857,58.72449,8.0,14.431122
77,objcpv20m0.8d0.5wd100es1536lr0.0001,11,258.163265,24.489796,55.102041,66.326531,5.0,23.071429,13.265306,40.867347,58.112245,8.0,14.611224


In [5]:
flow_dtvl1_test = collect_results(FLOW_DTVL1_DIR, result_name='test.stat')
flow_dtvl1_test.sort_values(by='rsum', ascending=False).head(5)

Unnamed: 0,name,epoch,rsum,r1,r5,r10,medr,meanr,r1i,r5i,r10i,medri,meanri
89,objcpv20m1.0d0.1wd100es2048lr0.0001,27,82.62519,4.704097,13.657056,22.610015,46.0,195.279211,3.915023,14.719272,23.019727,47.0,92.577693
38,objcpv20m0.4d0.5wd100es1792lr0.0001,40,80.318665,3.490137,15.174507,23.06525,48.0,201.919575,3.687405,13.61912,21.282246,49.0,100.54393
47,objcpv20m0.6d0.1wd100es1536lr0.0001,35,79.871017,3.490137,16.236722,23.368741,46.0,223.028832,3.53566,12.534143,20.705615,48.0,98.346889
93,objcpv20m1.0d0.25wd100es1792lr0.0001,40,79.150228,3.338392,13.657056,21.396055,47.0,204.420334,4.00607,14.127466,22.62519,46.0,96.468892
71,objcpv20m0.8d0.25wd100es1280lr0.0001,38,79.08953,4.552352,15.326252,21.396055,51.0,186.963581,3.649469,13.186646,20.978756,48.0,97.673065
