In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

# ----- 1) DataLoader (toy) -----
X = torch.randn(128, 10)                 # 128 samples, 10 features
y = torch.randint(0, 3, (128,))          # 3-class labels: 0,1,2
loader = DataLoader(TensorDataset(X, y), batch_size=32, shuffle=True)

# ----- 2) Model (nn.Module or Sequential) -----
class MLP(nn.Module):
    def __init__(self, in_dim=10, hidden=32, num_classes=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, num_classes),
        )

    def forward(self, x):
        return self.net(x)

model = MLP()

# ----- 3) Loss / Optimizer / Scheduler -----
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# ----- 4) Train -----
for epoch in range(3):
    model.train()
    for xb, yb in loader:
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

    scheduler.step()

# ----- 5) Eval -----
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for xb, yb in loader:
        logits = model(xb)
        pred = logits.argmax(dim=1)
        correct += (pred == yb).sum().item()
        total += yb.size(0)

print("accuracy:", correct / total)