In [1]:
import pandas as pd
import os

INCLUDE_CHART_TYPE_IN_TEXT = False

In [2]:
import numpy as np
import random
import os

def seed_everything(seed=3407):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)

seed_everything()

In [3]:
!rm -rf /home/pawel/Projects/benetech-making-graphs-accessible/data/train_with_extra_data
!mkdir /home/pawel/Projects/benetech-making-graphs-accessible/data/train_with_extra_data

In [4]:
extra_data_df = pd.read_csv("/home/pawel/Projects/benetech-making-graphs-accessible/data/extra_data/metadata.csv")
extra_data_df = extra_data_df[extra_data_df["validation"] != 1]
extra_data_df = extra_data_df[extra_data_df["count"] <= 20]

In [5]:
extra_data_df["chart_type"].value_counts()

chart_type
dot               100000
horizontal_bar    100000
line              100000
vertical_bar      100000
scatter            21561
Name: count, dtype: int64

In [6]:
dot_df = extra_data_df[(extra_data_df["chart_type"] == "dot")].sample(13000).reset_index(drop=True)
horizontal_df = extra_data_df[extra_data_df["chart_type"] == "horizontal_bar"].sample(19000).reset_index(drop=True)
scatter_df = extra_data_df[extra_data_df["chart_type"] == "scatter"].sample(10000).reset_index(drop=True)
vertical_df = extra_data_df[extra_data_df["chart_type"] == "vertical_bar"].sample(2000).reset_index(drop=True)


In [7]:
dot_df["image_path"] = dot_df["file_name"].str.replace("graphs_d", "/home/pawel/Projects/benetech-making-graphs-accessible/data/extra_data/graphs_d", regex=False)
horizontal_df["image_path"] = horizontal_df["file_name"].str.replace("graphs_h", "/home/pawel/Projects/benetech-making-graphs-accessible/data/extra_data/graphs_h", regex=False)
scatter_df["image_path"] = scatter_df["file_name"].str.replace("graphs_s", "/home/pawel/Projects/benetech-making-graphs-accessible/data/extra_data/graphs_s", regex=False)
vertical_df["image_path"] = vertical_df["file_name"].str.replace("graphs_v", "/home/pawel/Projects/benetech-making-graphs-accessible/data/extra_data/graphs_v", regex=False)

In [8]:
from shutil import copy

for image_path in dot_df["image_path"].values:
    copy(image_path, image_path.replace("/home/pawel/Projects/benetech-making-graphs-accessible/data/extra_data/graphs_d/", "/home/pawel/Projects/benetech-making-graphs-accessible/data/train_with_extra_data/d_"))

for image_path in horizontal_df["image_path"].values:
    copy(image_path, image_path.replace("/home/pawel/Projects/benetech-making-graphs-accessible/data/extra_data/graphs_h/", "/home/pawel/Projects/benetech-making-graphs-accessible/data/train_with_extra_data/h_"))
    
for image_path in scatter_df["image_path"].values:
    copy(image_path, image_path.replace("/home/pawel/Projects/benetech-making-graphs-accessible/data/extra_data/graphs_s/", "/home/pawel/Projects/benetech-making-graphs-accessible/data/train_with_extra_data/s_"))
    
for image_path in vertical_df["image_path"].values:
    copy(image_path, image_path.replace("/home/pawel/Projects/benetech-making-graphs-accessible/data/extra_data/graphs_v/", "/home/pawel/Projects/benetech-making-graphs-accessible/data/train_with_extra_data/v_"))

In [9]:
dot_df["image_path"] = dot_df["image_path"].str.replace("/home/pawel/Projects/benetech-making-graphs-accessible/data/extra_data/graphs_d/", "/home/pawel/Projects/benetech-making-graphs-accessible/data/train_with_extra_data/d_")
horizontal_df["image_path"] = horizontal_df["image_path"].str.replace("/home/pawel/Projects/benetech-making-graphs-accessible/data/extra_data/graphs_h/", "/home/pawel/Projects/benetech-making-graphs-accessible/data/train_with_extra_data/h_")
scatter_df["image_path"] = scatter_df["image_path"].str.replace("/home/pawel/Projects/benetech-making-graphs-accessible/data/extra_data/graphs_s/", "/home/pawel/Projects/benetech-making-graphs-accessible/data/train_with_extra_data/s_")
vertical_df["image_path"] = vertical_df["image_path"].str.replace("/home/pawel/Projects/benetech-making-graphs-accessible/data/extra_data/graphs_v/", "/home/pawel/Projects/benetech-making-graphs-accessible/data/train_with_extra_data/v_")

In [10]:
extra_data_df = pd.concat([dot_df, horizontal_df, scatter_df, vertical_df], ignore_index=True)

extra_data_df["target_text"] = "x | y <0x0A> " + extra_data_df["text"]#.str.replace(" <0x0A> ", "<0x0A>", regex=False)

if INCLUDE_CHART_TYPE_IN_TEXT:
    extra_data_df["target_text"] = extra_data_df["chart_type"] + " <0x0A> " + extra_data_df["target_text"]

In [11]:
extra_data_df["target_type"] = extra_data_df["chart_type"]

In [12]:
extra_data_df["source"] = "generated"

In [13]:
output = []
for xy in extra_data_df["text"].values:
    xy = xy.split("<0x0A>")
    
    x = ";".join([xy_.split(" | ")[0] for xy_ in xy])
    y = ";".join([xy_.split(" | ")[1] for xy_ in xy])
    
    output.append([x, y])

In [14]:
extra_data_df.loc[:, ["target_text_x", "target_text_y"]] = output

In [15]:
def get_axis_type(axis):
    axis = axis.split(";", )
    try:
        float(axis[0])
        return "numerical"
    except Exception:
        return "categorical"

In [16]:
extra_data_df["x_axis_type"] = extra_data_df["target_text_x"].apply(get_axis_type)
extra_data_df["y_axis_type"] = extra_data_df["target_text_y"].apply(get_axis_type)

In [17]:
extra_data_df.drop(["file_name", "text", "validation", "chart_type", "count"], axis=1, inplace=True)

In [18]:
extra_data_df.to_csv(f"../data/extra_data{'_with_type' if INCLUDE_CHART_TYPE_IN_TEXT else ''}.csv", index=False)