In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
import joblib 

df = pd.read_csv('data/global_wf_175_crops_average_2010_2019.csv', skiprows=3)
print(df.head())
# Drop rows with missing values (or use df.fillna() if you prefer)
df = df.dropna()

# Features and target
X = df.drop(columns=['wf_tot_m3_t'])
y = df['wf_tot_m3_t']

# Identify categorical and numerical features
categorical_features = ['crop_name', 'crop_group']
numerical_features = ['production_t', 'wfg_m3_t', 'wfb_cr_m3_t', 'wfb_i_m3_t']

# Preprocessing: encode categorical data
preprocessor = ColumnTransformer(
    transformers=[
        ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)
    ],
    remainder='passthrough'  # Keep numerical features as is
)

# Pipeline with preprocessing and model
model = Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('regressor', RandomForestRegressor(n_estimators=100, random_state=42))
])

# Split into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train model
model.fit(X_train, y_train)

# Predict
y_pred = model.predict(X_test)

# Evaluate
print("Mean Squared Error:", mean_squared_error(y_test, y_pred))
print("R^2 Score:", r2_score(y_test, y_pred))

joblib.dump(model, 'water_footprint_predictor.pkl')
print("Model saved as water_footprint_predictor.pkl")

  Global average production and unit water footprints of crops over 2010–2019  \
0  DOI: https://doi.org/10.4121/7b45bcc6-686b-404...                            
1                                                NaN                            
2                                          crop_code                            
3                                                 56                            
4                                                236                            

     Unnamed: 1  Unnamed: 2    Unnamed: 3   Unnamed: 4   Unnamed: 5  \
0           NaN         NaN           NaN          NaN          NaN   
1           NaN         NaN           NaN          NaN          NaN   
2     crop_name  crop_group  production_t     wfg_m3_t  wfb_cr_m3_t   
3  Maize (corn)     Cereals    1024541157  657.1664267  7.484206039   
4    Soya beans   Oil crops   303854111.5  1548.750728   30.2771554   

    Unnamed: 6   Unnamed: 7  
0          NaN          NaN  
1          NaN          Na

KeyError: "['wf_tot_m3_t'] not found in axis"