In [1]:
import os
import torch
import pytorch_lightning as pl
import pandas as pd
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from PIL import Image
from datasets import load_dataset
import torch.nn.functional as F
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint



In [2]:
class CustomDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self,idx):
        item = self.data[idx]
        image = item['image']
        label = float(item['label'])#标签是连续值

        if self.transform:
            image = self.transform(image)
        
        return image, torch.tensor(label, dtype=torch.float32)

In [3]:

class ImageRegression(pl.LightningModule):
    def __init__(self):
        super(ImageRegression,self).__init__()
        self.model = torch.hub.load('pytorch/vision:v0.10.0','resnet50',pretrained=True)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs,1)

    def forward(self, x):
        x = x.float()  # 转换数据类型为float32
        return self.model(x).squeeze(-1)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = F.mse_loss(outputs, labels)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        rounded_outputs = torch.round(outputs)
        loss = F.mse_loss(rounded_outputs, labels)
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        correct = (rounded_outputs == labels).sum().item()
        total = len(labels)
        accuracy = correct / total
        self.log('val_acc', accuracy, on_epoch=True, prog_bar=True)
        return {"val_loss": loss, "val_acc": accuracy}

#    def on_validation_epoch_end(self):
#        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
#        avg_acc = torch.stack([x['val_acc']for x in outputs]).mean()
        
    def test_step(self,batch,batch_idx):
        images, labels = batch
        outputs = self(images)
        rounded_outputs = torch.round(outputs)
        loss = F.mse_loss(rounded_outputs, labels)
        self.log('test_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        correct = (rounded_outputs == labels).sum().item()
        total = len(labels)
        accuracy = correct / total
        self.log('test_acc', accuracy, on_epoch=True, prog_bar=True)
        return {"test_loss": loss, "test_acc": accuracy}

 #   def on_test_epoch_end(self):
 #       test_results = self.trainer.callback_metrics
 #       test_loss = test_results['test_loss_epoch'] 
#        self.log('test_loss', test_loss, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer


In [4]:
#if __name__ == "__main__":
raw_datasets = load_dataset("Niche-Squad/mock-dots","regression-one-class", download_mode="force_redownload") # 使用你的数据集名称加载数据 

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
checkpoint_callback = ModelCheckpoint(monitor='val_acc',
dirpath = "E:/Files/Checkpoint",  # 设置存储路径
mode='max',          # 如果你的指标是准确率或类似的，可能需要更改为'max'
filename='best-model-{epoch:02d}-{val_acc:.2f}',
save_top_k=1,
verbose=True,# 打印检查点保存信息
save_last=False,      # 如果你还想保存最后一个checkpoint
)


train_dataset = CustomDataset(raw_datasets['train'], transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

val_dataset = CustomDataset(raw_datasets['validation'], transform=transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

test_dataset = CustomDataset(raw_datasets['test'], transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    
model = ImageRegression()
logger = TensorBoardLogger("tb_logs", name="Resnet50_Regression_Classification_batch_size_32_epoch_100_lr_0.001_callbacks_T")#save_dir="your/log/directory"
trainer = pl.Trainer(callbacks=[checkpoint_callback],max_epochs=100,logger = logger)#callbacks=[checkpoint_callback]
trainer.fit(model, train_loader,val_loader)




Downloading builder script:   0%|          | 0.00/9.99k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/30.0 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.60M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/11.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.77k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.82k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/600 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/200 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/200 [00:00<?, ? examples/s]

Using cache found in C:\Users\34691/.cache\torch\hub\pytorch_vision_v0.10.0
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 23.5 M
---------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.040    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 0, global step 19: 'val_acc' reached 0.04500 (best 0.04500), saving model to 'E:\\Files\\Checkpoint\\best-model-epoch=00-val_acc=0.05.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 1, global step 38: 'val_acc' reached 0.27500 (best 0.27500), saving model to 'E:\\Files\\Checkpoint\\best-model-epoch=01-val_acc=0.28.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 2, global step 57: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 3, global step 76: 'val_acc' reached 0.39000 (best 0.39000), saving model to 'E:\\Files\\Checkpoint\\best-model-epoch=03-val_acc=0.39.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 4, global step 95: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 5, global step 114: 'val_acc' reached 0.50500 (best 0.50500), saving model to 'E:\\Files\\Checkpoint\\best-model-epoch=05-val_acc=0.50.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 6, global step 133: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 7, global step 152: 'val_acc' reached 0.56000 (best 0.56000), saving model to 'E:\\Files\\Checkpoint\\best-model-epoch=07-val_acc=0.56.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 8, global step 171: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 9, global step 190: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 10, global step 209: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 11, global step 228: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 12, global step 247: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 13, global step 266: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 14, global step 285: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 15, global step 304: 'val_acc' reached 0.58000 (best 0.58000), saving model to 'E:\\Files\\Checkpoint\\best-model-epoch=15-val_acc=0.58.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 16, global step 323: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 17, global step 342: 'val_acc' reached 0.60000 (best 0.60000), saving model to 'E:\\Files\\Checkpoint\\best-model-epoch=17-val_acc=0.60.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 18, global step 361: 'val_acc' reached 0.65000 (best 0.65000), saving model to 'E:\\Files\\Checkpoint\\best-model-epoch=18-val_acc=0.65.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 19, global step 380: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 20, global step 399: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 21, global step 418: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 22, global step 437: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 23, global step 456: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 24, global step 475: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 25, global step 494: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 26, global step 513: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 27, global step 532: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 28, global step 551: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 29, global step 570: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 30, global step 589: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 31, global step 608: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 32, global step 627: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 33, global step 646: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 34, global step 665: 'val_acc' reached 0.66000 (best 0.66000), saving model to 'E:\\Files\\Checkpoint\\best-model-epoch=34-val_acc=0.66.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 35, global step 684: 'val_acc' reached 0.67500 (best 0.67500), saving model to 'E:\\Files\\Checkpoint\\best-model-epoch=35-val_acc=0.68.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 36, global step 703: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 37, global step 722: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 38, global step 741: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 39, global step 760: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 40, global step 779: 'val_acc' reached 0.68500 (best 0.68500), saving model to 'E:\\Files\\Checkpoint\\best-model-epoch=40-val_acc=0.69.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 41, global step 798: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 42, global step 817: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 43, global step 836: 'val_acc' reached 0.70500 (best 0.70500), saving model to 'E:\\Files\\Checkpoint\\best-model-epoch=43-val_acc=0.70.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 44, global step 855: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 45, global step 874: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 46, global step 893: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 47, global step 912: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 48, global step 931: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 49, global step 950: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 50, global step 969: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 51, global step 988: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 52, global step 1007: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 53, global step 1026: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 54, global step 1045: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 55, global step 1064: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 56, global step 1083: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 57, global step 1102: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 58, global step 1121: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 59, global step 1140: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 60, global step 1159: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 61, global step 1178: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 62, global step 1197: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 63, global step 1216: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 64, global step 1235: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 65, global step 1254: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 66, global step 1273: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 67, global step 1292: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 68, global step 1311: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 69, global step 1330: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 70, global step 1349: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 71, global step 1368: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 72, global step 1387: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 73, global step 1406: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 74, global step 1425: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 75, global step 1444: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 76, global step 1463: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 77, global step 1482: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 78, global step 1501: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 79, global step 1520: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 80, global step 1539: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 81, global step 1558: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 82, global step 1577: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 83, global step 1596: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 84, global step 1615: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 85, global step 1634: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 86, global step 1653: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 87, global step 1672: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 88, global step 1691: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 89, global step 1710: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 90, global step 1729: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 91, global step 1748: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 92, global step 1767: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 93, global step 1786: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 94, global step 1805: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 95, global step 1824: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 96, global step 1843: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 97, global step 1862: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 98, global step 1881: 'val_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 99, global step 1900: 'val_acc' was not in top 1
`Trainer.fit` stopped: `max_epochs=100` reached.


In [5]:
best_model_path = trainer.checkpoint_callback.best_model_path
best_model = ImageRegression.load_from_checkpoint(best_model_path)
print("Best model path:", best_model_path)
print("Best model:", best_model)

Using cache found in C:\Users\34691/.cache\torch\hub\pytorch_vision_v0.10.0


Best model path: E:\Files\Checkpoint\best-model-epoch=43-val_acc=0.70.ckpt
Best model: ImageRegression(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_runnin

In [6]:
trainer.test(dataloaders=test_loader,ckpt_path=best_model_path) 

Restoring states from the checkpoint path at E:\Files\Checkpoint\best-model-epoch=43-val_acc=0.70.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at E:\Files\Checkpoint\best-model-epoch=43-val_acc=0.70.ckpt
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Runningstage.testing metric      DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.6100000143051147
     test_loss_epoch        0.7099999785423279
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss_epoch': 0.7099999785423279, 'test_acc': 0.6100000143051147}]