Skip to content

Commit

Permalink
feat(finetune): 微调AlexNet_SPP的全连接层
Browse files Browse the repository at this point in the history
  • Loading branch information
zjZSTU committed Mar 30, 2020
1 parent 8be3276 commit ff32913
Showing 1 changed file with 134 additions and 0 deletions.
134 changes: 134 additions & 0 deletions py/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# -*- coding: utf-8 -*-

"""
@date: 2020/3/25 下午3:51
@file: finetune.py
@author: zj
@description:
"""

import os
import copy
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

import models.alexnet_spp as alexnet_spp
from utils.data.custom_finetune_dataset import CustomFinetuneDataset
from utils.data.custom_batch_sampler import CustomBatchSampler
from utils.util import check_dir


def load_data(data_root_dir, model, device, s=688):
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(s),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

data_loaders = {}
data_sizes = {}
for name in ['train', 'val']:
data_dir = os.path.join(data_root_dir, name)
data_set = CustomFinetuneDataset(data_dir, transform, model, device, s)
data_sampler = CustomBatchSampler(data_set.get_positive_num(), data_set.get_negative_num(), 32, 96)
data_loader = DataLoader(data_set, batch_size=128, sampler=data_sampler, num_workers=8, drop_last=True)

data_loaders[name] = data_loader
data_sizes[name] = data_sampler.__len__()

return data_loaders, data_sizes


def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epochs=25, device=None):
since = time.time()

best_model_weights = copy.deepcopy(model.state_dict())
best_acc = 0.0

for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)

# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode

running_loss = 0.0
running_corrects = 0

# Iterate over data.
for inputs, labels, cache_dicts in data_loaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)

# zero the parameter gradients
optimizer.zero_grad()

# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model.classify(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)

# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()

# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if phase == 'train':
lr_scheduler.step()

epoch_loss = running_loss / data_sizes[phase]
epoch_acc = running_corrects.double() / data_sizes[phase]

print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))

# deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_weights = copy.deepcopy(model.state_dict())

print()

time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))

# load best model weights
model.load_state_dict(best_model_weights)
return model


if __name__ == '__main__':
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = "cpu"
model = alexnet_spp.alexnet_spp(num_classes=2)
model = model.to(device)

data_loaders, data_sizes = load_data('./data/finetune_car', model, device, s=688)

criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)

best_model = train_model(data_loaders, model, criterion, optimizer, lr_scheduler, device=device, num_epochs=50)
# 保存最好的模型参数
check_dir('./models')
torch.save(best_model.state_dict(), 'models/alexnet_car.pth')

0 comments on commit ff32913

Please sign in to comment.