In [10]:
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt
import os
import json
from tqdm import tqdm

In [11]:
dataset_folder = "Original"
meshblock_filename = "meshblocks.shp"
dataset_name = "Brazil_Election_2018"
data_id = "INDEX"
#meshblocks_id = "GEOID"
meshblocks_id = "code_muni"
data_path = f"/home/tpinho/IJGIS/Datasets/{dataset_name}/{dataset_folder}/data.csv"
fold_path = f"/home/tpinho/IJGIS/Datasets/{dataset_name}/{dataset_folder}/folds/"
meshblock_path = f"/home/tpinho/IJGIS/Datasets/{dataset_name}/{dataset_folder}/meshblocks/{meshblock_filename}"
output_path = f"/home/tpinho/IJGIS/Datasets/{dataset_name}/{dataset_folder}"

In [12]:
methods_path = [os.path.join(fold_path,c) for c in os.listdir(fold_path)]
#methods_path = ['/home/tpinho/IJGIS/Datasets/Brazil_Election_2018/Brazil_Election_2018_Sampled_dec0.3_prob0.5/folds/RegGBSCV_R_Kappa_0.5']
methods_path = [p for p in methods_path if "TraditionalSCV" in p]

In [13]:
def map_color(row):
    if row['Type'] == 'discarded':
        return '#AFABAB'
    elif row['Type'] == 'test':
        return '#9FC5E8'
    elif row['Type'] == 'removing_buffer':
        return '#EA9999'
    elif row['Type'] == 'train':
        return '#B6D7A8'
    elif row["Type"] == "missing":
        return "#ffffffff"
    else:
        return '#ffffffff'


In [14]:

for method in tqdm(methods_path):
    fold_folders = [os.path.join(method,c) for c in os.listdir(method) if not os.path.isfile(os.path.join(method,c))]
    meshblocks = gpd.read_file(meshblock_path)
    #print(meshblocks)
    #break
    try:
        meshblocks.set_index(meshblocks_id, inplace=True)
    except KeyError:
        pass
    data = pd.read_csv(data_path, index_col=data_id)
    missing = [idx for idx in meshblocks.index if idx not in data.index]
    for fold_path in tqdm(fold_folders):
        
        meshblocks["Type"] = [0] * len(meshblocks)
        with open(os.path.join(fold_path, "split_data.json"), 'r') as fp:
            split_data = json.load(fp)
        split_data["missing"] = missing
        
        for key in split_data.keys():
            joiner_index = [idx for idx in split_data[key] if idx in meshblocks.index]
            meshblocks.loc[joiner_index, "Type"] = key
        
        #meshblocks["Type"].replace(0, "train", inplace=True)
        fig, ax = plt.subplots(1, 1)
        color_list = meshblocks.apply(lambda row: map_color(row), axis=1)
        meshblocks.plot(categorical=True, 
                    color=color_list, 
                    linewidth=.05, 
                    edgecolor='white',
                    legend=False,
                    markersize=.1,
                    legend_kwds={'bbox_to_anchor': (.3, 1.05), 
                                    'fontsize': 16, 
                                    'frameon': False}, 
                    ax=ax)
        plt.axis('off')
        plt.savefig(os.path.join(fold_path, 'new_train_test_split.png'), dpi=1000)
        plt.close()

100%|██████████| 27/27 [03:22<00:00,  7.49s/it]
100%|██████████| 1/1 [03:33<00:00, 213.40s/it]
