In [13]:
import os
from datasets import load_dataset
import pandas as pd
import shutil

#### Load dataset

In [2]:
data_dir = 'visual7w'

In [3]:
data = load_dataset("json", data_files=os.path.join(data_dir, "dataset_v7w_telling.json"), field="images", split="train")

Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 347.61it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 23.19it/s]
Generating train split: 28653 examples [00:02, 10340.18 examples/s]


In [4]:
train_dataset = data.filter(lambda x: x['split'] == 'train')
test_dataset = data.filter(lambda x: x['split'] == 'test')
val_dataset = data.filter(lambda x: x['split'] == 'val')

Filter: 100%|██████████| 28653/28653 [00:02<00:00, 12820.24 examples/s]
Filter: 100%|██████████| 28653/28653 [00:02<00:00, 10482.72 examples/s]
Filter: 100%|██████████| 28653/28653 [00:02<00:00, 10687.06 examples/s]


In [5]:
def preprocess(dataset):
    questions = []
    answers = []
    image_id = []
    types = []
    
    for i in range(len(dataset)):
        qa_pairs = dataset[i]['qa_pairs']
        for j in range(len(qa_pairs)):
            questions.append(qa_pairs[j]['question'])
            answers.append(str(qa_pairs[j]['answer']).strip('.'))
            types.append(qa_pairs[j]['type'])
            image_id.append(dataset[i]['filename'])
    
    df = pd.DataFrame({
        'questions': questions,
        'answers': answers,
        'image_id': image_id,
        'types': types
    })

    return df

In [6]:
data_train = preprocess(train_dataset)
data_train.to_csv("visual7w/train.csv", index=False)

data_test = preprocess(test_dataset)
data_test.to_csv("visual7w/test.csv", index=False)

data_val = preprocess(val_dataset)
data_val.to_csv("visual7w/val.csv", index=False)

#### Filter data

In [None]:
data_train = pd.read_csv("visual7w/train.csv")
data_test = pd.read_csv("visual7w/test.csv")
data_val = pd.read_csv("visual7w/val.csv")

In [8]:
for type in ["what", "where", "how", "who", "why", "when"]:
    sub_dataset = data_train[data_train['types'] == type]
    print(f"{type} question has {len(sub_dataset)} samples")

what question has 33293 samples
where question has 11421 samples
how question has 10305 samples
who question has 7075 samples
why question has 4470 samples
when question has 3253 samples


In [33]:
d = {"questions": [], "answers": [], "image_id": [], "types": []}
sample = pd.DataFrame(data=d)

In [34]:
for type in ["what", "where", "how", "who", "why", "when"]:
    sub_dataset = data_val[data_val['types'] == type]

    sub_sample = sub_dataset.sample(n=int(len(sub_dataset)*0.25))

    sample = pd.concat([sample,sub_sample],ignore_index=True)

#### Check data

In [7]:
data_train = pd.read_csv("visual7w/sample_1_4/sample_train.csv")
data_test = pd.read_csv("visual7w/sample_1_4/sample_test.csv")
data_val = pd.read_csv("visual7w/sample_1_4/sample_val.csv")

data_augment = pd.read_csv("visual7w/augment/augment.csv")

train = pd.read_csv("visual7w/original/train.csv")

In [38]:
df_not_in = pd.merge(data_augment,data_train[['questions','answers']], on=['questions','answers'], how='left', indicator=True)
df_not_in = df_not_in[df_not_in['_merge'] == 'left_only'].drop(columns='_merge')
df_not_in = df_not_in.drop_duplicates(ignore_index=True)

In [None]:
df_need_add = train.merge(df_not_in[['questions', 'answers']], on=['questions', 'answers'], how='inner',indicator=True)
df_need_add = df_need_add[df_need_add['_merge'] == 'both'].drop(columns='_merge')
df_need_add = df_need_add.drop_duplicates(ignore_index=True)
df_need_add

In [None]:
tmp = pd.concat([data_train,df_need_add],ignore_index=True)
tmp

In [34]:
order = ["what", "where", "how", "who", "why", "when"]

tmp['types'] = pd.Categorical(tmp['types'], categories=order, ordered=True)

df_sorted = tmp.sort_values('types',ignore_index=True)

In [36]:
df_sorted.to_csv("visual7w/sample_1_4/sample_train.csv",index=False)

In [26]:
counter = 1
for i in range(len(data_augment)):
    data_augment.loc[i,'image_id'] = f"gen{counter}.jpg"
    counter += 1

In [30]:
data_augment.to_csv("visual7w/augment/augment.csv",index=False)