In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision

In [2]:
class Residual(nn.Module):
    def __init__(self, in_dim, out_dim, use_1d=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
        if use_1d:
            self.conv3 = nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        
        self.bn1 = nn.BatchNorm2d(out_dim)
        self.bn2 = nn.BatchNorm2d(out_dim)
        
    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

def resnet_block(in_dim, out_dim, blocks, first_block=False):
    blk = []
    for i in range(blocks):
        if i == 0 and not first_block:
            blk.append(Residual(in_dim, out_dim,
                                use_1d=True, strides=2))
        else:
            blk.append(Residual(out_dim, out_dim))
    return blk

class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),# (1, 224, 224) -> (64, 112, 112)
                                nn.BatchNorm2d(64), nn.ReLU(),
                                nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# (64, 56, 56)
                                )
        self.b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True)) #(64, 56, 56)
        self.b3 = nn.Sequential(*resnet_block(64, 128, 2)) #(128, 28, 28)
        self.b4 = nn.Sequential(*resnet_block(128, 256, 2)) #(256, 14, 14)
        self.b5 = nn.Sequential(*resnet_block(256, 512, 2)) #(512, 7, 7)

        self.net = nn.Sequential(self.b1, self.b2, self.b3,
                                 self.b4, self.b5,
                                 nn.AdaptiveAvgPool2d((1, 1)),
                                 nn.Flatten(),
                                 nn.Linear(512, 10))

    def forward(self, X):
        return self.net(X)

In [3]:
import torchvision
from torchvision import transforms
from torch.utils import data
import pytorch_lightning as pl

trans = transforms.Compose([transforms.Resize(224), transforms.ToTensor()])

full_train_data = torchvision.datasets.FashionMNIST(
    "./dataset", train=True, transform=trans, download=False
)

test_data = torchvision.datasets.FashionMNIST(
    "./dataset", train=False, transform=trans, download=False
)

pl.seed_everything(42)

train_size = int(0.9 * len(full_train_data))
valid_size = len(full_train_data) - train_size
train_data, valid_data = data.random_split(full_train_data, [train_size, valid_size])

train_loader = data.DataLoader(train_data, batch_size=648, shuffle=True, num_workers=12)
valid_loader = data.DataLoader(valid_data, batch_size=648, shuffle=False, num_workers=12)
test_loader = data.DataLoader(test_data, batch_size=648, shuffle=False, num_workers=12)

Seed set to 42


In [4]:
class pl_resnet(pl.LightningModule):
    def __init__(self, model_name, model_params, optimizer_name, optimizer_params):
        super().__init__()
        self.save_hyperparameters()
        self.model = ResNet()
        # 有params再加进去
        self.criterion = nn.CrossEntropyLoss()
        self.predictions = []
        
    def forward(self, X):
        return self.model(X)
    
    def configure_optimizers(self):
        if self.hparams.optimizer_name == "Adam":
            optimizer = torch.optim.AdamW(self.parameters(), **self.hparams.optimizer_params)
        else:
            optimizer = torch.optim.SGD(self.parameters(), **self.hparams.optimizer_params)
        
        scheduler = {
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer=optimizer,
                mode='min',
                factor=0.1,
                patience=5,
            ),
            "monitor": "valid_acc",
            "interval": "epoch",
            "frequency": 1,
        }
        return [optimizer], [scheduler]
    
    def training_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self.model(X)
        loss = self.criterion(y_hat, y)
        acc = (y_hat.argmax(dim=-1) == y).float().mean()
        self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return loss
    def validation_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self.model(X)
        loss = self.criterion(y_hat, y)
        acc = (y_hat.argmax(dim=-1) == y).float().mean()
        self.log("valid_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("valid_loss", loss, logger=True)
        return {"val_loss": loss, "val_acc": acc}

    def test_step(self, batch, batch_idx):
        X, y = batch # 由于这里使用的是fashion mnist，test有label，正常情况用X = batch就行了，为了模拟无label，我这里的y没有用
        y_hat = self.model(X)
        y_hat = y_hat.argmax(dim=-1).cpu().numpy()
        self.predictions.append(y_hat)
        return {"prediction": y_hat}
    
    def on_test_end(self):
        import pandas as pd
        predictions = [item for sublist in self.predictions for item in sublist]
        df = pd.DataFrame({
            "prediction": predictions,})
        df.to_csv("./result.csv")
    
    # def on_validation_epoch_end(self):
    #     train_acc = self.trainer.callback_metrics.get("train_acc")
    #     valid_acc = self.trainer.callback_metrics.get("valid_acc")
    #     print("train_acc: ", train_acc, "\t valid_acc: ", valid_acc)

In [5]:
from pytorch_lightning.loggers import TensorBoardLogger
logger = TensorBoardLogger("logs", name="lightning_resnet")

In [6]:
import os
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
def train(model_dir,loader_dict, model_config):
    checkpoint_callback = ModelCheckpoint(
        dirpath=model_dir,
        monitor="valid_acc",
        mode="max",
        verbose=True,
        save_weights_only=True,
        filename="resnet_epoch-{epoch:02d}-acc-{valid_acc:.4f}-loss-{val_loss:.4f}",
    )
    lr_callback = LearningRateMonitor("epoch")

    trainer = pl.Trainer(
        max_epochs=10,
        accelerator="auto",
        devices=1,
        callbacks=[checkpoint_callback, lr_callback],
    )
    trainer.logger._log_graph = True
    trainer.logger._default_hp_metric = None

    model = pl_resnet(**model_config)
    trainer.fit(model, loader_dict["train_loader"], loader_dict["valid_loader"])
    # 训练完成
    model = pl_resnet.load_from_checkpoint(checkpoint_callback.best_model_path)

    valid_result = trainer.validate(model, loader_dict["valid_loader"])
    trainer.test(model, loader_dict["test_loader"])
    result = {"valid": valid_result[0]["valid_acc"]}

    return model, result


In [7]:
model_dir = "./models/resnet"

loader_dict = {
    "train_loader": train_loader,
    "valid_loader": valid_loader,
    "test_loader": test_loader,
}

model_config = {
    "model_name": "resnet",
    "model_params": None,
    "optimizer_name": "Adam",
    "optimizer_params": {
        "lr": 1e-2,
        "weight_decay": 1e-4,
    }
}

model, result = train(model_dir, loader_dict, model_config)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/zhoujiefeng/miniconda3/envs/dive/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /home/zhoujiefeng/wjy/my_lightning/models/resnet exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | ResNet           | 11.2 M | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total param

Epoch 0: 100%|██████████| 84/84 [00:41<00:00,  2.05it/s, v_num=1, train_loss=0.789, valid_acc=0.665, train_acc=0.552]

Epoch 0, global step 84: 'valid_acc' reached 0.66500 (best 0.66500), saving model to '/home/zhoujiefeng/wjy/my_lightning/models/resnet/resnet_epoch-epoch=00-acc-valid_acc=0.6650-loss-val_loss=0.0000.ckpt' as top 1


Epoch 1: 100%|██████████| 84/84 [00:41<00:00,  2.04it/s, v_num=1, train_loss=0.419, valid_acc=0.699, train_acc=0.821]

Epoch 1, global step 168: 'valid_acc' reached 0.69900 (best 0.69900), saving model to '/home/zhoujiefeng/wjy/my_lightning/models/resnet/resnet_epoch-epoch=01-acc-valid_acc=0.6990-loss-val_loss=0.0000.ckpt' as top 1


Epoch 2: 100%|██████████| 84/84 [00:41<00:00,  2.03it/s, v_num=1, train_loss=0.341, valid_acc=0.703, train_acc=0.857]

Epoch 2, global step 252: 'valid_acc' reached 0.70350 (best 0.70350), saving model to '/home/zhoujiefeng/wjy/my_lightning/models/resnet/resnet_epoch-epoch=02-acc-valid_acc=0.7035-loss-val_loss=0.0000.ckpt' as top 1


Epoch 3: 100%|██████████| 84/84 [00:41<00:00,  2.02it/s, v_num=1, train_loss=0.304, valid_acc=0.765, train_acc=0.880]

Epoch 3, global step 336: 'valid_acc' reached 0.76483 (best 0.76483), saving model to '/home/zhoujiefeng/wjy/my_lightning/models/resnet/resnet_epoch-epoch=03-acc-valid_acc=0.7648-loss-val_loss=0.0000.ckpt' as top 1


Epoch 4: 100%|██████████| 84/84 [00:41<00:00,  2.00it/s, v_num=1, train_loss=0.267, valid_acc=0.475, train_acc=0.893]

Epoch 4, global step 420: 'valid_acc' was not in top 1


Epoch 5: 100%|██████████| 84/84 [00:42<00:00,  2.00it/s, v_num=1, train_loss=0.275, valid_acc=0.743, train_acc=0.904]

Epoch 5, global step 504: 'valid_acc' was not in top 1


Epoch 6: 100%|██████████| 84/84 [00:41<00:00,  2.01it/s, v_num=1, train_loss=0.265, valid_acc=0.812, train_acc=0.910]

Epoch 6, global step 588: 'valid_acc' reached 0.81250 (best 0.81250), saving model to '/home/zhoujiefeng/wjy/my_lightning/models/resnet/resnet_epoch-epoch=06-acc-valid_acc=0.8125-loss-val_loss=0.0000.ckpt' as top 1


Epoch 7: 100%|██████████| 84/84 [00:41<00:00,  2.03it/s, v_num=1, train_loss=0.189, valid_acc=0.591, train_acc=0.919]

Epoch 7, global step 672: 'valid_acc' was not in top 1


Epoch 8: 100%|██████████| 84/84 [00:41<00:00,  2.02it/s, v_num=1, train_loss=0.222, valid_acc=0.819, train_acc=0.922]

Epoch 8, global step 756: 'valid_acc' reached 0.81867 (best 0.81867), saving model to '/home/zhoujiefeng/wjy/my_lightning/models/resnet/resnet_epoch-epoch=08-acc-valid_acc=0.8187-loss-val_loss=0.0000.ckpt' as top 1


Epoch 9: 100%|██████████| 84/84 [00:41<00:00,  2.04it/s, v_num=1, train_loss=0.221, valid_acc=0.878, train_acc=0.927]

Epoch 9, global step 840: 'valid_acc' reached 0.87817 (best 0.87817), saving model to '/home/zhoujiefeng/wjy/my_lightning/models/resnet/resnet_epoch-epoch=09-acc-valid_acc=0.8782-loss-val_loss=0.0000.ckpt' as top 1
`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 84/84 [00:41<00:00,  2.03it/s, v_num=1, train_loss=0.221, valid_acc=0.878, train_acc=0.927]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation DataLoader 0: 100%|██████████| 10/10 [00:01<00:00,  6.88it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        valid_acc            0.878166675567627
       valid_loss           0.33242717385292053
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 16/16 [00:02<00:00,  5.92it/s]
