In [185]:
import torch
import torch.nn as nn
from torch.optim import Adam
import random

import numpy as np
from torchsummary import summary


In [186]:
train_data = []
train_label = []
for _ in range(50):
    start = random.randint(0, 97)  # Random start to generate diverse sequences
    seq = [start, start+1, start+2]
    train_data.append(seq)
    train_label.append(start+3)  # Next number in sequence

In [196]:
print(len(train_data))
print(len(train_label))
print(train_data[1])
print(train_label[1])

50
50
[96, 97, 98]
99


In [187]:
# Convert to PyTorch tensors
X = torch.FloatTensor(train_data).unsqueeze(-1)  # Shape: [50, 3, 1]
y = torch.FloatTensor(train_label)

# Simple RNN model
class SimpleRNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.RNN(input_size=1, hidden_size=16, batch_first=True)
        self.linear = nn.Linear(16, 1)
        
    def forward(self, x):
        output, _ = self.rnn(x)
        return self.linear(output[:, -1, :]).squeeze()

# Create and train model
model = SimpleRNN()

In [188]:
loss_fn = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=0.01)

In [189]:
# Train
for epoch in range(500):
    model.train()
    optimizer.zero_grad()
    
    prediction = model(X)
    loss = loss_fn(prediction, y)
    
    loss.backward()
    optimizer.step()
    
    if epoch % 50 == 0:
        print(f'Epoch: {epoch} | Loss: {loss.item():.4f}')

Epoch: 0 | Loss: 3660.5479
Epoch: 50 | Loss: 2864.2576
Epoch: 100 | Loss: 2230.7095
Epoch: 150 | Loss: 1732.7688
Epoch: 200 | Loss: 1344.2004
Epoch: 250 | Loss: 1040.6659
Epoch: 300 | Loss: 805.0283
Epoch: 350 | Loss: 621.6252
Epoch: 400 | Loss: 480.4293
Epoch: 450 | Loss: 373.4489


In [190]:
# Test several sequences
test_sequences = [
    [7,8,9],
    [11,12,13],
    [20,21,22]
]

model.eval()
with torch.no_grad():
    for seq in test_sequences:
        test_input = torch.FloatTensor(seq).reshape(1, 3, 1)
        pred = model(test_input)
        print(f'\nSequence {seq} -> Predicted: {pred.item():.4f}, Expected: {seq[-1]+1}')


Sequence [7, 8, 9] -> Predicted: 10.0128, Expected: 10

Sequence [11, 12, 13] -> Predicted: 13.9411, Expected: 14

Sequence [20, 21, 22] -> Predicted: 22.9613, Expected: 23
