In [1]:
import sys

sys.path.append('../..')
%load_ext autoreload
%autoreload 2

In [2]:
from torch.utils.data import DataLoader
from datasets import load_from_disk, Dataset
import torch
from src.data_utils import generate_ravel_dataset, get_ravel_collate_fn, filter_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
city_train = load_from_disk('./data/city_country_train')
city_test = load_from_disk('./data/city_country_test')

In [7]:
import random
from datasets import Dataset

def _get_causal_size(dataset):
    
    causal_dataset = [d for d in dataset if d["attribute_type"] == "causal"]
    return len(causal_dataset)

def _split_by_attribute_type(dataset):
    
    causal_dataset, isolate_dataset = [], []
    for d in dataset:
        if d["attribute_type"] == "causal":
            causal_dataset.append(d)
        else:
            isolate_dataset.append(d)
            
    return causal_dataset, isolate_dataset


def _enlarge_dataset(dataset, target_causal_size, target_isolate_size, target_attribute):
    
    all_attributes = ["Country", "Continent", "Language", "Latitude", "Longitude", "Timezone"]
    other_attributes = [attribute for attribute in all_attributes if attribute != target_attribute]
    
    causal_dataset, isolate_dataset = _split_by_attribute_type(dataset)
    
    num_causal_sample = target_causal_size - len(causal_dataset)
    
    new_isolate_dataset = []
    
    for attribute in other_attributes:
        
        attribute_isolate_dataset = [d for d in isolate_dataset if d["attribute"] == attribute]
        
        num_isolate_sample = target_isolate_size - len(attribute_isolate_dataset)
        
        if num_isolate_sample > 0:
            attribute_isolate_dataset = attribute_isolate_dataset + [random.choice(attribute_isolate_dataset) for _ in range(num_isolate_sample)]
        else:
            attribute_isolate_dataset = random.sample(attribute_isolate_dataset, target_isolate_size)

        new_isolate_dataset.extend(attribute_isolate_dataset)
        
    enlarged_dataset = causal_dataset + new_isolate_dataset
    enlarged_dataset = Dataset.from_list(enlarged_dataset)
    
    return enlarged_dataset



for split in ['train', 'test']:
    
    
    merged_dataset = []
    
    data_dict = {}
    
    for attribute in ["Country", "Continent", "Language", "Latitude", "Longitude", "Timezone"]:
        dataset = load_from_disk(f'./data/city_{attribute.lower()}_{split}')
        data_dict[attribute] = dataset
    
    target_dataset_size = max([_get_causal_size(dataset) for dataset in data_dict.values()])
    target_isolate_dataset_size = target_dataset_size // 5
    
    for attribute in ["Country", "Continent", "Language", "Latitude", "Longitude", "Timezone"]:
        
        dataset = data_dict[attribute]
        enlarged_dataset = _enlarge_dataset(dataset, target_dataset_size, target_isolate_dataset_size, attribute)
        merged_dataset.extend(enlarged_dataset)
    
    merged_dataset = Dataset.from_list(merged_dataset)
    merged_dataset.save_to_disk(f'./data/city_{split}')

1676 1650
1255 1650
756 1650
665 1650
1029 1650
1589 1650
1236 1650
760 1650
668 1650
1022 1650
1596 1650
1677 1650
770 1650
668 1650
1026 1650
1587 1650
1692 1650
1263 1650
665 1650
1034 1650
1588 1650
1684 1650
1269 1650
781 1650
1051 1650
1587 1650
1687 1650
1272 1650
777 1650
657 1650


Saving the dataset (1/1 shards): 100%|██████████| 84399/84399 [00:00<00:00, 1092799.96 examples/s]


343 331
233 331
168 331
146 331
188 331
332 331
221 331
159 331
129 331
186 331
331 331
335 331
167 331
139 331
190 331
331 331
337 331
237 331
141 331
193 331
320 331
341 331
236 331
152 331
195 331
331 331
335 331
244 331
139 331
125 331


Saving the dataset (1/1 shards): 100%|██████████| 16946/16946 [00:00<00:00, 960755.28 examples/s]


In [20]:
stat_dict = {}
for data in city_train:
    key = data["attribute"] + "_" + data["attribute_type"]
    if key not in stat_dict:
        stat_dict[key] = 0
    
    stat_dict[key] += 1
stat_dict

{'Country_causal': 3929,
 'Continent_causal': 4161,
 'Language_causal': 3151,
 'Latitude_causal': 1887,
 'Longitude_causal': 1737,
 'Timezone_causal': 2636,
 'Country_isolate': 3905,
 'Continent_isolate': 4178,
 'Language_isolate': 3196,
 'Latitude_isolate': 1937,
 'Longitude_isolate': 1662,
 'Timezone_isolate': 2685}

In [28]:
import random

for attribute in ["Country", "Continent", "Language", "Longitude", "Latitude", "Timezone"]:
    
    attribute_dataset = []
    isolate_data = []
    
    for data in city_train:
        if data["attribute"] == attribute and data["attribute_type"] == "causal":
            attribute_dataset.append(data)
        elif attribute in data["edit_instruction"] and data["attribute_type"] == "isolate":
            isolate_data.append(data)
            
    if len(isolate_data) < len(attribute_dataset):
        attribute_dataset.extend(isolate_data)
    else:
        attribute_dataset.extend(random.sample(isolate_data, len(attribute_dataset)))
    
    print(attribute, len(attribute_dataset))
    attribute_dataset = Dataset.from_list(attribute_dataset)
    attribute_dataset.save_to_disk(f"./data/ravel_city_{attribute}_train")

Country 6852


Saving the dataset (1/1 shards): 100%|██████████| 6852/6852 [00:00<00:00, 731188.68 examples/s]


Continent 7120


Saving the dataset (1/1 shards): 100%|██████████| 7120/7120 [00:00<00:00, 756285.47 examples/s]


Language 6078


Saving the dataset (1/1 shards): 100%|██████████| 6078/6078 [00:00<00:00, 733526.49 examples/s]


Longitude 3474


Saving the dataset (1/1 shards): 100%|██████████| 3474/3474 [00:00<00:00, 582467.70 examples/s]


Latitude 3774


Saving the dataset (1/1 shards): 100%|██████████| 3774/3774 [00:00<00:00, 613586.45 examples/s]


Timezone 5272


Saving the dataset (1/1 shards): 100%|██████████| 5272/5272 [00:00<00:00, 704260.48 examples/s]


In [29]:
for attribute in ["Country", "Continent", "Language", "Longitude", "Latitude", "Timezone"]: 
    
    attribute_dataset = []
    isolate_data = []
    
    for data in city_test:
        if data["attribute"] == attribute and data["attribute_type"] == "causal":
            attribute_dataset.append(data)
        elif attribute in data["edit_instruction"] and data["attribute_type"] == "isolate":
            isolate_data.append(data)
            
    if len(isolate_data) < len(attribute_dataset):
        attribute_dataset.extend(isolate_data)
    else:
        attribute_dataset.extend(random.sample(isolate_data, len(attribute_dataset)))
    
    print(attribute, len(attribute_dataset))
    attribute_dataset = Dataset.from_list(attribute_dataset)
    attribute_dataset.save_to_disk(f"./data/ravel_city_{attribute}_test")
    

Country 685


Saving the dataset (1/1 shards): 100%|██████████| 685/685 [00:00<00:00, 208846.28 examples/s]


Continent 703


Saving the dataset (1/1 shards): 100%|██████████| 703/703 [00:00<00:00, 203099.31 examples/s]


Language 621


Saving the dataset (1/1 shards): 100%|██████████| 621/621 [00:00<00:00, 188143.80 examples/s]


Longitude 324


Saving the dataset (1/1 shards): 100%|██████████| 324/324 [00:00<00:00, 99462.38 examples/s] 

Latitude




 370


Saving the dataset (1/1 shards): 100%|██████████| 370/370 [00:00<00:00, 113583.58 examples/s]


Timezone 571


Saving the dataset (1/1 shards): 100%|██████████| 571/571 [00:00<00:00, 166825.55 examples/s]
