In [1]:
# pytorch
import torch
import torch.nn as nn
# Mobile Net
from torchvision.models.quantization.mobilenetv3 import mobilenet_v3_large
from torch.quantization import prepare_qat, get_default_qat_qconfig, convert
from torchvision.models import quantization
from torch.quantization import QuantStub, DeQuantStub, quantize_dynamic, prepare_qat, convert
# dataset
from torchvision import datasets
from torchvision import transforms
# dataloader
from torch.utils.data import DataLoader
# Util
import time
import datetime
import numpy as np
import matplotlib.pyplot as plt
import copy
# tensorboard
from torch.utils.tensorboard import SummaryWriter
# plt
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'Noto Sans CJK JP'
matplotlib.rcParams.update({'font.size': 18})

In [2]:
# 超参数
input_size = 224
batch_size = 32
n_worker = 8
lr = 0.001
epochs = 50

In [3]:
# 生成训练数据集
train_path = "image/train_image"
test_path = "image/test_image"
data_transform = transforms.Compose([
        transforms.Resize([input_size, input_size]),
        transforms.ToTensor(),
])
train_dataset = datasets.ImageFolder(train_path, transform=data_transform)
test_dataset = datasets.ImageFolder(test_path, transform=data_transform)

In [4]:
# 生成数据加载器
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
    num_workers=n_worker, pin_memory=True)
test_loader = DataLoader(
    test_dataset, batch_size=40, shuffle=False, 
    num_workers=n_worker)

In [5]:
# 定义模型和优化器
model = mobilenet_v3_large()
model.classifier[3] = nn.Linear(1280, 40)

In [6]:
assert model.training
model.fuse_model()

In [7]:
model.qconfig = get_default_qat_qconfig("fbgemm")
model = prepare_qat(model, inplace=True)



In [8]:
def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}
get_parameter_number(model)

{'Total': 4253272, 'Trainable': 4253272}

In [9]:
model = model.cuda(3)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
best_model_wts = copy.deepcopy(model.state_dict())
# writer = SummaryWriter()

In [10]:
def train(epoch, model):
    model.train()
    train_loss = 0
    for data, label in train_loader:
        data, label = data.cuda(3), label.cuda(3)
        # clear the grad
        optimizer.zero_grad()
        output = model(data)
        # loss function
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        # scheduler.step()
        train_loss += loss.item() * data.size(0)
    train_loss = train_loss / len(train_loader.dataset)
    # Re-quantize the model
    # model = quantize_dynamic(model, {'': torch.quantization.default_dynamic_qconfig}, dtype=torch.qint8)
    # writer.add_scalar("Loss/train", train_loss, epoch)
    # loss_vec32.append(train_loss)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))

In [11]:
def test(epoch, model):
    quantized_model = torch.quantization.convert(model.cpu().eval(), inplace=False)
    quantized_model.eval()
    idx = 0
    ans = 0.0
    best_acc = 0.0
    with torch.no_grad():
        for data, label in test_loader:
            output = quantized_model(data)
            preds = torch.argmax(output, 1)
            unique_values, counts = torch.unique(preds, return_counts=True)
            pres = unique_values[counts.argmax()]
            if pres.item() == idx:
                ans += 1
            idx += 1
    acc = ans / 40
    if acc > best_acc:
        best_model_wts = copy.deepcopy(model.state_dict())
    print('Epoch: {} Accuracy: {:6f}'.format(epoch, ans / 40))
    return ans / 40

In [None]:
start_time = time.time()
for epoch in range(1, epochs + 1):
    train(epoch, model)
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Training time {total_time_str}")


  return torch.fused_moving_avg_obs_fake_quant(
  return torch.fused_moving_avg_obs_fake_quant(


Epoch: 1 	Training Loss: 2.073403
Epoch: 2 	Training Loss: 0.503185
Epoch: 3 	Training Loss: 0.446780
Epoch: 4 	Training Loss: 0.402719
Epoch: 5 	Training Loss: 0.381998
Epoch: 6 	Training Loss: 0.227516
Epoch: 7 	Training Loss: 0.368299
Epoch: 8 	Training Loss: 0.202324
Epoch: 9 	Training Loss: 0.146318
Epoch: 10 	Training Loss: 0.266481
Epoch: 11 	Training Loss: 0.393295
Epoch: 12 	Training Loss: 0.184169
Epoch: 13 	Training Loss: 0.223331
Epoch: 14 	Training Loss: 0.154920
Epoch: 15 	Training Loss: 0.054038
Epoch: 16 	Training Loss: 0.022082
Epoch: 17 	Training Loss: 0.011152
Epoch: 18 	Training Loss: 0.048746
Epoch: 19 	Training Loss: 0.037339
Epoch: 20 	Training Loss: 0.118830
Epoch: 21 	Training Loss: 0.171060
Epoch: 22 	Training Loss: 0.201511
Epoch: 23 	Training Loss: 0.093984
Epoch: 24 	Training Loss: 0.041430
Epoch: 25 	Training Loss: 0.037189
Epoch: 26 	Training Loss: 0.069660
Epoch: 27 	Training Loss: 0.053549
Epoch: 28 	Training Loss: 0.026414
Epoch: 29 	Training Loss: 0.0

In [None]:
# 计算测试时间
tic = time.time()
test(1)
toc = time.time()
print(toc - tic)

In [18]:
s = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime())
torch.save(model.load_state_dict(best_model_wts), f'save_model/ecgid_model_{s}')