In [1]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor


**Creating the `Dataset` and `DataLoader`**

In [2]:
train_dataset = datasets.CIFAR10(
    root='data',
    train=True,
    download=True,
    transform=ToTensor(),
)

test_dataset = datasets.CIFAR10(
    root='data',
    train=False,
    download=True,
    transform=ToTensor(),
)

In [3]:
batch_size = 64

train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
)

test_dataloader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    shuffle=True,
    batch_size=batch_size,
)

In [4]:
for X, y in train_dataloader:
    print(X.shape, y.shape)
    break

torch.Size([64, 3, 32, 32]) torch.Size([64])


In [5]:
device = torch.accelerator.current_accelerator().type
device

'mps'

**Building the Model**

In [6]:
class NeuralNet(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.flatten = torch.nn.Flatten()
        self.linear_relu_stack = torch.nn.Sequential(
            torch.nn.Linear(3*32*32, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 256),  # So W (weight matrix) here is of shape (256, 512)
            torch.nn.ReLU(),
            torch.nn.Linear(256, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
    
    '''
    `logits`: unnormalized scores a model generates before they 
    are converted to probabilites (using Softmax). We leave logits
    as logits instead of softmaxing them right away because loss functions
    expect raw logits as input. CrossEntropy loss also handles numerical stability
    issues that might arise from taking log manually. This is a standard practice 
    to make models return logits instead of probabilities.
    '''

In [7]:
model = NeuralNet().to(device=device) 
print(model)

NeuralNet(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=3072, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=128, bias=True)
    (5): ReLU()
    (6): Linear(in_features=128, out_features=10, bias=True)
  )
)


**Optimizing Model Parameters**

In [8]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [14]:
# Main training loop
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()   # Sets the model in training phase
    for batch, (X,y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Get model predictions
        preds = model(X)

        # Apply loss function
        loss = loss_fn(preds, y)

        # Backpropagation
        loss.backward()      # Calculate gradients of the `loss_fn` with respect to model parameters. They are stored on `.grad` attribute of each parameter.
        optimizer.step()     # Update the model parameters using the gradients and optimizer (SGD here)
        optimizer.zero_grad() # Reset the `.grad` gradients to zero for next epoch

        # Calculate training metrics for the current epoch
        loss, current = loss.item(), (batch+1)*len(X)
        print(f"Loss: {loss}\n [{current}/{size}]")
        

In [16]:
# Main test loop
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0,0
    with torch.no_grad():
        for batch, (X,y) in enumerate(dataloader):
            X,y = X.to(device), y.to(device)
            preds = model(X)
            loss = loss_fn(preds, y)
            test_loss += loss.item()
            correct += (y == preds.argmax(1)).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [17]:
epochs=10
for t in range(epochs):
    print(f'Epoch {t}---------------\n\n')
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print('Done!')

Epoch 0---------------


Loss: 2.2912445068359375
 [64/50000]
Loss: 2.302004098892212
 [128/50000]
Loss: 2.2992639541625977
 [192/50000]
Loss: 2.3029541969299316
 [256/50000]
Loss: 2.303590774536133
 [320/50000]
Loss: 2.2936155796051025
 [384/50000]
Loss: 2.303614616394043
 [448/50000]
Loss: 2.298265218734741
 [512/50000]
Loss: 2.302595376968384
 [576/50000]
Loss: 2.304133415222168
 [640/50000]
Loss: 2.2952446937561035
 [704/50000]
Loss: 2.2951064109802246
 [768/50000]
Loss: 2.301146984100342
 [832/50000]
Loss: 2.290217399597168
 [896/50000]
Loss: 2.293613910675049
 [960/50000]
Loss: 2.2981581687927246
 [1024/50000]
Loss: 2.291701078414917
 [1088/50000]
Loss: 2.298288345336914
 [1152/50000]
Loss: 2.2912752628326416
 [1216/50000]
Loss: 2.2988550662994385
 [1280/50000]
Loss: 2.296255588531494
 [1344/50000]
Loss: 2.311572551727295
 [1408/50000]
Loss: 2.3058462142944336
 [1472/50000]
Loss: 2.3001387119293213
 [1536/50000]
Loss: 2.293457269668579
 [1600/50000]
Loss: 2.290198564529419
 [1664