This is just an example of how to use an LSTM in a network module. 

In [None]:
import torch
import torch.nn as nn
from torch.nn import Conv1d, Dropout, Linear, LSTM, ELU, ReLU, GroupNorm
# more info about these different layers can be found here https://pytorch.org/docs/stable/nn.html 
 
class zModel(nn.Module):
    def __init__(self, device):
        super(zModel,self).__init__()
        # the first two numbers in each constructor correspond to the input and output channels of your data respectively 
        self.conv1=Conv1d(3,16,kernel_size=1, stride=1)
        self.conv2=Conv1d(16,16,kernel_size=2,stride=2)
        self.conv3=Conv1d(16,16,kernel_size=3,stride=2)
        self.fc1=Linear(10,32)
        self.fc2=Linear(5,32)
        self.fc3=Linear(2,32)
        self.fc4=Linear(32,128)
        self.fc5=Linear(128,64)
        self.fc6=Linear(64,32)
        self.fc7=Linear(32,1)
        # the first two numbers in this constructor are input channel size, hidden layer size, number of reccurrent layers 
        self.lstm1=LSTM(32,16,32)
        # the next two are dummy variables to keep track of the hidden states of the LSTM
        self.h1=torch.rand(32,1,16).to(device)
        self.c1=torch.rand(32,1,16).to(device)
        # the .1 in here stands for 10% dropout rate 
        self.drop=Dropout(.1)
        self.elu=ELU()
        self.relu=ReLU()
        # the arguments in this constructor correspond to the number of groups you want the layer to output, and the number of channels in the groups 
        self.laynorm=GroupNorm(1,32)

    def forward(self,x):
        x=self.conv1(x)
        x=self.drop(x)
        res1=self.fc1(x[:,-1:,:])

        x=self.conv2(x)
        x=self.drop(x)
        res2=self.fc2(x[:,-1:,:])

        x=self.conv3(x)
        x=self.drop(x)
        res3=self.fc3(x[:,-1:,:])

        x=self.drop(self.relu(self.fc4(x.view(1,1,-1))))
        x=self.drop(self.relu(self.fc5(x)))
        x=self.fc6(x)

        x=self.laynorm(self.elu(x+res1+res2+res3).view(1,32,1))
        x=self.drop(x)

        self.h1 = self.h1.detach()
        self.c1 = self.c1.detach()
        out,(h,c)=self.lstm1(x.view(1,1,-1),(self.h1,self.c1))
        self.h1 = h
        self.c1 = c

        out=self.fc7(x.view(1,1,-1))

        return out