# GAM modelling MYDeviation

In [18]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import statsmodels.api as sm
import statsmodels.formula.api as smf
import scipy.stats as stats

from pygam import LinearGAM, s, f

In [19]:
# Load data
# milk_data = pd.read_csv('../Data/MergedData/HeatApproachCleanedYieldDataTestQuantile61.csv', low_memory=False)
# milk_data = pd.read_csv('../Data/MergedData/HeatApproachCleanedYieldDataTestQuantile67.csv', low_memory=False)

milk_data = pd.read_csv("../Data/MergedData/QuantileRerunTHI61.csv", low_memory=False)

col_keep = ["Date", "FarmName_Pseudo", "SE_Number", "LactationNumber", "BreedName", "DaysInMilk", "DailyYield", "ExpectedYield", "MeanTHI_adj", "YearSeason"]
milk_data = milk_data[col_keep]

In [None]:
# Make HYS effect
milk_data["HYS"] = milk_data["FarmName_Pseudo"].astype(str) + milk_data["YearSeason"].astype(str)
milk_data

In [None]:
# Shift THI reading down for each day
milk_data["THI1d"] = milk_data.groupby(["SE_Number", "LactationNumber"])["MeanTHI_adj"].shift(1)
milk_data["THI2d"] = milk_data.groupby(["SE_Number", "LactationNumber"])["MeanTHI_adj"].shift(2)
milk_data["THI3d"] = milk_data.groupby(["SE_Number", "LactationNumber"])["MeanTHI_adj"].shift(3)
milk_data

In [22]:
# Make Parity
milk_data["Parity"] = milk_data["LactationNumber"]
milk_data.loc[(milk_data['LactationNumber'] >= 3) & (milk_data['LactationNumber'] <= 8), 'Parity'] = 3

In [None]:
milk_data["MYDeviation"] = milk_data["DailyYield"] - milk_data["ExpectedYield"]
milk_data

In [24]:
# Define conditions and corresponding values
conditions = [
    milk_data['BreedName'] == 'NRDC',
    milk_data['BreedName'] == 'SLB',
    milk_data["BreedName"] == 'SJB',
    milk_data["BreedName"] == "Dairy crosses"
]
choices = [1, 2, 3, 4]

# Create 'ras' column
milk_data['ras'] = np.select(conditions, choices, default=0) 

In [None]:
# Check for NaN values
print(milk_data['MYDeviation'].isna().sum())

# Check for Inf values
print(np.isinf(milk_data['MYDeviation']).sum())

## First simple model

In [None]:
# Verify that columns exist
assert 'MeanTHI_adj' in milk_data.columns, "'MeanTHI_adj' column missing"
assert 'MYDeviation' in milk_data.columns, "'MYDeviation' column missing"

# Clean the data to remove NaN values
# milk_data = milk_data[milk_data["MeanTHI_adj"].notna()]
# milk_data = milk_data[milk_data["MYDeviation"].notna()]

# Remove NaN or Inf values
X = milk_data[['MeanTHI_adj']].values  # Independent variable (MeanTHI_adj)
y = milk_data['MYDeviation'].values    # Dependent variable (MYDeviation)

# Remove NaN or Inf values
# y_clean = y.replace([np.inf, -np.inf], np.nan).dropna()
# X_clean = X.replace([np.inf, -np.inf], np.nan).dropna()

# Remove NaN or Inf values using NumPy
mask = np.isfinite(X).all(axis=1) & np.isfinite(y)  # Ensure all elements are finite
X_clean = X[mask]
y_clean = y[mask]


# Fit the GAM model
gam = LinearGAM(s(0))  # s(0) indicates a smoothing term for the first (and only) feature
gam.fit(X_clean, y_clean)

# Print the summary of the model
print(gam.summary())

# Plot the results
plt.figure(figsize=(8, 6))
plt.plot(X_clean, y_clean, 'o', label='Observed data', alpha=0.5)  # Plot data points

# Reshape X for the prediction line
XX = np.linspace(X_clean.min(), X_clean.max(), 1000).reshape(-1, 1)  # Ensure the correct shape
plt.plot(XX, gam.predict(XX), label='Fitted curve', color='red')

plt.title('GAM Fit for MeanTHI_adj vs MYDeviation')
plt.xlabel('MeanTHI_adj')
plt.ylabel('MYDeviation')
plt.legend()
plt.show()


## Expanded models

In [None]:
# Verify that columns exist
# assert 'MeanTHI_adj' in milk_data.columns, "'MeanTHI_adj' column missing"
# assert 'MYDeviation' in milk_data.columns, "'MYDeviation' column missing"

# Clean the data to remove NaN values
# milk_data = milk_data[milk_data["MeanTHI_adj"].notna()]
# milk_data = milk_data[milk_data["MYDeviation"].notna()]

# Remove NaN or Inf values
X = milk_data[['MeanTHI_adj', 'DaysInMilk']].values  # Independent variable (MeanTHI_adj)
y = milk_data['MYDeviation'].values    # Dependent variable (MYDeviation)

# Remove NaN or Inf values
# y_clean = y.replace([np.inf, -np.inf], np.nan).dropna()
# X_clean = X.replace([np.inf, -np.inf], np.nan).dropna()

# Remove NaN or Inf values using NumPy
mask = np.isfinite(X).all(axis=1) & np.isfinite(y)  # Ensure all elements are finite
X_clean = X[mask]
y_clean = y[mask]

# Fit the GAM model
gam = LinearGAM(s(0) + s(1))  # s(0) indicates a smoothing term for the first (and only) feature
gam.fit(X_clean, y_clean)

# Print the summary of the model
print(gam.summary())

# Plot MeanTHI_adj vs DaysInMilk
plt.figure(figsize=(8, 6))
plt.plot(X_clean[:, 0], X_clean[:, 1], 'o', label='Observed data', alpha=0.5)  # Plot data points for MeanTHI_adj vs DaysInMilk

# Reshape X for the prediction line (predicting based on both features)
XX = np.linspace(X_clean[:, 0].min(), X_clean[:, 0].max(), 1000).reshape(-1, 1)  # MeanTHI_adj axis
YY = np.linspace(X_clean[:, 1].min(), X_clean[:, 1].max(), 1000).reshape(-1, 1)  # DaysInMilk axis

# Combine the two features for prediction and plot the fitted curve
# For this, we can predict the output across both axes
predicted = gam.predict(np.column_stack([XX, YY]))
plt.plot(XX, predicted, label='Fitted curve', color='red')

plt.title('GAM Fit for MeanTHI_adj vs DaysInMilk')
plt.xlabel('MeanTHI_adj')
plt.ylabel('DaysInMilk')
plt.legend()
plt.show()

In [None]:
# Check distribution of DaysInMilk to ensure have data
plt.hist(milk_data['DaysInMilk'], bins=20, color='skyblue', edgecolor='black')
plt.title('Histogram of Milk Yield Difference')
plt.xlabel('Milk Yield Difference')
plt.ylabel('Frequency')
plt.show()

Adding more parameters

In [None]:
# Verify that columns exist
assert 'MeanTHI_adj' in milk_data.columns, "'MeanTHI_adj' column missing"
assert 'MYDeviation' in milk_data.columns, "'MYDeviation' column missing"
assert 'DaysInMilk' in milk_data.columns, "'DaysInMilk' column missing"
assert 'Parity' in milk_data.columns, "'Parity' column missing"  # Verify 'Parity' is in the dataset

# Clean the data to remove NaN values for the necessary columns
milk_data = milk_data[milk_data[['MeanTHI_adj', 'DaysInMilk', 'MYDeviation', 'Parity']].notna().all(axis=1)]

# Prepare the data
X = milk_data[['MeanTHI_adj', 'DaysInMilk', 'Parity']].values  # Include Parity as a predictor
y = milk_data['MYDeviation'].values    # Dependent variable (MYDeviation)

# Remove NaN or Inf values using NumPy
mask = np.isfinite(X).all(axis=1) & np.isfinite(y)  # Ensure all elements are finite
X_clean = X[mask]
y_clean = y[mask]

# Fit the GAM model with smoothing terms for 'MeanTHI_adj' and 'DaysInMilk', and 'Parity' as a categorical variable
gam = LinearGAM(s(0) + s(1) + f(2))  # f(2) indicates Parity as a categorical variable (column index 2)
gam.fit(X_clean, y_clean)

# Print the summary of the model
print(gam.summary())

# Plot the observed data points
plt.figure(figsize=(10, 6))
scatter = plt.scatter(X_clean[:, 0], y_clean, c=y_clean, cmap='viridis', alpha=0.6) # note, should be able to remove cmpa alt to get "normal" plot
plt.colorbar(scatter, label='MYDeviation')

# Create a range of values for MeanTHI_adj (THI), keeping DaysInMilk constant
mean_thi_range = np.linspace(X_clean[:, 0].min(), X_clean[:, 0].max(), 100)  # MeanTHI_adj (THI)
days_in_milk_constant = np.mean(X_clean[:, 1])  # Keep DaysInMilk constant at its mean value

# Create prediction lines for each Parity level
for parity in [1, 2, 3]:
    X_grid = np.column_stack([mean_thi_range, np.full_like(mean_thi_range, days_in_milk_constant), np.full_like(mean_thi_range, parity)])
    y_pred = gam.predict(X_grid)
    plt.plot(mean_thi_range, y_pred, label=f'Parity {parity}')

# Add labels and title
plt.title('Effect of Parity on MYDeviation with MeanTHI_adj as X-axis')
plt.xlabel('MeanTHI_adj (THI)')
plt.ylabel('MYDeviation')
plt.legend(title="Parity")
plt.show()


Predict for specific DIMs

In [None]:
# Verify that columns exist
assert 'MeanTHI_adj' in milk_data.columns, "'MeanTHI_adj' column missing"
assert 'MYDeviation' in milk_data.columns, "'MYDeviation' column missing"
assert 'DaysInMilk' in milk_data.columns, "'DaysInMilk' column missing"
assert 'Parity' in milk_data.columns, "'Parity' column missing"  # Verify 'Parity' is in the dataset

# Clean the data to remove NaN values for the necessary columns
milk_data = milk_data[milk_data[['MeanTHI_adj', 'DaysInMilk', 'MYDeviation', 'Parity']].notna().all(axis=1)]

# Prepare the data
X = milk_data[['MeanTHI_adj', 'DaysInMilk', 'Parity']].values  # Include Parity as a predictor
y = milk_data['MYDeviation'].values    # Dependent variable (MYDeviation)

# Remove NaN or Inf values using NumPy
mask = np.isfinite(X).all(axis=1) & np.isfinite(y)  # Ensure all elements are finite
X_clean = X[mask]
y_clean = y[mask]

# Fit the GAM model with smoothing terms for 'MeanTHI_adj' and 'DaysInMilk', and 'Parity' as a categorical variable
gam = LinearGAM(s(0) + s(1) + f(2))  # f(2) indicates Parity as a categorical variable (column index 2)
gam.fit(X_clean, y_clean)

# Print the summary of the model
print(gam.summary())

# Predictions at specific DIM (DaysInMilk) values: 30, 150, and 220
dim_values = [30, 150, 220]  # List of DIM values for predictions

# Create a range of values for MeanTHI_adj (THI)
mean_thi_range = np.linspace(X_clean[:, 0].min(), X_clean[:, 0].max(), 100)  # MeanTHI_adj (THI)

# Plot setup
plt.figure(figsize=(10, 6))

# Loop through DIM values and Parity levels to generate predictions
for dim in dim_values:
    for parity in [1, 2, 3]:  # Assuming Parity levels 1, 2, and 3
        # Create a grid of MeanTHI_adj, specific DIM, and Parity
        X_grid = np.column_stack([
            mean_thi_range,               # MeanTHI_adj range
            np.full_like(mean_thi_range, dim),  # Set DIM to current value
            np.full_like(mean_thi_range, parity)  # Set Parity to current level
        ])
        # Predict MYDeviation for the grid
        y_pred = gam.predict(X_grid)
        
        # Plot predictions
        plt.plot(mean_thi_range, y_pred, 
                 label=f'DIM {dim}, Parity {parity}')

# Add labels and title
plt.title('Predicted MYDeviation at Specific DIM Values')
plt.xlabel('MeanTHI_adj (THI)')
plt.ylabel('MYDeviation')
plt.legend(title="DIM and Parity")
plt.tight_layout()
plt.show()