In [None]:
import csv
import numpy as np
import matplotlib.pyplot as plt

#绘制图像 y_pred为预测值，y为实际值
def plot_series(series, y=None, y_pred=None, y_pred_std=None, x_label="$t$", y_label="$x$"):
    #设置子图的数量为3行5列
    r, c = 3, 5
    #sharey=True 和 sharex=True 表示所有子图共享 y 轴和 x 轴的刻度。figsize=(20, 10)
    #fig表示画框，可以控制图像大小、分辨率等，axes是每一个画框里的画布
    fig, axes = plt.subplots(nrows=r, ncols=c, sharey=True, sharex=True, figsize=(20, 10))
    for row in range(r):
        for col in range(c):
            #设置当前的活动子图为 axes[row][col]
            plt.sca(axes[row][col])
            ix = col + row*c 
            #取出某一行的数据，用点表示数据位置，用实线进行连接
            plt.plot(series[ix, :], ".-")
            #绘制目标值，起始位置x坐标为len(series[ix, :])到len(series[ix, :])+len(y[ix])，y坐标为y[ix]，bx表示蓝色的叉
            if y is not None:
                plt.plot(range(len(series[ix, :]), len(series[ix, :])+len(y[ix])), y[ix], "bx", markersize=10)
            #绘制预测值，圆的红点
            if y_pred is not None:
                plt.plot(range(len(series[ix, :]), len(series[ix, :])+len(y_pred[ix])), y_pred[ix], "ro")
            #绘制预测值加上和减去标准差的曲线
            if y_pred_std is not None:
                plt.plot(range(len(series[ix, :]), len(series[ix, :])+len(y_pred[ix])), y_pred[ix] + y_pred_std[ix])
                plt.plot(range(len(series[ix, :]), len(series[ix, :])+len(y_pred[ix])), y_pred[ix] - y_pred_std[ix])
            #是否开启网格
            plt.grid(True)
            # plt.hlines(0, 0, 100, linewidth=1)
            # plt.axis([0, len(series[ix, :])+len(y[ix]), -1, 1])
            if x_label and row == r - 1:
              plt.xlabel(x_label, fontsize=16)
            if y_label and col == 0:
              plt.ylabel(y_label, fontsize=16, rotation=0)
    plt.show()

def generate_time_series(temp_values, batch_size, n_steps):
    # 生成用于填充到数据集矩阵，一共5000数据，取其中20条预测第21条
    series = np.zeros((batch_size, n_steps))  # 5000,21
    # 生成随机样本索引，sta_size表示有多少个气象站 158个
    sta_size = len(temp_values)
    # 从0到158随机生成5000条数据
    sta_idx = np.random.randint(0, sta_size, batch_size)  # 5000
    print(sta_idx)
    
    for i, idx in enumerate(sta_idx):
        # temp_values结构[[],[],[]]
        temps = temp_values[idx]  # 随机获取某个气象站的温度数据
        #判断这个气象站数据量有多少
        temp_size = len(temps)
        # 随机选取一个位置，获取这个位置之后的20条数据
        rnd_idx = np.random.randint(0, temp_size - n_steps)
        #每一行都是一个气象站
        series[i] = np.array(temps[rnd_idx:rnd_idx+n_steps])  # series (5000, 21)
    # 返回X和y  X(5000, 20, 1) y (5000, 1),说白一点，x取为5000个气象站20天的数据,y取最后一列，即第21天的数据
    return series[:,:n_steps,np.newaxis].astype(np.float32), series[:,-1,np.newaxis].astype(np.float32)

# 所有气象站最高气温数据,最后结构为一个键对一个数组，数组中是温度数据
stations_maxtemp = {}
with open('Summary of Weather.csv') as f:
    reader = csv.DictReader(f)
    for item in reader:
        sta = item['STA']
        stations_maxtemp[sta] = stations_maxtemp.get(sta,[])
        stations_maxtemp[sta].append(float(item['MaxTemp']))


temp_lengths = [len(temps) for temps in stations_maxtemp.values()]

#只绘制站点存有数据，且数据量大于20的站点天气曲线,此时max_temps结构为[[],[],[]]
max_temps = [temps for temps in stations_maxtemp.values() if len(temps) > 20]

# 过滤掉温度小于-17的极寒异常值,filted_maxtemps结构也为[[],[],[]]
filted_maxtemps = [[temp for temp in temps if temp > -17] for temps in max_temps]  


n_steps = 21
#[[],[]]
max_temps = filted_maxtemps

X_train,y_train = generate_time_series(max_temps, 7000, n_steps)
X_valid,y_valid = generate_time_series(max_temps, 2000, n_steps)
X_test,y_test = generate_time_series(max_temps, 1000, n_steps)
X_train.shape, y_train.shape


[ 30  77  57 ... 124  15  79]


In [2]:
import csv
import matplotlib.pyplot as plt

# 所有气象站最高气温数据,最后结构为一个键对一个数组，数组中是温度数据
stations_maxtemp = {}
with open('Summary of Weather.csv') as f:
    reader = csv.DictReader(f)
    for item in reader:
        sta = item['STA']
        stations_maxtemp[sta] = stations_maxtemp.get(sta,[])
        stations_maxtemp[sta].append(float(item['MaxTemp']))




#只绘制站点存有数据，且数据量大于20的站点天气曲线
max_temps = [temps for temps in stations_maxtemp.values() if len(temps) > 20]
print(max_temps)

# 过滤掉温度小于-17的极寒异常值
#filted_maxtemps = [[temp for temp in temps if temp > -17] for temps in max_temps]  



[[25.55555556, 28.88888889, 26.11111111, 26.66666667, 26.66666667, 26.66666667, 28.33333333, 26.66666667, 27.22222222, 25.55555556, 25.55555556, 24.44444444, 26.11111111, 26.11111111, 25.55555556, 26.11111111, 28.33333333, 26.66666667, 27.77777778, 26.11111111, 24.44444444, 24.44444444, 26.66666667, 27.22222222, 27.22222222, 25.55555556, 25.55555556, 26.66666667, 26.66666667, 27.22222222, 27.22222222, 26.11111111, 24.44444444, 24.44444444, 26.66666667, 26.66666667, 24.44444444, 24.44444444, 26.11111111, 26.11111111, 27.22222222, 25.0, 25.55555556, 25.55555556, 25.55555556, 25.55555556, 25.55555556, 25.55555556, 26.11111111, 27.77777778, 23.88888889, 32.77777778, 27.22222222, 25.0, 27.77777778, 26.66666667, 27.22222222, 27.22222222, 27.22222222, 25.0, 27.22222222, 27.22222222, 27.22222222, 27.22222222, 27.22222222, 27.22222222, 27.22222222, 26.66666667, 26.66666667, 27.22222222, 26.66666667, 28.33333333, 26.11111111, 27.22222222, 27.22222222, 27.77777778, 27.77777778, 30.0, 27.77777778,