In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, models
from torchvision.datasets import ImageFolder
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger

In [2]:
class CatDogDataModule(L.LightningDataModule):
    def __init__(self,data_dir=r"D:\zyy\data\train", batch_size=32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(30),
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(), 
            transforms.Normalize([0.485, 0.456, 0.406], 
                                 [0.229, 0.224, 0.225])])
    
    def setup(self,stage=None):
        full_dataset = ImageFolder(root=os.path.join(self.data_dir,"train"),
                                                     transform = self.transform)
        train_size = int(0.8 * len(full_dataset))
        val_size = len(full_dataset) - train_size
        self.train_dataset, self.val_dataset = random_split(full_dataset, [train_size, val_size])

    def train_dataloader(self):
         return DataLoader(self.train_dataset,
                           batch_size=self.batch_size, 
                           shuffle=True,num_workers=2)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, 
                          batch_size=self.batch_size,
                          num_workers=2)

class CatDog(L.LightningModule):
    def __init__(self, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1), 
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2), 
        )

        dummy_input = torch.zeros(1, 3, 224, 224)
        flatten_dim = self.conv_layers(dummy_input).view(1, -1).size(1)

        self.fc_layers = nn.Sequential(
            nn.Linear(flatten_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 2)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        return self.fc_layers(x)
       
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("train_loss", loss)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss)
        self.log("val_acc", acc, prog_bar=True,on_epoch=True)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4,weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            }
        }
        


In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from torchvision.datasets import ImageFolder
from lightning.pytorch.loggers import TensorBoardLogger

if __name__ == "__main__":
    data_module = CatDogDataModule(data_dir=r"D:\zyy\data", batch_size=32)
    model = CatDog(lr=0.001)

    logger = TensorBoardLogger("lightning_logs", name="catdog")
    checkpoint = ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1)
    early_stop = EarlyStopping(monitor="val_acc", patience=3, mode="max")
    
    trainer = L.Trainer(
        max_epochs=3,
        accelerator="auto",
        logger=logger,
        callbacks=[checkpoint, early_stop]
    )
    
    trainer.fit(model, datamodule=data_module)
    
    best_model_path = checkpoint.best_model_path
    print(f"Best checkpoint path: {best_model_path}")
    best_model = CatDog .load_from_checkpoint(best_model_path)
    trainer.validate(best_model, datamodule=data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name        | Type       | Params | Mode 
---------------------------------------------------
0 | conv_layers | Sequential | 5.1 K  | train
1 | fc_layers   | Sequential | 12.8 M | train
---------------------------------------------------
12.9 M    Trainable params
0         Non-trainable params
12.9 M    Total params
51.402    Total estimated model params size (MB)
12        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                               | 0/? [00:00<?, ?it/s]

D:\mini\envs\myenv\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:420: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                                                                       

D:\mini\envs\myenv\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:420: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 0: 100%|████████████████████████████████████████████| 625/625 [06:42<00:00,  1.55it/s, v_num=19, train_acc=1.000]
Validation: |                                                                                    | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                                                              | 0/157 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                 | 0/157 [00:00<?, ?it/s][A
Validation DataLoader 0:   1%|▎                                                        | 1/157 [00:00<00:43,  3.61it/s][A
Validation DataLoader 0:   1%|▋                                                        | 2/157 [00:00<00:44,  3.51it/s][A
Validation DataLoader 0:   2%|█                                                        | 3/157 [00:00<00:44,  3.47it/s][A
Validation DataLoader 0:   3%|█▍                                                       | 4/157 [00:01<00:44,  3.45it/s][A
Validation DataLoad

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|█████████████████████████████| 625/625 [09:04<00:00,  1.15it/s, v_num=19, train_acc=1.000, val_acc=1.000]
Best checkpoint path: lightning_logs\catdog\version_19\checkpoints\epoch=0-step=625.ckpt


In [None]:
!tensorboard --logdir lightning_logs