In [None]:
from statsmodels.tsa.arima.model import ARIMA
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import warnings

warnings.filterwarnings("ignore")

ds = pd.read_csv("crop_yield_dataset.csv")
ds["Date"] = pd.to_datetime(ds["Date"])
ds["Year"] = ds["Date"].dt.year
ds["Month"] = ds["Date"].dt.month

crops = ["Tomato", "Wheat", "Corn", "Rice", "Barley", "Soybean", "Cotton", "Sugarcane", "Potato", "Sunflower"]

palette = sns.color_palette("tab10", n_colors=len(crops))

# --------------------
# ARIMA
# --------------------

# Filter dataset for tomato yields by year
# tomato = ds[ds["Crop_Type"] == "Tomato"].groupby("Year")["Crop_Yield"].mean()

# Fit ARIMA (p,d,q = simple example)
# Model = ARIMA(tomato, order=(1,1,1))  # AR=1, I=1, MA=1
# results = model.fit()
# AR (p) = auto-regression - depends on past values (like yield in 2022).
# I (d) = differencing - helps remove trends (e.g., steady growth each year).
# MA (q) = moving average - depends on past forecast errors.
# Look back 1 lag in yields,
# Difference once to remove trend,
# Use 1 lag of errors

plt.figure(figsize=(14, 7))

for i, crop in enumerate(crops):
    series = ds[ds["Crop_Type"] == crop].groupby("Year")["Crop_Yield"].mean()
    series.index = pd.to_datetime(series.index.astype(str))
    color = palette[i]

    # Observed
    plt.plot(series.index, series.values, label=f"{crop} (observed)", color=color)

    try:
        model = ARIMA(series, order=(1,1,1))
        results = model.fit()

        forecast_res = results.get_forecast(steps=3)
        mean_forecast = forecast_res.predicted_mean
        conf_int = forecast_res.conf_int()

        future_index = pd.date_range(start=series.index[-1] + pd.offsets.YearEnd(0), 
                                     periods=3, freq="YE")
        
        for year, pred, lower, upper in zip(future_index.year, mean_forecast, conf_int.iloc[:,0], conf_int.iloc[:,1]):
            print(f"{crop} {year}: {pred:.2f} (95% CI: {lower:.2f} – {upper:.2f})")
        
        forecast_index = [series.index[-1]] + list(future_index)
        forecast_values = [series.values[-1]] + list(mean_forecast)

        lower_ci = [series.values[-1]] + list(conf_int.iloc[:, 0])
        upper_ci = [series.values[-1]] + list(conf_int.iloc[:, 1])

        # Forecast + interval
        plt.plot(forecast_index, forecast_values, linestyle="--", marker="o", color=color, label=f"{crop} (forecast)")
        plt.fill_between(forecast_index,
                         lower_ci,
                         upper_ci,
                         color=color, alpha=0.2)

    except Exception as e:
        print(f"ARIMA failed for {crop}: {e}")



plt.title("Crop Yields with ARIMA Forecasts (Statsmodels)")
plt.xlabel("Year")
plt.ylabel("Crop Yield")
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.grid(True)
plt.show()

print(results.summary())