# Data Inspections Demo

In [None]:
import csv
import os
import sys
import re
import string
import json
import xml.etree.ElementTree as ET
from IPython.display import Image
import random
import numpy as np
import glob
import shutil
from termcolor import colored


from IPython.display import HTML as html_print
from IPython.display import Markdown

from transformers import AutoTokenizer

from nltk.tokenize import word_tokenize
from nltk.tokenize import sent_tokenize

import cv2
import pandas as pd
from textwrap import wrap

import warnings
warnings.filterwarnings('ignore')

import seaborn as sns
sns.set()

from matplotlib import pyplot as plt
import matplotlib.patches as patches
%pylab inline

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

## Some Useful Functions

In [None]:
def show_image_with_path(img_path, title=None, title_max_len=None, img_size=6):
    img = cv2.imread(img_path)
    # print(img.shape)
    plt.figure(figsize=(img_size, img_size))
    plt.imshow(img[:, :, ::-1])
    plt.axis('off')
    if title is not None:
        if title_max_len is not None:
            title = title[:title_max_len] + "... ..."
        title = "\n".join(wrap(title, 60))
        plt.title(title)
    pass ####


def show_one_sampled_data(data, scrambled=False, resized_version=False, 
                          show_label=False, unimodal=None, img_size=6,
                          title_max_len=None, show_url=False, id_mappings=None,
                          step_id=None, order=None):
    if step_id is not None:
        step_id -= 1

    if unimodal is not None:
        assert unimodal in ["image", "text"]
    
    text_seq = data.text_seq[:]
    img_path_seq = data.img_path_seq[:]
    idx_seq = np.arange(len(text_seq))
    
    if show_url:
        data_guid = data.guid
        if len(data_guid.split("###")) > 1:
            url, title = data_guid.split("###")[0], data_guid.split("###")[1]
        else:
            url = data_guid
        print(url)
    
    if scrambled:
        np.random.shuffle(idx_seq)
        # text_seq = [text_seq[x] for x in idx_seq]
        # img_path_seq = [img_path_seq[x] for x in idx_seq]
        idx_seq_to_sort = idx_seq[:]
        arg_sort_idx_seq = np.argsort(idx_seq_to_sort)
        label = list(arg_sort_idx_seq + 1)
        if show_label:
            print("Label: {}".format(arg_sort_idx_seq + 1))

    if order is not None:
        order = [x-min(order) for x in order]
        idx_seq = order
            
    for seq_idx in idx_seq:
        if step_id is not None:
            if seq_idx != step_id:
                continue
        if show_label:
            seq_idx_in_title = str(seq_idx + 1) + ". "
        else:
            seq_idx_in_title = ""
        text = text_seq[seq_idx]
        text = seq_idx_in_title + text
        img_path = img_path_seq[seq_idx]
        if not resized_version:
            img_path = img_path.replace("jpg_resized_256", "jpg")
        if unimodal == "image":
            show_image_with_path(img_path, title=None,
                                 title_max_len=title_max_len,
                                 img_size=img_size)
        elif unimodal == "text":
            img = cv2.imread(img_path)
            plt.figure(figsize=(6, 6))
            plt.imshow(np.zeros(img.shape))
            plt.axis('off')
            title = "\n".join(wrap(text, 60))
            plt.title(title)
        else:
            show_image_with_path(img_path, title=text,
                                 title_max_len=title_max_len,
                                 img_size=img_size)
    
    if scrambled:
        return label
    
    return list(range(len(text_seq)))

## WikiHow Data Demo

### Read In WikiHow Data

In [None]:
from datasets.wikihow import WikiHowGeneralProcessor

version_text_to_use = "acl22"

wikihow_processor = WikiHowGeneralProcessor(version_text=version_text_to_use)
data_wikihow_train = wikihow_processor.get_train_examples()
data_wikihow_dev = wikihow_processor.get_dev_examples()
data_wikihow_test = wikihow_processor.get_test_examples()
data_wikihow = data_wikihow_train + data_wikihow_dev + data_wikihow_test

print("Total Valid Data Sequences: {}".format(len(data_wikihow)))

### WikiHow Category Information (Can Skip)

In [None]:
categories_to_exclude = [                                                      
    "Youth",                                                                   
    "Relationships",                                                           
    "Family Life",                                                             
    "Holidays and Traditions",                                                 
    "Personal Care and Style",                                                 
    "Philosophy and Religion",                                                 
    "Screenplays",                                                             
    "Health",                                                                  
    "Work World",                                                              
    "Root",                                                                    
]

# URL to data mappings
def get_url_data_mappings(data):
    url_mappings = {}
    for i in range(len(data)):
        datum = data[i]
        url = datum.guid.split("###")[0]
        url_mappings[url] = i
    return url_mappings

# Obtaining url mappings
url_data_mappings = get_url_data_mappings(data_wikihow)

# Get category mappings
def read_in_wikihow_categories(url_data_mappings, cat_path=None, cat_level=1):
    if cat_path is None:
        json_f = "data/wikihow/wikihow-categories-output.json"
    else:
        json_f = cat_path
    json_in = open(json_f, "r")
    url2cat = {}
    cat2url = {}
    for line in json_in:
        cat = json.loads(line.strip())
        url = cat["url"]
        categories = cat["categories"]
        if url not in url_data_mappings:
            pass
        if len(categories) - 1 >= cat_level:
            cat_level_desc = categories[cat_level]["category title"]
        else:
            cat_level_desc = "Root"
        url2cat[url] = cat_level_desc
        if cat_level_desc not in cat2url:
            cat2url[cat_level_desc] = []
        cat2url[cat_level_desc].append(url)
    return url2cat, cat2url

cat_level = 1
url2cat, cat2url = read_in_wikihow_categories(url_data_mappings, cat_level=cat_level)
total = 0
for cat in sorted(cat2url):
    total += len(cat2url[cat])
    print("Category: {}  Num of Data: {}".format(cat, len(cat2url[cat])))
print("Number of categories: {}".format(len(cat2url)))

### Show One Sample

In [None]:
rand_idx = np.random.randint(len(data_wikihow))

show_one_sampled_data(data_wikihow[rand_idx], 
                      scrambled=False,
                      title_max_len=200,
                      show_url=True,
                      img_size=4)

# RecipeQA  Data Demo

### Read In WikiHow Data

In [None]:
from datasets.recipeqa import RecipeQAGeneralProcessor

version_text_to_use = "acl22"

recipeqa_processor = RecipeQAGeneralProcessor(version_text=version_text_to_use)
data_recipeqa_train = recipeqa_processor.get_train_examples()
data_recipeqa_dev = recipeqa_processor.get_dev_examples()
data_recipeqa_test = recipeqa_processor.get_test_examples()
data_recipeqa = data_recipeqa_train + data_recipeqa_dev + data_recipeqa_test

print("Total Valid Data Sequences: {}".format(len(data_recipeqa)))

### Show One Sample

In [None]:
rand_idx = np.random.randint(0, len(data_recipeqa))

show_one_sampled_data(data_recipeqa[rand_idx], 
                      scrambled=False,
                      title_max_len=200,
                      show_url=True,
                      img_size=4)

## End