# 加载模型

导入必要的库

In [1]:
import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from torch import nn
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import torch.nn.functional as F
from sklearn.base import BaseEstimator, TransformerMixin
import math
from sklearn.metrics import accuracy_score,confusion_matrix, precision_score, recall_score, f1_score
import scipy.optimize as opt
import torch.distributions as dist
from sklearn.metrics import accuracy_score
import argparse
import warnings
from tqdm.notebook import tqdm_notebook as tqdm
from collections import defaultdict

In [2]:
warnings.filterwarnings('ignore')

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

超参数设置

In [4]:
# 超参数
tem = 0.02
bs = 128
seed = 5009
seed_round = 5
epochs = 800
epoch_online=1
sample_interval = 2784
flip_percent = 0.05

## 加载数据集

In [5]:
def get_dataset():
    train_data=pd.read_csv('en_UNSW_NB15_train.csv')
    test_data=pd.read_csv('en_UNSW_NB15_test.csv')
    y_train=train_data['label']
    y_test=test_data['label']
    X_train=train_data.drop(columns=['label'])
    X_test=test_data.drop(columns=['label'])
    normalize=MinMaxScaler()
    X_train=normalize.fit_transform(X_train)
    X_test=normalize.fit_transform(X_test)
    return X_train,y_train,X_test,y_test

In [6]:
x_train,y_train,x_test,y_test=get_dataset()
# 转换为torch张量
x_train,y_train=torch.FloatTensor(x_train).to(device),torch.LongTensor(y_train).to(device)
x_test,y_test=torch.FloatTensor(x_test).to(device),torch.LongTensor(y_test).to(device)

In [7]:
def evaluate(y,y_pred):
    y= y.cpu().detach().numpy()
    y_pred= y_pred.cpu().detach().numpy()
    # 混淆矩阵
    print("Confusion matrix")
    print(confusion_matrix(y, y_pred))
    # Accuracy 
    print('Accuracy ',accuracy_score(y, y_pred))
    # Precision 
    print('Precision ',precision_score(y, y_pred))
    # Recall
    print('Recall ',recall_score(y, y_pred))
    # F1 score
    print('F1 score ',f1_score(y,y_pred))

In [8]:
import matplotlib.pyplot as plt
def plot_loss(losses,epoch):
    plt.figure(figsize=(10,5))
    plt.title(f"{epoch} Training Loss")
    plt.plot(losses, marker='o')
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.show()

## 加载模型

AE自编码器模块

In [9]:
class AE(nn.Module):
    def __init__(self, input_dim):
        super(AE, self).__init__()
        # 计算输入维度的最近的2的幂次方，比如输入维度是206，则最近的2的幂次方是128
        nearest_power_of_2 = 2 ** round(math.log2(input_dim))
        
        second_fourth_layer_size = nearest_power_of_2 // 2
        third_layer_size = nearest_power_of_2 // 4
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, second_fourth_layer_size),
            nn.ReLU(),
            nn.Linear(second_fourth_layer_size, third_layer_size)
        )
        self.decoder = nn.Sequential(
            nn.ReLU(),
            nn.Linear(third_layer_size, second_fourth_layer_size),
            nn.ReLU(),
            nn.Linear(second_fourth_layer_size, input_dim)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

CRC对比损失模块

In [10]:
class CRCLoss(nn.Module):
    def __init__(self, device, temperature=0.1, scale_by_temperature=True):
        super(CRCLoss, self).__init__()
        self.device = device
        self.temperature = temperature
        self.scale_by_temperature = scale_by_temperature

    def forward(self, features, labels=None, recon_features=None):
        # 计算特征的归一化表示
        features = F.normalize(features, p=2, dim=1)
        batch_size = features.shape[0]
        num_norm = len(labels==0)
        ## contiguous方法确保张量在内存是连续存储的，因为变换视图操作需要确保张量在内存中是连续存储的
        labels = labels.contiguous().view(-1, 1) # batch_size * 1
        
        if labels.shape[0] != batch_size:
            raise ValueError('Batch size of features and labels do not match.')
        
        # 计算余弦相似度，cosine_sim[i][j]表示features[i]和features[j]的余弦相似度，矩阵大小为batch_size * batch_size
        cosine_sim = torch.nn.functional.cosine_similarity(features.unsqueeze(1), features.unsqueeze(0), dim = -1)/self.temperature
        # 将余弦相似度的对角线元素设置为0
        mask_diag = torch.eye(batch_size, dtype=torch.bool)
        cosine_sim[mask_diag] = 0
        # 正样本对
        sim_pos = cosine_sim[(labels==0).squeeze()]
        # 正样本对间的余弦相似度
        sim_pos_pos = sim_pos[:,(labels==0).squeeze()] 
        # 正样本与负样本间的余弦相似度
        sim_pos_neg = sim_pos[:,(labels==1).squeeze()]
        # 计算正样本与负样本间的分数和
        # sum_pos_neg = torch.sum(torch.exp(sim_pos_neg), axis=1, keepdims=True)
        sum_pos_neg = torch.sum(torch.exp(sim_pos_neg))
        # 计算分母
        denominator = torch.exp(sim_pos_pos) + sum_pos_neg       
        loss = -(sim_pos_pos-torch.log(denominator))  
        
        
        if self.scale_by_temperature:
            loss = loss * self.temperature
        # 计算损失
        loss= loss.mean()
        return loss

## ADM自主决策

In [11]:
def gaussian_pdf(x, mean, std):
    pdf = 1/(std*np.sqrt(2*np.pi))*np.exp(-(x-mean)**2/(2*std**2))
    return pdf
def log_likelihood(params,data):
    data = data.cpu().detach().numpy()
    mean_pos_enc, std_pos_enc, mean_neg_enc, std_neg_enc = params
    pdf1 = gaussian_pdf(data, mean_pos_enc, std_pos_enc)
    pdf2 = gaussian_pdf(data, mean_neg_enc, std_neg_enc)
    mixture_pdf = 0.5*pdf1 + 0.5*pdf2
    log_likelihood = -np.sum(np.log(mixture_pdf))
    
    return log_likelihood
def predict(norm_enc, norm_dec, x_train, y_train, x_test, model):
    x_train_pos=x_train[(y_train==0).squeeze()]
    x_train_neg=x_train[(y_train==1).squeeze()]
    
    # 训练集全体样本编码和解码特征
    train_enc = F.normalize(model(x_train)[0], p=2, dim=1)
    train_dec = F.normalize(model(x_train)[1], p=2, dim=1)
    # 训练集正样本编码和解码特征
    train_enc_pos = F.normalize(model(x_train_pos)[0], p=2, dim=1)
    train_dec_pos = F.normalize(model(x_train_pos)[1], p=2, dim=1)
    # 训练集负样本编码和解码特征
    train_enc_neg = F.normalize(model(x_train_neg)[0], p=2, dim=1)
    train_dec_neg = F.normalize(model(x_train_neg)[1], p=2, dim=1)
    
    # 分别计算训练集正样本特征和平均正样本特征的余弦相似度 负样本同理 全体样本同理 
    sim_pos_norm_enc = F.cosine_similarity(train_enc_pos, norm_enc.unsqueeze(0), dim=1) 
    sim_pos_norm_dec = F.cosine_similarity(train_dec_pos, norm_dec.unsqueeze(0), dim=1)
    sim_neg_norm_enc = F.cosine_similarity(train_enc_neg, norm_enc.unsqueeze(0), dim=1)
    sim_neg_norm_dec = F.cosine_similarity(train_dec_neg, norm_dec.unsqueeze(0), dim=1)
    sim_all_norm_enc = F.cosine_similarity(train_enc, norm_enc.unsqueeze(0), dim=1)
    sim_all_norm_dec = F.cosine_similarity(train_dec, norm_dec.unsqueeze(0), dim=1)
    # 进行排序
    sort_sim_pos_norm_enc, indices = torch.sort(sim_pos_norm_enc)
    sort_sim_pos_norm_dec, indices = torch.sort(sim_pos_norm_dec)
    sort_sim_neg_norm_enc, indices = torch.sort(sim_neg_norm_enc)
    sort_sim_neg_norm_dec, indices = torch.sort(sim_neg_norm_dec)

    
    ## 初始化参数
    mean_pos_enc = torch.mean(sort_sim_pos_norm_enc).cpu().detach().numpy()
    std_pos_enc = torch.std(sort_sim_pos_norm_enc).cpu().detach().numpy()
    mean_pos_dec = torch.mean(sort_sim_pos_norm_dec).cpu().detach().numpy()
    std_pos_dec = torch.std(sort_sim_pos_norm_dec).cpu().detach().numpy()
    mean_neg_enc = torch.mean(sort_sim_neg_norm_enc).cpu().detach().numpy()
    std_neg_enc = torch.std(sort_sim_neg_norm_enc).cpu().detach().numpy()
    mean_neg_dec = torch.mean(sort_sim_neg_norm_dec).cpu().detach().numpy()
    std_neg_dec = torch.std(sort_sim_neg_norm_dec).cpu().detach().numpy()
    initial_params_enc = [mean_pos_enc, std_pos_enc, mean_neg_enc, std_neg_enc]
    initial_params_dec = [mean_pos_dec, std_pos_dec, mean_neg_dec, std_neg_dec]
    # 拟合高斯分布
    fit_enc = opt.minimize(log_likelihood, initial_params_enc, args=(sim_all_norm_enc,), method='Nelder-Mead')
    fit_dec = opt.minimize(log_likelihood, initial_params_dec, args=(sim_all_norm_dec,), method='Nelder-Mead') 
    # print("enc init")
    # print(initial_params_enc)
    # print("dec init")
    # print(initial_params_dec)
    mean1_enc, std1_enc, mean2_enc, std2_enc = fit_enc.x
    mean1_dec, std1_dec, mean2_dec, std2_dec = fit_dec.x
    # print("encoder:")
    # print(mean1_enc, std1_enc, mean2_enc, std2_enc)
    # print("decoder:")
    # print(mean1_dec, std1_dec, mean2_dec, std2_dec)
    # 选择均值小的作为正常样本的均值
    if mean1_enc < mean2_enc:
        mean_pos_enc, mean_neg_enc = mean1_enc, mean2_enc
        std_pos_enc, std_neg_enc = std1_enc, std2_enc
        gaussian_pos_enc = dist.Normal(mean_pos_enc, std_pos_enc)
        gaussian_neg_enc = dist.Normal(mean_neg_enc, std_neg_enc)
    else:
        mean_pos_enc, mean_neg_enc = mean2_enc, mean1_enc
        std_pos_enc, std_neg_enc = std2_enc, std1_enc
        gaussian_pos_enc = dist.Normal(mean_pos_enc, std_pos_enc)
        gaussian_neg_enc = dist.Normal(mean_neg_enc, std_neg_enc)
    if mean1_dec < mean2_dec:
        mean_pos_dec, mean_neg_dec = mean1_dec, mean2_dec
        std_pos_dec, std_neg_dec = std1_dec, std2_dec
        gaussian_pos_dec = dist.Normal(mean_pos_dec, std_pos_dec)
        gaussian_neg_dec = dist.Normal(mean_neg_dec, std_neg_dec)
    else:
        mean_pos_dec, mean_neg_dec = mean2_dec, mean1_dec
        std_pos_dec, std_neg_dec = std2_dec, std1_dec
        gaussian_pos_dec = dist.Normal(mean_pos_dec, std_pos_dec)
        gaussian_neg_dec = dist.Normal(mean_neg_dec, std_neg_dec)
    # gaussian_pos_enc = dist.Normal(mean1_enc, std1_enc)
    # gaussian_neg_enc = dist.Normal(mean2_enc, std2_enc)
    # gaussian_pos_dec = dist.Normal(mean1_dec, std1_dec)
    # gaussian_neg_dec = dist.Normal(mean2_dec, std2_dec)
    
    # 计算测试数据与正常样本的余弦相似度
    test_enc = F.cosine_similarity(F.normalize(model(x_test)[0], p=2, dim=1), norm_enc.unsqueeze(0), dim=1)
    test_dec = F.cosine_similarity(F.normalize(model(x_test)[1], p=2, dim=1), norm_dec.unsqueeze(0), dim=1)
    # 使用解码器和编码器分别预测
    y_pred_enc = torch.where(gaussian_pos_enc.log_prob(test_enc) > gaussian_neg_enc.log_prob(test_enc), 1, 0)
    y_pred_dec = torch.where(gaussian_pos_dec.log_prob(test_dec) > gaussian_neg_dec.log_prob(test_dec), 1, 0)
    # 投票预测
    diff_enc = torch.abs(gaussian_pos_enc.log_prob(test_enc)-gaussian_neg_enc.log_prob(test_enc))
    diff_dec = torch.abs(gaussian_pos_dec.log_prob(test_dec)-gaussian_neg_dec.log_prob(test_dec))
    y_pred = torch.where(diff_enc > diff_dec, y_pred_enc, y_pred_dec)
    return y_pred,y_pred_enc,y_pred_dec

初始训练模型  
用一小部分数据来训练模型  
训练的数据来自于 online_x_train

In [12]:
# 在线学习只使用20%的数据进行训练, 80%的数据用于在线学习和更新
online_x_train, online_x_test, online_y_train, online_y_test = train_test_split(x_train, y_train, test_size=0.8, random_state=42)
# 创建张量数据集
train_dataset = TensorDataset(online_x_train, online_y_train)
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# 创建一个损失函数
criterion = CRCLoss(device, tem)
# 输入维度
input_dim = x_train.shape[1]
# 创建自编码器模型
model = AE(input_dim).to(device)
# 创建优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 调整模型进入训练模式
model.train()

AE(
  (encoder): Sequential(
    (0): Linear(in_features=196, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
  )
  (decoder): Sequential(
    (0): ReLU()
    (1): Linear(in_features=64, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=196, bias=True)
  )
)

## 训练模型

In [13]:
########################## 初始化训练模型 ##########################
for epoch in tqdm(range(epochs)):
    
    #  j是批次索引，data是一个元组，包含输入和标签
    for _ ,data in enumerate(train_loader,0):
        # 获取输入和标签
        inputs, labels = data # inputs.shape = (128, 206), labels.shape = (128,)
        # 将标签移动到设备上
        labels = labels.to(device)
        # 优化器梯度清零
        optimizer.zero_grad()
        # 前向传播
        enc_features, dec_features = model(inputs.to(device))     
        # 计算损失
        loss=criterion(enc_features, labels)+criterion(dec_features, labels)
        # 反向传播
        loss.backward()
        # 更新权重
        optimizer.step()
   


  0%|          | 0/800 [00:00<?, ?it/s]

In [14]:
#  将所有数据移动到设备上
# x_train = x_train.to(device)
# x_test = x_test.to(device)
online_x_train, online_y_train  = online_x_train.to(device), online_y_train.to(device)
# 克隆当前轮次的训练数据和测试数据
x_train_this_epoch, x_test_left_epoch, y_train_this_epoch, y_test_left_epoch = online_x_train.clone(), online_x_test.clone().to(device), online_y_train.clone(), online_y_test.clone()     

在线学习阶段

In [15]:
# 在线训练的过程，需要不断的测试新输入的数据，每次来一个新的数据，就需要测试一次，预测出标签，同时做一个累积，每当累积到一定数量的数据时，就需要微调模型
while len(x_test_left_epoch) > 0:
    # 如果剩余的数据少于采样间隔，则将所有剩余的数据作为这一轮的测试数据
    if len(x_test_left_epoch) < sample_interval:
        x_test_this_epoch = x_test_left_epoch.clone()
        x_test_left_epoch.resize_(0)
        y_test_this_epoch = y_test_left_epoch.clone()
        y_test_left_epoch.resize_(0)
        
    # 否则，从剩余的数据中采样一部分作为这一轮的测试数据
    else:
        x_test_this_epoch = x_test_left_epoch[:sample_interval].clone()
        x_test_left_epoch = x_test_left_epoch[sample_interval:]
        y_test_this_epoch = y_test_left_epoch[:sample_interval].clone()
        y_test_left_epoch = y_test_left_epoch[sample_interval:]
    
    # 据训练集里的正常样本的特征得出一个平均正常特征
    ## online_y_train == 0是正常样本的标签
    ## online_x_train[(online_y_train == 0)].shape   torch.Size([11190, 206]) 之所以要加squeeze()，是为了防止避免下面的情况[1,11190,206]
    ## 论文里有提到，即采用编码器的输出，也采用解码器的输出
    normal_enc = torch.mean(F.normalize(model(online_x_train[(online_y_train == 0).squeeze()])[0], p=2, dim=1), dim=0)
    normal_dec = torch.mean(F.normalize(model(online_x_train[(online_y_train == 0).squeeze()])[1], p=2, dim=1), dim=0)
    # 预测标签
    predict_label,_,_ = predict(normal_enc, normal_dec, x_train_this_epoch, y_train_this_epoch, x_test_this_epoch, model)
    # 评估准确性
    evaluate(y_test_this_epoch, predict_label)
    
    # 随机翻转
    num_flips = int(flip_percent * len(predict_label))
    shuffle_index = np.random.choice(len(predict_label), num_flips, replace=False)
    # 翻转标签
    flip_label = predict_label.clone()
    flip_label[shuffle_index] = 1 - flip_label[shuffle_index]
    flip_label = flip_label.to(device)
    # 更新数据集
    x_train_this_epoch = torch.cat((x_train_this_epoch, x_test_this_epoch), 0)
    y_train_this_epoch = torch.cat((y_train_this_epoch, flip_label), 0)
    
    train_ds=TensorDataset(x_train_this_epoch, y_train_this_epoch)
    train_dl=DataLoader(train_ds, batch_size=bs, shuffle=True)
    
    for epoch in tqdm(range(epoch_online)):
        for _ ,data in enumerate(train_dl,0):
            inputs, labels = data
            labels = labels.to(device)
            optimizer.zero_grad()
            enc_features, dec_features = model(inputs.to(device))
            loss = criterion(enc_features, labels) + criterion(dec_features, labels)
            loss.backward()
            optimizer.step()
        

Confusion matrix
[[ 869    4]
 [ 441 1470]]
Accuracy  0.8401580459770115
Precision  0.9972862957937585
Recall  0.7692307692307693
F1 score  0.8685376661742984


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 900    5]
 [ 429 1450]]
Accuracy  0.8441091954022989
Precision  0.9965635738831615
Recall  0.7716870675891432
F1 score  0.8698260347930414


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 875    6]
 [ 421 1482]]
Accuracy  0.8466235632183908
Precision  0.9959677419354839
Recall  0.7787703625853915
F1 score  0.8740784429371866


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 881   10]
 [ 345 1548]]
Accuracy  0.8724856321839081
Precision  0.993581514762516
Recall  0.8177496038034865
F1 score  0.8971312662996233


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 888   14]
 [ 277 1605]]
Accuracy  0.8954741379310345
Precision  0.9913526868437307
Recall  0.8528161530286928
F1 score  0.9168808911739503


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 900    7]
 [ 312 1565]]
Accuracy  0.8854166666666666
Precision  0.9955470737913485
Recall  0.8337773042088439
F1 score  0.9075094230211656


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 854   19]
 [ 294 1617]]
Accuracy  0.8875718390804598
Precision  0.9883863080684596
Recall  0.8461538461538461
F1 score  0.9117564138708768


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 869   10]
 [ 335 1570]]
Accuracy  0.8760775862068966
Precision  0.9936708860759493
Recall  0.8241469816272966
F1 score  0.9010043041606887


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 877   23]
 [ 255 1629]]
Accuracy  0.9001436781609196
Precision  0.9860774818401937
Recall  0.8646496815286624
F1 score  0.9213800904977375


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 853    8]
 [ 303 1620]]
Accuracy  0.8882902298850575
Precision  0.995085995085995
Recall  0.8424336973478939
F1 score  0.9124190368910167


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 869   15]
 [ 312 1588]]
Accuracy  0.8825431034482759
Precision  0.990642545227698
Recall  0.8357894736842105
F1 score  0.9066514416214673


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 893   12]
 [ 305 1574]]
Accuracy  0.8861350574712644
Precision  0.9924337957124842
Recall  0.8376796168174561
F1 score  0.9085137085137085


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 859   10]
 [ 298 1617]]
Accuracy  0.889367816091954
Precision  0.9938537185003073
Recall  0.8443864229765013
F1 score  0.9130434782608695


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 898    9]
 [ 323 1554]]
Accuracy  0.8807471264367817
Precision  0.9942418426103646
Recall  0.8279168886521044
F1 score  0.9034883720930232


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 848    7]
 [ 321 1608]]
Accuracy  0.882183908045977
Precision  0.9956656346749226
Recall  0.833592534992224
F1 score  0.90744920993228


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 909   16]
 [ 284 1575]]
Accuracy  0.8922413793103449
Precision  0.9899434318038969
Recall  0.8472296933835395
F1 score  0.9130434782608695


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 872    5]
 [ 319 1588]]
Accuracy  0.8836206896551724
Precision  0.9968612680477087
Recall  0.8327215521761929
F1 score  0.9074285714285715


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 889    7]
 [ 324 1564]]
Accuracy  0.8811063218390804
Precision  0.9955442393380013
Recall  0.8283898305084746
F1 score  0.9043076033535704


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 865   10]
 [ 350 1559]]
Accuracy  0.8706896551724138
Precision  0.9936265137029955
Recall  0.8166579360921948
F1 score  0.8964922369177688


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 855    7]
 [ 314 1608]]
Accuracy  0.884698275862069
Precision  0.9956656346749226
Recall  0.8366285119667014
F1 score  0.909245122985581


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 940    6]
 [ 317 1521]]
Accuracy  0.8839798850574713
Precision  0.9960707269155207
Recall  0.8275299238302503
F1 score  0.9040118870728083


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 849    8]
 [ 342 1585]]
Accuracy  0.8742816091954023
Precision  0.9949780288763339
Recall  0.8225220550077841
F1 score  0.9005681818181818


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 854   14]
 [ 337 1579]]
Accuracy  0.8739224137931034
Precision  0.9912115505335845
Recall  0.8241127348643006
F1 score  0.8999715018523796


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 907   10]
 [ 317 1550]]
Accuracy  0.8825431034482759
Precision  0.9935897435897436
Recall  0.8302088912694162
F1 score  0.904581266413773


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 894   11]
 [ 353 1526]]
Accuracy  0.8692528735632183
Precision  0.9928432010409889
Recall  0.8121341138903673
F1 score  0.8934426229508197


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 862    2]
 [ 332 1588]]
Accuracy  0.8800287356321839
Precision  0.9987421383647799
Recall  0.8270833333333333
F1 score  0.9048433048433049


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 844   12]
 [ 325 1603]]
Accuracy  0.8789511494252874
Precision  0.9925696594427245
Recall  0.8314315352697096
F1 score  0.9048828676263054


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 899   12]
 [ 338 1535]]
Accuracy  0.8742816091954023
Precision  0.9922430510665805
Recall  0.819540843566471
F1 score  0.8976608187134503


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 866   10]
 [ 346 1562]]
Accuracy  0.8721264367816092
Precision  0.9936386768447837
Recall  0.8186582809224319
F1 score  0.8977011494252873


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 863    6]
 [ 362 1553]]
Accuracy  0.867816091954023
Precision  0.9961513790891597
Recall  0.8109660574412533
F1 score  0.8940702360391479


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 901   15]
 [ 334 1534]]
Accuracy  0.8746408045977011
Precision  0.9903163331181407
Recall  0.8211991434689507
F1 score  0.8978636230611647


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 824    7]
 [ 351 1602]]
Accuracy  0.8714080459770115
Precision  0.9956494717215661
Recall  0.8202764976958525
F1 score  0.8994946659180236


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 874   15]
 [ 344 1551]]
Accuracy  0.8710488505747126
Precision  0.9904214559386973
Recall  0.8184696569920844
F1 score  0.8962727535394395


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 945    9]
 [ 340 1490]]
Accuracy  0.8746408045977011
Precision  0.9939959973315544
Recall  0.8142076502732241
F1 score  0.8951637128266747


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 875   12]
 [ 324 1573]]
Accuracy  0.8793103448275862
Precision  0.9924290220820189
Recall  0.8292040063257775
F1 score  0.903503733486502


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 893   15]
 [ 296 1580]]
Accuracy  0.8882902298850575
Precision  0.9905956112852664
Recall  0.8422174840085288
F1 score  0.9104004609622587


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 825    9]
 [ 312 1638]]
Accuracy  0.884698275862069
Precision  0.994535519125683
Recall  0.84
F1 score  0.9107589658048374


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 860    9]
 [ 368 1547]]
Accuracy  0.8645833333333334
Precision  0.9942159383033419
Recall  0.8078328981723237
F1 score  0.8913857677902621


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 883    9]
 [ 328 1564]]
Accuracy  0.8789511494252874
Precision  0.9942784488239034
Recall  0.8266384778012685
F1 score  0.9027417027417027


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 904    7]
 [ 373 1500]]
Accuracy  0.8635057471264368
Precision  0.9953550099535501
Recall  0.800854244527496
F1 score  0.8875739644970414


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 867    4]
 [ 361 1552]]
Accuracy  0.8688936781609196
Precision  0.9974293059125964
Recall  0.8112911657083115
F1 score  0.8947823580282502


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 868    8]
 [ 350 1558]]
Accuracy  0.8714080459770115
Precision  0.9948914431673053
Recall  0.8165618448637316
F1 score  0.8969487622337363


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 912    6]
 [ 333 1533]]
Accuracy  0.8782327586206896
Precision  0.9961013645224172
Recall  0.8215434083601286
F1 score  0.9004405286343612


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 884   12]
 [ 349 1539]]
Accuracy  0.8703304597701149
Precision  0.9922630560928434
Recall  0.8151483050847458
F1 score  0.8950276243093923


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 913   14]
 [ 354 1503]]
Accuracy  0.867816091954023
Precision  0.990771259063942
Recall  0.8093699515347335
F1 score  0.8909306461173682


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 863   13]
 [ 338 1570]]
Accuracy  0.8739224137931034
Precision  0.9917877447883765
Recall  0.8228511530398323
F1 score  0.8994557433400172


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 897   10]
 [ 363 1514]]
Accuracy  0.8660201149425287
Precision  0.9934383202099738
Recall  0.8066062866275973
F1 score  0.8903263745957072


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 895    7]
 [ 318 1564]]
Accuracy  0.8832614942528736
Precision  0.9955442393380013
Recall  0.8310308182784272
F1 score  0.9058789458441935


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 878   14]
 [ 321 1571]]
Accuracy  0.8796695402298851
Precision  0.9911671924290221
Recall  0.830338266384778
F1 score  0.9036525740580961


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[ 897    5]
 [ 274 1608]]
Accuracy  0.8997844827586207
Precision  0.9969001859888407
Recall  0.8544102019128587
F1 score  0.9201716738197425


  0%|          | 0/1 [00:00<?, ?it/s]

Confusion matrix
[[354   2]
 [116 601]]
Accuracy  0.8900279589934762
Precision  0.9966832504145937
Recall  0.8382147838214784
F1 score  0.9106060606060606


  0%|          | 0/1 [00:00<?, ?it/s]

# 评估模型（在测试集上训练）

In [16]:
normal_enc = torch.mean(F.normalize(model(online_x_train[(online_y_train == 0).squeeze()])[0], p=2, dim=1), dim=0)
normal_dec = torch.mean(F.normalize(model(online_x_train[(online_y_train == 0).squeeze()])[1], p=2, dim=1), dim=0)
x_test = x_test.to(device)
y_pred,_,_ = predict(normal_enc, normal_dec, x_train_this_epoch,y_train_this_epoch , x_test , model)
evaluate(y_test, y_pred)

Confusion matrix
[[33766  3234]
 [ 5061 40271]]
Accuracy  0.8992493805567702
Precision  0.9256637168141593
Recall  0.8883570105003088
F1 score  0.9066267433614372
