In [2]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [4]:
model = NeuralNetwork().to(device)
print(model)

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


In [5]:
X = torch.rand(1, 28, 28, device=device)

In [6]:
logits = model(X)

In [7]:
torch.argmax(nn.Softmax(dim=1)(logits))

tensor(5, device='cuda:0')

In [8]:
list(model.named_parameters())

[('linear_relu_stack.0.weight',
  Parameter containing:
  tensor([[-0.0310,  0.0208,  0.0055,  ...,  0.0110,  0.0133,  0.0091],
          [ 0.0140,  0.0327, -0.0173,  ..., -0.0018,  0.0154, -0.0157],
          [-0.0237,  0.0213, -0.0269,  ...,  0.0308, -0.0231,  0.0160],
          ...,
          [-0.0357, -0.0246,  0.0300,  ..., -0.0325,  0.0190, -0.0280],
          [ 0.0171, -0.0111, -0.0267,  ..., -0.0070, -0.0279, -0.0292],
          [-0.0029, -0.0185, -0.0169,  ..., -0.0157, -0.0269, -0.0294]],
         device='cuda:0', requires_grad=True)),
 ('linear_relu_stack.0.bias',
  Parameter containing:
  tensor([ 2.2661e-02, -5.5666e-03, -1.4075e-02, -1.4687e-02,  1.0062e-02,
          -3.3937e-02,  3.8157e-03,  2.5960e-02, -3.0138e-02, -1.1059e-02,
           2.9186e-02, -2.6695e-02,  1.0879e-02, -8.4445e-03,  2.8941e-02,
          -1.1882e-02,  7.1831e-03,  2.3959e-02,  1.8371e-02,  3.0254e-02,
           3.0497e-02, -2.8459e-02, -1.8877e-02, -1.5487e-02, -3.5036e-02,
           2.8933e-