In [1]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, [3]))
# print('using GPU %s' % ','.join(map(str, [3])))

import torch
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from thop import profile, clever_format

import csv
import time
import numpy as np
import json
from datetime import datetime
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
plt.rc('font',family='Times New Roman') 

from option import opt
from loadData import data_pipe
from loadData.dataAugmentation import dataAugmentation
from models import cnns, vision_transformer, mamba, S2ENet, FusAtNet
from models.MS2CANet import pymodel
from models.CrossHL import CrossHL
from models.HCTNet import HCTNet
from models.DSHFNet import DSHF
from models.MIViT import MMA

from utils import trainer, tester

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args = opt.get_args()
# args.dataset_name = "PaviaU"
args.dataset_name = "Houston_2013"


args.backbone = "vit"   
args.backbone = "cnn"   
# args.backbone = "mamba" 

# args.backbone = "MS2CANet"
# args.backbone = "S2ENet"    
# args.backbone = "FusAtNet"  
# args.backbone = "CrossHL"  
# args.backbone = "HCTNet"
# args.backbone = "MIViT"  
args.backbone = "DSHFNet"

args.pca = True

if args.backbone == "FusAtNet":
    args.epochs = 1000
    args.patch_size = 11
    args.components = 20
    args.learning_rate = 0.000005

elif  args.backbone == "CrossHL":
    args.epochs = 200
    args.patch_size = 11
    args.batch_size = 64
    args.pca = False
    args.learning_rate = 0.0005

elif args.backbone == "MS2CANet":
    args.epochs = 100
    args.patch_size = 11
    args.batch_size = 64
    args.learning_rate = 0.001
    args.components = 20    # Houston

elif args.backbone == "S2ENet":
    args.epochs = 128
    args.batch_size = 64
    args.patch_size = 7
    args.components = 20    # Houston
    args.learning_rate = 0.001

elif args.backbone == "HCTNet":
    args.epochs = 100
    args.batch_size = 64
    args.patch_size = 11
    args.components = 20    # Houston
    args.learning_rate = 0.001
    # args.components = 30    # Trento_train2
    
elif args.backbone == "DSHFNet":
    args.epochs = 500
    args.batch_size = 64
    args.patch_size = 6
    args.pca = False
    args.learning_rate = 5e-4
    args.weight_decay = 0
    args.gamma = 0.9

elif args.backbone == "MIViT":
    args.epochs = 500
    # args.epochs = 11
    args.patch_size = 8
    args.batch_size = 64
    args.learning_rate = 1e-4
    args.weight_decay = 0.001
    args.gamma = 0.9
    args.fusion = 'TTOA'
    args.pred_flag = 'o_fuse'
    
else:
    args.epochs = 100
    args.batch_size = 64
    args.patch_size = 11
    # args.patch_size = 12
    # args.randomCrop = 8
    args.learning_rate = 0.001
    args.components = 15    

args.randomCrop = 11
print("args.backbone", args.backbone)

args.backbone DSHFNet


In [3]:
# data_pipe.set_deterministic(seed = 666)
args.print_data_info = True
args.data_info_start = 1
args.show_gt = False
args.remove_zero_labels = True
args.split_type = "disjoint"
# transform = dataAugmentation(args.randomCrop)
transform = None


# create dataloader
if args.dataset_name in args.SD:
    args.train_ratio = 0.1
    args.path_data = "/home/leo/DatasetSMD"
    img1, train_gt, test_gt, data_gt = data_pipe.get_data(args)
elif args.dataset_name in args.MD:
    args.train_ratio = 1
    args.path_data = "/home/leo/DatasetMMF"
    img1, img2, train_gt, test_gt, data_gt = data_pipe.get_data(args)


if args.backbone in args.MMISO or args.backbone in args.MMIMO:
    print("mutlisacle multimodality")
    # 在这直接输出多尺度的图像
    train_dataset = data_pipe.HyperXMM(img1, data2=img2, gt=train_gt, 
                                    transform=None, patch_size=args.patch_size, 
                                    remove_zero_labels=args.remove_zero_labels)
    test_dataset = data_pipe.HyperXMM(img1, data2=img2, gt=test_gt, 
                                    transform=None, patch_size=args.patch_size, 
                                    remove_zero_labels=args.remove_zero_labels)
    
    height, wigth, data1_bands = train_dataset.data1.shape
    height, wigth, data2_bands = train_dataset.data2.shape
    print("data1", train_dataset.data1.shape, "data2", train_dataset.data2.shape)


# elif args.backbone in args.SSSM:
#     print("singlescale singlemodality")
#     transform = dataAugmentation(args.randomCrop)
#     train_dataset = data_pipe.HyperX(img1, gt=train_gt, transform=transform, patch_size=args.patch_size, 
#                             remove_zero_labels=args.remove_zero_labels)
#     test_dataset = data_pipe.HyperX(img1, gt=test_gt, transform=None, patch_size=args.patch_size, 
#                             remove_zero_labels=args.remove_zero_labels)
    
#     height, wigth, data1_bands = train_dataset.data1.shape
#     print("data1", train_dataset.data1.shape)
    

elif args.backbone in args.SSISO or args.backbone in args.SMIMO or args.backbone in args.SMISO:
    print("singlescale multimodality")
    train_dataset = data_pipe.HyperX(img1, data2=img2, gt=train_gt, 
                                    transform=transform, patch_size=args.patch_size, 
                                    remove_zero_labels=args.remove_zero_labels)
    test_dataset = data_pipe.HyperX(img1, data2=img2, gt=test_gt, 
                                    transform=None, patch_size=args.patch_size, 
                                    remove_zero_labels=args.remove_zero_labels)
    
    height, wigth, data1_bands = train_dataset.data1.shape
    height, wigth, data2_bands = train_dataset.data2.shape
    print("data1", train_dataset.data1.shape, "data2", train_dataset.data2.shape)


train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

class_num = np.max(train_gt)
# print(class_num, train_gt.shape, len(train_loader.dataset))

pca is not used
split_type:  disjoint train_ratio:  1
split_type:  disjoint train_ratio:  1
split_type:  disjoint train_ratio:  1
print_data_info : ---->
class 1 	 198 	 1053
class 2 	 190 	 1064
class 3 	 192 	 505
class 4 	 188 	 1056
class 5 	 186 	 1056
class 6 	 182 	 143
class 7 	 196 	 1072
class 8 	 191 	 1053
class 9 	 193 	 1059
class 10 	 191 	 1036
class 11 	 181 	 1054
class 12 	 192 	 1041
class 13 	 184 	 285
class 14 	 181 	 247
class 15 	 187 	 473
total     	 2832 	 12197
mutlisacle multimodality
data1 (367, 1923, 144) data2 (367, 1923, 1)


In [4]:
if args.backbone in args.MMISO or args.backbone in args.MMIMO:
    for data11, data12, data13, data21, data22, data23, label in train_loader:
        print("x.shape, y.shape", data11.shape, data12.shape, data13.shape)
        print("x.shape, y.shape", data11.dtype, data12.dtype, data13.dtype)

        print("x.shape, y.shape", data21.shape, data22.shape, data23.shape)
        print("x.shape, y.shape", data21.dtype, data22.dtype, data23.dtype)
        break
elif args.backbone in args.SSISO or args.backbone in args.SMISO or args.backbone in args.SMIMO:
    for data11, data12, label in train_loader:
        print("x.shape, y.shape", data11.shape, data12.shape, label.shape)
        print("x.shape, y.shape", data11.dtype, data12.dtype, label.dtype)
        break
# elif args.backbone in args.SSSM:
#     for x, z in train_loader:
#         print("x.shape, y.shape", x.shape, z.shape)
#         break

x.shape, y.shape torch.Size([64, 144, 6, 6]) torch.Size([64, 144, 12, 12]) torch.Size([64, 144, 18, 18])
x.shape, y.shape torch.float32 torch.float32 torch.float32
x.shape, y.shape torch.Size([64, 1, 6, 6]) torch.Size([64, 1, 12, 12]) torch.Size([64, 1, 18, 18])
x.shape, y.shape torch.float32 torch.float32 torch.float32


In [5]:
# print((data21[0, 0, :, :].shape))
# print((data22[0, 0, 3:9, 3:9].shape))
# print((data23[0, 0, 6:12, 6:12].shape))
# print(data21.shape)

# # data21[20, 0, :, :], data22[20, 0, 3:9, 3:9], data23[20, 0, 6:12, 6:12]

# path

In [6]:
args.result_dir = os.path.join("/home/leo/Multimodal_Classification/MyMultiModal/result",
                    datetime.now().strftime("%m-%d-%H-%M-") + args.backbone)
print(args.result_dir)

# 加载已有权重路径
# args.result_dir = "/home/liuquanwei/code/DMVL_joint_MNDIS/results_final/08-09-17-05-vit_D8"
if not os.path.exists(args.result_dir):
    os.mkdir(args.result_dir)
with open(args.result_dir + '/args.json', 'w') as fid:
    json.dump(args.__dict__, fid, indent=2)

/home/leo/Multimodal_Classification/MyMultiModal/result/03-25-17-51-DSHFNet


# model

In [10]:
if args.backbone == "cnn":
    model = cnns.Model_base(args.components).to(args.device)
    args.feature_dim = 512
    # args.feature_dim = 2048
    super_head = cnns.FDGC_head(args.feature_dim, class_num=class_num).to(args.device)
    params = list(super_head.parameters())  + list(model.parameters())

elif args.backbone == "vit":
    model = vision_transformer.vit_hsi(args.components, args.randomCrop).to(args.device)
    # encoder = vision_transformer.vit_small(args.components, args.randomCrop).to(args.device)
    args.feature_dim = 126
    super_head = cnns.FDGC_head(args.feature_dim, class_num=class_num).to(args.device)
    params = list(super_head.parameters())  + list(model.parameters())

elif args.backbone == "mamba":
    model = mamba.Vim(
        dim=64,  # Dimension of the transformer model
        # heads=8,  # Number of attention heads
        dt_rank=32,  # Rank of the dynamic routing matrix
        dim_inner=64,  # Inner dimension of the transformer model
        d_state=64,  # Dimension of the state vector
        num_classes=10,  # Number of output classes
        image_size=args.randomCrop,  # Size of the input image
        patch_size=4,  # Size of each image patch
        channels=args.components,  # Number of input channels
        dropout=0.1,  # Dropout rate
        depth=4,  # Depth of the transformer model
    ).to(args.device)
    args.feature_dim = 64
    super_head = cnns.FDGC_head(args.feature_dim, class_num=class_num).to(args.device)
    params = list(super_head.parameters())  + list(model.parameters())

elif args.backbone == "MS2CANet":
    FM = 64
    para_tune = False
    if args.dataset_name == "Houston_2013":
        para_tune = True                # para_tune 这个参数对于 Houston 的提升有两个点！！
    model = pymodel.pyCNN(FM=FM, NC=data1_bands, \
                            Classes=class_num, para_tune=para_tune).to(args.device)
    params = model.parameters()

elif args.backbone == 'S2ENet':
    model = S2ENet.S2ENet(data1_bands, data2_bands, class_num, \
                            patch_size=args.patch_size).to(args.device)
    params = model.parameters()

elif args.backbone == "FusAtNet":
    model = FusAtNet.FusAtNet(data1_bands, data2_bands, class_num).to(args.device)
    params = model.parameters()

elif args.backbone == "CrossHL":
    FM = 16
    model = CrossHL.CrossHL_Transformer(FM, data1_bands, data2_bands, class_num, \
                                args.patch_size).to(args.device)
    params = model.parameters()

elif args.backbone == "HCTNet":
    model = HCTNet(in_channels=1, num_classes=class_num).to(args.device)
    params = model.parameters()

elif args.backbone == "DSHFNet":
    model = DSHF(l1=data1_bands, l2=data2_bands, \
                         num_classes=class_num, encoder_embed_dim=64).to(args.device)
    params = model.parameters()

elif args.backbone == "MIViT":
    model = MMA.MMA(l1=data1_bands, l2=data2_bands, patch_size=args.patch_size, \
                num_patches=64, num_classes=class_num,
                encoder_embed_dim=64, decoder_embed_dim=32, en_depth=5, \
                en_heads=4, de_depth=5, de_heads=4, mlp_dim=8, dropout=0.1, \
                emb_dropout=0.1,fusion=args.fusion).to(args.device)
    params = model.parameters()

else:
    raise NotImplementedError("No models")
print("backbone: ", args.backbone)

backbone:  DSHFNet


In [11]:
# flops, params = profile(net, inputs=(torch.randn(1, 
#                                                 args.components, 
#                                                 args.patch_size, 
#                                                 args.patch_size).cuda(),))
# flops, params = clever_format([flops, params])
# print('# Model Params: {} FLOPs: {}'.format(params, flops))

# # contra_head
# flops, params = profile(contra_head, inputs=(torch.randn(1, args.feature_dim).cuda(),))
# flops, params = clever_format([flops, params])
# print('# Model Params: {} FLOPs: {}'.format(params, flops))

# # super_head
# flops, params = profile(super_head, inputs=(torch.randn(1, args.feature_dim).cuda(),))
# flops, params = clever_format([flops, params])
# print('# Model Params: {} FLOPs: {}'.format(params, flops))

In [12]:
criterion = torch.nn.CrossEntropyLoss()

if args.backbone == "CrossHL":
	optimizer = optim.Adam(params, lr=args.learning_rate, weight_decay=args.weight_decay)
	scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.9)
	
elif args.backbone == "S2ENet":
	optimizer = optim.Adam(params, lr=args.learning_rate)
	scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, \
	                                                eta_min=1e-5, last_epoch=-1)
	
elif args.backbone == "DSHFNet" or args.backbone == "MIViT":
	optimizer = optim.Adam(params, lr=args.learning_rate, weight_decay=args.weight_decay)
	scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.epochs // 10, gamma=args.gamma)

else:
	# optimizer = optim.Adam(params, lr=args.learning_rate, weight_decay=1e-6, amsgrad=True)
	optimizer = optim.Adam(params, lr=args.learning_rate, weight_decay=args.weight_decay)
	# scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = args.epochs, \
	#                                                 eta_min = 1e-5, last_epoch = -1)
	# scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=60, T_mult=3, \
															#    eta_min=1e-5)
	scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)

# 训练前加载权重

In [13]:
# args.resume = os.path.join(args.result_dir, "joint_oa_model.pth")
if args.resume != '':
    checkpoint = torch.load(args.resume)
    model.load_state_dict(checkpoint['base'], strict=False)
    epoch_start = checkpoint['epoch'] + 1
    print('Loaded from: {}'.format(args.resume))
else:
    epoch_start = 0

In [14]:
best_loss = 999
best_acc = 0
train_losses = []

total_train_time = time.time()

for epoch in range(epoch_start, args.epochs):
    if args.backbone in args.SSISO:
        train_loss, train_accuracy, test_accuracy, train_time\
                            = trainer.train_SSISO(epoch, model, super_head, criterion, train_loader, test_loader, optimizer, args)
    elif args.backbone in args.SMIMO:
        train_loss, train_accuracy, test_accuracy, train_time\
                            = trainer.train_SMIMO(epoch, model, criterion, train_loader, test_loader, optimizer, args)
    elif args.backbone in args.SMISO:
        train_loss, train_accuracy, test_accuracy, train_time \
                            = trainer.train_SMISO(epoch, model, criterion, train_loader, test_loader, optimizer, args)
    elif args.backbone in args.MMISO:
        train_loss, train_accuracy, test_accuracy, train_time \
                            = trainer.train_MMISO(epoch, model, criterion, train_loader, test_loader, optimizer, args)
    elif args.backbone in args.MMIMO:
        train_loss, train_accuracy, test_accuracy, train_time \
                            = trainer.train_MMIMO(epoch, model, criterion, train_loader, test_loader, optimizer, args)
    else:
        raise NotImplementedError("NO this model")
    train_losses.append(train_loss)

    scheduler.step()
    
    with open(os.path.join(args.result_dir, "log.csv"), 'a+', encoding='gbk') as f:
        row=[["epoch", epoch, 
            "loss", train_loss, 
            "train_accuracy", round(train_accuracy, 2),
            "test_accuracy", round(test_accuracy, 2),
            "train_time", round(train_time, 2),
            '\n']]
        write=csv.writer(f)
        for i in range(len(row)):
            write.writerow(row[i])

    if train_loss < best_loss:
        best_loss = train_loss
        torch.save({
                "epoch": epoch,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict()}, 
                os.path.join(args.result_dir, "model_loss.pth"))

    if best_acc < test_accuracy:
        best_acc = test_accuracy
        torch.save({
                "epoch": epoch,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict()}, 
                os.path.join(args.result_dir, "model_acc.pth"))
        
total_train_time = time.time() - total_train_time

torch.save({
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict()}, 
os.path.join(args.result_dir, "model_last.pth"))

Train Epoch: [0/500] Loss: 2.6147 TRA: 23.7641 TEA: 22.2186 TIME: 17.1800
Train Epoch: [1/500] Loss: 2.5542 TRA: 46.6455 TEA: 40.6329 TIME: 15.7900
Train Epoch: [2/500] Loss: 2.3607 TRA: 51.4477 TEA: 42.6908 TIME: 15.3600
Train Epoch: [3/500] Loss: 2.2985 TRA: 53.8136 TEA: 42.6088 TIME: 15.4600
Train Epoch: [4/500] Loss: 2.2943 TRA: 55.2260 TEA: 42.8466 TIME: 15.4400
Train Epoch: [5/500] Loss: 2.3404 TRA: 60.5226 TEA: 51.6848 TIME: 15.5600
Train Epoch: [6/500] Loss: 2.1852 TRA: 65.0777 TEA: 51.7504 TIME: 15.2000
Train Epoch: [7/500] Loss: 2.1091 TRA: 66.5254 TEA: 57.5879 TIME: 15.2600
Train Epoch: [8/500] Loss: 2.0768 TRA: 67.9732 TEA: 57.7437 TIME: 15.5600
Train Epoch: [9/500] Loss: 2.1103 TRA: 70.1271 TEA: 64.5487 TIME: 15.4400
Train Epoch: [10/500] Loss: 2.0711 TRA: 75.0353 TEA: 64.3519 TIME: 15.3900
Train Epoch: [11/500] Loss: 2.0100 TRA: 82.3799 TEA: 69.7385 TIME: 15.3500
Train Epoch: [12/500] Loss: 2.0311 TRA: 89.1949 TEA: 74.7889 TIME: 15.3400
Train Epoch: [13/500] Loss: 1.9696 

# 保存最优结果

In [15]:
# args.result_dir = "/home/leo/Multimodal_Classification/MyMultiModal/result/03-25-12-26-MIViT"
args.resume = os.path.join(args.result_dir, "model_acc.pth")
# args.resume = os.path.join(args.result_dir, "model_loss.pth")
if args.resume != '':
    checkpoint = torch.load(args.resume)
    model.load_state_dict(checkpoint['model'], strict=False)
    epoch = checkpoint['epoch'] + 1
    print('Loaded from: {}'.format(args.resume))
else:
    epoch_start = 0

tic = time.time()

if args.backbone in args.SSISO:
    test_losses, test_preds, correct, targets = \
        tester.linear_test_SSISO(model, super_head, criterion, test_loader, args)
if args.backbone in args.SMIMO:
    test_losses, test_preds, correct, targets = \
        tester.linear_test_SMIMO(model, criterion, test_loader, args)
elif args.backbone in args.SMISO:
    test_losses, test_preds, correct, targets = \
        tester.linear_test_SMISO(model, criterion, test_loader, args)
elif args.backbone in args.MMISO:
    test_losses, test_preds, correct, targets = \
        tester.linear_test_MMISO(model, criterion, test_loader, args)
elif args.backbone in args.MMIMO:
    test_losses, test_preds, correct, targets = \
        tester.linear_test_MMIMO(model, criterion, test_loader, args)
classification, kappa = tester.get_results(test_preds, targets)

test_time = time.time() - tic

with open(os.path.join(args.result_dir, "log_final.csv"), 'a+', encoding='gbk') as f:
    row=[["training",
        "\nepoch", epoch, 
        "\ndata_name = " + str(args.dataset_name),
        "\nbatch_size = " + str(args.batch_size),
        "\npatch_size = " + str(args.patch_size),
        "\nnum_components = " + str(args.components),
        '\n' + classification,
        "\nkappa = \t\t\t" + str(round(kappa, 4)),
        "\ntotal_time = ", round(total_train_time, 2),
        '\ntest time = \t' + str(round(test_time, 2)),
        ]]
    write=csv.writer(f)
    for i in range(len(row)):
        write.writerow(row[i])

Loaded from: /home/leo/Multimodal_Classification/MyMultiModal/result/03-25-17-51-DSHFNet/model_acc.pth
Accuracy: 10924/12197 (89.56%)



# 绘图

In [None]:
# args.resume = os.path.join(args.result_dir, "joint_model_oa.pth")
# if args.resume != '':
#     checkpoint = torch.load(args.resume)
#     net.load_state_dict(checkpoint['base'], strict=False)
#     epoch_start = checkpoint['epoch'] + 1
#     print('Loaded from: {}'.format(args.resume))
# else:
#     epoch_start = 0

# _, groundTruth = data_reader.load_data(args.dataset_name, path_data=args.path_data, type_data="Houston")

# args.remove_zero_labels = False
# args.train_ratio = 1
# img1, img2, train_gt, test_gt, data_gt = data_pipe_D.get_data(args)

# # create dataloader
# data_dataset = HyperX(img1, img2, data_gt, transform=None, 
#                        patch_size=args.patch_size, 
#                        flip_augmentation=False, 
#                         radiation_augmentation=False, mixture_augmentation=False, 
#                         remove_zero_labels=args.remove_zero_labels)
# data_loader = torch.utils.data.DataLoader(data_dataset, batch_size=args.batch_size, \
#                                            shuffle=False, drop_last=False)

# KNNCA, KNNOA, KNNAA, KNNKA = KNN.test(net, memory_loader, data_loader, class_num, \
#                                       groundTruth, args, visulation=True)

# 画 loss 曲线

In [None]:
# # args.plot_loss_curve = True
# if args.plot_loss_curve:
#     fig = plt.figure()
#     plt.plot(range(args.epochs), train_losses, color='blue')
#     plt.plot(range(args.epochs), loss_contras, color='red')
#     plt.plot(range(args.epochs), loss_supers, color='pink')

#     # test_counter 是先定义好的，test——losses 训练一轮记录一次
#     # plt.scatter(test_counter, test_losses, color='red')
#     plt.legend(['train_losses', 'loss_contras', 'loss_supers'], loc='upper right')
#     plt.xlabel('number of training examples seen')
#     plt.ylabel('negative log likelihood loss')
#     plt.show()

# TSNE

In [None]:
# with torch.no_grad():
#     for idx_o, (data, _, target) in enumerate(test_loader):
#         target = target - 1
#         data = data.to(args.device)
        
#         output = super_head(net(data))
#         target = target.to(args.device)

#         for idx, _ in enumerate(output.cpu().numpy()):
#             # print(output.cpu().numpy().shape)
#             if idx == 0 and idx_o == 0:
#                 list_output = output.cpu().numpy()[idx]
#                 list_target = target.cpu().numpy()[idx]
#             # print(list_output.shape, list_target.shape)

#             if idx < 100:
#                 list_output = np.vstack((list_output, output.cpu().numpy()[idx]))
#                 list_target = np.append(list_target, target.cpu().numpy()[idx])

In [None]:
# tsne = TSNE()

# out = tsne.fit_transform(list_output)
# fig, ax = plt.subplots()
# ax.set_axis_off()
# ax.xaxis.set_visible(False)
# ax.yaxis.set_visible(False)
# # fig.set_size_inches(label.shape[1] * scale / dpi, label.shape[0] * scale / dpi)
# plt.gca().xaxis.set_major_locator(plt.NullLocator())
# plt.gca().yaxis.set_major_locator(plt.NullLocator())
# plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)

# label = ["Broccoli-weeds-1","Broccoli-weeds-2","Fallow","Fallow-rough-plow",
#          "Fallow-smooth","Stubble","Celery","Grapes-untrained",
#          "Soil-senesced-develop","Corn-weeds","Lettuce-4wk","Lettuce-5wk",
#          "Lettuce-6wk","Lettuce-7wk","Vinyard-untrained",
#          "Vinyard-vertical-trellis"]
# for i in range(class_num):
#     indices = list_target  == i
#     x, y = out[indices].T
#     # plt.scatter(x, y, s=5, label=label[i])
#     plt.scatter(x, y, s=5, label=str(i+1))
# plt.legend(loc=2,bbox_to_anchor=(1.05,1.0),borderaxespad = 0., fontsize=12)
# plt.savefig(args.result_dir  + '/tsneFull' + '.png', dpi=400, bbox_inches="tight")