In [1]:
import numpy as np
import pandas as pd 
from glob import glob
from os import path
import matplotlib.pyplot as plt
from torch.utils.data.dataset import Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import preprocessing
from torch.autograd import Variable
import torch.utils.data as utils

In [2]:
N = 1200
batch_size = 32
seq_len = 20
num_features = 1
window_size = 10

In [221]:
def generate_data(N, seq_len):
    N_train = int(N*0.8)
    X = np.zeros((N, seq_len), dtype=np.float32)
    y = np.zeros((N,1), dtype=np.float32)
    
    indices = np.random.randint(0, N, size=N//2)
    
    X[indices, 0] = 1.0
    y[indices, 0] = 1.0
    
    X_train, y_train = X[:N_train], y[:N_train]
    X_test, y_test = X[N_train:], y[N_train:]
    
    return X_train, y_train, X_test, y_test
    

In [222]:
def generate_synthetic_test_data(N, seq_len):
    N_train = int(N*0.8)
    X = np.reshape(np.arange(0, N*seq_len, dtype=np.float32),(N, seq_len))
    y = np.reshape(np.arange(0,N, dtype=np.float32), (N,1))
    
    
    X_train, y_train = X[:N_train], y[:N_train]
    X_test, y_test = X[N_train:], y[N_train:]
    
    return X_train, y_train, X_test, y_test

In [223]:
X_train, y_train, X_test, y_test = generate_data(N, seq_len)

In [224]:
class StatefulSequenceDataset(object):
    def __init__(self, X, y, batch_size, seq_len, window_size):
        self.X = X
        self.y = y
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.window_size = window_size
        N = X.shape[0]
        
        assert(N%batch_size == 0)
        assert(seq_len >= window_size)
        self.N = N
        
        # index into X,y matrices, ranges from 0:(N - batch_size)
        self.idx = 0
        
        # index into subsequence, ranges from 0:window_size
        self.subseq_idx = 0
        self.subseq_len = self.seq_len - self.window_size
        
    
    def get_batch(self):
        if self.subseq_idx > self.num_subsequences_per_batch():
            self.idx = (self.idx + self.batch_size) % self.N
            self.subseq_idx = 0
        
        ii = self.idx
        bs = self.batch_size
        
        jj = self.subseq_idx
        xx = self.X[ii:ii+bs, jj:jj+self.window_size]        
        batch_X = torch.from_numpy(xx)
        batch_y = torch.from_numpy(self.y[ii:ii+bs])
        
        self.subseq_idx += 1
        return batch_X, batch_y
    
    def num_batches_per_epoch(self):
        return self.N//self.batch_size
    
    def num_subsequences_per_batch(self):
        return self.seq_len - self.window_size
    
    

In [317]:
class Net(nn.Module):
    def __init__(self, num_features, num_hidden, num_lstm_layers, batch_size):
        super(Net, self).__init__()        
        self.num_hidden = num_hidden
        self.num_lstm_layers = num_lstm_layers
        self.batch_size = batch_size
        self.lstm1 = nn.LSTM(input_size=1, hidden_size=num_hidden, 
                             num_layers=num_lstm_layers, batch_first=True)
        self.fc1 = nn.Linear(num_hidden, 1)
        self.sigmoid = nn.Sigmoid()
        self.hidden = self.init_hidden()
        
    def init_hidden(self):
        return (Variable(torch.zeros(self.num_lstm_layers, self.batch_size, self.num_hidden, dtype=torch.float32)), Variable(torch.zeros(self.num_lstm_layers, self.batch_size, self.num_hidden, dtype=torch.float32)))
    
    def detach(self):
        for h in self.hidden:
            h = h.detach()
    
    def forward(self, x):
        
        lstm_out, self.hidden = self.lstm1(x, self.hidden)
        
        y_pred = self.sigmoid(self.fc1(lstm_out[:,-1]))
        return y_pred

In [318]:
ds = StatefulSequenceDataset(X_train, y_train, batch_size, seq_len, window_size)
num_epochs = 20
net = Net(num_features=1, num_hidden=64, num_lstm_layers=2, batch_size=batch_size)

In [319]:
learning_rate = 1e-3
loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

In [320]:
hist = list()
for epoch in range(num_epochs):
    for _ in range(ds.num_batches_per_epoch()):
        net.init_hidden()
        for _ in range(ds.num_subsequences_per_batch()):
            optimizer.zero_grad()
            net.detach()
            
            batch_data, batch_labels = ds.get_batch()
            batch_data = batch_data.unsqueeze(-1)
            y_pred = net(batch_data)
            loss = loss_fn(y_pred, batch_labels)
            
            loss.backward()
            optimizer.step()

    
#         hist.append(loss.item())


#         loss.backward()

#         optimizer.step()

    print("Epoch ", epoch, "BCE: ", loss.item())
            




RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

In [275]:
X_test_ = X_test[:224]
y_test_ = y_test[:224]

In [282]:
test_ds = StatefulSequenceDataset(X_test_, y_test_, batch_size=32,seq_len=seq_len, window_size=window_size)
               


In [283]:
xx, yy = test_ds.get_batch()

In [284]:
net(xx.unsqueeze(-1))

tensor([[0.2347],
        [0.9946],
        [0.9946],
        [0.2347],
        [0.9946],
        [0.9946],
        [0.9946],
        [0.9946],
        [0.9946],
        [0.2347],
        [0.9946],
        [0.9946],
        [0.9946],
        [0.9946],
        [0.9946],
        [0.9946],
        [0.9946],
        [0.9946],
        [0.2347],
        [0.9946],
        [0.2347],
        [0.2347],
        [0.9946],
        [0.9946],
        [0.2347],
        [0.2347],
        [0.9946],
        [0.2347],
        [0.9946],
        [0.9946],
        [0.9946],
        [0.9946]], grad_fn=<SigmoidBackward>)

In [285]:
yy

tensor([[0.],
        [0.],
        [1.],
        [0.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [1.],
        [0.],
        [1.],
        [1.],
        [1.],
        [0.],
        [1.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.]])