Skip to content

Commit

Permalink
fix(classifier): 在训练阶段Inception_v2输出2个分类器结果
Browse files Browse the repository at this point in the history
  • Loading branch information
zjZSTU committed Apr 9, 2020
1 parent 3ba787b commit f760bc5
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions py/classifier_inception_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,20 @@ def train_model(data_loaders, data_sizes, model_name, model, criterion, optimize
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
if phase == 'train':
outputs, aux2, aux1 = model(inputs)
if model_name == 'googlenet_bn':
outputs, aux2, aux1 = model(inputs)

# 仅使用最后一个分类器进行预测
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels) + 0.3 * (
criterion(aux2, labels) + criterion(aux1, labels))
else:
outputs, aux = model(inputs)

# 仅使用最后一个分类器进行预测
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels) + 0.3 * criterion(aux, labels)

# 仅使用最后一个分类器进行预测
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels) + 0.3 * (criterion(aux2, labels) + criterion(aux1, labels))
else:
outputs = model(inputs)
# print(outputs.shape)
Expand Down

0 comments on commit f760bc5

Please sign in to comment.