In [41]:
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import joblib
from sklearn.impute import SimpleImputer


# 讀取上傳的文件
'''
north_data = pd.read_csv('D:/DS_Prediction/Weather/north_weekly_averages.csv')
south_data = pd.read_csv('D:/DS_Prediction/Weather/south_weekly_averages.csv')
central_data = pd.read_csv('D:/DS_Prediction/Weather/central_weekly_averages.csv')
east_data = pd.read_csv('D:/DS_Prediction/Weather/east_weekly_averages.csv')
fuel_prices = pd.read_csv('D:/DS_Prediction/fuel_prices.csv')
cabbage_prices = pd.read_csv('D:/DS_Prediction/Domestic_Cabbage.csv')
'''

north_data = pd.read_csv('C:/Users/$EKH000-V5FHVTC5DRPM/DS_Prediction/weather-csv/north_weekly_averages.csv')
south_data = pd.read_csv('C:/Users/$EKH000-V5FHVTC5DRPM/DS_Prediction/weather-csv/south_weekly_averages.csv')
central_data = pd.read_csv('C:/Users/$EKH000-V5FHVTC5DRPM/DS_Prediction/weather-csv/central_weekly_averages.csv')
east_data = pd.read_csv('C:/Users/$EKH000-V5FHVTC5DRPM/DS_Prediction/weather-csv/east_weekly_averages.csv')
fuel_prices = pd.read_csv('C:/Users/$EKH000-V5FHVTC5DRPM/DS_Prediction/fuel_prices.csv')
cabbage_prices = pd.read_csv('C:/Users/$EKH000-V5FHVTC5DRPM/DS_Prediction/vegetable-csv/Domestic_Cabbage.csv')




In [None]:
## 
''' 
Data Processing

'''
## 

# Merge regional data into a single DataFrame
regional_data_1 = pd.concat([north_data, south_data, central_data, east_data], ignore_index=True)


# Check and rename date columns if necessary
def ensure_date_column(df, possible_names):
    for col in df.columns:
        if col in possible_names:
            df.rename(columns={col: 'date'}, inplace=True)
            break
    return df

# Rename the date columns where applicable
regional_data = ensure_date_column(regional_data_1, ['週', 'date'])
fuel_prices = ensure_date_column(fuel_prices, ['Date', 'date', '週', '日期'])
cabbage_prices = ensure_date_column(cabbage_prices, ['週', 'date'])

# Convert date columns to datetime
def parse_date(df, column_name):
    if column_name in df.columns:
        df[column_name] = pd.to_datetime(df[column_name], errors='coerce', utc=True)
    return df

regional_data = parse_date(regional_data, 'date')
fuel_prices = parse_date(fuel_prices, 'date')
cabbage_prices = parse_date(cabbage_prices, 'date')

# Drop rows with missing or invalid 'date' values
for df in [regional_data, fuel_prices, cabbage_prices]:
    if 'date' in df.columns:
        df.dropna(subset=['date'], inplace=True)

# Ensure 'date' columns are consistent and datetimelike
for df in [regional_data, fuel_prices, cabbage_prices]:
    if 'date' in df.columns:
        df['date'] = pd.to_datetime(df['date']).dt.normalize()

# Extract additional features from date
def extract_date_features(df, date_column):
    df['year'] = df[date_column].dt.year
    df['month'] = df[date_column].dt.month
    df['week'] = df[date_column].dt.day // 7
    return df

cabbage_prices = extract_date_features(cabbage_prices, 'date')
fuel_prices = extract_date_features(fuel_prices, 'date')

# Merge all the data into a single DataFrame
try:
    merged_data = pd.merge(cabbage_prices, regional_data, on='date', how='left')
    merged_data = pd.merge(merged_data, fuel_prices, on=['year', 'month', 'week'], how='left')
except KeyError as e:
    raise KeyError(f"Error during merging: {e}. Please check that all dataframes contain a 'date' column.")

# 根據 Group 分組計算均值並填補缺失值
merged_data['Fuel_92'] = merged_data.groupby('month')['Fuel_92'].transform(lambda x: x.fillna(x.mean()))
merged_data['Fuel_95'] = merged_data.groupby('month')['Fuel_95'].transform(lambda x: x.fillna(x.mean()))
merged_data['Fuel_High'] = merged_data.groupby('month')['Fuel_High'].transform(lambda x: x.fillna(x.mean()))

# Handle missing values
merged_data.fillna(method='ffill', inplace=True)

# Prepare features and target variable
y = merged_data[['平均價', '交易量']]
X = merged_data.drop(columns=['date_x', '平均價', '交易量', 'year', 'month', 'week', 'date_y'])

# Save the merged DataFrame to a CSV file
output_file = "All_X.csv"
X.to_csv(output_file, index=False, encoding='utf-8-sig')

print(f"Merged data saved to {output_file}")

print(merged_data)

Merged data saved to Train_X.csv
                       date_x    平均價       交易量  year  month  week  平均氣壓(hPa)  \
0   2019-01-01 00:00:00+00:00  22.22   4487.67  2019      1     0    1019.18   
1   2019-01-01 00:00:00+00:00  22.22   4487.67  2019      1     0    1019.90   
2   2019-01-01 00:00:00+00:00  22.22   4487.67  2019      1     0    1020.22   
3   2019-01-01 00:00:00+00:00  22.22   4487.67  2019      1     0    1017.62   
4   2019-01-08 00:00:00+00:00  20.40   4323.00  2019      1     1    1016.19   
..                        ...    ...       ...   ...    ...   ...        ...   
922 2024-11-12 00:00:00+00:00  47.13  14259.17  2024     11     1    1006.90   
923 2024-11-25 00:00:00+00:00  40.70  17676.00  2024     11     3    1015.60   
924 2024-11-25 00:00:00+00:00  40.70  17676.00  2024     11     3    1014.70   
925 2024-11-25 00:00:00+00:00  40.70  17676.00  2024     11     3    1016.55   
926 2024-11-25 00:00:00+00:00  40.70  17676.00  2024     11     3    1013.20   

     平

  df[column_name] = pd.to_datetime(df[column_name], errors='coerce', utc=True)
  merged_data.fillna(method='ffill', inplace=True)


## XGBOOST

In [None]:
## 
''' 
Building the Model

'''
## 

In [None]:
## 
''' 
Training the Model

'''
## 

In [None]:
## 
''' 
Evaluation
1. 哪種蔬果準確度最高，哪種最低，並分析原因

'''
## 

In [None]:
# Merge regional data into a single DataFrame
regional_data = pd.concat([north_data, south_data, central_data, east_data], ignore_index=True)

# Check and rename date columns if necessary
def ensure_date_column(df, possible_names):
    for col in df.columns:
        if col in possible_names:
            df.rename(columns={col: 'date'}, inplace=True)
            break
    return df

# Rename the date columns where applicable
regional_data = ensure_date_column(regional_data, ['週', 'date'])
fuel_prices = ensure_date_column(fuel_prices, ['Date', 'date', '週', '日期'])
cabbage_prices = ensure_date_column(cabbage_prices, ['週', 'date'])

# Convert date columns to datetime
def parse_date(df, column_name):
    if column_name in df.columns:
        df[column_name] = pd.to_datetime(df[column_name], errors='coerce', utc=True)
    return df

regional_data = parse_date(regional_data, 'date')
fuel_prices = parse_date(fuel_prices, 'date')
cabbage_prices = parse_date(cabbage_prices, 'date')

# Drop rows with missing or invalid 'date' values
for df in [regional_data, fuel_prices, cabbage_prices]:
    if 'date' in df.columns:
        df.dropna(subset=['date'], inplace=True)

# Ensure 'date' columns are consistent and datetimelike
for df in [regional_data, fuel_prices, cabbage_prices]:
    if 'date' in df.columns:
        df['date'] = pd.to_datetime(df['date']).dt.normalize()

# Extract additional features from date
def extract_date_features(df, date_column):
    df['year'] = df[date_column].dt.year
    df['month'] = df[date_column].dt.month
    df['week'] = df[date_column].dt.isocalendar().week
    df['day_of_week'] = df[date_column].dt.dayofweek
    return df

cabbage_prices = extract_date_features(cabbage_prices, 'date')

# Merge all the data into a single DataFrame
try:
    merged_data = pd.merge(cabbage_prices, regional_data, on='date', how='left')
    merged_data = pd.merge(merged_data, fuel_prices, on='date', how='left')
except KeyError as e:
    raise KeyError(f"Error during merging: {e}. Please check that all dataframes contain a 'date' column.")

# Handle missing values
merged_data.fillna(method='ffill', inplace=True)

# Prepare features and target variable
X = merged_data.drop(columns=['date', '平均價'])
y = merged_data['平均價']

# Impute missing values in features
imputer = SimpleImputer(strategy='mean')
X = imputer.fit_transform(X)

# # Split the data into training and testing sets
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# # Set up the parameter grid for GridSearchCV
# param_grid = {
#     'max_depth': [3, 5, 7],
#     'learning_rate': [0.01, 0.1, 0.2],
#     'n_estimators': [100, 200, 300],
#     'subsample': [0.8, 1.0],
#     'colsample_bytree': [0.8, 1.0]
# }

# # Set up the XGBoost model
# xgb_model = xgb.XGBRegressor(objective='reg:squarederror', random_state=42)

# # Set up GridSearchCV
# grid_search = GridSearchCV(estimator=xgb_model, param_grid=param_grid, cv=3, scoring='neg_mean_squared_error', verbose=0, n_jobs=-1)

# # Fit the model using GridSearchCV
# grid_search.fit(X_train, y_train)

# # Get the best estimator and parameters
# best_model = grid_search.best_estimator_
# best_params = grid_search.best_params_
# print(f'Best Parameters: {best_params}')

# # Make predictions
# y_pred = best_model.predict(X_test)

# # Evaluate the model
# rmse = mean_squared_error(y_test, y_pred, squared=False)
# mae = mean_absolute_error(y_test, y_pred)
# r2 = r2_score(y_test, y_pred)
# print(f'Root Mean Squared Error: {rmse}')
# print(f'Mean Absolute Error: {mae}')
# print(f'R2 Score: {r2}')

# # Save the best model
# joblib.dump(best_model, 'cabbage_price_xgboost_best_model.pkl')

# # Load and test the model
# loaded_model = joblib.load('cabbage_price_xgboost_best_model.pkl')
# loaded_y_pred = loaded_model.predict(X_test)
# loaded_rmse = mean_squared_error(y_test, loaded_y_pred, squared=False)
# print(f'Loaded Model Root Mean Squared Error: {loaded_rmse}')


Merged data saved to merged_data_1.csv


  df[column_name] = pd.to_datetime(df[column_name], errors='coerce', utc=True)
  merged_data.fillna(method='ffill', inplace=True)


## SVR

In [None]:
import pandas as pd
from sklearn.svm import SVR
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.impute import SimpleImputer
import joblib

# Load the data
north_data = pd.read_csv('D:/DS_Prediction/Weather/north_weekly_averages.csv')
south_data = pd.read_csv('D:/DS_Prediction/Weather/south_weekly_averages.csv')
central_data = pd.read_csv('D:/DS_Prediction/Weather/central_weekly_averages.csv')
east_data = pd.read_csv('D:/DS_Prediction/Weather/east_weekly_averages.csv')
fuel_prices = pd.read_csv('D:/DS_Prediction/fuel_prices.csv')
cabbage_prices = pd.read_csv('D:/DS_Prediction/國產包心菜.csv')


In [42]:
# Merge regional data into a single DataFrame
regional_data = pd.concat([north_data, south_data, central_data, east_data], ignore_index=True)

# Check and rename date columns if necessary
def ensure_date_column(df, possible_names):
    for col in df.columns:
        if col in possible_names:
            df.rename(columns={col: 'date'}, inplace=True)
            break
    return df

# Rename the date columns where applicable
regional_data = ensure_date_column(regional_data, ['週', 'date'])
fuel_prices = ensure_date_column(fuel_prices, ['Date', 'date', '週', '日期'])
cabbage_prices = ensure_date_column(cabbage_prices, ['週', 'date'])

# Convert date columns to datetime
def parse_date(df, column_name):
    if column_name in df.columns:
        df[column_name] = pd.to_datetime(df[column_name], errors='coerce', utc=True)
    return df

regional_data = parse_date(regional_data, 'date')
fuel_prices = parse_date(fuel_prices, 'date')
cabbage_prices = parse_date(cabbage_prices, 'date')

# Drop rows with missing or invalid 'date' values
for df in [regional_data, fuel_prices, cabbage_prices]:
    if 'date' in df.columns:
        df.dropna(subset=['date'], inplace=True)

# Ensure 'date' columns are consistent and datetimelike
for df in [regional_data, fuel_prices, cabbage_prices]:
    if 'date' in df.columns:
        df['date'] = pd.to_datetime(df['date']).dt.normalize()

# Extract additional features from date
def extract_date_features(df, date_column):
    df['year'] = df[date_column].dt.year
    df['month'] = df[date_column].dt.month
    df['week'] = df[date_column].dt.isocalendar().week
    df['day_of_week'] = df[date_column].dt.dayofweek
    return df

cabbage_prices = extract_date_features(cabbage_prices, 'date')

# Merge all the data into a single DataFrame
try:
    merged_data = pd.merge(cabbage_prices, regional_data, on='date', how='left')
    merged_data = pd.merge(merged_data, fuel_prices, on='date', how='left')
except KeyError as e:
    raise KeyError(f"Error during merging: {e}. Please check that all dataframes contain a 'date' column.")

# Handle missing values
merged_data.fillna(method='ffill', inplace=True)

# Prepare features and target variable
X = merged_data.drop(columns=['date', '平均價'])
y = merged_data['平均價']

# Impute missing values in features
imputer = SimpleImputer(strategy='mean')
X = imputer.fit_transform(X)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Set up the parameter grid for GridSearchCV
param_grid = {
    'C': [0.1, 1, 10, 100],
    'gamma': [0.001, 0.01, 0.1, 1],
    'kernel': ['rbf']
}

# Set up the SVR model
svr_model = SVR()

# Set up GridSearchCV
grid_search = GridSearchCV(estimator=svr_model, param_grid=param_grid, cv=3, scoring='neg_mean_squared_error', verbose=0, n_jobs=-1)

# Fit the model using GridSearchCV
grid_search.fit(X_train, y_train)

# Get the best estimator and parameters
best_model = grid_search.best_estimator_
best_params = grid_search.best_params_
print(f'Best Parameters: {best_params}')

# Make predictions
y_pred = best_model.predict(X_test)

# Evaluate the model
rmse = mean_squared_error(y_test, y_pred, squared=False)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f'Root Mean Squared Error: {rmse}')
print(f'Mean Absolute Error: {mae}')
print(f'R2 Score: {r2}')

# Save the best model
joblib.dump(best_model, 'cabbage_price_svr_best_model.pkl')

# Load and test the model
loaded_model = joblib.load('cabbage_price_svr_best_model.pkl')
loaded_y_pred = loaded_model.predict(X_test)
loaded_rmse = mean_squared_error(y_test, loaded_y_pred, squared=False)
print(f'Loaded Model Root Mean Squared Error: {loaded_rmse}')


  df[column_name] = pd.to_datetime(df[column_name], errors='coerce', utc=True)
  merged_data.fillna(method='ffill', inplace=True)


Best Parameters: {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'}
Root Mean Squared Error: 5.984999371397111
Mean Absolute Error: 3.033320808642719
R2 Score: 0.8938061673671138
Loaded Model Root Mean Squared Error: 5.984999371397111




## CNN+Transfermor

In [64]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import joblib
# Load the data
north_data = pd.read_csv('D:/DS_Prediction/Weather/north_weekly_averages.csv')
south_data = pd.read_csv('D:/DS_Prediction/Weather/south_weekly_averages.csv')
central_data = pd.read_csv('D:/DS_Prediction/Weather/central_weekly_averages.csv')
east_data = pd.read_csv('D:/DS_Prediction/Weather/east_weekly_averages.csv')
fuel_prices = pd.read_csv('D:/DS_Prediction/fuel_prices.csv')
cabbage_prices = pd.read_csv('D:/DS_Prediction/國產包心菜.csv')

In [71]:
# 合并区域数据
regional_data = pd.concat([north_data, south_data, central_data, east_data], ignore_index=True)

# 处理日期列
def ensure_date_column(df, possible_names):
    for col in df.columns:
        if col in possible_names:
            df.rename(columns={col: 'date'}, inplace=True)
            break
    return df

# 重命名日期列
regional_data = ensure_date_column(regional_data, ['週', 'date'])
fuel_prices = ensure_date_column(fuel_prices, ['Date', 'date', '週', '日期'])
cabbage_prices = ensure_date_column(cabbage_prices, ['週', 'date'])

# 转换日期列
def parse_date(df, column_name):
    if column_name in df.columns:
        df[column_name] = pd.to_datetime(df[column_name], errors='coerce', utc=True)
    return df

regional_data = parse_date(regional_data, 'date')
fuel_prices = parse_date(fuel_prices, 'date')
cabbage_prices = parse_date(cabbage_prices, 'date')

# 去除缺失的日期值
for df in [regional_data, fuel_prices, cabbage_prices]:
    if 'date' in df.columns:
        df.dropna(subset=['date'], inplace=True)

# 确保日期列一致并标准化
for df in [regional_data, fuel_prices, cabbage_prices]:
    if 'date' in df.columns:
        df['date'] = pd.to_datetime(df['date']).dt.normalize()

# 提取日期特征
def extract_date_features(df, date_column):
    df['year'] = df[date_column].dt.year
    df['month'] = df[date_column].dt.month
    df['week'] = df[date_column].dt.isocalendar().week
    df['day_of_week'] = df[date_column].dt.dayofweek
    return df

cabbage_prices = extract_date_features(cabbage_prices, 'date')

# 合并数据
try:
    merged_data = pd.merge(cabbage_prices, regional_data, on='date', how='left')
    merged_data = pd.merge(merged_data, fuel_prices, on='date', how='left')
except KeyError as e:
    raise KeyError(f"合并数据时发生错误: {e}. 请检查所有数据框中是否包含 'date' 列.")

# 处理缺失值
merged_data.fillna(method='ffill', inplace=True)

# 准备特征和目标变量
X = merged_data.drop(columns=['date', '平均價'])
y = merged_data['平均價']

# 确保 X 和 y 的样本数量一致
assert X.shape[0] == y.shape[0], "X 和 y 的样本数量不一致!"

# 处理缺失的特征值
imputer = SimpleImputer(strategy='mean')
X = imputer.fit_transform(X)

# 标准化特征
scaler = StandardScaler()
X = scaler.fit_transform(X)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 重塑输入形状以适应CNN
X_train = X_train.reshape(-1, X_train.shape[1], 1)
X_test = X_test.reshape(-1, X_test.shape[1], 1)
def build_simplified_model(input_shape):
    inputs = layers.Input(shape=input_shape)
    
    # CNN 层
    x = layers.Conv1D(64, 3, activation='relu')(inputs)
    x = layers.MaxPooling1D(2)(x)  # 池化操作
    x = layers.Conv1D(128, 3, activation='relu')(x)
    x = layers.MaxPooling1D(2)(x)  # 池化操作
    
    # Flatten 层：将卷积输出展平成一维
    x = layers.Flatten()(x)
    
    # 全连接层
    x = layers.Dense(64, activation='relu')(x)
    x = layers.Dense(1)(x)  # 输出预测值

    # 创建模型
    model = models.Model(inputs=inputs, outputs=x)
    
    return model


# 创建模型并编译
model = build_simplified_model((X_train.shape[1], 1))
optimizer = optimizers.Adam(learning_rate=1e-3)

model.compile(optimizer=optimizer, loss='mean_squared_error')

# 打印模型概述
model.summary()

# 训练模型
model.fit(X_train, y_train, epochs=200, batch_size=32, validation_data=(X_test, y_test))

# 预测
y_pred = model.predict(X_test)

# 评估模型
rmse = mean_squared_error(y_test, y_pred, squared=False)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f"Root Mean Squared Error: {rmse}")
print(f"Mean Absolute Error: {mae}")
print(f"R2 Score: {r2}")

# 保存模型
model.save('cabbage_price_cnn_transformer_model.h5')

# 加载并测试模型
loaded_model = models.load_model('cabbage_price_cnn_transformer_model.h5')
loaded_y_pred = loaded_model.predict(X_test)
loaded_rmse = mean_squared_error(y_test, loaded_y_pred, squared=False)
print(f"Loaded Model Root Mean Squared Error: {loaded_rmse}")

  df[column_name] = pd.to_datetime(df[column_name], errors='coerce', utc=True)
  merged_data.fillna(method='ffill', inplace=True)


Epoch 1/200
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 13ms/step - loss: 1826.0809 - val_loss: 809.1030
Epoch 2/200
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - loss: 703.3657 - val_loss: 455.3844
Epoch 3/200
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 403.6914 - val_loss: 423.7955
Epoch 4/200
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - loss: 388.7868 - val_loss: 391.0662
Epoch 5/200
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 370.1884 - val_loss: 380.6324
Epoch 6/200
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - loss: 363.6079 - val_loss: 354.7948
Epoch 7/200
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 326.3534 - val_loss: 331.4171
Epoch 8/200
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 324.1828 - val_loss: 339.9662
Epoch 9/200
[



Root Mean Squared Error: 8.134504816459765
Mean Absolute Error: 5.8534705413285115
R2 Score: 0.8038296720186959




[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step
Loaded Model Root Mean Squared Error: 8.134504816459765




## LSTM

In [73]:
from tensorflow.keras import layers, models, optimizers

# 构建 LSTM 模型
def build_lstm_model(input_shape):
    inputs = layers.Input(shape=input_shape)
    
    # 第一层 LSTM 层
    x = layers.LSTM(64, return_sequences=True)(inputs)
    x = layers.Dropout(0.2)(x)
    
    # 第二层 LSTM 层
    x = layers.LSTM(128, return_sequences=True)(x)
    x = layers.Dropout(0.2)(x)
    
    # 第三层 LSTM 层
    x = layers.LSTM(256)(x)
    x = layers.Dropout(0.2)(x)
    
    # 全连接层
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dense(64, activation='relu')(x)
    x = layers.Dense(1)(x)  # 输出预测值

    # 创建模型
    model = models.Model(inputs=inputs, outputs=x)
    
    return model

# 创建 LSTM 模型并编译
model = build_lstm_model((X_train.shape[1], 1))

# 使用较小的学习率
optimizer = optimizers.Adam(learning_rate=1e-3)

model.compile(optimizer=optimizer, loss='mean_squared_error')

# 打印模型概述
model.summary()

# 训练模型，epochs 设置为 200
model.fit(X_train, y_train, epochs=200, batch_size=32, validation_data=(X_test, y_test))

# 预测
y_pred = model.predict(X_test)

# 评估模型
rmse = mean_squared_error(y_test, y_pred, squared=False)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f"Root Mean Squared Error: {rmse}")
print(f"Mean Absolute Error: {mae}")
print(f"R2 Score: {r2}")

# 保存模型
model.save('cabbage_price_lstm_model.h5')

# 加载并测试模型
loaded_model = models.load_model('cabbage_price_lstm_model.h5')
loaded_y_pred = loaded_model.predict(X_test)
loaded_rmse = mean_squared_error(y_test, loaded_y_pred, squared=False)
print(f"Loaded Model Root Mean Squared Error: {loaded_rmse}")


Epoch 1/200
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 55ms/step - loss: 1424.7906 - val_loss: 381.2544
Epoch 2/200
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 35ms/step - loss: 364.8336 - val_loss: 385.7947
Epoch 3/200
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 35ms/step - loss: 366.1913 - val_loss: 342.1292
Epoch 4/200
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 35ms/step - loss: 383.2879 - val_loss: 340.4164
Epoch 5/200
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 36ms/step - loss: 348.7037 - val_loss: 342.9824
Epoch 6/200
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 35ms/step - loss: 338.6981 - val_loss: 362.4497
Epoch 7/200
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 37ms/step - loss: 380.7781 - val_loss: 337.0110
Epoch 8/200
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 36ms/step - loss: 365.9332 - val_loss: 335.4888
Epoch 9



Root Mean Squared Error: 6.355207960706002
Mean Absolute Error: 4.716572159797915
R2 Score: 0.8802623832045587
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 89ms/step
Loaded Model Root Mean Squared Error: 6.355207960706002


