In [12]:
import h5py
from tqdm import tqdm, trange

from chainer import cuda
from skimage import img_as_float
from skimage.transform import resize
import numpy as np
import cv2

cuda.Device(3).use()
class Extractor:    
    def __init__(self,model):
        self.model = model
        model.to_gpu()
    def transform(self,f):
        #mean = np.array([103.939, 116.779, 123.68])
        img = cv2.imread(f).astype(np.float32)
        #img -= mean
        img = cv2.resize(img, (224, 224)).transpose((2, 0, 1))
        img = img[np.newaxis, :, :, :]
        return img
    def get_features(self,fs):
        x = Variable(np.vstack([self.transform(f) for f in fs]), volatile=True)
        x.to_gpu()
        
        conv = resnet(x,None)
        return conv.data.get()
    
# https://github.com/yasunorikudo/chainer-ResNet
from chainer import serializers, Variable
from resnet.ResNet101 import ResNet
modelName = 'resnet101'
resnet = ResNet()
serializers.load_hdf5('resnet/ResNet101.model', resnet)
extrator = Extractor(resnet)

In [13]:
from tqdm import tqdm, trange
def process_features(filenames, outname, batchsize=32):
    total = len(filenames)
    with h5py.File(outname,'w') as f:
        feats_conv = f.create_dataset('feats_conv', (total,49,2048))
        for i in trange(0,total,batchsize):
            xs = filenames[i:i+batchsize]
            size = len(xs)
            feature_conv = extrator.get_features(xs)
            feats_conv[i:i+size] = feature_conv.reshape(size,2048,-1).transpose((0,2,1))

In [14]:
import json
flickr8k = json.load(open('json/dataset_flickr8k.json', 'r'))
train = [item for item in flickr8k['images'] if item['split']=='train']
val = [item for item in flickr8k['images'] if item['split']=='val']
test = [item for item in flickr8k['images'] if item['split']=='test']
data_path = '../data/flickr8k/Flicker8k_Dataset/'

In [15]:
process_features([data_path+train[i]['filename'] for i in range(len(train))], 'train_flickr8k_'+modelName+'.h5', batchsize=128)
process_features([data_path+test[i]['filename'] for i in range(len(test))], 'test_flickr8k_'+modelName+'.h5', batchsize=128)
process_features([data_path+val[i]['filename'] for i in range(len(val))], 'val_flickr8k_'+modelName+'.h5', batchsize=128)

100%|██████████| 47/47 [01:27<00:00,  1.21s/it]
100%|██████████| 8/8 [00:12<00:00,  1.52s/it]
100%|██████████| 8/8 [00:13<00:00,  1.64s/it]


In [16]:
import json
with h5py.File('train_flickr8k_'+modelName+'.h5') as f:
    print f['feats_conv'].shape
    f.attrs['sents'] = unicode(json.dumps(train))
with h5py.File('test_flickr8k_'+modelName+'.h5') as f:
    print f['feats_conv'].shape
    f.attrs['sents'] = unicode(json.dumps(test))
with h5py.File('val_flickr8k_'+modelName+'.h5') as f:
    print f['feats_conv'].shape
    f.attrs['sents'] = unicode(json.dumps(val))

(6000, 49, 2048)
(1000, 49, 2048)
(1000, 49, 2048)
