In [54]:
#libraries
import json
from collections import defaultdict
import time
import csv
import os

In [55]:
#import json files from folder
BASE_PATH = "/nas/home/slnagark/Genome/datasets/"
# Write tsv to this destination
PATH_EXPORT = "/nas/home/slnagark/Genome/output.tsv"

def get_region_graphs():
    with open(os.path.join(BASE_PATH, 'region_graphs.json')) as f:
        region_graphs = json.load(f)
    return region_graphs

def get_relationship_synsets():
    with open(os.path.join(BASE_PATH, 'relationship_synsets.json')) as f:
        relationship_synets = json.load(f)
    return relationship_synets

def get_question_answers():
    with open(os.path.join(BASE_PATH, 'question_answers.json')) as f:
        q_a = json.load(f)
    return q_a

def get_mapping_qa_region():
    with open(os.path.join(BASE_PATH,'qa_to_region_mapping.json')) as f:
        map_ = json.load(f)
    return map_

def get_json_data():
    """
    returns data from following files:
        region_graphs.json
        relationship_synsets.json
        attributes.json
        question_answers.json
        qa_to_region_mapping.json
    """
    start = time.time()
    map_ = get_mapping_qa_region()
    rel_syns_ = get_relationship_synsets()
    q_a_ = get_question_answers()
    region_graphs_ = get_region_graphs()
    
    print("Successfully loaded data from json files!")
    print("Time taken: {0:.3f} minutes".format((time.time()-start)/60.))
    
    return map_, rel_syns_, q_a_, region_graphs_

def region_to_question_mappings(qa_region_mappings):
    """
    input: question id to region id mappings
    output: region id to question id mappings
    """
    
    #Initialize region id to question id mappings:
    r_id2qid = defaultdict(list)
    
    for q_id, reg_id in qa_region_mappings.items():
        r_id2qid[int(reg_id)].append(int(q_id))
    
    return r_id2qid

def get_info_from_region_graphs(region_graphs):
    """
    input: region_graphs.json

    output: dictionary of list of dictionaries of regions
            {
              image_id:  
              [
                  {
                    "region_id" : int,
                    "phrase" : str,
                    "relationships" : {
                                        "rel_id" : relationship id,
                                        "rel_name" : relationship name,
                                        "rel_syns" : relationship synsets,
                                        "sub_id" : subject id,
                                        "sub_name" : suject name,
                                        "sub_syns" : subject synset,
                                        "ob_id" : object id,
                                        "ob_name" : object name,
                                        "ob_syns" : object synset
                                      }
                  ....
              ]
            }
    """
    rgraphs_image_to_relation = {}
    
    for ind, reg_graphs in enumerate(region_graphs):
        region_data = []
        region_info = []
        image_id = int(reg_graphs["image_id"])
        
        
        for regions in reg_graphs['regions']:
            rel = {}
            reg_info = {}
            repitions = []
            
            # object id to name mapping
            r_objid2name = defaultdict(list)
            # object id to synset mapping
            r_objid2syns = defaultdict(list)
                
            region_id = int(regions["region_id"])
            phrase = regions["phrase"]      
            
            #stores information of individual regions
            reg_info["region_id"] = region_id
            reg_info["phrase"] = phrase
            
            # loading object id to name and synset mappings
            for obj in regions["objects"]:
                ob_id = int(obj["object_id"])
                if ob_id in r_objid2name.keys() or obj["synsets"] == []:
                    continue
              
                r_objid2name[ob_id] = obj["name"]
                r_objid2syns[ob_id] = obj["synsets"]
            
            # Go to next region if relationships parameter is missing in the current region (skip the current region)
            if regions["relationships"] == []:
                continue
                
            for relations in regions["relationships"]:
            
                rel_id = int(relations["relationship_id"])
                obj_id = int(relations["object_id"])
                sub_id = int(relations["subject_id"])
                
                # if relation synsets field is empty get those synsets from relation_synset.json (rel_syns_)
                if relations["synsets"] == []:
                    relations["synsets"] = rel_syns_[relations["predicate"]].lower()
                
                # skip relation if object doesnot have a name or synset
                if r_objid2name[sub_id] == [] or r_objid2name[obj_id] == []:
                    continue
                
                # handling multiple entries of same relations
                tup = (relations["synsets"], r_objid2syns[sub_id], r_objid2syns[obj_id])
                if tup in repitions:
                    continue
                repitions.append(tup)
                
                info = {
                        "rel_id" : rel_id,
                        "rel_name" : relations["predicate"].lower(),
                        "rel_syns" : relations["synsets"][0],
                        "sub_id" : sub_id,
                        "sub_name" : r_objid2name[sub_id],
                        "sub_syns" : r_objid2syns[sub_id][0],
                        "ob_id" : obj_id,
                        "ob_name" : r_objid2name[obj_id],
                        "ob_syns" : r_objid2syns[obj_id][0]
                    }

                rel[rel_id] = info
                
            if not rel:
                continue
                
            reg_info["relationships"] = rel    
            region_data.append(reg_info)
            
        rgraphs_image_to_relation[image_id] = region_data
        
        if ind == 9:
            break
            
    return rgraphs_image_to_relation

def get_info_from_q_a(ques_ans):
    """
    reads question_answers.json and extracts question id, question, answer and type of question (location, time, event)
    input: data from question_answers.json
    output: question id to dict(type, question, answer) mappings

    """
    q_id2params = {}
    
    for ind, qa in enumerate(ques_ans):

        image_id = qa["id"]
        
        for qas in qa["qas"]:
            
            qa_id = int(qas["qa_id"])
            question = qas["question"]
            answer = qas["answer"]
            
            q = question.split(" ")
            
            if q[0].lower() == "where":
                q_id2params[qa_id] = {
                    "type" : "location",
                    "question" : question,
                    "answer" : answer
                }
  
            elif q[0].lower() =="when":
                q_id2params[qa_id] = {
                    "type" : "time",
                    "question" : question,
                    "answer" : answer
                } 
   
            elif "doing" in question:
                q_id2params[qa_id] = {
                    "type" : "event",
                    "question" : question,
                    "answer" : answer
                }

        if ind ==9:
            break
            
    return q_id2params
    
def write_to_tsv(q_id2params, rgraphs_image_to_relation):
    """
    generates tsv file.
    
    """
    with open(os.path.join(PATH_EXPORT), 'wt') as out_file:
        tsv_writer = csv.writer(out_file, delimiter='\t')
        tsv_writer.writerow(['image_id', 'region_id','rel_id' ,'relationship', 'subject', 
                             'object','subject_synset','object_synset', 'relationship_synset', 'sentence', 
                            'location', 'time', 'event'])

        for ind in range(10):
            for reg in rgraphs_image_to_relation[ind+1]:
                # get region id
                reg_id = reg["region_id"]  
                loc = []
                time = []
                event = []
                
                # get questions and answers from the region
                for q_ids in r_id2qid[reg_id]:  

                    if q_ids in q_id2params.keys():
                        if q_id2params[q_ids]["type"] == "location":
                            loc.append({"question": q_id2params[q_ids]["question"], "answer":q_id2params[q_ids]["answer"]})
                        elif q_id2params[q_ids]["type"] == "time":
                            time.append({"question": q_id2params[q_ids]["question"], "answer":q_id2params[q_ids]["answer"]})
                        elif q_id2params[q_ids]["type"] == "event":
                            event.append({"question": q_id2params[q_ids]["question"], "answer":q_id2params[q_ids]["answer"]})   

                # get relationship details
                for ids in reg["relationships"].keys():
                    rel = reg["relationships"][ids]

                    output = {
                        "r_id" : reg["region_id"],
                        "rel_id": rel['rel_id'],
                        "relationship" : rel['rel_name'],
                        "subject" : rel["sub_name"],
                        "object" : rel["ob_name"],
                        "subject_synset" : rel["sub_syns"],
                        "object_synset" : rel["ob_syns"],
                        "r_synset" : rel["rel_syns"],
                        "phrase" : reg["phrase"],
                    }


                    tsv_writer.writerow([
                        ind+1, output["r_id"],output['rel_id'],output["relationship"],output["subject"],output["object"],
                        output["subject_synset"], output["object_synset"],output["r_synset"], output["phrase"],
                        loc, time, event
                    ])
        out_file.close()
        print("Successfully generated output.tsv")

def main():
    qa_region_mappings, rel_syns, ques_ans, region_graphs = get_json_data()
    r_id2qid = region_to_question_mappings(qa_region_mappings)
    rgraphs_image_to_relation = get_info_from_region_graphs(region_graphs)
    q_id2params = get_info_from_q_a(ques_ans)
    write_to_tsv(q_id2params, rgraphs_image_to_relation)


In [48]:
# if __name__ == "__main__":
#     main()

In [42]:
qa_region_mappings, rel_syns, ques_ans, region_graphs = get_json_data()

Successfully loaded data from json files!
Time taken: 2.057 minutes


In [56]:
r_id2qid = region_to_question_mappings(qa_region_mappings)

In [57]:
rgraphs_image_to_relation = get_info_from_region_graphs(region_graphs)

In [58]:
q_id2params = get_info_from_q_a(ques_ans)

In [59]:
write_to_tsv(q_id2params, rgraphs_image_to_relation)

Successfully generated output.tsv
