In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import pickle

# Load the CSV data
df = pd.read_csv('energy.csv')  


df.head()



Unnamed: 0.1,Unnamed: 0,Country,Energy_type,Year,Energy_consumption,Energy_production,GDP,Population,Energy_intensity_per_capita,Energy_intensity_by_GDP,CO2_emission
0,0,World,all_energy_types,1980,292.89979,296.337228,27770.910281,4298127.0,68.145921,10.547,4946.62713
1,1,World,coal,1980,78.656134,80.114194,27770.910281,4298127.0,68.145921,10.547,1409.790188
2,2,World,natural_gas,1980,53.865223,54.761046,27770.910281,4298127.0,68.145921,10.547,1081.593377
3,3,World,petroleum_n_other_liquids,1980,132.064019,133.111109,27770.910281,4298127.0,68.145921,10.547,2455.243565
4,4,World,nuclear,1980,7.5757,7.5757,27770.910281,4298127.0,68.145921,10.547,0.0


In [None]:

df = df.loc[:, ~df.columns.str.contains('^Unnamed')]


df.dropna(inplace=True)

# Check missing values after cleaning
print(df.isnull().sum())


Country                        0
Energy_type                    0
Year                           0
Energy_consumption             0
Energy_production              0
GDP                            0
Population                     0
Energy_intensity_per_capita    0
Energy_intensity_by_GDP        0
CO2_emission                   0
dtype: int64


In [None]:
df_all = df[df['Energy_type'] == 'all_energy_types'].copy()


df_all.head()


Unnamed: 0,Country,Energy_type,Year,Energy_consumption,Energy_production,GDP,Population,Energy_intensity_per_capita,Energy_intensity_by_GDP,CO2_emission
0,World,all_energy_types,1980,292.89979,296.337228,27770.910281,4298127.0,68.145921,10.547,4946.62713
1320,United States,all_energy_types,1980,78.021113,67.146595,7080.75,227119.0,343.525258,11.018764,4946.62713
1386,World,all_energy_types,1981,289.401022,291.287773,28665.819138,4377060.0,66.11768,10.095683,18701.97439
1434,Argentina,all_energy_types,1981,1.669763,1.671288,430.09,28398.0,58.798626,3.882358,93.029914
1452,Australia,all_energy_types,1981,3.049886,4.142083,389.0065,14957.07,203.909295,7.840192,216.00701


In [None]:

X = df_all[['Year', 'Country']]
y = df_all['Energy_consumption']


X_encoded = pd.get_dummies(X, columns=['Country'])


dummy_columns = X_encoded.columns

print("Encoded feature columns:")
print(dummy_columns)


Encoded feature columns:
Index(['Year', 'Country_Afghanistan', 'Country_Albania', 'Country_Algeria',
       'Country_American Samoa', 'Country_Angola',
       'Country_Antigua and Barbuda', 'Country_Argentina', 'Country_Armenia',
       'Country_Aruba',
       ...
       'Country_United States', 'Country_Uruguay', 'Country_Uzbekistan',
       'Country_Vanuatu', 'Country_Venezuela', 'Country_Vietnam',
       'Country_World', 'Country_Yemen', 'Country_Zambia', 'Country_Zimbabwe'],
      dtype='object', length=200)


In [None]:

X_train, X_test, y_train, y_test = train_test_split(X_encoded, y, test_size=0.2, random_state=42)

# Initialize and train the model
model = LinearRegression()
model.fit(X_train, y_train)

y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print("Mean Squared Error on test set:", mse)


Mean Squared Error on test set: 78.58470168639914


In [None]:
def predict_energy_consumption(year, country, model=model, dummy_columns=dummy_columns):
    
    input_df = pd.DataFrame({'Year': [year], 'Country': [country]})
    
    
    input_encoded = pd.get_dummies(input_df, columns=['Country'])
    
    
    input_encoded = input_encoded.reindex(columns=dummy_columns, fill_value=0)
    
    
    prediction = model.predict(input_encoded)
    return prediction[0]


pred = predict_energy_consumption(1980, 'World')
print("Predicted Energy Consumption for World in 1980:", pred)


Predicted Energy Consumption for World in 1980: 418.6786193847656


In [None]:

with open('energy_model.pkl', 'wb') as f:
    pickle.dump(model, f)


with open('dummy_columns.pkl', 'wb') as f:
    pickle.dump(dummy_columns, f)


NameError: name 'df' is not defined