In [23]:
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

In [37]:
T = 200

In [38]:
time = torch.arange(1, T + 1, dtype=torch.float32)
x = torch.sin(0.01 * time) + torch.randn(T) * 0.2
tau = 50

In [39]:
features = torch.stack([x[t: T - tau + t] for t in range(tau)], dim=1)
labels = x[tau:].reshape((-1, 1))

In [40]:
features.shape

torch.Size([150, 50])

In [41]:
labels.shape

torch.Size([150, 1])

In [42]:
dataloader = DataLoader(TensorDataset(features[:len(features) - 100], labels[:len(features) - 100]), batch_size=1, shuffle=True)

In [43]:
val_loader = DataLoader(TensorDataset(features[-100:], labels[-100:]), batch_size=1, shuffle=False)

In [44]:
print(next(iter(dataloader)))

[tensor([[ 0.0329, -0.1116, -0.0430,  0.2754,  0.5573,  0.2633,  0.3522,  0.1889,
          0.4153, -0.2685,  0.2520,  0.2593,  0.0136,  0.0028,  0.3988,  0.3097,
          0.2838,  0.3626,  0.3331,  0.1046,  0.3038, -0.0018,  0.5099,  0.4191,
          0.4650,  0.3839,  0.5900,  0.3472,  0.8398,  0.1206,  0.7195,  0.1391,
          0.7940,  0.2899,  0.4240,  0.3432,  0.3446,  0.3947,  0.8026,  0.1959,
          0.4477,  0.5315,  0.3475,  0.5878,  0.4623,  0.4390,  0.5442,  0.8262,
          0.3848,  0.4464]]), tensor([[0.5239]])]


Training

In [45]:
class LinearModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
        self.linear = nn.Linear(tau, 1)
        
    def forward(self, x):
        return self.linear(x)

In [52]:
lin_model = LinearModel()

In [53]:
opt = torch.optim.SGD(lin_model.parameters(), lr=0.01)
criterion = nn.MSELoss()

In [54]:
for epoch in range(10):
    tr_loss = 0
    val_loss = 0
    
    for x, y in dataloader:
        y_hat = lin_model(x)
        loss = criterion(y_hat, y)
        
        tr_loss += loss
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        
    tr_loss = tr_loss / len(dataloader)
        
    with torch.no_grad():
        for x, y in val_loader:
            vloss = criterion(lin_model(x), y)
            val_loss += vloss
        
        val_loss = val_loss / len(val_loader)
    
    print(f'Epoch: {epoch}, Tr_Loss: {tr_loss.item():.4f}, Val_Loss: {val_loss.item():.4f}')
    

Epoch: 0, Tr_Loss: 0.0594, Val_Loss: 0.2171
Epoch: 1, Tr_Loss: 0.0479, Val_Loss: 0.0884
Epoch: 2, Tr_Loss: 0.0432, Val_Loss: 0.1200
Epoch: 3, Tr_Loss: 0.0412, Val_Loss: 0.0728
Epoch: 4, Tr_Loss: 0.0360, Val_Loss: 0.0637
Epoch: 5, Tr_Loss: 0.0350, Val_Loss: 0.0805
Epoch: 6, Tr_Loss: 0.0307, Val_Loss: 0.0651
Epoch: 7, Tr_Loss: 0.0303, Val_Loss: 0.1033
Epoch: 8, Tr_Loss: 0.0266, Val_Loss: 0.0781
Epoch: 9, Tr_Loss: 0.0311, Val_Loss: 0.0668
