In [None]:
import torch
import torch.nn as nn

In [None]:
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.lstm(x)  # out: [batch, seq_len, hidden]
        out = self.fc(out[:, -1, :])  # take last time step
        return out


In [None]:
def train(model, dataloader, criterion, optimizer, device, epochs=10):
    model.to(device)
    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")


In [None]:
from torch.utils.data import Dataset, DataLoader
import numpy as np

class DummySequenceDataset(Dataset):
    def __init__(self, num_samples=1000, seq_len=20, input_size=1):
        self.x = torch.randn(num_samples, seq_len, input_size)
        self.y = torch.randn(num_samples, 1)  # regression target

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


In [None]:
input_size = 1
hidden_size = 64
num_layers = 2
output_size = 1
batch_size = 32
epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = DummySequenceDataset()
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = LSTMModel(input_size, hidden_size, num_layers, output_size)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train(model, dataloader, criterion, optimizer, device, epochs)
