In [1]:
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch
import torch.nn as nn

In [2]:
class WindDataSet(Dataset):
    def __init__(self,path,num_steps=50):
        self.data = []
        file = pd.read_csv(path,skiprows=1)
        self.wind = np.array(file['wind speed at 10m (m/s)']).reshape(1,-1)[0]
        wind_len = len(self.wind)
        for i in range(wind_len-num_steps-1):
            self.data.append((self.wind[i:i+num_steps],self.wind[i+1:i+num_steps+1]))
        self.data = self.data[:int(len(self.data)/50)*50]
    def __len__(self):
        return len(self.data)
    def __getitem__(self,index):
        seq,pre = self.data[index]
        return seq,pre


In [3]:
def try_gpu(i=0):
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')
try_gpu()

device(type='cpu')

In [4]:
INPUT_SIZE=1
HIDDEN_SIZE=128
BATCH_SIZE=50
DROP_RATE=0.2

class lstm(nn.Module):
    def __init__(self):
        super(lstm,self).__init__()
        self.rnn = nn.LSTM(input_size=INPUT_SIZE,hidden_size=HIDDEN_SIZE)
        self.fc = nn.Linear(HIDDEN_SIZE,1)
    def forward(self,x,state):
        out,state = self.rnn(x.T.reshape((x.shape[1],-1,1)),state)
        out = self.fc(out)
        return out,state
    def begin_state(self,batch_size,device):
        return (torch.zeros((1,batch_size,HIDDEN_SIZE),device=device),torch.zeros((1,batch_size, HIDDEN_SIZE), device=device))

In [5]:
net = lstm()
lr = 0.001
device = try_gpu()
print(device)
net = net.to(device)
optimizer = torch.optim.Adam(net.parameters(),lr=lr)
loss = nn.MSELoss()
epochs = 1
num_steps=50

cpu


In [8]:
def grad_clipping(net,theta):
    if isinstance(net,nn.Module):
        params = [p for p in net.parameters() if p.requires_grad]
    else:
        params = net.params
    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
    if norm > theta:
        for param in params:
            param.grad[:] *= theta/norm

In [9]:

def train_epoch(epoch,net,train_loader,device,train_loss):
    net.train()
    state=net.begin_state(batch_size=BATCH_SIZE, device=device)
    runing_loss=0
    for batch_idx,(X,y) in enumerate(train_loader):
        for s in state:
            s.detach_()
        optimizer.zero_grad()
        X,y = X.to(torch.float32).to(device),y.to(torch.float32).T.to(device)
        y_hat,state = net(X,state)
        y_hat = y_hat.reshape(num_steps,BATCH_SIZE)
        l = loss(y_hat,y)
        l.backward()
        grad_clipping(net, 1)
        optimizer.step()
        runing_loss += l.item()
        if batch_idx%100 == 99:
            print(f'epoch:{epoch+1},batch_idx:{batch_idx+1},running_loss:{runing_loss/100}')
            runing_loss = 0
        train_loss.append(l.item())


In [10]:
dataset = WindDataSet('../data/wind_dataset144-2014/wind_dataset144/1.csv',num_steps)
print(dataset.__len__())

105050


In [11]:
def train(epochs):
    train_loss = []
    for epoch in range(epochs):
        for i in range(120):
            dataset = WindDataSet(f'./datasets/{i}.csv',num_steps)
            train_loader = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=0)
            train_epoch(epoch,net,train_loader,device,train_loss)
        print(f'###epoch:{epoch+1},train_loss:{train_loss[-1]}')

In [31]:
train(epochs)

epoch:1,batch_idx:100,running_loss:6.337180415913463
epoch:1,batch_idx:200,running_loss:0.6713938944879919
epoch:1,batch_idx:300,running_loss:0.42530034283641727
epoch:1,batch_idx:400,running_loss:0.28629369055619464
epoch:1,batch_idx:500,running_loss:0.20516959247179328
epoch:1,batch_idx:600,running_loss:0.15068694863002746
epoch:1,batch_idx:700,running_loss:0.152924257008126
epoch:1,batch_idx:800,running_loss:0.1843894370063208
epoch:1,batch_idx:900,running_loss:0.21094414785504342
epoch:1,batch_idx:1000,running_loss:0.2295802732463926
epoch:1,batch_idx:1100,running_loss:0.17408983214292675
epoch:1,batch_idx:1200,running_loss:0.23392418023664505
epoch:1,batch_idx:1300,running_loss:0.0941151354345493
epoch:1,batch_idx:1400,running_loss:0.12239223188487813
epoch:1,batch_idx:1500,running_loss:0.07022963202325627
epoch:1,batch_idx:1600,running_loss:0.07245184284052812
epoch:1,batch_idx:1700,running_loss:0.021008826259057967
epoch:1,batch_idx:1800,running_loss:0.01834531005937606
epoch:1,

In [32]:
def save_net(path = 'wind.pt', net=None):
    torch.save(net.state_dict(),path)
save_net(net=net)

In [12]:
def load_net(path='wind_cpu.pt',net=None):
    net.load_state_dict(torch.load(path))
load_net(net=net)

In [15]:
def predict(prefix,num_preds,net,device):
    state = net.begin_state(batch_size=1,device=device)
    outputs = [prefix[0]]
    get_input = lambda:torch.tensor([outputs[-1]],device=device).reshape(1,1)
    for y in prefix[1:]:
        _,state = net(get_input(),state)
        outputs.append(y)
    for _ in range(num_preds):
        y,state = net(get_input(),state)
        outputs.append(y.reshape(1).detach()[0])
    tmp = []
    for t in outputs:
        tmp.append(t.item())
    outputs =tmp
    return outputs#torch.cat(outputs,dim=1)


In [22]:
test_dataset = WindDataSet('../data/wind_dataset144-2014/wind_dataset144/142.csv',num_steps=50)
test_loader = DataLoader(test_dataset,batch_size=1,shuffle=False)
preds,truth=None,None
for batch_idx,(X, y) in enumerate(test_loader):
    if batch_idx == 0:
        X = (X.reshape(-1).to(torch.float32))
        preds = predict(X,50,net,device)
    if batch_idx == 1:
        truth = np.array(X)
        break
print((preds[50:]))
print((truth.reshape(50)))



[6.985098838806152, 7.161799430847168, 7.4930925369262695, 7.9666666984558105, 8.546405792236328, 9.145689964294434, 9.714990615844727, 10.25847339630127, 10.766324996948242, 11.239696502685547, 11.693398475646973, 12.125823974609375, 12.534660339355469, 12.923495292663574, 13.292140007019043, 13.64033031463623, 13.969825744628906, 14.281046867370605, 14.57377815246582, 14.848335266113281, 15.104844093322754, 15.34335708618164, 15.56422233581543, 15.767910957336426, 15.954975128173828, 16.12609100341797, 16.282011032104492, 16.42353630065918, 16.551515579223633, 16.66683006286621, 16.770366668701172, 16.863008499145508, 16.945615768432617, 17.019012451171875, 17.0839786529541, 17.141246795654297, 17.191499710083008, 17.23537254333496, 17.273448944091797, 17.3062686920166, 17.33432388305664, 17.358062744140625, 17.377893447875977, 17.394187927246094, 17.40727424621582, 17.417449951171875, 17.424985885620117, 17.43012809753418, 17.433107376098633, 17.43414878845215]
[9.1  9.14 9.15 9.13 