# Goal:
Implement nn.Module abstraction after batching the dataset

# Notes:
- When reporting epoch loss, calculate average across all batches (epoch_loss / len(dataloader))
- Reporting last batch loss only causes oscillation â€” it reflects one random batch, not overall convergence
- Batched SGD converges more slowly than full-dataset gradient descent but scales to large datasets

In [10]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

torch.manual_seed(42)  # For reproducibility
model = nn.Linear(1, 1)  # Simple linear model
molecules = torch.randn(50, 1) 
y_true = 0.5 * molecules + 0.1 + torch.randn(50, 1) * 0.1  # Linear relation with some noise

class MoleculeDataset(Dataset):
    def __init__(self, descriptors, labels):
        self.descriptors = descriptors
        self.labels = labels
    
    def __len__(self):
        return len(self.descriptors)
    
    def __getitem__(self, idx):
        return self.descriptors[idx], self.labels[idx]

# Create dataset and dataloader
dataset = MoleculeDataset(molecules, y_true)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

# loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

for epoch in range(100):
    epoch_loss = 0
    for batch in dataloader:
        descriptors, labels = batch
        y_pred = model(descriptors)
        loss = criterion(y_pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    if epoch % 10 == 0:
        print(f'Epoch {epoch}: loss={epoch_loss/len(dataloader):.4f}, w={model.weight.item():.4f}, b={model.bias.item():.4f}')
        


Epoch 0: loss=0.5362, w=0.7402, b=0.7620
Epoch 10: loss=0.0793, w=0.5909, b=0.3536
Epoch 20: loss=0.0169, w=0.5332, b=0.2038
Epoch 30: loss=0.0085, w=0.5111, b=0.1490
Epoch 40: loss=0.0073, w=0.5026, b=0.1287
Epoch 50: loss=0.0071, w=0.4992, b=0.1214
Epoch 60: loss=0.0071, w=0.4980, b=0.1186
Epoch 70: loss=0.0071, w=0.4975, b=0.1177
Epoch 80: loss=0.0071, w=0.4974, b=0.1173
Epoch 90: loss=0.0071, w=0.4973, b=0.1172
