In [2]:
%pip install pytorch-lightning

Collecting pytorch-lightning
  Downloading pytorch_lightning-2.5.0.post0-py3-none-any.whl.metadata (21 kB)
Collecting tqdm>=4.57.0 (from pytorch-lightning)
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.6.1-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.12.0-py3-none-any.whl.metadata (5.6 kB)
Collecting aiohttp!=4.0.0a0,!=4.0.0a1 (from fsspec[http]>=2022.5.0->pytorch-lightning)
  Downloading aiohttp-3.11.12-cp312-cp312-win_amd64.whl.metadata (8.0 kB)
Collecting aiohappyeyeballs>=2.3.0 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning)
  Downloading aiohappyeyeballs-2.4.6-py3-none-any.whl.metadata (5.9 kB)
Collecting aiosignal>=1.1.2 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning)
  Downloading aiosignal-1.3.2-py2.py3-none-any.whl.metadata (3.8 kB)
Co

### 简介
用Vgg-11训练fashionMNist数据集
调整Vgg-11模型
    1.5个卷积块的输出特征个数缩小至1/4*N

In [1]:
import torch
import pytorch_lightning as pl

### 数据集加载

In [None]:
from torchvision.datasets import FashionMNIST
from torchvision import transforms

class DataConfiguration:
    def __init__(self, batch_size, num_workers, pin_memory):
        self.batch_size = batch_size
        self.num_workers = num_workers  
        self.pin_memory = pin_memory    # True if GPU is available 

class LitLoadData_FashionMNist(pl.LightningDataModule):
   
    def __init__(self, data_config):
        super().__init__()
        self.data_config = data_config

    def prepare_data(self):
        # Download the FashionMNIST dataset if not already downloaded
        FashionMNIST(root="./data", train=True, download=True)
        FashionMNIST(root="./data", train=False, download=True) 

    def setup(self, stage=None):
        # Transformations to apply to the data
        transform = transforms.Compose([
            transforms.Resize((224, 224)),  # Resize the images to 224x224
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))    # Normalize the data to the range [-1, 1]
        ])

        # Load the FashionMNIST dataset
        self.train_dataset = FashionMNIST(root="../data", train=True, transform=transform)
        self.val_dataset = FashionMNIST(root="../data", train=False, transform=transform)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.data_config.batch_size,
            num_workers=self.data_config.num_workers,
            pin_memory=self.data_config.pin_memory, # 
            persistent_workers=True,
            shuffle=True
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.data_config.batch_size,
            num_workers=self.data_config.num_workers,
            pin_memory=self.data_config.pin_memory,
            persistent_workers=True,
            shuffle=False
        )


### 模型定义

In [6]:


from typing import Any

class TrainingConfiguration:
    def __init__(self, learning_rate,optimizer):
        self.learning_rate = learning_rate
        self.optimizer=optimizer

def make_layers(cfg, batch_norm=False):
        layers = []
        in_channels = 1
        for v in cfg:
            if v == 'M':
                layers += [torch.nn.MaxPool2d(kernel_size=2, stride=2)] 
            else:
                conv2d = torch.nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [conv2d, torch.nn.BatchNorm2d(v), torch.nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, torch.nn.ReLU(inplace=True)]
                in_channels = v
        return torch.nn.Sequential(*layers)
    

#vgg卷积块定义
class LitVgg11Model(pl.LightningModule):   
    def __init__(self, training_config):
        super().__init__()
        self.training_config = training_config 
        vgg_convlayers=[64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'] #标准的vgg11卷积块结构
        #fashionMnist相对于vgg11的训练集ImageNet图像过于简单,不需要太多特征
        vgg_convlayers = [int(i//4) if isinstance(i, int) else i for i in vgg_convlayers]
        
        self.model=torch.nn.Sequential(
            make_layers(vgg_convlayers),
            torch.nn.Flatten(),
            torch.nn.Linear(128*7*7, 4096),    
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(),
            torch.nn.Linear(4096, 4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(),
            torch.nn.Linear(4096, 10)
        )

              
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_acc", torch.sum(y == torch.argmax(logits, dim=1)).item() / len(y), prog_bar=True, logger=True,on_epoch=True)       
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        self.log("val_loss", loss, prog_bar=True, logger=True) 
        self.log("val_acc", torch.sum(y == torch.argmax(logits, dim=1)).item() / len(y), prog_bar=True, logger=True)       
        return loss
    
    def configure_optimizers(self):
        if self.training_config.optimizer=="Adam":
            optimizer = torch.optim.Adam(self.parameters(), lr=self.training_config.learning_rate)
        elif self.training_config.optimizer=="SGD":
            optimizer = torch.optim.SGD(self.parameters(), lr=self.training_config.learning_rate)
        return optimizer
    
    
    

### 工作流程

In [7]:
%%time 
class TrainerConfiguration:
    max_epochs: int 
    def __init__(self, max_epochs):
        self.max_epochs = max_epochs

if __name__ == '__main__': 
    data_config = DataConfiguration(batch_size=128, num_workers=2, pin_memory=torch.cuda.is_available())
    training_config = TrainingConfiguration(learning_rate=0.05, optimizer="SGD")
    trainer_config = TrainerConfiguration(max_epochs=10)

    model = LitVgg11Model(training_config)
    data = LitLoadData_FashionMNist(data_config)

    #add tensorboardLogger
    tb_logger=pl.loggers.TensorBoardLogger('tensorBoard-logs/',name='classModel_vgg11_v1',default_hp_metric=False)
    
    #add modelcheckpoint
    checkpoint_callback=pl.callbacks.ModelCheckpoint(
        monitor='val_acc',
        dirpath='checkPoint-logs/classModel_vgg11_v1',
        filename='classModel_vgg11_v1_{epoch:02d}_{val_acc:.2f}',
        #save_top_k=3, # save the top 3 models
        mode='max', 
    )

    trainer=pl.Trainer(
        max_epochs=trainer_config.max_epochs, 
        logger=tb_logger, 
        callbacks=[checkpoint_callback],
        accelerator='gpu',
        enable_model_summary=True 
        )
    
    trainer.fit(model, data)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | model | Sequential | 43.1 M | train
---------------------------------------------
43.1 M    Trainable params
0         Non-trainable params
43.1 M    Total params
172.373   Total estimated model params size (MB)
31        Modules in train mode
0         Modules in eval mode


Epoch 9: 100%|██████████| 469/469 [01:59<00:00,  3.91it/s, v_num=0, train_loss_step=0.267, train_acc_step=0.917, val_loss=0.265, val_acc=0.900, train_loss_epoch=0.241, train_acc_epoch=0.910]  

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 469/469 [02:14<00:00,  3.49it/s, v_num=0, train_loss_step=0.267, train_acc_step=0.917, val_loss=0.265, val_acc=0.900, train_loss_epoch=0.241, train_acc_epoch=0.910]
CPU times: total: 21min 34s
Wall time: 21min 24s
