In [1]:
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoader, random_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_set = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
train_size = int(len(train_set) * 0.8)
val_size = len(train_set) - train_size
train_set, val_set = torch.utils.data.random_split(
    train_set, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)
test_set = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)

classes = (
    "plane",
    "car",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
class LitResNet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        original_model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18')
        self.model = nn.Sequential(*list(original_model.children())[:-1])
        self.fc = nn.Linear(in_features=512, out_features=10, bias=True)
        self.softmax = nn.Softmax(1)
    
    def forward(self, x):
        out = self.model(x)
        out = torch.flatten(out, start_dim=1, end_dim=-1)
        out = self.fc(out)
        out = self.softmax(out)
        return out

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, y = batch
        y_hat = self.model(x)
        loss = F.mse_loss(y_hat.type(torch.FloatTensor), y.type(torch.FloatTensor))
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [10]:
train_loader = DataLoader(train_set, batch_size=32, num_workers=8)
val_loader = DataLoader(val_set, batch_size=32, num_workers=8)
test_loader = DataLoader(test_set, batch_size=32, num_workers=8)

model = LitResNet()
trainer = pl.Trainer(accelerator="gpu")
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

Using cache found in /home/clx/.cache/torch/hub/pytorch_vision_v0.10.0
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | model   | Sequential | 11.2 M
1 | fc      | Linear     | 5.1 K 
2 | softmax | Softmax    | 0     
---------------------------------------
11.2 M    Trainable params
0         Non-trainable 

Epoch 0:   0%|          | 2/1250 [00:00<07:20,  2.83it/s, loss=24.1, v_num=3]



Epoch 0:   0%|          | 70/40000 [02:32<24:06:42,  2.17s/it, loss=21.5, v_num=1]
Epoch 0:   3%|▎         | 1250/40000 [01:43<53:39, 12.03it/s, loss=5.01, v_num=2]
Epoch 1:   8%|▊         | 98/1250 [00:05<00:59, 19.38it/s, loss=8.78, v_num=3]  

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
