In [None]:
import io
import h5py
import json
import torch
from matplotlib import pyplot
import PIL.Image
import numpy as np

from collections import defaultdict as ddict


%matplotlib inline

In [34]:
data_root = '/data/milatmp1/bahdanau/flatqa-letters/long-tail/'
model_output_path = '/u/murtyjay/nmn-iwp/outputs/output_mac4.h5'

def parse_dat(data_root, part):
  features_path = data_root + part + 'features.h5'
  with h5py.File(features_path) as src:
    features = src['features'][:]
    
  data_path = data_root + part + 'questions.h5'
  with h5py.File(data_path) as src:
    questions = src['questions'][:]
    answers = src['answers'][:]
    image_idxs = src['image_idxs'][:]
  
  with open(data_root + 'vocab.json') as src:
    vocab = json.load(src)
    question_idx_to_token = {v: k for k, v in vocab['question_token_to_idx'].items()}
    answer_idx_to_token = {v: k for k, v in vocab['answer_token_to_idx'].items()}  
  return features, questions, answers, image_idxs, vocab, question_idx_to_token, answer_idx_to_token

def get_question_distribution(train_questions):
  relation_count = ddict(int)
  object_count = ddict(int)
  object_pairwise_count = ddict(int)

  for question in train_questions:
    relation_count[question[5]] += 1.0
    object_count[(question[3], question[4])] += 1.0
    object_count[(question[6], question[7])] += 1.0
    object_pairwise_count[(question[3], question[4], question[6], question[7])] += 1.0

  return relation_count, object_count, object_pairwise_count

train_features, train_questions, train_answers, train_image_idxs, vocab, question_idx_to_token, answer_idx_to_token = parse_dat(data_root, 'train_')
relation_count, object_count, object_pairwise_count = get_question_distribution(train_questions)

In [35]:
def pprint_question(questions, vocab):
  for q in questions:
    print(' '.join([vocab[idx] for idx in q]))

print(pprint_question(train_questions[:10], question_idx_to_token))
print(relation_count)





def print_image(img):
  image = np.array(PIL.Image.open(io.BytesIO(features[img]) ))
  pyplot.figure(figsize=(5, 5))
  pyplot.imshow(image, origin='lower')
  pyplot.show()

    
    
features, questions, answers, image_idxs, _, _, _ = parse_dat(data_root, part)


  


def get_confusion_matrix(model_path, data_root):
  with h5py.File(model_path) as src:
    correct = src['correct'].value
    
  features, questions, answers, image_idxs, vocab, question_idx_to_token, answer_idx_to_token = parse_dat(data_root,'val_')
  confusion = {'TP' : 0.0, 'FP' : 0.0, 'FN' : 0.0, 'TN' : 0.0}
  acc = 0.0
  for i in range(1000):

    if answer_idx_to_token[answers[i]] == 'false':
      confusion['TN'] += 1.0
      if not correct[i]: confusion['FP'] += 1.0 
          
    else:
      confusion['TP'] += 1.0
      if not correct[i]: 
        confusion['FN'] += 1.0
        #print_image(image_idxs[i])
        #question_tokens = ' '.join([question_idx_to_token[idx] for idx in questions[i]])
        #print(question_tokens,  answer_idx_to_token[answers[i]])
  
        
    if correct[i]: 
      acc += 1.0
      continue

  return confusion


def get_pr(confusion_matrix):
    p = confusion_matrix['TP'] / (confusion_matrix['TP'] + confusion_matrix['FP'])
    r = confusion_matrix['TP'] / (confusion_matrix['TP'] + confusion_matrix['FN'])
    return p, r

is there a blue D right_of blue A
is there a blue D below purple A
is there a yellow A left_of purple C
is there a gray A left_of purple B
is there a purple N below yellow C
is there a brown C right_of purple A
is there a green E below purple D
is there a purple D right_of yellow F
is there a yellow D below gray C
is there a red F left_of yellow D
None
defaultdict(<class 'int'>, {40: 250051.0, 41: 250109.0, 42: 249960.0, 43: 249880.0})


In [25]:
# CONFUSION MATRIX ANALYSIS

model_output_path = '/u/murtyjay/nmn-iwp/outputs/output_mac'

for i in range(1, 6):
    print(i)
    file="%s%s.h5" %(model_output_path, i)
    print(file)
    cm = get_confusion_matrix(file , data_root )
    print(get_pr(cm))


1
/u/murtyjay/nmn-iwp/outputs/output_mac1.h5
(0.7668711656441718, 0.9823182711198428)
2
/u/murtyjay/nmn-iwp/outputs/output_mac2.h5
(0.7874015748031497, 0.9689922480620154)
3
/u/murtyjay/nmn-iwp/outputs/output_mac3.h5
(0.78125, 0.9920634920634921)
4
/u/murtyjay/nmn-iwp/outputs/output_mac4.h5
(0.8561643835616438, 0.9727626459143969)
5
/u/murtyjay/nmn-iwp/outputs/output_mac5.h5
(0.7923930269413629, 0.9523809523809523)


In [None]:
# OOV analysis 

model_output_path = '/u/murtyjay/nmn-iwp/outputs/output_film'
def get_object_counts(model_path, data_root):
  with h5py.File(model_path) as src:
    correct = src['correct'].value
    
  features, questions, answers, image_idxs, vocab, question_idx_to_token, answer_idx_to_token = parse_dat(data_root,'val_')
  object_pair_counts = []
  inv_object_pair_counts = []
  for i in range(1000): 
    question = questions[i]
    if not correct[i]: 
      lobject = (question[3], question[4])
      robject = (question[6], question[7])
      object_pair = (*lobject, *robject)
      inv_object_pair = (*robject, *lobject)
      object_pair_counts.append(object_pairwise_count[object_pair])
      inv_object_pair_counts.append(object_pairwise_count[inv_object_pair])
    

  return object_pair_counts,inv_object_pair_counts   

from collections import Counter

for i in range(1, 6):
    print(i)
    file="%s%s.h5" %(model_output_path, i)
    oc, inv_oc = get_object_counts(file , data_root )
    #oc = Counter(oc)
    #inv_oc = Counter(inv_oc)
    
    #bins, items = (list(zip(*sorted(oc.items()))))
    #print(bins, items)
    #pyplot.hist(items, [int(bin) for bin in bins])
    print(oc)
    print(inv_oc)
    #pyplot.show()
    
