In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

ds = pd.read_csv("crop_yield_dataset.csv")

os.makedirs("plots", exist_ok=True)

print(ds.head())
print(ds.info())
print(ds.describe())

# Set a nice style
sns.set_theme(style="whitegrid")

# Make all plots bigger by default
plt.rcParams["figure.figsize"] = (10, 6)

print("1. LINE PLOT (trend over time)")
plt.figure()
sns.lineplot(x="Temperature", y="Crop_Yield", data=ds)
plt.title("Crop Yield vs Temperature (line)")
plt.savefig("plots/lineplot_yield_vs_temp.png", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

print("2. SCATTER PLOT (relationship between two variables)")
plt.figure()
sns.scatterplot(x="Humidity", y="Crop_Yield", hue="Crop_Type", data=ds)
plt.title("Crop Yield vs Humidity (scatter, colored by crop type)")
plt.savefig("plots/scatter_yield_vs_humidity.png", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

print("3. HISTOGRAM / DISTRIBUTION")
plt.figure()
sns.histplot(ds["Soil_pH"], bins=20, kde=True)
plt.title("Distribution of Soil pH")
plt.savefig("plots/hist_soil_pH.png", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

print("4. BOX PLOT (spread + outliers by category)")
plt.figure()
sns.boxplot(x="Soil_Type", y="Crop_Yield", data=ds)
plt.title("Crop Yield by Soil Type")
plt.savefig("plots/boxplot_yield_by_soiltype.png", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

print("5. BAR PLOT (average values across categories)")
plt.figure()
sns.barplot(x="Crop_Type", y="Soil_Quality", data=ds, errorbar=None)
plt.title("Average Soil Quality by Crop Type")
plt.savefig("plots/barplot_soilquality_by_crop.png", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

print("6. HEATMAP (correlation matrix)")
plt.figure(figsize=(12, 8))
corr = ds[["Soil_pH", "Temperature", "Humidity", "Wind_Speed", "N", "P", "K", "Crop_Yield", "Soil_Quality"]].corr()
sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f")
plt.title("Correlation Heatmap")
plt.savefig("plots/heatmap_correlations.png", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

print("7. PAIRPLOT (scatterplots for all numeric relationships)")
pair = sns.pairplot(ds[["Temperature", "Humidity", "Soil_pH", "Crop_Yield"]], hue="Crop_Yield")
pair.fig.suptitle("Pairplot Example (subset of features)", y=1.02)
pair.savefig("plots/pairplot_features.png", dpi=300, bbox_inches="tight")
plt.show()
plt.close()


         Date Crop_Type Soil_Type  Soil_pH  Temperature   Humidity  \
0  2014-01-01     Wheat     Peaty     5.50     9.440599  80.000000   
1  2014-01-01      Corn     Loamy     6.50    20.052576  79.947424   
2  2014-01-01      Rice     Peaty     5.50    12.143099  80.000000   
3  2014-01-01    Barley     Sandy     6.75    19.751848  80.000000   
4  2014-01-01   Soybean     Peaty     5.50    16.110395  80.000000   

   Wind_Speed     N     P     K  Crop_Yield  Soil_Quality  
0   10.956707  60.5  45.0  31.5    0.000000     22.833333  
1    8.591577  84.0  66.0  50.0  104.871310     66.666667  
2    7.227751  71.5  54.0  38.5    0.000000     27.333333  
3    2.682683  50.0  40.0  30.0   58.939796     35.000000  
4    7.696070  49.5  45.0  38.5   32.970413     22.166667  
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 36520 entries, 0 to 36519
Data columns (total 12 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   Date          36520