In [10]:
import numpy as np
import os
import pickle
import torch
import pandas as pd
from model import VSE

In [85]:
with open('./data/msr-vtt/captions_pkl/msr-vtt_captions_val.pkl', 'rb') as f:
    captions, lengths, video_ids = pickle.load(f)

In [10]:
class Vocabulary:
    def __len__(self):
        return len(self.word2idx)

In [42]:
with open('./vocab/vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)

In [111]:
with open('../../data/msvd_video_caps.pkl', 'rb') as f:
    data = pickle.load(f)

In [112]:
data

array([['mv89psg6zh4_33_46.avi',
        'A bird in a sink keeps getting under the running water from a faucet.'],
       ['mv89psg6zh4_33_46.avi', 'A bird is bathing in a sink.'],
       ['mv89psg6zh4_33_46.avi',
        'A bird is splashing around under a running faucet.'],
       ...,
       ['m7x8uIdg2XU_67_73.avi',
        'The lady added a cream sauce to the pasta.'],
       ['m7x8uIdg2XU_67_73.avi', 'women are cooking her kichen'],
       ['m7x8uIdg2XU_67_73.avi',
        'The woman is pouring cream over the pasta.']], dtype=object)

In [117]:
train = np.load('../../data/train.npy')
val = np.load('../../data/val.npy')
test = np.load('../../data/test.npy')

In [43]:
with open('./runs/lr00002wc8m02wd100_cpv20rand_wp/model_best.pth.tar', 'rb') as f:
    data = pickle.load(f)

In [2]:
data = torch.load('./runs/lr00002wc8m02wd100_cpv20rand_wp/model_best.pth.tar')

In [3]:
model = data['model']
opt = data['opt']
opt.weight_decay = 0

In [4]:
vse = VSE(opt)
vse.load_state_dict(model)

In [5]:
new_state_dict = vse.state_dict()

In [6]:
vse.load_state_dict(new_state_dict)

## Search the best model

In [8]:
OBJ_DIR = 'E:/gs_obj'
ACT_DIR = 'E:/gs_act'

In [40]:
def collect_results(base_dir, result_name = 'val.stat'):
    model_names = os.listdir(base_dir)
    
    df = pd.DataFrame(columns=['name', 'epoch', 'rsum', 
                      'r1', 'r5', 'r10', 'medr', 'meanr', 
                      'r1i', 'r5i', 'r10i', 'medri', 'meanri'], dtype='float')
    df['name'] = df['name'].astype('str')
    df['epoch'] = df['epoch'].astype('int')
    
    rows = []
    for model_name in model_names:
        row = {'name': model_name}
        with open(os.path.join(base_dir, model_name, result_name), 'r') as f:
            for l in f:
                line = l.split('=')
                row[line[0]] = float(line[1])

        row['epoch'] = int(row['epoch'])
        rows.append(row)
    df = df.append(rows, ignore_index=True)
    return df

In [41]:
obj_val = collect_results(OBJ_DIR, result_name='val.stat')

In [48]:
obj_val.sort_values(by='rsum', ascending=False).head(5)

Unnamed: 0,name,epoch,rsum,r1,r5,r10,medr,meanr,r1i,r5i,r10i,medri,meanri
382,objcpv20m0.6wc1e-08wd100es1280lr0.1,37,75.8,1.333333,6.666667,12.0,69.0,122.36,6.133333,18.333333,31.333333,19.0,24.000667
613,objcpv20m0.4wc1e-16wd100es1280lr1e-05,40,71.4,8.0,14.666667,24.0,35.0,111.56,1.266667,8.6,14.866667,31.0,34.675333
731,objcpv20m1.0wc1e-08wd100es1024lr1e-05,40,70.733333,4.0,13.333333,21.333333,46.0,80.426667,2.733333,9.866667,19.466667,30.0,32.894
464,objcpv20m0.8wc1e-16wd100es1280lr0.1,37,67.133333,1.333333,6.666667,13.333333,64.0,127.266667,4.133333,14.133333,27.533333,25.0,29.633333
81,objcpv20m0.4wc0wd100es1024lr1e-05,13,67.066667,2.666667,10.666667,22.666667,74.0,148.546667,2.133333,10.133333,18.8,36.0,36.018667


In [49]:
act_val = collect_results(ACT_DIR, result_name='val.stat')

In [55]:
act_val.sort_values(by='rsum', ascending=False).head(10)

Unnamed: 0,name,epoch,rsum,r1,r5,r10,medr,meanr,r1i,r5i,r10i,medri,meanri
391,actcpv20m0.8wc1e-12wd100es1536lr0.0001,39,271.326531,31.632653,60.204082,74.489796,3.0,13.806122,14.591837,38.010204,52.397959,9.0,19.265816
213,actcpv20m0.4wc1e-16wd100es1536lr0.0001,38,269.642857,31.632653,57.142857,68.367347,4.0,17.969388,18.520408,39.642857,54.336735,9.0,17.448469
169,actcpv20m1.0wc1e-16wd100es768lr0.001,40,261.734694,31.632653,56.122449,66.326531,4.0,16.27551,14.336735,39.081633,54.234694,9.0,16.415306
687,actcpv20m0.6wc1e-08wd100es256lr0.001,32,261.530612,26.530612,60.204082,70.408163,4.0,12.928571,11.632653,36.428571,56.326531,9.0,15.915306
591,actcpv20m0.4wc0wd100es1536lr0.0001,39,260.765306,30.612245,52.040816,68.367347,5.0,14.418367,15.816327,39.336735,54.591837,9.0,16.967857
223,actcpv20m0.8wc1e-08wd100es1024lr0.0001,38,259.744898,26.530612,57.142857,71.428571,5.0,18.908163,14.030612,38.265306,52.346939,10.0,18.595408
554,actcpv20m1.0wc1e-08wd100es1536lr0.0001,37,259.234694,29.591837,57.142857,69.387755,4.0,22.653061,14.642857,37.397959,51.071429,10.0,17.770918
671,actcpv20m0.6wc1e-08wd100es1024lr0.0001,39,257.653061,29.591837,56.122449,66.326531,4.0,20.561224,13.826531,36.785714,55.0,9.0,16.991837
676,actcpv20m0.6wc1e-08wd100es1280lr0.0001,35,257.602041,31.632653,57.142857,69.387755,4.0,18.336735,12.704082,36.071429,50.663265,10.0,19.511224
426,actcpv20m0.4wc1e-08wd100es1536lr0.0001,39,256.989796,31.632653,57.142857,65.306122,4.0,17.183673,15.05102,37.091837,50.765306,10.0,19.285204


In [52]:
act_test = collect_results(ACT_DIR, result_name='test.stat')

In [54]:
act_test.sort_values(by='rsum', ascending=False).head(10)

Unnamed: 0,name,epoch,rsum,r1,r5,r10,medr,meanr,r1i,r5i,r10i,medri,meanri
591,actcpv20m0.4wc0wd100es1536lr0.0001,39,96.502276,6.525038,20.182094,29.893778,32.0,122.855842,4.294385,13.92261,21.68437,52.0,106.500303
72,actcpv20m1.0wc1e-16wd100es1280lr0.0001,40,90.455235,5.159332,19.271624,28.376328,32.0,133.83915,3.846737,13.376328,20.424886,58.0,113.718134
213,actcpv20m0.4wc1e-16wd100es1536lr0.0001,38,90.424886,5.614568,18.816388,28.679818,32.0,126.766313,3.998483,13.277693,20.037936,55.0,109.169347
152,actcpv20m0.8wc1e-16wd100es1280lr0.0001,40,89.764795,6.069803,19.119879,30.045524,35.0,121.323217,3.550835,11.98786,18.990895,56.0,111.731942
500,actcpv20m0.4wc1e-12wd100es1536lr0.0001,40,88.945372,5.614568,20.637329,33.23217,26.0,107.681335,2.587253,10.394537,16.479514,68.0,123.174583
391,actcpv20m0.8wc1e-12wd100es1536lr0.0001,39,88.292868,5.614568,18.816388,28.376328,35.0,114.899848,3.505311,12.503794,19.47648,60.0,112.868892
558,actcpv20m0.2wc0wd100es1536lr0.0001,38,87.890744,4.248862,19.423369,30.045524,29.0,126.735964,3.338392,12.025797,18.808801,59.0,111.51176
426,actcpv20m0.4wc1e-08wd100es1536lr0.0001,39,87.253414,5.462822,18.816388,28.376328,32.0,113.400607,3.277693,12.207891,19.112291,65.0,116.898027
711,actcpv20m0.6wc1e-12wd100es1536lr0.0001,30,87.07132,4.552352,19.423369,30.197269,30.0,124.921093,3.725341,11.396055,17.776935,64.0,118.276024
626,actcpv20m0.8wc1e-08wd100es1536lr0.0001,36,86.919575,5.766313,19.878604,30.500759,31.0,119.515933,2.890744,10.493171,17.389985,68.0,124.031563
