Skip to content

Commit

Permalink
lut & update vqa
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 22, 2016
1 parent b315a1a commit d5fe531
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 17 deletions.
1 change: 1 addition & 0 deletions scripts/dump_train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
if idx > NR_DP_TEST:
break
pbar.update()
from IPython import embed; embed()



2 changes: 1 addition & 1 deletion tensorpack/dataflow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, ds, batch_size, remainder=False):
"""
Group data in `ds` into batches.
:param ds: a DataFlow instance
:param ds: a DataFlow instance. Its component must be either a scalar or a numpy array
:param remainder: whether to return the remaining data smaller than a batch_size.
If set True, will possibly return a data point of a smaller 1st dimension.
Otherwise, all generated data are guranteed to have the same size.
Expand Down
33 changes: 19 additions & 14 deletions tensorpack/dataflow/dataset/visualqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>

from ..base import DataFlow
from ...utils import *
from six.moves import zip, map
from collections import Counter
import json
Expand All @@ -17,16 +18,17 @@ class VisualQA(DataFlow):
Simply read q/a json file and produce q/a pairs in their original format.
"""
def __init__(self, question_file, annotation_file):
qobj = json.load(open(question_file))
self.task_type = qobj['task_type']
self.questions = qobj['questions']
self._size = len(self.questions)
with timed_operation('Reading VQA JSON file'):
qobj = json.load(open(question_file))
self.task_type = qobj['task_type']
self.questions = qobj['questions']
self._size = len(self.questions)

aobj = json.load(open(annotation_file))
self.anno = aobj['annotations']
assert len(self.anno) == len(self.questions), \
"{}!={}".format(len(self.anno), len(self.questions))
self._clean()
aobj = json.load(open(annotation_file))
self.anno = aobj['annotations']
assert len(self.anno) == len(self.questions), \
"{}!={}".format(len(self.anno), len(self.questions))
self._clean()

def _clean(self):
for a in self.anno:
Expand All @@ -42,15 +44,17 @@ def get_data(self):
yield [q, a]

def get_common_answer(self, n):
""" Get the n most common answers (could be phrases) """
""" Get the n most common answers (could be phrases)
n=3000 ~= thresh 4
"""
cnt = Counter()
for anno in self.anno:
cnt[anno['multiple_choice_answer']] += 1
return [k[0] for k in cnt.most_common(n)]

def get_common_question_words(self, n):
"""
Get the n most common words in questions
""" Get the n most common words in questions
n=4600 ~= thresh 6
"""
from nltk.tokenize import word_tokenize # will need to download 'punckt'
cnt = Counter()
Expand All @@ -64,7 +68,8 @@ def get_common_question_words(self, n):
vqa = VisualQA('/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json',
'/home/wyx/data/VQA/mscoco_train2014_annotations.json')
for k in vqa.get_data():
#print json.dumps(k)
print json.dumps(k)
break
vqa.get_common_question_words(100)
# vqa.get_common_question_words(100)
vqa.get_common_answer(100)
#from IPython import embed; embed()
1 change: 1 addition & 0 deletions tensorpack/dataflow/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from six.moves import range
from .base import ProxyDataFlow
from ..utils.concurrency import ensure_procs_terminate
from ..utils import logger

__all__ = ['PrefetchData']

Expand Down
13 changes: 13 additions & 0 deletions tensorpack/utils/lut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: lut.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>

import six

__all__ = ['LookUpTable']

class LookUpTable(object):
def __init__(self, objlist):
self.idx2obj = dict(enumerate(objlist))
self.obj2idx = {v : k for k, v in six.iteritems(self.idx2obj)}
4 changes: 2 additions & 2 deletions tensorpack/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from . import logger

__all__ = ['timed_operation', 'change_env', 'get_rng', 'memoized',
'get_nr_gpu']
__all__ = ['timed_operation', 'change_env',
'get_rng', 'memoized', 'get_nr_gpu']

#def expand_dim_if_necessary(var, dp):
# """
Expand Down

0 comments on commit d5fe531

Please sign in to comment.