# Feature Screening

## Settings
### Imports

In [1]:
import pandas as pd
import numpy as np
import pickle
import os
import gc
import warnings
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns

import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.stats.outliers_influence import variance_inflation_factor
from scipy.stats import pearsonr

warnings.filterwarnings("ignore")

### Configuration

In [10]:
TARGET_VARIABLES = ['co2', 'gdp', 'primary_energy_consumption']
G20_COUNTRIES = [
    'United States', 'China', 'Japan', 'Germany', 
    'United Kingdom', 'France', 'Italy', 'Canada',
    'Brazil', 'Russia', 'India', 'Australia', 
    'Mexico', 'Indonesia', 'Turkey', 'Saudi Arabia',
    'South Africa', 'Argentina', 'South Korea'
]
# Top N number of features to consider
TOP_N = 40
# Minimum number of appearance in TOP N across metrics
MIN_APPEAR = 3
# To remove highly correlated features
CORR_THRESHOLD = 0.90
# Variance Inflation Factor threshold
VIF_THRESHOLD = 10.0
# Feature counts to analyse at the end
feature_counts_steps = {}

### Load data

In [3]:
# Load scores
def load_importance_scores(save_dir='data_export'):

    input_dir = os.path.join(save_dir, '02_results')
    screen_path = os.path.join(input_dir, 'importance_scores.pkl')

    try:
        with open(screen_path, 'rb') as f:
            importance_scores = pickle.load(f)
        print(f"loaded pickle")

        return importance_scores
    
    except FileNotFoundError:
        print(f"Run Stage 1 once again")
        return None

In [4]:
# Loading lagged data
def load_lag_data(save_dir='data_export'):

    lag_path = os.path.join(save_dir, 'lag_df_1965.pkl')

    try:
        lag_df = pd.read_pickle(lag_path)
        print(f"loaded lag data")
        return lag_df
    except FileNotFoundError:
        try:
            lag_csv = lag_path.replace('.pkl', '.csv')
            lag_df = pd.read_csv(lag_csv)
            return lag_df
        except FileNotFoundError as e:
            print(f"Error in loading {e}")
            return None

In [5]:
importance_scores = load_importance_scores()
importance_scores.keys()

loaded pickle


dict_keys(['co2', 'gdp', 'primary_energy_consumption'])

In [6]:
lag_df = load_lag_data()
lag_df.head(5)

loaded lag data


Unnamed: 0,country,year,primary_energy_consumption,population,energy_per_gdp,energy_per_capita,gdp,temperature_change_from_ghg,land_use_change_co2_per_capita,other_co2_per_capita,...,hydro_cons_change_pct_lag3,hydro_cons_change_pct_lag4,solar_share_energy_lag1,solar_share_energy_lag2,solar_share_energy_lag3,solar_share_energy_lag4,hydro_consumption_lag1,hydro_consumption_lag2,hydro_consumption_lag3,hydro_consumption_lag4
2567,Argentina,1965,328.528,22112635.0,1.452,14896.737,226284900000.0,0.009,5.968,,...,,,,,,,,,,
2568,Argentina,1966,339.175,22453898.0,1.489,15139.622,227834500000.0,0.009,5.804,,...,,,,,,,3.623,,,
2569,Argentina,1967,349.032,22799062.0,1.492,15337.348,233929100000.0,0.009,8.292,,...,,,,,,,3.669,3.623,,
2570,Argentina,1968,360.844,23150579.0,1.48,15612.201,243886500000.0,0.01,5.502,,...,,,,,,,3.757,3.669,3.623,
2571,Argentina,1969,372.911,23508715.0,1.409,15887.447,264721500000.0,0.01,4.045,,...,1.291,,,,,,4.429,3.757,3.669,3.623


### Utility Functions

In [7]:
def clear_memory():
    gc.collect()

## Removing highly correlated features

### Find features in Top N rank for all targets
wonder if i should include, there are some domain relevant features such like population are removed

In [None]:
"""# Choosing top n-number of features by metric
def get_top_features(importance_scores, top_n=20):

    feature_appearance = {}

    for target in TARGET_VARIABLES:
        if target not in importance_scores:
            continue

        features = importance_scores[target]
        metrics = ['avg_abs_correlation', 'granger_significance_rate', 'avg_mutual_info']
        
        for metric in metrics:
            # Sort and rank
            sorted_features = sorted(features.items(), key=lambda x: x[1][metric])
            for rank, (feature, _) in enumerate(sorted_features[:top_n], 1):

                if feature not in feature_appearance:
                    feature_appearance[feature] = {
                        'ranks': [],
                        'targets': set()
                        }
                    
                feature_appearance[feature]['ranks'].append(rank)
                feature_appearance[feature]['targets'].add(target)

    feature_scores = {}

    for feature, data in feature_appearance.items():
        
        avg_rank = np.mean(data['ranks'])
        n_targets = len(data['targets'])
        
        feature_scores[feature] = {
            'avg_rank': avg_rank,
            'n_targets': n_targets
        }
    
    # Sort by combined score
    sorted_features = sorted(feature_scores.items(), key=lambda x: x[1]['n_targets'], reverse=True)
    
    print(f"\nTop features across all targets:")
    for i, (feat, scores) in enumerate(sorted_features[:], 1):
        print(f"{i:2d}. {feat:35} Targets: {scores['n_targets']}/3,   Avg Rank: {scores['avg_rank']:3.1f}")
    
    return sorted_features, feature_scores"""

In [None]:
"""sorted_features, feature_scores = get_top_features(importance_scores, 20)"""


Top features across all targets:
 1. hydro_cons_change_twh               Targets: 3/3,   Avg Rank: 4.7
 2. renewables_cons_change_pct          Targets: 3/3,   Avg Rank: 3.7
 3. co2_including_luc_growth_abs        Targets: 3/3,   Avg Rank: 4.4
 4. coal_cons_change_pct                Targets: 3/3,   Avg Rank: 4.7
 5. nuclear_cons_change_twh             Targets: 3/3,   Avg Rank: 9.0
 6. coal_cons_change_twh                Targets: 3/3,   Avg Rank: 10.4
 7. co2_including_luc_growth_prct       Targets: 3/3,   Avg Rank: 5.9
 8. oil_prod_change_twh                 Targets: 3/3,   Avg Rank: 12.5
 9. coal_prod_change_twh                Targets: 3/3,   Avg Rank: 12.4
10. low_carbon_cons_change_pct          Targets: 3/3,   Avg Rank: 5.9
11. gas_cons_change_twh                 Targets: 3/3,   Avg Rank: 10.5
12. co2_growth_abs                      Targets: 3/3,   Avg Rank: 14.1
13. low_carbon_cons_change_twh          Targets: 3/3,   Avg Rank: 14.2
14. oil_cons_change_twh                 Targets: 3

In [None]:
"""scores_appear_all = sorted_features[:24]"""

In [None]:
"""features_appear_all = []
for i, (feature, _) in enumerate(scores_appear_all):
    features_appear_all.append(feature)"""

### Feature appearance
to check how many targets for each feature appear

In [8]:
def get_feature_app(importance_scores):
    
    feature_counts = {}
    for target in TARGET_VARIABLES:
        if target in importance_scores:
            for feature in importance_scores[target]:
                if feature not in feature_counts:
                    feature_counts[feature] = {'n_targets': 0, 'targets': []}
                feature_counts[feature]['n_targets'] += 1
                feature_counts[feature]['targets'].append(target)

    return feature_counts

### Domain relevant features
Features theoretically important to be kept

In [9]:
def domain_rel_features():

    domain = {
        'co2': [
            # direct emission
            'coal_consumption', 'oil_consumption', 'gas_consumption', 'fossil_fuel_consumption',
            # reduction
            'renewables_consumption', 'nuclear_consumption', 'low_carbon_consumption',
            # others
            'population', 'gdp', 'primary_energy_consumption', 'energy_per_capita', 'co2_per_unit_energy'
        ],
        'gdp': [
            'population',
            # energy - economy
            'primary_energy_consumption', 'electricity_generation', 'energy_per_gdp',
            # co2 - economy
            'trade_co2', 'consumption_co2', 'fossil_fuel_consumption', 'renewables_consumption'
        ],
        'primary_energy_consumption': [
            # sources
            'coal_consumption', 'oil_consumption', 'gas_consumption', 'renewables_consumption', 'nuclear_consumption', 'hydro_consumption',
            # economy
            'population', 'gdp',
            # intensity
            'energy_per_capita', 'energy_per_gdp', 'electricity_generation'
        ]
    }

    # Union all features
    union = set()
    for dom in domain.values():
        union.update(dom)

    return domain, list(union)

In [11]:
# All features from importance scores
all_features = set()
for target in TARGET_VARIABLES:
    if target in importance_scores:
        all_features.update(importance_scores[target].keys())
        
all_features = list(all_features)
feature_counts_steps['Initial'] = len(all_features)
feature_counts = get_feature_app(importance_scores)

In [12]:
domain, all_relevant_features = domain_rel_features()

### Correlations Feature - Feature
Finding feature - feature correlations to remove highly correlated features

In [13]:
def ff_correlation(lag_df, countries, features):

    all_data = []

    for country in countries:
        country_data = lag_df[lag_df['country'] == country][features]
        all_data.append(country_data)

    combined_data = pd.concat(all_data, ignore_index=True)
    corr_matrix = combined_data.corr()

    return corr_matrix

In [14]:
def find_collinear_groups(lag_df, features, corr_threshold):

    correlation_matrix = ff_correlation(lag_df, G20_COUNTRIES, features)
    
    groups = []
    visited = set()
    
    for i, feat1 in enumerate(features):
        if feat1 in visited:
            continue
            
        group = [feat1]
        visited.add(feat1)
        
        for j, feat2 in enumerate(features):
            if feat2 not in visited and abs(correlation_matrix.loc[feat1, feat2]) > corr_threshold:
                group.append(feat2)
                visited.add(feat2)
        
        if len(group) > 1:
            groups.append(group)
    
    print("\nHighly correlated feature groups:")
    for i, group in enumerate(groups, 1):
        print(f"\nGroup {i}:")
        for feat in group:
            print(f" - {feat}")
    
    return groups

In [15]:
find_collinear_groups(lag_df, all_features, CORR_THRESHOLD)


Highly correlated feature groups:

Group 1:
 - share_global_cumulative_cement_co2
 - coal_electricity
 - share_global_co2

Group 2:
 - electricity_generation
 - total_ghg
 - low_carbon_electricity
 - fossil_electricity
 - co2_including_luc
 - total_ghg_excluding_lucf
 - low_carbon_consumption
 - fossil_fuel_consumption

Group 3:
 - nuclear_share_elec
 - nuclear_energy_per_capita
 - nuclear_elec_per_capita
 - nuclear_share_energy

Group 4:
 - wind_share_energy
 - wind_share_elec

Group 5:
 - cumulative_cement_co2
 - coal_co2
 - cement_co2
 - coal_consumption

Group 6:
 - hydro_energy_per_capita
 - renewables_energy_per_capita
 - hydro_elec_per_capita
 - renewables_elec_per_capita

Group 7:
 - oil_consumption
 - share_global_cumulative_co2
 - oil_co2
 - cumulative_co2_including_luc
 - temperature_change_from_ghg
 - share_global_oil_co2
 - cumulative_co2
 - temperature_change_from_co2

Group 8:
 - fossil_share_energy
 - low_carbon_share_elec
 - low_carbon_share_energy
 - fossil_share_ele

[['share_global_cumulative_cement_co2',
  'coal_electricity',
  'share_global_co2'],
 ['electricity_generation',
  'total_ghg',
  'low_carbon_electricity',
  'fossil_electricity',
  'co2_including_luc',
  'total_ghg_excluding_lucf',
  'low_carbon_consumption',
  'fossil_fuel_consumption'],
 ['nuclear_share_elec',
  'nuclear_energy_per_capita',
  'nuclear_elec_per_capita',
  'nuclear_share_energy'],
 ['wind_share_energy', 'wind_share_elec'],
 ['cumulative_cement_co2', 'coal_co2', 'cement_co2', 'coal_consumption'],
 ['hydro_energy_per_capita',
  'renewables_energy_per_capita',
  'hydro_elec_per_capita',
  'renewables_elec_per_capita'],
 ['oil_consumption',
  'share_global_cumulative_co2',
  'oil_co2',
  'cumulative_co2_including_luc',
  'temperature_change_from_ghg',
  'share_global_oil_co2',
  'cumulative_co2',
  'temperature_change_from_co2'],
 ['fossil_share_energy',
  'low_carbon_share_elec',
  'low_carbon_share_energy',
  'fossil_share_elec'],
 ['co2_growth_abs', 'energy_cons_change

In [16]:
corr_matrix = ff_correlation(lag_df, G20_COUNTRIES, all_features)

In [17]:
def remove_high_corr(features, corr_matrix, feature_counts, features_to_keep, threshold):

    features_to_remove = set()
    corr_abs = corr_matrix.abs()

    # Finding highly correlated pairs
    for i in range(len(corr_abs.columns)):
        if corr_abs.columns[i] in features_to_remove:
            continue
        for j in range(i+1, len(corr_abs.columns)):
            if corr_abs.columns[j] in features_to_remove:
                continue

            if corr_abs.iloc[i, j] > threshold:
                
                feat1 = corr_abs.columns[i]
                feat2 = corr_abs.columns[j]

                # Keep the features
                if feat1 in features_to_keep and feat2 in features_to_keep:
                    continue

                if feat1 in features_to_keep:
                    features_to_remove.add(feat2)
                elif feat2 in features_to_keep:
                    features_to_remove.add(feat1)
                else:
                    # Keeping the one appears more in the targets
                    n_target1 = feature_counts.get(feat1, {}).get('n_targets', 0)
                    n_target2 = feature_counts.get(feat2, {}).get('n_targets', 0)

                    if n_target1 >= n_target2:
                        features_to_remove.add(feat2)
                    else:
                        features_to_remove.add(feat1)
    
    selected_features = [f for f in features if f not in features_to_remove]

    return selected_features, features_to_remove

In [18]:
selected_features, removed = remove_high_corr(all_features, corr_matrix, feature_counts, all_relevant_features, CORR_THRESHOLD)
feature_counts_steps['After Correlation'] = len(selected_features)
selected_features

['share_global_cumulative_cement_co2',
 'electricity_generation',
 'nuclear_share_elec',
 'gas_prod_change_pct',
 'cumulative_flaring_co2',
 'wind_share_energy',
 'hydro_energy_per_capita',
 'oil_consumption',
 'fossil_share_energy',
 'co2_growth_abs',
 'oil_co2_per_capita',
 'solar_energy_per_capita',
 'gas_prod_per_capita',
 'temperature_change_from_n2o',
 'solar_consumption',
 'coal_share_elec',
 'low_carbon_cons_change_twh',
 'share_global_luc_co2',
 'fossil_elec_per_capita',
 'energy_cons_change_pct',
 'renewables_cons_change_pct',
 'oil_cons_change_pct',
 'co2_including_luc_per_unit_energy',
 'renewables_consumption',
 'share_global_flaring_co2',
 'co2_per_gdp',
 'nitrous_oxide',
 'flaring_co2_per_capita',
 'coal_prod_per_capita',
 'wind_cons_change_twh',
 'coal_elec_per_capita',
 'low_carbon_cons_change_pct',
 'coal_cons_change_pct',
 'share_of_temperature_change_from_ghg',
 'wind_elec_per_capita',
 'other_renewables_elec_per_capita',
 'co2_including_luc_per_capita',
 'hydro_con

## Statistical Significance Testing

### Regression analysis
Testing each feature's effectiveness for each target with OLS regression

and remove features which are not signicant in enough countries or below 0.5 R-squared

In [19]:
def features_reg(lag_df, selected_features, p_threshold=0.05):

    reg_results = defaultdict(lambda: defaultdict(dict))

    for feature in selected_features:
        for target in TARGET_VARIABLES:

            sig_countries = 0
            r_squared_values = []
            p_values = []

            for country in G20_COUNTRIES:
                country_data = lag_df[lag_df['country'] == country]

                data_subset = country_data[[feature, target]].dropna()

                X = data_subset[feature].values
                y = data_subset[target].values

                Xc = sm.add_constant(X)

                model = sm.OLS(y, Xc).fit()

                p_value = model.pvalues[1]
                r_squared = model.rsquared

                if p_value < p_threshold:
                    sig_countries += 1

                r_squared_values.append(r_squared)
                p_values.append(p_value)


            if len(p_values) > 0:
                reg_results[feature][target] = {
                    'n_sig_countries': sig_countries,
                    'avg_r_squared': np.mean(r_squared_values),
                    'avg_p_value': np.mean(p_values)
                }

    features_to_remove = set()
    
    for feature in selected_features:
        keep_feature = False
        
        for target in TARGET_VARIABLES:
            if target in reg_results[feature]:
                results = reg_results[feature][target]
                
                # Keep if significant >= 10 countries
                if results['n_sig_countries'] >= 10:
                    keep_feature = True
                    break
                
                # Or if R-squared > 0.5
                if results['avg_r_squared'] >= 0.5:
                    keep_feature = True
                    break

        
        if not keep_feature:
            features_to_remove.add(feature)
            
            print(f"\n Removing {feature}:")
            for target in TARGET_VARIABLES:
                if target in reg_results[feature]:
                    res = reg_results[feature][target]
                    print(f"{target:27}: significant {res['n_sig_countries']}/{len(G20_COUNTRIES)} countries, "
                          f"avg R-sqaured = {res['avg_r_squared']:.3f}, avg p-value = {res['avg_p_value']}")
                    
    valid_features = [f for f in selected_features if f not in features_to_remove]

    summary_df = []
    for feature in valid_features:
        for target in TARGET_VARIABLES:
            if target in reg_results[feature]:
                res = reg_results[feature][target]
                summary_df.append({
                    'feature': feature,
                    'target': target,
                    'countries_significant': res['n_sig_countries'],
                    'avg_r_squared': res['avg_r_squared'],
                    'avg_p_value': res['avg_p_value']
                })
    
    summary_df = pd.DataFrame(summary_df)

    del reg_results, model, features_to_remove
    clear_memory()

    return valid_features, summary_df

In [20]:
valid_features_, reg_summary_df = features_reg(lag_df, selected_features, p_threshold=0.05)


 Removing gas_prod_change_pct:
co2                        : significant 9/19 countries, avg R-sqaured = 0.073, avg p-value = 0.17421660238860998
gdp                        : significant 9/19 countries, avg R-sqaured = 0.102, avg p-value = 0.18905244840930555
primary_energy_consumption : significant 8/19 countries, avg R-sqaured = 0.080, avg p-value = 0.17697722035617824

 Removing co2_growth_abs:
co2                        : significant 8/19 countries, avg R-sqaured = 0.072, avg p-value = 0.2760782069417647
gdp                        : significant 8/19 countries, avg R-sqaured = 0.081, avg p-value = 0.2193307274502376
primary_energy_consumption : significant 8/19 countries, avg R-sqaured = 0.070, avg p-value = 0.25425630400279986

 Removing low_carbon_cons_change_twh:
co2                        : significant 6/19 countries, avg R-sqaured = 0.110, avg p-value = 0.268747243499257
gdp                        : significant 8/19 countries, avg R-sqaured = 0.132, avg p-value = 0.234581412116

In [21]:
# Add all_relevant_features if they were removed
valid_features = valid_features_.copy()
for feat in all_relevant_features:
    if feat in selected_features and feat not in valid_features:

        # If it has importance
        importance = False

        for target in TARGET_VARIABLES:
            if target in importance_scores and feat in importance_scores[target]:
                
                scores = importance_scores[target][feat]
                if (scores.get('avg_abs_correlation', 0) > 0.1 or
                    scores.get('granger_significance_rate', 0) > 0.1):

                    importance = True
        
        if importance:
            valid_features.append(feat)
            print(f"{feat} is added back")

In [22]:
feature_counts_steps['After Regression'] = len(valid_features)

### Variance Inflation Factor
Use VIF to check multicollinearity

In [23]:
"""
VIF_i = 1 / (1 - R^2_i)
Generally, if vif >= 5 or 10, this might indicate there is a strong multicollinearity, 1 indicates no multicollinearity
"""
def check_vif(lag_df, features):
    
    all_data = []
    for country in G20_COUNTRIES:
        country_data = lag_df[lag_df['country'] == country][features].dropna()
        all_data.append(country_data)

    combined_data = pd.concat(all_data, ignore_index=True)

    # Calculating VIF
    vif_data = pd.DataFrame()
    vif_data["Feature"] = features
    vif_values = []

    for i in range(len(features)):
        try:
            vif = variance_inflation_factor(combined_data.values, i)
            vif_values.append(vif)

        except:
            vif_values.append(np.inf)

    vif_data["VIF"] = vif_values

    del all_data, combined_data
    clear_memory()

    return vif_data

In [24]:
vif_data = check_vif(lag_df, valid_features)

In [25]:
high_vif_features = vif_data[vif_data['VIF'] > VIF_THRESHOLD]['Feature'].tolist()
high_vif_features

['share_global_cumulative_cement_co2',
 'electricity_generation',
 'nuclear_share_elec',
 'cumulative_flaring_co2',
 'wind_share_energy',
 'hydro_energy_per_capita',
 'oil_consumption',
 'fossil_share_energy',
 'oil_co2_per_capita',
 'solar_energy_per_capita',
 'gas_prod_per_capita',
 'temperature_change_from_n2o',
 'solar_consumption',
 'coal_share_elec',
 'share_global_luc_co2',
 'fossil_elec_per_capita',
 'co2_including_luc_per_unit_energy',
 'renewables_consumption',
 'share_global_flaring_co2',
 'co2_per_gdp',
 'nitrous_oxide',
 'flaring_co2_per_capita',
 'coal_prod_per_capita',
 'coal_elec_per_capita',
 'share_of_temperature_change_from_ghg',
 'wind_elec_per_capita',
 'other_renewables_elec_per_capita',
 'co2_including_luc_per_capita',
 'hydro_consumption',
 'low_carbon_energy_per_capita',
 'methane_per_capita',
 'oil_production',
 'oil_prod_per_capita',
 'gas_energy_per_capita',
 'gas_share_energy',
 'share_global_gas_co2',
 'land_use_change_co2_per_capita',
 'energy_per_capita'

In [26]:
features_to_remove = []
for feat in high_vif_features:

    vif_value = vif_data[vif_data['Feature'] == feat]['VIF'].values[0]
    print(f"{feat}: {vif_value:.2f}")

    if feat not in all_relevant_features:
        features_to_remove.append(feat)

    else:
        print(f"\nKeeping {feat}\n")

valid_features = [f for f in valid_features if f not in features_to_remove]

share_global_cumulative_cement_co2: 172.22
electricity_generation: 2849.23

Keeping electricity_generation

nuclear_share_elec: 1443.65
cumulative_flaring_co2: 509.13
wind_share_energy: 50.18
hydro_energy_per_capita: 914.98
oil_consumption: 27294543196184.82

Keeping oil_consumption

fossil_share_energy: 28056.21
oil_co2_per_capita: 1201.05
solar_energy_per_capita: 18.32
gas_prod_per_capita: 127.49
temperature_change_from_n2o: 200.78
solar_consumption: 45.28
coal_share_elec: 7920.51
share_global_luc_co2: 34.59
fossil_elec_per_capita: 71485708370960.25
co2_including_luc_per_unit_energy: 229.64
renewables_consumption: 5465533528362.25

Keeping renewables_consumption

share_global_flaring_co2: 28.96
co2_per_gdp: 1754.05
nitrous_oxide: 445.10
flaring_co2_per_capita: 26.56
coal_prod_per_capita: 58.87
coal_elec_per_capita: 26807140639110.09
share_of_temperature_change_from_ghg: 1926.50
wind_elec_per_capita: 39.79
other_renewables_elec_per_capita: 57.68
co2_including_luc_per_capita: 34827.63


In [27]:
valid_features

['electricity_generation',
 'oil_consumption',
 'energy_cons_change_pct',
 'oil_cons_change_pct',
 'renewables_consumption',
 'wind_cons_change_twh',
 'hydro_consumption',
 'gas_prod_change_twh',
 'energy_per_capita',
 'coal_consumption',
 'renewables_cons_change_twh',
 'low_carbon_consumption',
 'other_renewables_cons_change_twh',
 'co2_per_unit_energy',
 'fossil_fuel_consumption',
 'nuclear_consumption',
 'co2_growth_prct',
 'gas_cons_change_pct',
 'gas_consumption',
 'energy_per_gdp',
 'population']

In [28]:
feature_counts_steps['After VIF'] = len(valid_features)
feature_counts_steps['Final'] = len(valid_features)

## Analysis

In [29]:
def feature_summary(valid_features, importance_scores, reg_summary_df):

    summary = []

    for feature in valid_features:
        feature_info = {
            'feature': feature,
            'appears_in_targets': 0,
            'targets': [],
            'avg_correlation': 0,
            'avg_granger_sig': 0,
            'avg_mi': 0,
            'regression_sig': 0
        }

        corr_values = []
        granger_values = []
        mi_values = []

        for target in TARGET_VARIABLES:
            if target in importance_scores and feature in importance_scores[target]:
                feature_info['appears_in_targets'] += 1
                feature_info['targets'].append(target)

                scores = importance_scores[target][feature]
                corr_values.append(scores.get('avg_abs_correlation', 0))
                granger_values.append(scores.get('granger_significance_rate', 0))
                mi_values.append(scores.get('avg_mutual_info', 0))

        feature_info['avg_correlation'] = np.mean(corr_values) if corr_values else 0
        feature_info['avg_granger_sig'] = np.mean(granger_values) if granger_values else 0
        feature_info['avg_mi'] = np.mean(mi_values) if mi_values else 0

        reg_count = len(reg_summary_df[reg_summary_df['feature'] == feature])
        feature_info['regression_sig'] = reg_count

        summary.append(feature_info)

    summary_df = pd.DataFrame(summary)
    summary_df = summary_df.sort_values(['appears_in_targets', 'avg_correlation'], ascending=[False, False])

    return summary_df

In [30]:
summary_df = feature_summary(valid_features, importance_scores, reg_summary_df)

In [31]:
summary_df

Unnamed: 0,feature,appears_in_targets,targets,avg_correlation,avg_granger_sig,avg_mi,regression_sig
14,fossil_fuel_consumption,3,"[co2, gdp, primary_energy_consumption]",0.920757,0.473684,1.818021,3
20,population,3,"[co2, gdp, primary_energy_consumption]",0.877415,0.614035,1.702223,3
0,electricity_generation,3,"[co2, gdp, primary_energy_consumption]",0.874056,0.421053,1.355025,3
18,gas_consumption,3,"[co2, gdp, primary_energy_consumption]",0.847214,0.45614,1.475571,3
8,energy_per_capita,3,"[co2, gdp, primary_energy_consumption]",0.846111,0.245614,1.555621,3
9,coal_consumption,3,"[co2, gdp, primary_energy_consumption]",0.832639,0.368421,1.37031,3
1,oil_consumption,3,"[co2, gdp, primary_energy_consumption]",0.824997,0.350877,1.410547,3
11,low_carbon_consumption,3,"[co2, gdp, primary_energy_consumption]",0.818156,0.508772,1.249805,3
4,renewables_consumption,3,"[co2, gdp, primary_energy_consumption]",0.732484,0.491228,1.061657,3
13,co2_per_unit_energy,3,"[co2, gdp, primary_energy_consumption]",0.661884,0.263158,0.854495,3


## Visualisation

In [32]:
def plot_results(valid_features, corr_matrix, summary_df, save_dir='data_export/02_plots'):

    output_dir = os.path.join(save_dir, '02_feature_screen')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Summary plot
    fig, axes = plt.subplots(2, 2, figsize=(20, 12))

    # Correlation
    ax = axes[0, 0]
    top20_corr = summary_df.nlargest(20, 'avg_correlation')
    ax.barh(range(len(top20_corr)), top20_corr['avg_correlation'])
    ax.set_yticks(range(len(top20_corr)))
    ax.set_yticklabels(top20_corr['feature'], fontsize=7)
    ax.set_xlabel('Average Correlation')
    ax.set_title('Top 20 Features by Correlation')
    ax.invert_yaxis()

    # Granger significance
    ax = axes[0, 1]
    top20_granger = summary_df.nlargest(20, 'avg_granger_sig')
    ax.barh(range(len(top20_granger)), top20_granger['avg_granger_sig'])
    ax.set_yticks(range(len(top20_granger)))
    ax.set_yticklabels(top20_granger['feature'], fontsize=7)
    ax.set_xlabel('Average Granger Significance Rate')
    ax.set_title('Top 20 Features by Granger Causality')
    ax.invert_yaxis()

    # Feature counts for each step
    ax = axes[1, 0]
    steps = list(feature_counts_steps.keys())
    counts = list(feature_counts_steps.values())

    ax.bar(range(len(steps)), counts)
    ax.set_xticks(range(len(steps)))
    ax.set_xticklabels(steps, rotation=30, ha='right')
    ax.set_ylabel('Number of Features')
    ax.set_title('Feature Counts for each step')

    for i, count in enumerate(counts):
        ax.text(i, count + 1, str(count), ha='center')

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'screening_summary.png'), bbox_inches='tight')
    plt.close()

    # Correlation heatmap sampled 30
    if len(valid_features) > 30:
        sample_features = summary_df.head(30)['feature'].tolist()
    else:
        sample_features = valid_features
    
    plt.figure(figsize=(24, 20))
    selected_corr = corr_matrix.loc[sample_features, sample_features]
    
    mask = np.triu(np.ones_like(selected_corr, dtype=bool))
    sns.heatmap(selected_corr, mask=mask, annot=True, cmap='coolwarm', 
                center=0, square=True, linewidths=0.5, annot_kws={"fontsize": 6})
    plt.xticks(rotation=30)
    plt.title('Feature Correlation Matrix After Screening')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'screen_correlation_matrix.png'), bbox_inches='tight')
    plt.close()

In [33]:
plot_results(valid_features, corr_matrix, summary_df)

In [34]:
# Save valid features and summary df
def save_features_df(valid_features, summary_df, save_dir='data_export'):

    output_dir = os.path.join(save_dir, '02_results')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    with open(os.path.join(output_dir, 'selected_features.pkl'), 'wb') as f:
        pickle.dump(valid_features, f)

    summary_df.to_csv(os.path.join(output_dir, 'screen_summary.csv'), index=False)

In [35]:
save_features_df(valid_features, summary_df)