In [36]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("当前设备：", device)


当前设备： cuda


In [37]:
import os
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, models
from torchvision.datasets import ImageFolder
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger

In [38]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [39]:
!ls /content/drive/MyDrive/


'Colab Notebooks'   dogs-vs-cats.zip


In [40]:
import zipfile

zip_path = "/content/drive/MyDrive/dogs-vs-cats.zip"
extract_path = "/content/data"

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)



In [41]:
!ls /content/data


sample_submission.csv  test  test.zip  train  train.zip


In [42]:
import zipfile

train_zip = "/content/data/train.zip"
test_zip = "/content/data/test.zip"

train_extract_path = "/content/data/train"
test_extract_path = "/content/data/test"

with zipfile.ZipFile(train_zip, 'r') as zip_ref:
    zip_ref.extractall(train_extract_path)

with zipfile.ZipFile(test_zip, 'r') as zip_ref:
    zip_ref.extractall(test_extract_path)



In [None]:
!ls /content/data/train
!ls /content/data/test


In [None]:
class CatDogDataModule(L.LightningDataModule):
    def __init__(self,data_dir, batch_size=64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.RandomRotation(30),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
             [0.229, 0.224, 0.225])])

    def setup(self,stage=None):
        full_dataset = ImageFolder(root=os.path.join(self.data_dir,"train"),
                                                     transform = self.transform)
        train_size = int(0.8 * len(full_dataset))
        val_size = len(full_dataset) - train_size
        self.train_dataset, self.val_dataset = random_split(full_dataset, [train_size, val_size])

    def train_dataloader(self):
         return DataLoader(self.train_dataset,
                           batch_size=self.batch_size,
                           shuffle=True,num_workers=0)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.batch_size,
                          num_workers=2)

class CatDog(L.LightningModule):
    def __init__(self, lr=1e-3,input_shape=(3,64,64)):
        super().__init__()
        self.save_hyperparameters()

        C, H, W = input_shape

        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        with torch.no_grad():
            dummy = torch.zeros(1, C, H, W)
            conv_out = self.conv(dummy)
            n_size = conv_out.view(1, -1).size(1)

        self.fc = nn.Sequential(
            nn.Linear(n_size, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1),
            nn.Sigmoid(),
        )

        self.loss_fn = nn.BCELoss()

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    def training_step(self, batch, batch_idx):
        images, labels = batch
        preds = self(images)
        labels = labels.to(preds.dtype).unsqueeze(1)
        loss = self.loss_fn(preds, labels)
        acc = ((preds > 0.5) == labels.bool()).float().mean()
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        labels = labels.float().unsqueeze(1)
        preds = self(images)
        loss = self.loss_fn(preds, labels)
        acc = ((preds > 0.5) == labels.bool()).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4,weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            }
        }

In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from torchvision.datasets import ImageFolder
from lightning.pytorch.loggers import TensorBoardLogger

if __name__ == "__main__":
    data_module = CatDogDataModule(data_dir="/content/data", batch_size=64)
    model = CatDog(lr=0.001, input_shape=(3,224,224))

    logger = TensorBoardLogger("lightning_logs", name="catdog")
    checkpoint = ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1)
    early_stop = EarlyStopping(monitor="val_acc", patience=5, mode="max")

    trainer = L.Trainer(
        max_epochs=10,
        accelerator='gpu',
        devices=1,
        precision=32,
        logger=logger,
        callbacks=[checkpoint, early_stop]
    )

    trainer.fit(model, datamodule=data_module)

    best_model_path = checkpoint.best_model_path
    print(f"Best checkpoint path: {best_model_path}")
    best_model = CatDog .load_from_checkpoint(best_model_path)
    trainer.validate(best_model, datamodule=data_module)

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs