-
Notifications
You must be signed in to change notification settings - Fork 1
/
save_dic.py
49 lines (39 loc) · 1.33 KB
/
save_dic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from os import listdir
import json
import cPickle
def rename_train(s):
return 'COCO_train2014_'+'0'*(12-len(s))+s+'.jpg'
def rename_val(s):
return 'COCO_val2014_'+'0'*(12-len(s))+s+'.jpg'
dic={}
with open('/tmp3/alvin/dataset/MSCOCO2014/captions_train2014.json') as fp:
data = json.load(fp)
cnt=0
for i in range(len(data['annotations'])):
s = rename_train(str(data['annotations'][i]['image_id']))
if s not in dic:
dic[s]=[]
dic[s].append(data['annotations'][i]['caption'].replace('\n','').replace('"',''))
cnt+=1
else:
#if len(dic[s])==5:
# continue
dic[s].append(data['annotations'][i]['caption'].replace('\n','').replace('"',''))
cnt+=1
print cnt
with open('/tmp3/alvin/dataset/MSCOCO2014/captions_val2014.json') as fp:
data = json.load(fp)
for i in range(len(data['annotations'])):
s = rename_val(str(data['annotations'][i]['image_id']))
if s not in dic:
dic[s]=[]
dic[s].append(data['annotations'][i]['caption'].replace('\n','').replace('"',''))
cnt+=1
else:
#if len(dic[s])==5:
# continue
dic[s].append(data['annotations'][i]['caption'].replace('\n','').replace('"',''))
cnt+=1
print cnt
with open('capdict.pkl','wb') as fp:
cPickle.dump(dic,fp,protocol=cPickle.HIGHEST_PROTOCOL)