In [1]:
from tqdm import tqdm
from scipy.special import comb  # 排列组合中的组合公式
from sklearn.metrics import hamming_loss
from sklearn.decomposition import PCA
from sklearn.metrics import f1_score
import numpy as np
import torch.utils.data as Data
import torch
import time

In [104]:
def f_k(dataSet, Labels, d, q, device):
    """
    :param dataSet: 某一个样本的特征集
    :param Labels: 某一个样本的标签集
    :param d: 样本的维度，即一个样本含有的特征数
    :param q: 标签的维度，即标签集中标签的个数
    :return: 返回的是fk(x,y)
    """
    F_k = []
    for l in range(d):
        for j in range(q):
            if Labels[j] == 1:
                try:
                    F_k.append(dataSet[l])
                except:
                    # print(dataSet)
                    # print(l, dataSet)
                    raise IndexError

            else:
                F_k.append(
                    torch.tensor(
                        0, dtype=torch.float, requires_grad=True, device=device
                    )
                )

    for j1 in range(q - 1):
        for j2 in range(j1 + 1, q):
            y_j1 = Labels[j1]
            y_j2 = Labels[j2]
            if y_j1 == 1 and y_j2 == 1:
                F_k.append(
                    torch.tensor(
                        1, dtype=torch.float, requires_grad=True, device=device
                    )
                )
            else:
                F_k.append(
                    torch.tensor(
                        0, dtype=torch.float, requires_grad=True, device=device
                    )
                )
            if y_j1 == 1 and y_j2 == 0:
                F_k.append(
                    torch.tensor(
                        1, dtype=torch.float, requires_grad=True, device=device
                    )
                )
            else:
                F_k.append(
                    torch.tensor(
                        0, dtype=torch.float, requires_grad=True, device=device
                    )
                )
            if y_j1 == 0 and y_j2 == 1:
                F_k.append(
                    torch.tensor(
                        1, dtype=torch.float, requires_grad=True, device=device
                    )
                )
            else:
                F_k.append(
                    torch.tensor(
                        0, dtype=torch.float, requires_grad=True, device=device
                    )
                )
            if y_j1 == 0 and y_j2 == 0:
                F_k.append(
                    torch.tensor(
                        1, dtype=torch.float, requires_grad=True, device=device
                    )
                )
            else:
                F_k.append(
                    torch.tensor(
                        0, dtype=torch.float, requires_grad=True, device=device
                    )
                )
    # print(len(F_k))
    return torch.tensor(F_k, requires_grad=True, device=device)

In [105]:
def basic_rand_labels(len):
    """
    变成辅助函数
    #关于这个函数的for循环的嵌套次数，Y标签集中，有几个标签就嵌套几层。（y1,y2,...,yq）
    :return: 返回的是q维的标签集的所有组合情况
    """
    """
    randLabels=[]
    for i in range(2):
        randLabels.append([i])
    return randLabels
    """
    randLabels = []
    for i in range(2 ** len):
        randLabel = np.zeros(shape=len)
        for j in range(len):
            randLabel[len - j - 1] = i % 2
            i = i // 2
            if i == 0:
                break
        print(randLabel)
        randLabels.append(randLabel)
    np.save("./basic_rand_Labels.npy", np.array(randLabels))

In [106]:
def supported_rand_labels(train_label):
    """
    这是个辅助函数，用来生成support_rand_Labels
    """
    """
    randLabels=[]
    for i in range(2):
        randLabels.append([i])
    return randLabels
    """
    # for _, y in train_iter:
    labels = train_label.tolist()
    label_set = []
    for label in labels:
        if label in label_set:
            continue
        else:
            label_set.append(label)
    randLables = np.array(label_set)
    print(label_set)
    np.save("./supported_rand_labels.npy", randLables)
    print("finish")

**用于产生randLabels，输入的mode为basic或者是supported**

In [107]:
def generate_rand_Labels(mode):
    if mode == "supported":
        randLabels = np.load("./supported_rand_labels.npy")
        randLabels = randLabels.tolist()
        return randLabels
    elif mode == "basic":
        randLabels = np.load("./basic_rand_Labels.npy")
        randLabels = randLabels.tolist()
        return randLabels

In [108]:
def Z(dataSet, d, q, Lambda, device, randLabels):  # 对于某一个样本的Z
    """
    :param dataSet: 某一个样本的特征集
    :param d: 样本的维度，即特征的个数
    :param q: 标签集的个数
    :param Lambda: Lambda是一个1*K维向量
    :return: 归一化范数，以及所有标签集组合的对应f_k
    """
    Z = torch.tensor(0.0, requires_grad=True, device=device)
    for i in range(len(randLabels)):
        fk = f_k(dataSet, randLabels[i], d, q, device)
        temp_sum = torch.exp((Lambda * fk).sum())
        Z = Z + temp_sum
    return Z

In [109]:
# 求目标函数l(Lambda|D)
def obj_func(DataSets, Labels, thegma, Lambda, randLabels, device):
    """
    :param q:标签集的维度
    :param DataSets:所有训练样本的特征集
    :param Labels:所有训练样本的标签集
    :param thegma:自己给定的参数值，2**-6,2**-5,2**-4,2**-3,2**-2,2**-1,2**1,2**2,2**3,2**4,2**5,2**6逐个取值，参数寻优
    :return:目标函数，以及待定参数Lambda
    """
    samples = len(DataSets)
    d = len(DataSets[0])
    q = len(Labels[0])
    temp_sum = torch.tensor(0.0, requires_grad=True, device=device)
    for i in range(samples):
        fk = f_k(DataSets[i], Labels[i], d, q, device)
        z = Z(DataSets[i], d, q, Lambda, device, randLabels)
        temp_sum = temp_sum + (Lambda * fk).sum() - torch.log2(z)
        temp_div = (
            (Lambda * Lambda) / (2 * thegma ** 2)
        ).sum()  # temp_div=sum(Lambda**2/(2*thegma**2))

    l = -(temp_sum - temp_div)
    return l  # 求解l的最大值，可以转化为求-l的最小值问题

In [110]:
def Train(
    objfunc,
    train_iter,
    test_iter,
    test_target,
    num_epochs,
    optimizer,
    thegma,
    Lambda,
    device,
    randLabels,
):

    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
        Lambda.to(device)
        for X, y in tqdm(train_iter, total=len(train_iter), ascii=True, desc="train"):
            
            # def closure():
            #     loss = objfunc(X, y, thegma, Lambda, randLabels, device)
            #     optimizer.zero_grad()
            #     loss.backward()
            #     return loss
            
            X = X.to(device)
            y = y.to(device)
            loss = objfunc(X, y, thegma, Lambda, randLabels, device)
            train_l_sum += loss
            optimizer.zero_grad()
            loss.backward()
            # optimizer.step(closure)
            optimizer.step()
            optimizer.zero_grad()
            n += y.shape[0]
            batch_count += 1
        hamming, f1_macro, f1_micro, acc = test(
            test_iter, test_target, Lambda, d, q, randLabels
        )
        print(
            "epoch %d, time %.1f sec, train_loss %.2f, hamming %.2f, f1_macro %.2f, f1_micro %.2f, subset acc %.4f acc"
            % (epoch + 1, time.time() - start, train_l_sum/batch_count, hamming, f1_macro, f1_micro, acc)
        )
        # if epoch == 30:
        #     optimizer = torch.optim.Adam([Lambda], lr=0.002,betas=(0.9, 0.999), eps=1e-08)
        # elif epoch == 50:
        #     optimizer = torch.optim.Adam([Lambda], lr=0.001,betas=(0.9, 0.999), eps=1e-08)
        # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
    #     best_acc = 0.5791
    #     temp = torch.tensor([-1.3607e-02, -5.3982e-02, -1.6628e-02,  2.5557e-02,  4.0150e-02,
    #      1.1365e-01, -6.7485e-02, -4.5096e-02,  7.0721e-02,  3.9045e-02,
    #      1.6875e-01, -1.0130e-01, -7.0357e-02,  1.7599e-01,  2.8182e-02,
    #      8.5130e-03, -6.4480e-02, -1.5023e-02, -6.6069e-02, -3.4371e-03,
    #     -1.4052e+01, -1.4002e+01, -1.4399e+01, -1.4142e+01, -1.4136e+01,
    #     -1.3876e+01, -1.4346e+01, -1.4041e+01, -1.4566e+01, -1.3794e+01,
    #     -1.4097e+01, -1.4222e+01, -1.4476e+01, -1.4082e+01, -1.4156e+01,
    #     -1.3897e+01, -1.4380e+01, -1.4143e+01, -1.4241e+01, -1.3817e+01,
    #     -1.4295e+01, -1.4292e+01, -1.4363e+01, -1.3645e+01],
    #    requires_grad=True)
    #     with torch.no_grad():
    #         if acc <= best_acc:
    #             Lambda.data.copy_(temp.data)
    #         else:
    #             best_acc = acc
    #             temp.data.copy_(Lambda.data)
    return Lambda

In [111]:
def test(Test_iter, test_target, Lambda, d, q, randLabels):
    with torch.no_grad():
        preLabels = []
        for X in tqdm(Test_iter, total=len(Test_iter), ascii=True, desc="test"):
            preLabels.append(Pred(X[0][0], Lambda, d, q, randLabels))
        preLabels = np.array(preLabels)
        hamming = hamming_loss(test_target, preLabels)  # 汉明损失，越低越好
        f1_macro = f1_score(test_target, preLabels, average="macro")  # 0.6
        f1_micro = f1_score(test_target, preLabels, average="micro")

        temp = preLabels == test_target
        acc_list = []
        for data in temp:
            acc = 1
            for x in data:
                acc *= x
            acc_list.append(acc)
        acc = sum(acc_list) / len(acc_list)
    return hamming, f1_macro, f1_micro, acc

In [112]:
def Pred(test_data, Lambda, d, q, randLabels):
    bestLabels = None
    z = Z(test_data, d, q, Lambda, device, randLabels)
    bestP = -1.0
    for i in range(len(randLabels)):
        fk = f_k(test_data, randLabels[i], d, q, device)
        temp_P = torch.exp((Lambda * fk).sum()) / z
        if temp_P > bestP:
            bestP = temp_P
            bestLabels = randLabels[i]
    return np.array(bestLabels)

In [113]:
"""
    dataSet=[0.5,0.1,0.3]
    Labels=[1,0]
    f=f_k(dataSet,Labels,3,2)
    print(f)
    """
# 训练集的处理
train_data_path = "./traindataReuters_all.npy"
train_label_path = "./trainlabelReuters_all.npy"
test_data_path = "./testdataReuters_all.npy"
test_label_path = "./testlabelReuters_all.npy"
# train_data_path = "./yeast_train_data.npy"
# train_label_path = "./yeast_train_label.npy"
# test_data_path = "./yeast_test_data.npy"
# test_label_path = "./yeast_test_label.npy"

train_data = np.load(train_data_path)
train_data = train_data[:, :]
train_label = np.load(train_label_path)
train_label = train_label[:, :]
test_data = np.load(test_data_path)
test_data = test_data[:, :]
test_label = np.load(test_label_path)
test_target = test_label[:, :]

train_target = torch.tensor(train_label, dtype=torch.float, requires_grad=True)

# 主成分降维
pca = PCA(n_components=5)  # 保留5个主成分
train_data = torch.tensor(pca.fit_transform(train_data), requires_grad=True)
# train_data = torch.tensor(train_data, requires_grad=True)
# print(train_data[0])

d = len(train_data[0])
q = len(train_target[0])
K = int(d * q + 4 * comb(q, 2))
thegma = 2 ** (1)  # 参数寻优，-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6



In [114]:
train_data.shape

torch.Size([10497, 5])

In [115]:
train_label.shape

(10497, 4)

In [116]:
# 训练数据集
device = torch.device("cpu")
batch_size = 32
dataset = Data.TensorDataset(train_data, train_target)
# 测试集的处理

test_data = Data.TensorDataset(torch.tensor((pca.fit_transform(test_data))))
# test_data = Data.TensorDataset(torch.tensor(test_data, requires_grad=True))
# 随机读取小批量
train_iter = Data.DataLoader(dataset, batch_size, shuffle=True)
test_iter = Data.DataLoader(test_data, 1, shuffle=False)
#
# supported_rand_labels(train_label)
# basic_rand_labels(q)
mode = "basic"
randLabels = generate_rand_Labels(mode)

In [117]:
test_data

<torch.utils.data.dataset.TensorDataset at 0x2615917fd08>

In [119]:
# Lambda = torch.tensor(
#     np.random.rand(K), requires_grad=True, device=device, dtype=torch.float
# )  # 初始点 
# Lambda =torch.tensor([ 1.1060e-01,  1.2891e-01, -3.2866e-02,  1.1960e-01, -5.9192e-02,
#         -6.9003e-02,  2.1254e-01, -3.6856e-02,  2.3860e-01,  1.4376e-01,
#          3.2683e-01, -4.0302e-01,  9.6339e-02,  1.1465e-01,  1.3273e-01,
#         -1.4790e-01,  6.8920e-02,  1.0594e-01,  2.8951e-02, -1.8953e-01,
#          3.3756e-02,  1.8381e-02,  7.7972e-03,  1.3185e-01,  1.0120e-01,
#          2.3428e-02, -1.6687e-01, -7.5685e-02,  2.3136e-03,  1.0234e-01,
#          2.8742e-02,  5.6123e-02,  5.3597e-02, -4.2912e-03, -2.4144e-01,
#         -4.1165e-02,  1.5462e-02, -5.8708e-02, -6.6897e-02,  1.7060e-01,
#          4.9510e-02,  1.1143e-01,  1.4395e-01,  1.1102e-01,  8.3381e-02,
#         -3.3308e-02,  1.5899e-01,  1.4644e-01, -7.0203e-02,  5.3325e-02,
#          1.1100e-01,  1.5264e-03, -2.2287e-03, -5.9585e-02,  1.1739e-01,
#         -1.2248e-01,  2.3899e-02, -3.8095e-02, -5.1161e-02, -1.2378e-01,
#          6.5217e-02,  1.0395e-01, -2.1152e-01,  2.2577e-01, -6.4882e-03,
#          7.8403e-02, -7.7026e-02,  2.0222e-01,  8.1410e-02,  7.4142e-02,
#         -1.4224e-01,  3.6684e-02,  1.0633e-01,  9.9144e-02, -1.0876e-01,
#          3.8941e-01,  1.0463e-01,  6.1642e-02,  5.3251e-02, -1.3368e-01,
#          2.0914e-02, -1.7902e-02,  4.5028e-02, -1.3880e-02,  6.8724e-02,
#          3.8901e-02, -5.4046e-02,  6.8841e-02,  1.2888e-01,  1.1804e-01,
#         -1.6146e-01,  1.5285e-01,  8.5081e-02,  1.3078e-01,  5.3087e-03,
#          2.2880e-01,  1.9101e-01,  1.9246e-01, -5.4693e-02,  1.1620e-02,
#          8.0848e-02,  5.8956e-02, -9.5426e-02, -1.8407e-01,  1.5768e-01,
#          1.6044e-01, -1.4881e-01, -9.7151e-02, -4.3380e-02, -3.3622e-03,
#          5.6034e-02, -2.0823e-01,  1.0806e-01,  1.2957e-01,  3.2643e-02,
#         -1.2202e-02,  1.1217e-01,  2.1918e-01,  5.8355e-06, -1.6637e-01,
#          1.9173e-01,  1.0840e-01, -8.7356e-02, -1.1701e-01, -1.1002e-02,
#          1.6890e-02,  1.4050e-01, -2.7408e-01,  1.1422e-02,  4.1725e-02,
#         -1.2099e-01,  5.8066e-01, -5.6434e-03,  1.2136e-03, -3.4210e-02,
#          6.8016e-01, -1.6435e-02, -4.4914e-02,  2.0760e-03,  4.6706e-01,
#         -3.0609e-02, -5.2306e-02,  2.1807e-02, -5.3149e-02,  3.6104e-02,
#          4.0167e-02,  1.7140e-01,  2.3728e-02,  4.4643e-02,  3.3928e-02,
#         -1.1360e-02,  4.8301e-02,  8.4145e-03,  9.0057e-03,  9.4930e-02,
#         -2.2399e-01, -3.7842e-03, -2.9024e-02,  2.6847e-02, -1.9178e-02,
#         -3.1477e-02,  1.6695e-02, -2.4708e-04,  8.7243e-02,  1.1352e-01,
#          7.5380e-02,  1.0744e-01,  1.1491e-01,  5.0565e-02,  2.8792e-02,
#         -5.1602e-02, -2.5803e-01,  2.6961e-01,  2.8810e-01, -8.9398e-02,
#          7.9644e-02,  2.3039e-01,  2.7814e-01, -9.8623e-02,  5.1343e-03,
#          1.5740e-01,  1.7090e-01,  1.7396e-02, -1.3507e-01,  1.5925e-01,
#          1.9609e-01,  3.8724e-02, -1.3261e-01,  1.2253e-01,  1.3421e-01,
#          7.3776e-02,  1.2411e-01,  1.1991e-01,  1.2462e-01,  3.2155e-01,
#          2.1812e-01,  1.2465e-01,  1.3106e-01,  1.4936e-01,  4.5185e-01,
#          8.9912e-02,  4.1353e-02,  3.7427e-01,  2.0763e-01, -1.0323e-02,
#          4.7101e-02,  1.5794e-01,  2.1201e-01,  2.6795e-02,  4.8086e-02,
#          7.3452e-02,  1.5899e-01,  2.7094e-01,  1.9950e-01,  6.4753e-02,
#         -8.4400e-02,  5.9152e-02,  2.3613e-03,  1.1030e-01, -6.7860e-02,
#          2.5184e-03,  5.2043e-02, -6.2118e-02, -2.0245e-01, -1.6326e-02,
#          2.6731e-02,  4.2211e-01,  6.3297e-01, -5.2351e-02, -2.8716e-02,
#          2.7612e-01,  5.2174e-01,  9.7902e-02,  8.7654e-02,  1.0300e-01,
#          7.2343e-03,  6.5939e-02,  1.3197e-01,  4.0094e-01, -4.1652e-02,
#          5.3036e-02,  3.2880e-02,  4.8731e-01,  3.0857e-01,  4.0441e-02,
#          4.8622e-02,  1.7460e-01,  5.2026e-01,  5.1831e-02, -4.6035e-02,
#          4.8584e-01,  2.2022e-01,  2.0922e-02,  4.5327e-02,  1.7822e-01,
#          4.1559e-01, -1.9413e-02,  3.0169e-04, -1.7722e-01, -7.8817e-02,
#          1.4249e-01,  1.6098e-01, -3.4642e-01,  1.6393e-01,  7.2104e-02,
#          6.4422e-02,  1.2601e-01, -2.8712e-01, -8.2173e-02, -3.0818e-02,
#          2.9887e-01, -3.3247e-02,  5.9212e-02,  4.7733e-02,  2.9471e-01,
#         -1.1061e-01, -5.1678e-02, -1.1294e-01,  1.5822e-01,  3.2727e-01,
#          9.2444e-02,  8.8324e-02, -2.7437e-02,  2.3568e-01, -1.9502e-02,
#          2.1329e-02,  2.9853e-02,  6.3013e-01, -4.4068e-02, -1.5697e-01,
#         -1.1179e-02,  4.5506e-03,  4.6263e-02,  6.3008e-02,  2.9863e-02,
#          5.4568e-02, -8.7568e-02, -8.3089e-02,  1.4185e-01, -6.3543e-02,
#          1.0549e-01,  1.5917e-01,  3.0460e-01, -3.3359e-01,  2.0656e-01,
#          1.9506e-01,  3.0794e-01, -4.1998e-01,  3.8008e-02,  4.1834e-02,
#          2.6383e-01,  5.3604e-02,  1.3181e-01,  1.2294e-01,  4.3972e-01,
#          1.1709e-01,  4.7870e-02,  2.2732e-02,  4.9303e-02, -4.7343e-02,
#         -1.9305e-02,  4.8193e-02, -7.8233e-02,  4.7806e-02,  9.7495e-05,
#         -5.0551e-02,  2.5404e-01, -5.3597e-02,  1.3268e-01,  2.2429e-01,
#          2.2779e-02,  3.1488e-01,  3.3134e-03, -2.2103e-02,  3.0937e-01,
#         -2.1294e-01,  1.5814e-02,  4.4241e-02,  1.8270e-01,  1.4547e-01,
#          1.5248e-01,  5.4174e-02,  1.9122e-01,  2.2837e-01,  1.0581e-01,
#          1.7039e-01,  3.0470e-01, -2.1793e-01, -2.6159e-02, -5.7892e-02,
#          3.4461e-01, -5.9059e-01, -1.1255e-01, -3.2743e-02,  1.0925e-01,
#         -7.4039e-02,  1.2495e-01,  1.6548e-01, -7.8420e-02,  3.3073e-01,
#          9.3158e-02,  8.4230e-02, -9.2579e-02,  4.6507e-01,  5.4323e-02,
#          1.0333e-01,  8.7364e-02,  6.1410e-01,  1.7659e-01,  2.2725e-01,
#         -2.4356e-01,  1.3661e-01,  6.1566e-02,  5.9560e-02,  6.2763e-02,
#          4.9676e-01,  1.6572e-01,  1.5559e-01, -2.5082e-01, -1.1748e-01,
#          1.7051e-01,  1.1768e-01,  1.8481e-01, -4.3656e-01,  7.5439e-02,
#          9.4334e-02,  3.3729e-02, -1.6214e-02,  4.4877e-03,  4.3966e-02,
#         -2.2959e-02,  9.1131e-02,  1.0819e-01,  1.7791e-01,  1.3177e-01,
#          1.7466e-01, -1.5815e-02, -2.5539e-02,  3.1724e-01, -2.1843e-01,
#          1.0146e-01,  1.1613e-01,  1.8322e-01, -1.3279e-01,  1.1094e-01,
#          1.2434e-01, -1.9698e-01,  1.8375e-01, -9.0627e-02, -7.5274e-02,
#          1.4163e-01,  6.4854e-01, -2.9890e-02, -1.0156e+00, -1.0947e+00,
#         -1.8915e-02, -1.0928e+00, -1.0490e+00, -1.3483e+00, -1.7844e+00,
#         -8.5044e-01, -6.9808e-01, -1.0177e+00, -9.4430e-01, -7.1589e-01,
#         -7.0497e-01, -1.0243e+00, -1.2275e+00, -1.1581e+00, -5.1944e-01,
#         -1.2679e+00, -9.4034e-01, -8.7963e-01, -1.6425e+00, -8.5794e-01,
#         -1.2049e+00], requires_grad=True)


Lambda = torch.tensor([-1.1988e-02, -5.2650e-02, -1.8131e-02,  2.4598e-02,  3.9671e-02,
         1.2995e-01, -6.7338e-02, -4.9823e-02,  7.0689e-02,  2.7791e-02,
         1.7456e-01, -1.0386e-01, -6.9977e-02,  1.6798e-01,  2.6399e-02,
         7.9110e-03, -6.1223e-02, -1.1949e-02, -6.6362e-02, -5.5660e-03,
        -1.4056e+01, -1.3988e+01, -1.4395e+01, -1.4157e+01, -1.4139e+01,
        -1.3862e+01, -1.4340e+01, -1.4056e+01, -1.4579e+01, -1.3779e+01,
        -1.4088e+01, -1.4238e+01, -1.4488e+01, -1.4080e+01, -1.4152e+01,
        -1.3899e+01, -1.4379e+01, -1.4144e+01, -1.4238e+01, -1.3818e+01,
        -1.4294e+01, -1.4291e+01, -1.4360e+01, -1.3648e+01],
       requires_grad=True)
num_epochs = 10
# optimizer = torch.optim.LBFGS([Lambda], lr=0.001)
optimizer = torch.optim.Adam([Lambda], lr=0.001,betas=(0.9, 0.999), eps=1e-08)

Lambda = Train(
    obj_func,
    train_iter,
    test_iter,
    test_target,
    num_epochs,
    optimizer,
    thegma,
    Lambda,
    device,
    randLabels,
)

train:   7%|6         | 23/329 [00:03<00:40,  7.58it/s]


KeyboardInterrupt: 

In [None]:
print(Lambda)
# [ 3.5124e-02, -8.5055e-02,  3.3453e-02, -1.7653e-02, -7.1441e-02,
#         -1.5750e-01, -7.4084e-03,  1.1770e-01,  8.3856e-02,  2.6715e-01,
#         -1.0405e-02,  4.3997e-02, -2.9593e-02,  3.3307e-01,  8.2613e-02,
#         -2.1751e-02, -2.1640e-01, -1.2033e-01, -3.2102e-02,  8.6221e-02,
#         -1.4092e+01, -1.4235e+01, -1.4270e+01, -1.3893e+01, -1.4207e+01,
#         -1.4120e+01, -1.4363e+01, -1.3801e+01, -1.4291e+01, -1.4036e+01,
#         -1.4195e+01, -1.3968e+01, -1.4245e+01, -1.4117e+01, -1.4324e+01,
#         -1.3804e+01, -1.4239e+01, -1.4123e+01, -1.4247e+01, -1.3882e+01,
#         -1.4161e+01, -1.4408e+01, -1.4324e+01, -1.3597e+01]
# 目前最好的参数

tensor([ 0.4219,  0.3179,  0.3100,  0.3262,  0.5201,  0.6068,  0.2907,  0.1044,
         0.3724,  0.2115,  0.3200,  0.4062,  0.0607,  0.1945,  0.2309,  0.3505,
         0.1622,  0.4449,  0.1262,  0.3358, -0.2075,  0.0766,  0.5853,  0.0017,
         0.6564,  0.5548, -0.1624,  0.6698,  0.5878,  0.4539,  0.5941,  0.3025,
         0.1586,  0.4095,  0.6774,  0.6503, -0.1377,  0.6092,  0.1168,  0.1416,
         0.2743,  0.5307,  0.8794,  0.3395], requires_grad=True)


In [None]:
hamming, f1_macro, f1_micro, acc = test(test_iter, test_target, Lambda, d, q, randLabels)
print(
    "hamming %.3f, f1_macro %.3f, f1_micro %.3f, subset acc %.3f"
    % (hamming, f1_macro, f1_micro, acc)
)

test: 100%|##########| 4500/4500 [00:23<00:00, 188.83it/s]

hamming 0.459, f1_macro 0.434, f1_micro 0.565, subset acc 0.030





hamming 0.209, f1_macro 0.155, f1_micro 0.208, subset acc 0.356,
测试值

In [2]:
Lambda = torch.tensor(
        [
            1.3684e-03,
            1.0080e-01,
            2.1870e-03,
            -1.9966e-02,
            9.5491e-03,
            -3.5565e-01,
            5.6930e-02,
            8.2445e-02,
            -9.0433e-03,
            7.7105e-02,
            5.3024e-02,
            -5.3038e-02,
            -7.3672e-02,
            1.1402e-01,
            -1.4662e-01,
            9.3514e-02,
            -4.7597e-02,
            -1.3051e-03,
            2.2363e-02,
            -2.6729e-02,
            -1.4027e01,
            -1.3943e01,
            -1.4388e01,
            -1.4196e01,
            -1.4172e01,
            -1.3829e01,
            -1.4311e01,
            -1.4092e01,
            -1.4597e01,
            -1.3747e01,
            -1.4072e01,
            -1.4278e01,
            -1.4439e01,
            -1.4082e01,
            -1.4181e01,
            -1.3875e01,
            -1.4323e01,
            -1.4137e01,
            -1.4269e01,
            -1.3813e01,
            -1.4222e01,
            -1.4327e01,
            -1.4379e01,
            -1.3603e01,
        ],
        requires_grad=True,
    )

In [3]:
torch.save(Lambda, "./Lambda/Lambda.pt")