In [None]:
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader 

In [None]:
class Regression(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_in = config.n_in
        self.n_out = config.n_out
        self.lin = nn.Linear(self.n_in,self.n_out)

    def forward(self,X,Y=None):
        out = self.lin(X)
        if Y is not None:
            loss = torch.mean(torch.sum( (out-Y)**2,dim=-1) )
        else: 
            loss = None
        return out, loss

In [None]:
@dataclass
class Config():
    n_in: int = 4
    n_out: int = 6
    N: int = 10

In [None]:
class TensorData(Dataset):
    def __init__(self, X,Y=None):
        super().__init__()
        self.X = X
        self.Y = Y

    def __len__(self):
        if self.Y is not None:
            assert self.X.size(0)==self.Y.size(0), 'number of input does not match number of output'
        return self.X.size(0)

    def __getitem__(self,idx):
        if self.Y is not None: 
            return self.X[idx], self.Y[idx]
        else: 
            return self.X[idx]

In [None]:
config = Config()
reg = Regression(config)
dataset = TensorData(torch.randn(config.N, config.n_in), torch.randn(config.N,config.n_out))
dataloader = DataLoader(dataset,batch_size=2, shuffle=True)
optimizer = torch.optim.SGD(reg.parameters(), lr=0.001, momentum=0.9)

In [None]:
for i, batch in enumerate(dataloader):
    X,Y = batch
    optimizer.zero_grad()
    pred, loss = reg(X,Y)
    loss.backward()
    optimizer.step()
    print(f'{loss.item():.4f}')

5.8237
13.6167
9.5706
10.6715
9.8730
