Skip to content

Commit

Permalink
update resnet readme, add basic vqa
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 22, 2016
1 parent f67dc2a commit b315a1a
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ docs/_build/

# PyBuilder
target/
*.dat
3 changes: 3 additions & 0 deletions examples/ResNet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@
Implements the paper "Deep Residual Learning for Image Recognition", [http://arxiv.org/abs/1512.03385](http://arxiv.org/abs/1512.03385)
with the variants proposed in "Identity Mappings in Deep Residual Networks", [https://arxiv.org/abs/1603.05027](https://arxiv.org/abs/1603.05027).

The train error shown here is a moving average of the error rate of each batch in training.
The validation error here is computed on test set.

![cifar10](https://github.com/ppwwyyxx/tensorpack/raw/master/examples/ResNet/cifar10-resnet.png)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pillow
scipy
tqdm
h5py
nltk
70 changes: 70 additions & 0 deletions tensorpack/dataflow/dataset/visualqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: visualqa.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>

from ..base import DataFlow
from six.moves import zip, map
from collections import Counter
import json

__all__ = ['VisualQA']

# TODO shuffle
class VisualQA(DataFlow):
"""
Visual QA dataset. See http://visualqa.org/
Simply read q/a json file and produce q/a pairs in their original format.
"""
def __init__(self, question_file, annotation_file):
qobj = json.load(open(question_file))
self.task_type = qobj['task_type']
self.questions = qobj['questions']
self._size = len(self.questions)

aobj = json.load(open(annotation_file))
self.anno = aobj['annotations']
assert len(self.anno) == len(self.questions), \
"{}!={}".format(len(self.anno), len(self.questions))
self._clean()

def _clean(self):
for a in self.anno:
for aa in a['answers']:
del aa['answer_id']

def size(self):
return self._size

def get_data(self):
for q, a in zip(self.questions, self.anno):
assert q['question_id'] == a['question_id']
yield [q, a]

def get_common_answer(self, n):
""" Get the n most common answers (could be phrases) """
cnt = Counter()
for anno in self.anno:
cnt[anno['multiple_choice_answer']] += 1
return [k[0] for k in cnt.most_common(n)]

def get_common_question_words(self, n):
"""
Get the n most common words in questions
"""
from nltk.tokenize import word_tokenize # will need to download 'punckt'
cnt = Counter()
for q in self.questions:
cnt.update(word_tokenize(q['question'].lower()))
del cnt['?'] # probably don't need this
ret = cnt.most_common(n)
return [k[0] for k in ret]

if __name__ == '__main__':
vqa = VisualQA('/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json',
'/home/wyx/data/VQA/mscoco_train2014_annotations.json')
for k in vqa.get_data():
#print json.dumps(k)
break
vqa.get_common_question_words(100)
#from IPython import embed; embed()

0 comments on commit b315a1a

Please sign in to comment.