## CIFAR10/100 分类

In [1]:
import copy
from tqdm.notebook import tqdm
import torch
import torch.optim as optim
import torch.nn.functional as F
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.models import resnet18, ResNet18_Weights
from torchvision import transforms
import os
os.environ['TORCH_HOME'] = '/sun/home_torch'

In [3]:
model = resnet18(weights=ResNet18_Weights.DEFAULT)
# model = resnet18()

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010))
])
trainset = CIFAR10('/sun/home_datasets/cifar', train=True, transform=transform)
testset = CIFAR10('/sun/home_datasets/cifar', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)

## train

In [6]:
def accuracy(test_loader, model):
    num_correct = 0
    num_samples = 0
    model.eval()
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            output = model(data)
            _, preds = output.max(1)
            num_correct += preds.eq(target).sum().item()
            num_samples += data.size(0)
        acc = float(num_correct) / num_samples
        return acc

In [7]:
model = model.cuda()
optimizer = optim.Adam(model.parameters())

In [8]:
best_acc = 0
best_epoch = 0 
epochs = 10
for epoch in tqdm(range(epochs), desc='epochs'):
    model.train()
    for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader), desc='batchs', leave=False):
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        # if batch_idx % 10 == 0:
        #     print(f"{batch_idx}: {loss:.4f}")
    acc = accuracy(test_loader, model)
    if acc > best_acc:
        best_acc = acc
        best_epoch = epoch
        best_model_wts = copy.deepcopy(model.state_dict())
results = {
    "best_acc": best_acc,
    "best_model_wts": best_model_wts,
    "epoch": best_epoch
}
print(f"accuracy: {100*best_acc:.2f}% @{best_epoch}")

epochs:   0%|          | 0/10 [00:00<?, ?it/s]

batchs:   0%|          | 0/391 [00:00<?, ?it/s]

batchs:   0%|          | 0/391 [00:00<?, ?it/s]

batchs:   0%|          | 0/391 [00:00<?, ?it/s]

batchs:   0%|          | 0/391 [00:00<?, ?it/s]

batchs:   0%|          | 0/391 [00:00<?, ?it/s]

batchs:   0%|          | 0/391 [00:00<?, ?it/s]

batchs:   0%|          | 0/391 [00:00<?, ?it/s]

batchs:   0%|          | 0/391 [00:00<?, ?it/s]

batchs:   0%|          | 0/391 [00:00<?, ?it/s]

batchs:   0%|          | 0/391 [00:00<?, ?it/s]

accuracy: 80.99% @6


In [10]:
model.load_state_dict(best_model_wts)
torch.save(model.state_dict(), 'resnet18_cifar.pt')

In [11]:
model.load_state_dict(torch.load('resnet18_cifar.pt', map_location='cpu'))

<All keys matched successfully>

In [16]:
model.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-5.6434e-03,  1.1770e-02,  5.5189e-03,  ...,  7.5806e-02,
                          2.2262e-02, -1.2375e-02],
                        [ 6.1034e-03, -2.4244e-03, -1.3475e-01,  ..., -2.8524e-01,
                         -1.5330e-01, -2.1479e-02],
                        [ 1.1635e-02,  5.5371e-02,  3.2172e-01,  ...,  5.4818e-01,
                          2.5599e-01,  7.3951e-02],
                        ...,
                        [-1.5583e-02,  2.7290e-02,  7.9628e-02,  ..., -3.0897e-01,
                         -4.0501e-01, -2.4802e-01],
                        [ 1.1526e-02,  2.8654e-02,  5.0473e-02,  ...,  4.0818e-01,
                          3.9956e-01,  1.7286e-01],
                        [-1.0594e-02,  4.2732e-03, -3.4698e-02,  ..., -2.0408e-01,
                         -1.2073e-01, -7.8641e-03]],
              
                       [[-9.6453e-04, -8.9982e-03, -3.0208e-02,  ...,  5.2855e-02,
                          1.8708

In [12]:
test_acc = accuracy(test_loader, model)

In [14]:
test_acc

0.8099