In [16]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose

In [19]:
# Download training data from open datasets

mean_t = torch.tensor([0.485, 0.456, 0.406])
std_t = torch.tensor([0.229, 0.224, 0.225])

from torchvision.transforms import v2

transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=mean_t, std=std_t),
])

training_data = datasets.CIFAR100(
    root="data/cifar/",
    train=True,
    download=True,
    transform=transform,
)

# Download test data from open datasets
test_data = datasets.CIFAR100(
    root="data/cifar/",
    train=False,
    download=True,
    transform=transform,
)

Files already downloaded and verified
Files already downloaded and verified


In [20]:
batch_size = 256

l_kws = dict(
    num_workers = 1,
    pin_memory = True,
    persistent_workers=True,
    prefetch_factor=32,
)

# Create data loaders
train_dataloader = DataLoader(training_data, batch_size=batch_size, **l_kws)
test_dataloader = DataLoader(test_data, batch_size=batch_size, **l_kws)

for X, y in test_dataloader:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

Shape of X [N, C, H, W]:  torch.Size([256, 3, 32, 32])
Shape of y:  torch.Size([256]) torch.int64


In [21]:
# Get cpu or gpu device for training
def get_device(use_cpu=False):
    return torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))

device = get_device()
print("Using {} device".format(device))

# Define model
# class NeuralNetwork(nn.Module):
#     def __init__(self):
#         super(NeuralNetwork, self).__init__()
#         self.flatten = nn.Flatten()
#         self.linear_relu_stack = nn.Sequential(
#             nn.Linear(32*32*3, 512),
#             nn.ReLU(),
#             nn.Linear(512, 512),
#             nn.ReLU(),
#             nn.Linear(512, 100)
#         )

#     def forward(self, x):
#         x = self.flatten(x)
#         logits = self.linear_relu_stack(x)
#         return logits

# model = NeuralNetwork().to(device)
# print(model)


import torch.nn as nn
from torchvision import models

model = models.resnet50(
            weights=models.ResNet50_Weights.IMAGENET1K_V1)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 100)
model = model.to(device)

Using mps device


In [8]:
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# SGD momentum 0.9

In [9]:
scaler = torch.cuda.amp.GradScaler()

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)


        # Compute prediction error
        with torch.cuda.amp.autocast():
          # loss = model(data)
            pred = model(X)
            loss = loss_fn(pred, y)

        scaler.scale(loss).backward()
        # loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        # optimizer.step()

        if (batch + 1) % 20 == 0:
            # Unscales gradients and calls
            # or skips optimizer.step()
            scaler.step(optimizer)

            # Updates the scale for next iteration
            scaler.update()

            # Backpropagation
            optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")



In [10]:
import time
def test(dataloader, model, loss_fn, st):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}. ", time.time() - st)

In [11]:
epochs = 200
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    st = time.time()
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn, st)
print("Done!")

Epoch 1
-------------------------------




loss: 4.740006  [    0/50000]
loss: 2.963702  [25600/50000]


KeyboardInterrupt: 

### Lightning

In [23]:
import lightning as L
import torch.nn.functional as F

loss_fn = nn.CrossEntropyLoss()

class LightModel(L.LightningModule):
    def __init__(self, output_features, lr=1e-1,
                 weights=models.ResNet50_Weights.IMAGENET1K_V1):
        super().__init__()
        model = models.resnet50(
            weights=weights)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, output_features)
        self.model = model
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = loss_fn.to(x.device)(output, y)
        self.log("train_loss", loss)
        return loss

    # def test_step(self, batch, batch_idx):
    #     x, y = batch
    #     output = self.model(x)
    #     loss = F.mse_loss(torch.argmax(output), y)
    #     self.log("test_loss", loss, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = loss_fn.to(x.device)(output, y)
        self.log("val_loss", loss)

    def configure_optimizers(self):
        # optimizer = torch.optim.SGD(
        #     self.parameters(), lr=self.lr, momentum=0.9)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        return optimizer

In [24]:
model = LightModel(output_features=100, lr=0.001)

# train model
trainer = L.Trainer()
# tuner = Tuner(trainer)
# lr_finder = tuner.lr_find(model, attr_name="lr",
            #   train_dataloaders=train_loader, val_dataloaders=valid_loader)

train_dataloader = DataLoader(training_data, batch_size=batch_size, **l_kws)
test_dataloader = DataLoader(test_data, batch_size=batch_size, **l_kws)

# Plot with
# fig = lr_finder.plot(suggest=True)
# fig.show()
trainer.fit(model=model,
                train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/yinnnyou/anaconda3/envs/data_aug_3115/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
Missing logger folder: /Users/yinnnyou/workspace/medical_imaging_imbalancing/lightning_logs

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 23.7 M
---------------------------------
23.7 M    Trainable params
0         Non-trainable params
23.7 M    Total params
94.852    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/yinnnyou/anaconda3/envs/data_aug_3115/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


                                                                           

/Users/yinnnyou/anaconda3/envs/data_aug_3115/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 7:  84%|████████▎ | 164/196 [00:27<00:05,  5.98it/s, v_num=0]

/Users/yinnnyou/anaconda3/envs/data_aug_3115/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
