# Climate Prediction Model with K-Nearest Neighbors (KNN)


## 1. Imports and Setup
In this section, we will import the necessary libraries for developing the model.

In [None]:
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
import logging

# Enable logging
logging.basicConfig(level=logging.INFO)


## 2. KNNClimatePredictor Class
The class below encapsulates the training and prediction of the KNN model based on age.

In [None]:
class KNNClimatePredictor:
    def __init__(self, n_neighbors=3):
        self.models_by_age = {}
        self.scalers_by_age = {}
        self.climate_mapping = {'h': 1, 's': 2, 't': 3, 'd': 4}
        self.n_neighbors = n_neighbors

    def train(self, climate_data):
        try:
            climate_data['climate'] = climate_data['climate'].map(self.climate_mapping)
            climate_data = climate_data.dropna(subset=['climate'])

            unique_ages = climate_data['age'].unique()
            for age in unique_ages:
                age_data = climate_data[climate_data['age'] == age]
                features = age_data[['lat', 'long', 'paleontology_weight', 'lithology_weight', 'palynomorphs_weight', 'geochemistry_weight']]
                target = age_data['climate']

                scaler = StandardScaler()
                features_scaled = scaler.fit_transform(features)

                knn = KNeighborsClassifier(n_neighbors=self.n_neighbors)
                knn.fit(features_scaled, target)

                self.models_by_age[age] = knn
                self.scalers_by_age[age] = scaler
            logging.info('Model training completed successfully.')
        except Exception as e:
            logging.error(f'Error during training: {e}')

    def predict(self, lat, lon, paleontology_weight, lithology_weight, palynomorphs_weight, geochemistry_weight, age):
        try:
            if age not in self.models_by_age:
                raise ValueError(f'No trained model found for age: {age}')

            scaler = self.scalers_by_age[age]
            features = pd.DataFrame([[lat, lon, paleontology_weight, lithology_weight, palynomorphs_weight, geochemistry_weight]],
                                    columns=['lat', 'long', 'paleontology_weight', 'lithology_weight', 'palynomorphs_weight', 'geochemistry_weight'])
            features_scaled = scaler.transform(features)

            knn = self.models_by_age[age]
            prediction = knn.predict(features_scaled)
            return prediction[0]
        except Exception as e:
            logging.error(f'Error during prediction: {e}')
            return None


## 3. Load Climate Data
In this section, we will load the input data that will be used to train and test the KNN model.

In [None]:
# Example of loading data
data = pd.read_csv('dataset/points.csv')
display(data.head())


## 4. Train the Model
We will train the KNN model using the loaded climate data.

In [None]:
# Example usage
knn_predictor = KNNClimatePredictor(n_neighbors=3)
knn_predictor.train(data)


## 5. Predict with the Trained Model
After training, we can make climate predictions based on latitude, longitude, and age.

In [None]:
# Example prediction
prediction = knn_predictor.predict(lat=10.0, lon=20.0, age=140, paleontology_weight=0, lithology_weight=3, palynomorphs_weight=2, geochemistry_weight=1)
print(f'Predicted climate: {prediction}')


## 6. Plotting Points on the Map
Let's plot the prediction points on a map.

In [None]:
import geopandas as gpd
import matplotlib.pyplot as plt
import pandas as pd

# Load data for prediction
points_to_predict = pd.read_csv('dataset/points_to_predict.csv')
display(points_to_predict.head())

# Make predictions
points_to_predict['predicted_climate'] = points_to_predict.apply(
    lambda row: knn_predictor.predict(
        lat=row['lat'],
        lon=row['long'],
        paleontology_weight=row['paleontology_weight'],
        lithology_weight=row['lithology_weight'],
        palynomorphs_weight=row['palynomorphs_weight'],
        geochemistry_weight=row['geochemistry_weight'],
        age=row['age']
    ), axis=1
)

# Display predictions
display(points_to_predict.head())

# Convert DataFrame to GeoDataFrame
gdf = gpd.GeoDataFrame(
    points_to_predict,
    geometry=gpd.points_from_xy(points_to_predict.long, points_to_predict.lat)
)

# Plot points on the map
world = gpd.read_file('data/ne_110m_admin_0_countries.shp')  # Update the path to the local file
fig, ax = plt.subplots(figsize=(15, 10))
world.boundary.plot(ax=ax)
gdf.plot(ax=ax, color='red', markersize=5)
plt.title('Prediction Points on the Map')
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.show()