Lets analyze what is the output of an RNN layer

In [1]:
import torch
import torchvision
import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
import torch.nn.functional as F  # All functions that don't have any parameters
from torch.utils.data import (
    DataLoader,
)  # Gives easier dataset managment and creates mini batches
import torchvision.datasets as datasets  # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms  # Transformations we can perform on our dataset

We will feed RNN an image of 28x28. This means that we will feed in one row at a time, 1 row contains 28 features(columns). The total sequence length would also be 28 (rows)

In [2]:
# Hyperparameters
input_size = 28
hidden_size = 256
num_layers = 1
num_classes = 10
sequence_length = 28
learning_rate = 0.005
batch_size = 64
num_epochs = 2

train_dataset = datasets.MNIST(
    root="dataset/", train=True, transform=transforms.ToTensor(), download=True
)

test_dataset = datasets.MNIST(
    root="dataset/", train=False, transform=transforms.ToTensor(), download=True
)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

Lets send 1 batch of images to this RNN. In our case we have set batch size to 64

In [3]:
images,labels=next(iter(train_loader))

In [4]:
images.shape

torch.Size([64, 1, 28, 28])

In [5]:
labels.shape

torch.Size([64])

In [17]:
class RNN(nn.Module):
    def __init__(self,input_size,hidden_size,num_classes):
        super(RNN,self).__init__()
        self.input_size=input_size
        self.hidden_size=hidden_size
        self.num_layers=num_layers
    
        self.rnn=nn.RNN(input_size=input_size,hidden_size=hidden_size,batch_first=True)
        
    def forward(self,x):
        out,_=self.rnn(x)
        return out

In [18]:
rnn=RNN(input_size,hidden_size,num_classes)

In [19]:
 images=images.squeeze(1)   ## Our RNN expects input to be of shape. Batch x seq_length x features. We will get rid of channels dimension
images.shape

torch.Size([64, 28, 28])

In [20]:
rnn(images).shape


torch.Size([64, 28, 256])

This indicates that RNN outputs this shape:  Batch X seq_length X Hidden size . Now if you want to apply FC after RNN, then you can simply reshape the output of RNN to out.reshape(out.shape(0), -1) . This would convert it to batch X whatever