In [None]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, [0]))
# print('using GPU %s' % ','.join(map(str, [0])))

import torch
from thop import profile, clever_format
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from sklearn.metrics import classification_report,cohen_kappa_score, accuracy_score
import numpy as np
import matplotlib.pyplot as plt

import json
import csv
from datetime import datetime
import time

from models import FDGC
from option import opt
from loadData import data_pipe, data_reader
from utlis import tools, trainer

In [None]:
args = opt.get_args()
# args.dataset_name = "IndianPines"
# args.dataset_name = "PaviaU"
# args.dataset_name = "LongKou"
# args.dataset_name = "HanChuan"
# args.dataset_name = "HongHu"
# args.dataset_name = "PaviaC"
# args.dataset_name = "Salinas"
args.dataset_name = "Dioni"

args.train_ratio = 0.1
args.split_type = "number"
args.path_data = r"C:\Users\jc962911\Project\datasets\HSI\\"
args.result_dir = args.path_head + 'results\\' + \
                     datetime.now().strftime("%Y%m%d-%H%M-D")
print(args.result_dir)

# create a new file
if not os.path.exists(args.result_dir):
    os.mkdir(args.result_dir)
with open(args.result_dir + '/args.json', 'w') as fid:
    json.dump(args.__dict__, fid, indent=2)

# Dataset

In [None]:
# data_pipe.set_deterministic(seed = 666)
args.print_data_info = False
args.show_gt = False
args.remove_zero_labels = True
train_loader, test_loader, train_label, test_label = data_pipe.get_data(args)
len(train_loader.dataset)

In [None]:
for i,j in train_loader:
    print(i.shape, j.shape)
    break

# Model

In [None]:
net = FDGC(input_channels=args.components, num_nodes=(np.max(test_label)+1)*args.num_nodes, 
           num_classes=np.max(test_label)+1, patch_size=args.patch_size).to(args.device)

flops, params = profile(net, inputs=(torch.randn(2, 1, args.components, args.patch_size, \
                                                 args.patch_size).cuda(),))
flops, params = clever_format([flops, params])
print('# Model Params: {} FLOPs: {}'.format(params, flops))

# criterion = nn.CrossEntropyLoss()
criterion = LabelSmoothingCrossEntropy(smoothing=args.lb_smooth)
# criterion = SoftTargetCrossEntropy()
# criterion = nn.MultiLabelSoftMarginLoss()
# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=[args.epochs // 2, (5 * args.epochs) // 6], gamma=0.1)

In [None]:
tic1 = time.time()
args.epochs = 50
train_losses, train_accuracy, test_accuracy, epoch_time, test_time = \
                        trainer.train(net, args.epochs, train_loader, \
                        test_loader, criterion, optimizer, scheduler, args)
train_time = time.time() - tic1

In [None]:
# args.resume = "/home/liuquanwei/code/FDGC/results/20230812-2110-FDGCF-D/best_model_loss.pth"
if args.resume != '':
    checkpoint = torch.load(args.resume)
    net.load_state_dict(checkpoint['model'])
    epoch_start = checkpoint['epoch'] + 1
    print('Loaded from: {}'.format(args.resume))
else:
   print("start new")

tic2 = time.time()
test_losses, test_preds, test_accuracy = tools.test(net, criterion, test_loader, args)
test_time = time.time() - tic2

In [None]:
args.plot_loss_curve = True
if args.plot_loss_curve:
    fig = plt.figure()
    plt.plot(range(args.epochs), train_losses, color='blue')
    plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
    plt.xlabel('number of training examples seen')
    plt.ylabel('negative log likelihood loss')
    plt.show()

# show results

In [None]:
y_pred_test = [j for i in test_preds for j in i]
classification = classification_report(test_label, y_pred_test, digits=4)
kappa = cohen_kappa_score(test_label, y_pred_test)
print(classification)

In [None]:
# 保存数据信息
f = open(args.result_dir + "/"+ args.dataset_name + '_results.txt', 'a+')
str_results = '\n ======================' \
            + "\nsamples_type = " + str(args.split_type) \
            + "\ntrain ratio = " + str(args.train_ratio) \
            + "\nbatch_size = " + str(args.batch_size) \
            + "\npatch_size = " + str(args.patch_size) \
            + "\nnum_components = " + str(args.components) \
            + '\n' + classification \
            + "kappa = \t\t" + str(kappa) \
            + '\ntrain time = ' + str(train_time) \
            + '\ntest time = ' + str(test_time) \
            + '\n'
            
f.write(str_results)
f.close()

# Visulization

In [None]:
# data_pipe.set_deterministic(seed = 666)
args.print_data_info = False
args.show_gt = False
args.remove_zero_labels = False
args.train_ratio = 0
data, data_gt = data_reader.load_data(args.dataset_name, path_data=args.path_data)
train_loader, data_loader, _, _ = data_pipe.get_data(args)
len(train_loader.dataset)

In [None]:
# args.resume = "/home/liuquanwei/code/FDGC/results/20230812-2110-FDGCF-D/best_model_loss.pth"
if args.resume != '':
    checkpoint = torch.load(args.resume)
    net.load_state_dict(checkpoint['model'])
    epoch_start = checkpoint['epoch'] + 1
    print('Loaded from: {}'.format(args.resume))
else:
   print("start new")

tic2 = time.time()
test_losses, test_preds, test_accuracy = tools.test(net, criterion, data_loader, args, data_gt, visulation=True)
test_time = time.time() - tic2