一般来说迁移学习都服从下面的框架：
+ 初始化预先训练的模型
+ 重构最后的网络层，使得网络的数据类别和新数据集的类别相同
+ 定义希望优化的参数
+ 训练

In [11]:
import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import torch.optim as optim
import os
data_dir = "./data/hymenoptera_data"
model_name = "resnet"
num_class = 2
batch_size = 8
num_epoch = 10

In [12]:
"""
    在训练数据集和中，采用了RandomHorizontalFlip方法，由于其固有的随机性，dataloader获取数据时每次
获取的数据都具有随机性，实现了对数据的增强。
"""
data_transform = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}


image_datasets = {
    x: datasets.ImageFolder(
        os.path.join(
            data_dir,
            x),
        transform=data_transform[x]) for x in [
        'train',
        'val']}
data_loader = {
    x: data.DataLoader(
        image_datasets[x],
        batch_size=8,
        shuffle=True) for x in [
            'train',
        'val']}
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = models.resnet101(pretrained=True)

In [13]:

"""
Finetune
"""
model_ft = model.to(device)
model_ft.fc = nn.Linear(2048,2)
optimizer_ft = optim.SGD(params=model_ft.fc.parameters(),lr=0.001,momentum=0.9)
criterion = nn.CrossEntropyLoss()

In [19]:
def test(model_ft):
    corretct = 0.0
    model_ft.eval()
    num = 0
    for x, y in data_loader['val']:
        x = x.to(device)
        y = y.to(device)
        num = num + x.size(0)
        output = model_ft(x)  # batch * class 的tensor
        _, prediction = output.max(1)
        corretct += prediction.eq(y).sum().item()
    accuracy = corretct / num
    print("Accuracy = ", accuracy, corretct, num)

In [18]:

for epoch in range(num_epoch):
    print("Epoch:",epoch+1)
    for step,(x_data,label) in enumerate(data_loader['train']):
        x_data = x_data.to(device)
        label = label.to(device)
        
        model_ft.train()
        optimizer_ft.zero_grad()
        output = model_ft(x_data)

        loss = criterion(output,label)
        loss.backward()
        optimizer_ft.step()
    model_ft.eval()
    test(model_ft)



Epoch: 1


KeyboardInterrupt: 

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [None]:
model.fc = nn.Linear(512,2)

In [None]:
model