In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [2]:
# Define the XOR input and target data
XOR_INPUT = np.array([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=np.float32)
XOR_TARGET = np.array([[0], [1], [1], [0]], dtype=np.float32)

In [3]:
# Convert the NumPy arrays to PyTorch tensors
inputs = torch.from_numpy(XOR_INPUT).view(1, 4, 2)  # Add a batch and sequence dimension
targets = torch.from_numpy(XOR_TARGET).view(1, 4, 1)  # Add a batch and sequence dimension

In [4]:
# Define the RNN model for XOR
class XORRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(XORRNN, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.rnn(x)
        out = self.linear(out)
        return out

In [5]:
# Create an instance of the XORRNN model
rnn_model = XORRNN(input_size=2, hidden_size=5, output_size=1)

# Define the loss function and the optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(rnn_model.parameters(), lr=0.01)

# Training the RNN for XOR
epochs = 100

In [6]:
for epoch in range(epochs):
    # Forward pass
    outputs = rnn_model(inputs)
    loss = criterion(outputs, targets)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print the loss every 1000 epochs
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')

Epoch 0, Loss: 0.3835149109363556
Epoch 10, Loss: 0.21913310885429382
Epoch 20, Loss: 0.11365991830825806
Epoch 30, Loss: 0.016928473487496376
Epoch 40, Loss: 0.006078544072806835
Epoch 50, Loss: 0.00344679388217628
Epoch 60, Loss: 0.0003252499445807189
Epoch 70, Loss: 0.000544689770322293
Epoch 80, Loss: 3.52094357367605e-05
Epoch 90, Loss: 6.425708852475509e-05


In [7]:
# Testing the RNN on XOR data
with torch.no_grad():
    test_outputs = rnn_model(inputs)
    for i in range(len(XOR_INPUT)):
        input_data = XOR_INPUT[i]
        output = test_outputs[0, i, 0].item()
        target = XOR_TARGET[i, 0]
        print(f'Input: {input_data}, Output: {output}, Target: {target}')


Input: [0. 0.], Output: 0.0018776878714561462, Target: 0.0
Input: [1. 0.], Output: 0.9964309930801392, Target: 1.0
Input: [0. 1.], Output: 0.9981204271316528, Target: 1.0
Input: [1. 1.], Output: 0.0024109743535518646, Target: 0.0
