In [1]:
import import_ipynb
import os
import random
import csv
from PIL import Image
import torch
from torch import optim
from torch import nn
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.models import resnet18

# Pokemon类定义

In [2]:
class Pokemon(Dataset):
    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()
        self.root = root
        self.resize = resize
        self.name2label = {}
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root, name)):
                continue
            self.name2label[name] = len(self.name2label.keys())
        print(self.name2label)
        # image, label
        self.images, self.labels = self.load_csv('images.csv')
        if mode == 'train':  # 60%
            self.images = self.images[: int(0.6 * len(self.images))]
            self.labels = self.labels[: int(0.6 * len(self.labels))]
        elif mode == 'val':  # 20% = 60% -> 80%
            self.images = self.images[int(0.6 * len(self.images)): int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)): int(0.8 * len(self.labels))]
        else:  # 20% = 80% -> 100%
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

    # 加载csv
    def load_csv(self, filename):
        if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.name2label.keys():
                # 'pokemon\\mewtwo\\00001.png
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            # 1167, 'pokemon\\bulbasaur\\00000000.png'
            # print(len(images), images)
            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images:  # 'pokemon\\bulbasaur\\00000000.png'
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    # 'pokemon\\bulbasaur\\00000000.png', 0
                    writer.writerow([img, label])
                print('writen into csv file:', filename)
        # read from csv file
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                # 'pokemon\\bulbasaur\\00000000.png', 0
                img, label = row
                label = int(label)
                images.append(img)
                labels.append(label)
        assert len(images) == len(labels)
        return images, labels

    def __len__(self):
        return len(self.images)

    def denormalize(self, x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hat = (x-mean) / std
        # x = x_hat*std + mean
        # x: [c, h, w]
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        # std: [3] => [3, 1, 1]
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        # print(mean.shape, std.shape)
        # 广播：[3, 1, 1] => [3, h, w]，对应元素相乘，对应元素相加
        x = x_hat * std + mean
        return x

    def __getitem__(self, idx):
        # idx~[0~len(images)]
        # self.images, self.labels
        # img: 'pokemon\\bulbasaur\\00000000.png'
        # label: 0
        img, label = self.images[idx], self.labels[idx]
        tf = transforms.Compose([
            lambda x: Image.open(x).convert('RGB'),  # string path= > image data
            transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        img = tf(img)
        label = torch.tensor(label)
        return img, label

In [3]:
batchsz = 32
lr = 1e-3
epochs = 10

In [4]:
device = torch.device('cuda')
torch.manual_seed(1234)

<torch._C.Generator at 0x1ff46fbb5f0>

# 数据加载

In [5]:
train_db = Pokemon('./data/pokemon', 224, mode='train')
val_db = Pokemon('./data/pokemon', 224, mode='val')
test_db = Pokemon('./data/pokemon', 224, mode='test')
# train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True)
# val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
val_loader = DataLoader(val_db, batch_size=batchsz)
# test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz)

{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}


# 模型加载

In [6]:
# Tensor二维化
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

In [7]:
# 加载预训练模型
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[: -1],  # [b, 512, 1, 1]
                      Flatten(),  # [b, 512, 1, 1] => [b, 512]
                      nn.Linear(512, 5)  # [b, 512] => [b, 5]
                      ).to(device)
# x = torch.randn(2, 3, 224, 224)
# print(model(x).shape)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()



# 模型训练

In [8]:
def train(model, loader, epoch):
    for step, (x, y) in enumerate(train_loader):
        # x: [b, 3, 224, 224], y: [b]
        x, y = x.to(device), y.to(device)
        model.train()
        logits = model(x)
        loss = criteon(logits, y)
        print("epoch: {}, traing step: {}, training loss: {}".format(epoch, step, loss.item()))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# 模型评估

In [9]:
# 模型评估
def evalute(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()
    return correct / total

In [10]:
def main():
    best_acc, best_epoch = 0, 0
    global_step = 0
    for epoch in range(epochs):
        train(model, train_loader, epoch)
        global_step += 1
        if epoch % 1 == 0:
            val_acc = evalute(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc
                torch.save(model.state_dict(), './model/best.mdl')
    print('best acc:', best_acc, 'best epoch:', best_epoch)
    model.load_state_dict(torch.load('./model/best.mdl'))
    print('loaded from ckpt!')
    test_acc = evalute(model, test_loader)
    print('test acc:', test_acc)

In [11]:
if __name__ == '__main__':
    main()

epoch: 0, traing step: 0, training loss: 1.765452265739441




epoch: 0, traing step: 1, training loss: 0.9145488142967224
epoch: 0, traing step: 2, training loss: 0.3724343180656433
epoch: 0, traing step: 3, training loss: 0.553935706615448
epoch: 0, traing step: 4, training loss: 0.1549140214920044
epoch: 0, traing step: 5, training loss: 0.35490691661834717
epoch: 0, traing step: 6, training loss: 0.4270232021808624
epoch: 0, traing step: 7, training loss: 0.19795778393745422
epoch: 0, traing step: 8, training loss: 0.11685838550329208
epoch: 0, traing step: 9, training loss: 0.44310757517814636
epoch: 0, traing step: 10, training loss: 0.1910984069108963
epoch: 0, traing step: 11, training loss: 1.0186823606491089
epoch: 0, traing step: 12, training loss: 0.1295214742422104
epoch: 0, traing step: 13, training loss: 0.19655130803585052
epoch: 0, traing step: 14, training loss: 0.11039809882640839
epoch: 0, traing step: 15, training loss: 0.15897434949874878
epoch: 0, traing step: 16, training loss: 0.3864029049873352
epoch: 0, traing step: 17, 

epoch: 6, traing step: 7, training loss: 0.12729626893997192
epoch: 6, traing step: 8, training loss: 0.07294700294733047
epoch: 6, traing step: 9, training loss: 0.05210501700639725
epoch: 6, traing step: 10, training loss: 0.039170973002910614
epoch: 6, traing step: 11, training loss: 0.03699680045247078
epoch: 6, traing step: 12, training loss: 0.012232406064867973
epoch: 6, traing step: 13, training loss: 0.004225059412419796
epoch: 6, traing step: 14, training loss: 0.23048371076583862
epoch: 6, traing step: 15, training loss: 0.08058521151542664
epoch: 6, traing step: 16, training loss: 0.022507579997181892
epoch: 6, traing step: 17, training loss: 0.002821570262312889
epoch: 6, traing step: 18, training loss: 0.4455980658531189
epoch: 6, traing step: 19, training loss: 0.07919394224882126
epoch: 6, traing step: 20, training loss: 0.003211475443094969
epoch: 6, traing step: 21, training loss: 0.008389432914555073
epoch: 7, traing step: 0, training loss: 0.146104633808136
epoch: 7