In [2]:
import os
import torch
import pytorch_lightning as pl
import pandas as pd
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from PIL import Image
from datasets import load_dataset
import torch.nn.functional as F
from pytorch_lightning.loggers import TensorBoardLogger

  from .autonotebook import tqdm as notebook_tqdm


In [21]:
class CustomDataset(Dataset):
    def __init__(self,data,transform=None):
        self.data = data
        self.transform=transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self,idx):
        item = self.data[idx]
        image = item['image']
        label = int(item['label'])-1

        if self.transform:
            image = self.transform(image)

        return image,label

In [23]:
class ImageClassifier(pl.LightningModule):
    def __init__(self, num_classes):
        super(ImageClassifier, self).__init__()
        self.model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = F.cross_entropy(outputs, labels)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = F.cross_entropy(outputs, labels)
        _, preds = torch.max(outputs,1)
        acc = (preds == labels).float().mean()
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        return {"val_loss":loss,"val_acc":acc}

#    def on_validation_epoch_end(self):
#        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
#        avg_acc = torch.stack([x['val_acc']for x in outputs]).mean()
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [24]:

if __name__ == "__main__":
    raw_datasets = load_dataset("Niche-Squad/mock-dots","regression-one-class", download_mode="force_redownload") # 使用你的数据集名称加载数据 

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataset = CustomDataset(raw_datasets['train'], transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    val_dataset = CustomDataset(raw_datasets['validation'], transform=transform)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    
    model = ImageClassifier(num_classes=15)
    logger = TensorBoardLogger("tb_logs", name="model")
    trainer = pl.Trainer(max_epochs=100,logger = logger)
    trainer.fit(model, train_loader,val_loader)

Downloading builder script: 100%|██████████| 9.99k/9.99k [00:00<?, ?B/s]
Downloading readme: 100%|██████████| 30.0/30.0 [00:00<?, ?B/s]
Downloading data: 100%|██████████| 2.60M/2.60M [00:00<00:00, 13.7MB/s]
Downloading data: 100%|██████████| 11.3k/11.3k [00:00<00:00, 11.0MB/s]
Downloading data: 100%|██████████| 3.77k/3.77k [00:00<00:00, 3.77MB/s]
Downloading data: 100%|██████████| 3.82k/3.82k [00:00<?, ?B/s]
Generating train split: 100%|██████████| 600/600 [00:00<00:00, 16891.40 examples/s]
Generating validation split: 100%|██████████| 200/200 [00:00<00:00, 13334.51 examples/s]
Generating test split: 100%|██████████| 200/200 [00:00<00:00, 13232.70 examples/s]
Using cache found in C:\Users\吴晓辉/.cache\torch\hub\pytorch_vision_v0.10.0
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: tb_logs\model


RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
