In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
class RNN_model(nn.Module):
    def __init__(self, n_inputs, n_neurons, n_outputs):
        super(RNN_model, self).__init__()
        self.rnn = nn.RNN(input_size= n_inputs, hidden_size=n_neurons, batch_first=True)
        self.fc = nn.Linear(n_neurons, n_outputs)

    def forward(self, x):
        out, _ = self.rnn(x)
        out = self.fc(out[:, -1, :]) #Output of last time step
        return out


In [7]:
# Hyperparameters
n_inputs = 3  # Input features
n_neurons = 5  # Number of RNN neurons
n_outputs = 2  # Number of classes
learning_rate = 0.01

# Create the model, loss function, and optimizer
model = RNN_model(n_inputs, n_neurons, n_outputs)
loss_fn = nn.CrossEntropyLoss()  # This includes softmax internally
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-3)

# Dummy data (10 samples, 3 features)
X = torch.randn(10, 1, n_inputs)  # 10 samples, 1 time-step, 3 features
y = torch.randint(0, n_outputs, size=(10,))  # 10 target labels (0 or 1)

# Training loop
for epoch in range(100):
    # Zero the gradients
    optimizer.zero_grad()

    # Forward pass
    outputs = model(X)
    loss = loss_fn(outputs, y)

    # Get predictions
    _, predictions = torch.max(outputs, 1)

    # Calculate accuracy
    accuracy = (predictions == y).float().mean()

    # Backward pass
    loss.backward()
    
    # Update weights
    optimizer.step()

    print(f"Epoch: {epoch}, Loss: {loss.item():.4f}, Accuracy: {accuracy:.2f}")


Epoch: 0, Loss: 0.8006, Accuracy: 0.30
Epoch: 1, Loss: 0.7857, Accuracy: 0.20
Epoch: 2, Loss: 0.7718, Accuracy: 0.20
Epoch: 3, Loss: 0.7587, Accuracy: 0.20
Epoch: 4, Loss: 0.7465, Accuracy: 0.20
Epoch: 5, Loss: 0.7349, Accuracy: 0.30
Epoch: 6, Loss: 0.7239, Accuracy: 0.30
Epoch: 7, Loss: 0.7133, Accuracy: 0.40
Epoch: 8, Loss: 0.7031, Accuracy: 0.40
Epoch: 9, Loss: 0.6931, Accuracy: 0.50
Epoch: 10, Loss: 0.6834, Accuracy: 0.70
Epoch: 11, Loss: 0.6739, Accuracy: 0.70
Epoch: 12, Loss: 0.6646, Accuracy: 0.70
Epoch: 13, Loss: 0.6556, Accuracy: 0.80
Epoch: 14, Loss: 0.6469, Accuracy: 0.80
Epoch: 15, Loss: 0.6385, Accuracy: 0.70
Epoch: 16, Loss: 0.6303, Accuracy: 0.60
Epoch: 17, Loss: 0.6224, Accuracy: 0.70
Epoch: 18, Loss: 0.6146, Accuracy: 0.70
Epoch: 19, Loss: 0.6069, Accuracy: 0.70
Epoch: 20, Loss: 0.5994, Accuracy: 0.70
Epoch: 21, Loss: 0.5921, Accuracy: 0.70
Epoch: 22, Loss: 0.5849, Accuracy: 0.60
Epoch: 23, Loss: 0.5780, Accuracy: 0.70
Epoch: 24, Loss: 0.5713, Accuracy: 0.70
Epoch: 25,