In [1]:
# pytorch
import torch
import torch.nn as nn
# Mobile Net
from torchvision.models.quantization.mobilenetv3 import mobilenet_v3_large
from torch.quantization import prepare_qat, get_default_qat_qconfig, convert
from torchvision.models import quantization
from torch.quantization import QuantStub, DeQuantStub, quantize_dynamic, prepare_qat, convert
# dataset
from torchvision import datasets
from torchvision import transforms
# dataloader
from torch.utils.data import DataLoader
# Util
import time
import datetime
import numpy as np
import matplotlib.pyplot as plt
import copy
# tensorboard
from torch.utils.tensorboard import SummaryWriter
# plt
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'Noto Sans CJK JP'
matplotlib.rcParams.update({'font.size': 18})

In [2]:
# 超参数
input_size = 224
batch_size = 32
n_worker = 8
lr = 0.001
epochs = 50

In [3]:
# 生成训练数据集
train_path = "image/train_image"
test_path = "image/test_image"
data_transform = transforms.Compose([
        transforms.Resize([input_size, input_size]),
        transforms.ToTensor(),
])
train_dataset = datasets.ImageFolder(train_path, transform=data_transform)
test_dataset = datasets.ImageFolder(test_path, transform=data_transform)

In [4]:
# 生成数据加载器
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
    num_workers=n_worker, pin_memory=True)
test_loader = DataLoader(
    test_dataset, batch_size=20, shuffle=False, 
    num_workers=n_worker)

In [5]:
# 定义模型和优化器
model = mobilenet_v3_large()
model.classifier[3] = nn.Linear(1280, 90)

In [6]:
model.fuse_model()

In [7]:
model.qconfig = get_default_qat_qconfig("fbgemm")
prepare_qat(model, inplace=True)



QuantizableMobileNetV3(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): ConvBn2d(
        3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
        (bn): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
          fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_channel_symmetric, reduce_range=False
          (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
        )
        (activation_post_process): FusedMovingAvgObsFakeQuantize(
          fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True


In [8]:
model = model.cuda(3)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
best_model_wts = None
# writer = SummaryWriter()

In [9]:
def train(epoch, model):
    model.train()
    train_loss = 0
    for data, label in train_loader:
        data, label = data.cuda(3), label.cuda(3)
        # clear the grad
        optimizer.zero_grad()
        output = model(data)
        # loss function
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        # scheduler.step()
        train_loss += loss.item() * data.size(0)
    train_loss = train_loss / len(train_loader.dataset)
    # Re-quantize the model
    # model = quantize_dynamic(model, {'': torch.quantization.default_dynamic_qconfig}, dtype=torch.qint8)
    # writer.add_scalar("Loss/train", train_loss, epoch)
    # loss_vec32.append(train_loss)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))

In [10]:
def test(epoch, model):
    quantized_model = torch.quantization.convert(model.cpu().eval(), inplace=False)
    quantized_model.eval()
    idx = 0
    ans = 0.0
    best_acc = 0.0
    with torch.no_grad():
        for data, label in test_loader:
            output = quantized_model(data)
            preds = torch.argmax(output, 1)
            unique_values, counts = torch.unique(preds, return_counts=True)
            pres = unique_values[counts.argmax()]
            if pres.item() == idx:
                ans += 1
            idx += 1
    acc = ans / 90
    if acc > best_acc:
        best_model_wts = copy.deepcopy(quantized_model.state_dict())
        # print(best_model_wts)
    print('Epoch: {} Accuracy: {:6f}'.format(epoch, acc))
    return best_model_wts

In [11]:
start_time = time.time()
for epoch in range(1, epochs + 1):
    train(epoch, model)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Training time {total_time_str}")

  return torch.fused_moving_avg_obs_fake_quant(
  return torch.fused_moving_avg_obs_fake_quant(


Epoch: 1 	Training Loss: 4.314790
Epoch: 2 	Training Loss: 3.867196
Epoch: 3 	Training Loss: 3.877579
Epoch: 4 	Training Loss: 3.649849
Epoch: 5 	Training Loss: 3.932557
Epoch: 6 	Training Loss: 3.877930
Epoch: 7 	Training Loss: 7.346864
Epoch: 8 	Training Loss: 3.705178
Epoch: 9 	Training Loss: 3.245525
Epoch: 10 	Training Loss: 2.982609
Epoch: 11 	Training Loss: 2.738164
Epoch: 12 	Training Loss: 2.533820
Epoch: 13 	Training Loss: 2.373858
Epoch: 14 	Training Loss: 2.236841
Epoch: 15 	Training Loss: 2.129011
Epoch: 16 	Training Loss: 1.948139
Epoch: 17 	Training Loss: 1.823758
Epoch: 18 	Training Loss: 1.720606
Epoch: 19 	Training Loss: 1.533189
Epoch: 20 	Training Loss: 1.545346
Epoch: 21 	Training Loss: 1.353186
Epoch: 22 	Training Loss: 1.282354
Epoch: 23 	Training Loss: 1.239254
Epoch: 24 	Training Loss: 1.184227
Epoch: 25 	Training Loss: 1.005921
Epoch: 26 	Training Loss: 0.940876
Epoch: 27 	Training Loss: 0.852463
Epoch: 28 	Training Loss: 0.916623
Epoch: 29 	Training Loss: 0.7

In [19]:
quantized_model = torch.quantization.convert(model.cpu().eval(), inplace=False)
quantized_model.eval()
print(quantized_model.state_dict())
s = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime())
torch.save(quantized_model.state_dict(), f'save_model/int8/ecgid_model_{s}.pt')

OrderedDict([('features.0.0.weight', tensor([[[[-0.6229, -0.3594,  1.5094],
          [ 0.4073, -0.0240,  2.4917],
          [-0.8146, -2.1083, -0.7427]],

         [[-2.9469, -0.1917,  0.4552],
          [-0.1198, -0.7427,  0.0240],
          [ 1.2458, -0.2156,  0.3354]],

         [[-0.4073,  0.3833, -1.6052],
          [-1.1500,  0.7188,  0.9823],
          [ 0.9583,  0.7667,  0.8625]]],


        [[[-0.7083,  0.2471,  1.9603],
          [-0.4118,  0.0000,  0.5271],
          [ 0.3624, -0.8895, -0.8566]],

         [[ 0.4283,  0.0988,  0.3789],
          [ 0.2306,  0.2800,  0.2800],
          [ 1.0543,  0.6095,  0.3954]],

         [[-0.1153, -0.6589,  0.2306],
          [-0.1483, -0.2471,  2.0921],
          [ 0.2141, -0.0494, -0.3130]]],


        [[[-0.4073,  0.8400, -0.0255],
          [ 1.0819, -0.4200,  0.0127],
          [ 0.0000, -0.3564,  0.9928]],

         [[ 0.4073,  0.0127, -0.1909],
          [ 0.1146,  0.2036,  0.3309],
          [-0.7637, -0.5855, -0.1909]],

       

In [7]:
model = mobilenet_v3_large()
state_dict = torch.load('/home/yangaowei/ecg_id/save_model/int8/ecgid_model_2024_05_18_15_11_23.pt')
model.qconfig = get_default_qat_qconfig("fbgemm")
# prepare_qat(model, inplace=True)
print(state_dict)

None


In [14]:
print(best_model_wts)
s = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime())
torch.save(best_model_wts, f'save_model/int8/ecgid_model_{s}.pt')

None
