In [1]:
import os
import sys
import geopandas as gpd
import pandas as pd
import glob
import seaborn as sns

In [2]:
def get_error_plots_geodataframe(geojson_path):
    
    # Read GeoJSON file
    gdf = gpd.read_file(geojson_path)

    # Extract relevant columns
    gdf = gdf[['year', 'range', 'row', 'plot', 'type', 'rep', 'treatment', 'species', 'accession', 'Entry_ID', 'ID', 'error_ID', 'mismatch']]

    # Extract only those plots that had error
    gdf = gdf[gdf['mismatch']=='error']

    return gdf


def get_error_plots_list(gdf):

    # Get a list of those plots that had error
    error_plot_list = gdf['error_ID'].unique().tolist()

    return error_plot_list


def fix_error_plots(geojson_path, data_path, output_directory, error_plot_list, gdf, columns_to_replace):

    # Iterate through all CSV files
    for csv in glob.glob(data_path):

        # Read CSV
        temp_df = pd.read_csv(csv)

        # Convert plot column to string
        temp_df['plot'] = temp_df['plot'].astype(str)

        # Initialize the "replaced" column with False
        temp_df['replaced'] = False
        
        # Find the rows where 'plot' is in error_plot_list
        mask = temp_df['plot'].isin(error_plot_list)
        
        # Get the corresponding rows from gdf
        temp_gdf = gdf.set_index('error_ID').loc[temp_df.loc[mask, 'plot']]
        
        # Fill in the NA values with the values from temp_gdf
        for col in columns_to_replace:
            temp_df.loc[mask, col] = temp_gdf[col].values
        
        # Indicate that values have been replaced
        temp_df.loc[mask, 'replaced'] = True
        
        temp_df = temp_df.sort_values('plot')
        cols = temp_df.columns.tolist()

        # ADDED 
        temp_df = temp_df.drop('accession', axis=1)
        gdf_geno = gpd.read_file(geojson_path)
        gdf_geno = gdf_geno[['plot', 'accession']]
        temp_df = gdf_geno.merge(temp_df, on='plot')
        temp_df = temp_df[cols]
        ###
        
        csv_outname = os.path.basename(csv)
        temp_df.to_csv(os.path.join(output_directory, csv_outname.replace('.csv', '_corrected.csv')), index=False)

In [3]:
if not os.path.isdir('./data/drone_greenness_corrected'):
    os.makedirs('./data/drone_greenness_corrected')

In [4]:
# Get error plots to fix
gdf = get_error_plots_geodataframe(geojson_path='./sorghum/season14_multi_latlon_geno_correction_labeled.geojson')

# Get list of error plots to fix
error_plot_list = get_error_plots_list(gdf=gdf)

# Fix error plots using GeoJSON
fix_error_plots(
    geojson_path='./sorghum/season14_multi_latlon_geno_correction_labeled.geojson',
    data_path='./data/drone_greenness/*.csv',
    output_directory='./data/drone_greenness_corrected',
    error_plot_list=error_plot_list, 
    gdf=gdf,
    columns_to_replace=['plot', 'year', 'range', 'row', 'species', 'treatment', 'type', 'rep', 'accession']
)