In [1]:
import sys

sys.path.append('../')

import os
import torch
import torch.optim as optim
from classify_model import LeNet
from classify_model import LeNetGray
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score
from tqdm import tqdm
import pandas as pd
# from data.data import ImageDataset
from sklearn.metrics import f1_score
from data.data import ImageDataset
import numpy as np
from torch import nn
import time

from classify_model import VGG16


datasets_fold='/home/fulin.xc/ride-route/data/cluster'
label='/home/fulin.xc/ride-route/data/cluster/labels_cluster.csv'
checkpoints_dir='../../model'
epoch_num=5
batch_size=32
learning_rate=0.001
seed=10

In [2]:
model = VGG16()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# transform = transforms.Compose([
#     transforms.Resize((512, 512)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
# ])
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # 添加灰度转换
    transforms.RandomHorizontalFlip(),      # 随机水平翻转
    transforms.RandomRotation(10),          # 随机旋转
    # transforms.Resize((512, 512)),
    transforms.RandomResizedCrop(512),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),  # 注意均值和标准差的维度
])

In [3]:
data = pd.read_csv(label)
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

train_dataset = ImageDataset(datasets_fold, train_data, transform=transform)
valid_dataset = ImageDataset(datasets_fold, test_data, transform=transform)
test_dataset = ImageDataset(datasets_fold, data, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [9]:
def evaluate(dataloader, epoch_th, dataset_name):
    model.eval()
    all_preds = []
    all_labels = []
    total=0
    correct=0
    start = time.time()
    with torch.no_grad():
        for batch in dataloader:
            images, valid_label_ids = batch
            # batch_text_ids, batch_text_mask, batch_label = batch
            y_hat = model(images)
            valid_predicted = (y_hat.reshape(-1) > 0.5).int()
            all_preds.extend(valid_predicted.tolist())
            all_labels.extend(valid_label_ids.tolist())

            total += len(valid_label_ids)
            correct += valid_predicted.eq(valid_label_ids).sum().item()

    test_acc = correct/total
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    end = time.time()
    print("Epoch:{}, Acc:{:.2f}, Precision: {:.2f}, Recall: {:.2f}, F1: {:.2f} on {}, Spend: {:.2f} minutes for evaluation".format(
                        epoch_th, 100.*test_acc, 100.*precision, 100.*recall, 100.*f1, dataset_name, (end-start)/60.0))
    return test_acc, f1

model.load_state_dict(torch.load('/home/fulin.xc/ride-route/checkpoints/classify_epoch_20_f1_0.78_checkpoint.pt', weights_only=False)['model_state'])

valid_f1_prev = 0
epoch_num = 10
for epoch in range(21, epoch_num + 21):
    model.train()
    tr_loss = 0

    for step, batch in enumerate(train_loader):
        images, labels = batch
        outputs = model(images)  # 前向传播
        loss = criterion(outputs.squeeze(), labels.float())  # 计算损失
        tr_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()  # 反向传播
        optimizer.step()  # 更新权重

        if step % 5 == 0:
            print("Epoch:{}-{}/{}, loss: {:.6f} ".format(epoch, step, len(train_loader), loss.item()))

    print("Epoch:{} completed, Total training's Loss: {}".format(epoch, tr_loss))

    valid_acc, valid_f1 = evaluate(valid_loader, epoch, 'Valid_set')

    # if valid_f1 > valid_f1_prev:
    torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'valid_acc': valid_acc, 'valid_f1': valid_f1}, os.path.join('../../checkpoints', f'classify_epoch_{epoch}_f1_{valid_f1:.2f}_checkpoint.pt'))
        # valid_f1_prev = valid_f1


evaluate(test_loader, epoch_num, "Totol_set")


Epoch:21-0/44, loss: 0.480280 
Epoch:21-5/44, loss: 0.157334 
Epoch:21-10/44, loss: 0.293924 
Epoch:21-15/44, loss: 0.269012 
Epoch:21-20/44, loss: 0.391951 
Epoch:21-25/44, loss: 0.293093 
Epoch:21-30/44, loss: 0.223626 
Epoch:21-35/44, loss: 0.092261 
Epoch:21-40/44, loss: 0.123194 
Epoch:21 completed, Total training's Loss: 9.93380505591631
Epoch:21, Acc:91.62, Precision: 97.96, Recall: 63.16, F1: 76.80 on Valid_set, Spend: 0.12 minutes for evaluation
Epoch:22-0/44, loss: 0.150135 
Epoch:22-5/44, loss: 0.148114 
Epoch:22-10/44, loss: 0.186896 
Epoch:22-15/44, loss: 0.141209 
Epoch:22-20/44, loss: 0.205482 
Epoch:22-25/44, loss: 0.222861 
Epoch:22-30/44, loss: 0.112183 
Epoch:22-35/44, loss: 0.464005 
Epoch:22-40/44, loss: 0.137915 
Epoch:22 completed, Total training's Loss: 9.876605287194252
Epoch:22, Acc:93.06, Precision: 98.15, Recall: 69.74, F1: 81.54 on Valid_set, Spend: 0.12 minutes for evaluation
Epoch:23-0/44, loss: 0.497555 
Epoch:23-5/44, loss: 0.160612 
Epoch:23-10/44, los

(0.9352601156069364, 0.8127090301003345)

In [None]:
torch.save({'epoch': epoch_num, 'model_state': model.state_dict(), 'valid_acc': valid_acc, 'valid_f1': valid_f1}, os.path.join('../../checkpoints', 'classify_checkpoint.pt'))