In [1]:
import numpy as np
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import FunctionTransformer
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

In [2]:
# Morphological traits 1 ~ 16
df = pd.read_csv('herb_data.csv')
df = df[df["Putative_spp"] != "Quercus sp."]
df = df[df["Putative_spp"] != "Quercus buckleyi"]
#df = df[df["Putative_spp"] != "Quercus shumardii"]
#df = df[df["Putative_spp"] != "Quercus shumardii var. acerifolia first, Quercus shumardii later"]
df["spp"] = df["Putative_spp"].apply(lambda x: "S" if x == "Quercus shumardii" or x == "Quercus shumardii var. acerifolia first, Quercus shumardii later" else "R" if x == "Quercus rubra" else "A" if x == "Quercus shumardii var. acerifolia" else "N/A")

df["Longitude"] = df["Longitude"].abs()
data = df[["Latitude", "Longitude"]].dropna()

# Log-transform the data
transformer = FunctionTransformer(np.log, validate=True)
data_transformed = transformer.fit_transform(data)

# Standardize the data (important for PCA)
scaler = StandardScaler()
data_standardized = scaler.fit_transform(data_transformed)

In [None]:
# Fit GMM with 3 components
gmm = GaussianMixture(n_components=3, random_state=0)
gmm.fit(data_standardized)
cluster_labels = gmm.predict(data_standardized)

#use k-means clustering instead of GMM
# from sklearn.cluster import KMeans
# kmeans = KMeans(n_clusters=3, random_state=0)
# kmeans.fit(data_standardized)
# cluster_labels = kmeans.predict(data_standardized)

# Perform PCA
pca = PCA(n_components=2)
data_pca = pca.fit_transform(data_standardized)

# Plot the PCA-transformed data
plt.figure(figsize=(8, 6))
scatter = plt.scatter(data_pca[:, 0], data_pca[:, 1], c=cluster_labels, cmap='viridis', edgecolor='k', alpha=0.7)
plt.title('PCA of Leaf Traits Data with GMM Clustering')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.colorbar(scatter, label='Cluster')
plt.show()