In [1]:
import os
import torch
import pytorch_lightning as pl
import pandas as pd
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from PIL import Image
from datasets import load_dataset
import torch.nn.functional as F
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
class CustomDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self,idx):
        item = self.data[idx]
        image = item['image']
        label = float(item['label'])#标签是连续值

        if self.transform:
            image = self.transform(image)
        
        return image, torch.tensor(label, dtype=torch.float32)

In [15]:

class ImageRegression(pl.LightningModule):
    def __init__(self):
        super(ImageRegression,self).__init__()
        self.model = torch.hub.load('pytorch/vision:v0.10.0','resnet50',pretrained=True)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs,1)

    def forward(self, x):
        x = x.float()  # 转换数据类型为float32
        return self.model(x).squeeze(-1)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = F.mse_loss(outputs, labels)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = F.mse_loss(outputs, labels)
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return {"val_loss":loss}

#    def on_validation_epoch_end(self):
#        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
#        avg_acc = torch.stack([x['val_acc']for x in outputs]).mean()
        
    def test_step(self,batch,batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = F.mse_loss(outputs, labels)
        self.log('test_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return {"test_loss":loss}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer


In [16]:
if __name__ == "__main__":
    raw_datasets = load_dataset("Niche-Squad/mock-dots","regression-one-class", download_mode="force_redownload") # 使用你的数据集名称加载数据 

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',  # 替换为您正在监视的适当指标
    mode='min',          # 如果你的指标是准确率或类似的，可能需要更改为'max'
    filename='best-model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    verbose=True,
    save_last=True,      # 如果你还想保存最后一个checkpoint
)


    train_dataset = CustomDataset(raw_datasets['train'], transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    val_dataset = CustomDataset(raw_datasets['validation'], transform=transform)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    test_dataset = CustomDataset(raw_datasets['test'], transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    
    model = ImageRegression()
    logger = TensorBoardLogger("tb_logs", name="Resnet50_batch_size_32_epoch_100")
    trainer = pl.Trainer(max_epochs=100,logger = logger)#callbacks=[checkpoint_callback]
    trainer.fit(model, train_loader,val_loader)
    trainer.test(dataloaders=test_loader)  


Downloading builder script: 100%|█████████████████████████████████████████████████| 9.99k/9.99k [00:00<00:00, 5.64MB/s]
Downloading readme: 100%|███████████████████████████████████████████████████████████████████| 30.0/30.0 [00:00<?, ?B/s]
Downloading data: 100%|███████████████████████████████████████████████████████████| 2.60M/2.60M [00:00<00:00, 11.7MB/s]
Downloading data: 100%|███████████████████████████████████████████████████████████████████| 11.3k/11.3k [00:00<?, ?B/s]
Downloading data: 100%|███████████████████████████████████████████████████████████████████| 3.77k/3.77k [00:00<?, ?B/s]
Downloading data: 100%|███████████████████████████████████████████████████████████████████| 3.82k/3.82k [00:00<?, ?B/s]
Generating train split: 100%|███████████████████████████████████████████████| 600/600 [00:00<00:00, 5429.77 examples/s]
Generating validation split: 100%|██████████████████████████████████████████| 200/200 [00:00<00:00, 4951.19 examples/s]
Generating test split: 100%|████████████

Sanity Checking DataLoader 0:   0%|                                                              | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                                                                       

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:   0%|                                                                                  | 0/19 [03:41<?, ?it/s]
Epoch 61:  74%|██████████████▋     | 14/19 [01:54<00:41,  8.21s/it, v_num=2, val_loss_step=0.754, val_loss_epoch=0.441]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
  rank_zero_warn(
Restoring states from the checkpoint path at tb_logs\Resnet50_batch_size_32_epoch_100\version_2\checkpoints\epoch=60-step=1159.ckpt
Loaded model weights from the checkpoint at tb_logs\Resnet50_batch_size_32_epoch_100\version_2\checkpoints\epoch=60-step=1159.ckpt
  rank_zero_warn(


Testing DataLoader 0:  71%|████████████████████████████████████████████▎                 | 5/7 [00:10<00:04,  2.17s/it]