In [None]:
import sys
sys.path.append("./../")
from torch.nn import LayerNorm, Linear, Dropout, Softmax
from einops import rearrange, repeat
import ssl
import copy
from timm.models.layers import DropPath, trunc_normal_
from pathlib import Path
import re
import torch.backends.cudnn as cudnn
import record
import matplotlib.pyplot as plt
from torchsummary import summary
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, cohen_kappa_score
from operator import truediv
import math
from PIL import Image
import time
import torchvision.transforms.functional as TF
from torch.nn.parameter import Parameter
from sklearn.decomposition import PCA
from scipy.io import loadmat as loadmat
from scipy import io
import torch.utils.data as dataf
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch import einsum
import random
import numpy as np
from sklearn import svm
import pickle
import pandas as pd
import seaborn as sns
import os
cudnn.deterministic = True       
cudnn.benchmark = False     


In [None]:
from torch.nn import LayerNorm,Linear,Dropout,Softmax
import copy

def INF(B,H,W):

    return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)    

class HetConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,padding = None, bias = None,p = 64, g = 64):
        super(HetConv, self).__init__()
        
        self.ACmix = ACmix(in_channels, out_channels)
        self.pwc = nn.Conv2d(in_channels, out_channels, kernel_size=1,groups=p, stride = stride)

    def forward(self, x):
        return self.ACmix(x) + self.pwc(x)   

class MCrossAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.1, proj_drop=0.1):
        super().__init__()
        self.num_heads = num_heads    
        head_dim = dim // num_heads   
        self.scale = qk_scale or head_dim ** -0.5

        self.wq = nn.Linear(head_dim, dim , bias=qkv_bias)
        self.wk = nn.Linear(head_dim, dim , bias=qkv_bias)
        self.wv = nn.Linear(head_dim, dim , bias=qkv_bias)   
        self.proj = nn.Linear(dim * num_heads, dim) 
        self.proj_drop = nn.Dropout(proj_drop)
        self.position_embeddings = nn.Parameter(torch.randn(1, num_heads, 8 + 1, 64))
    
    def forward(self, x):

        B, N, C = x.shape
        q = self.wq(x[:, 0:1, ...].reshape(B, 1, self.num_heads, C // self.num_heads)).permute(0, 2, 1, 3)  
        k = self.wk(x.reshape(B, N, self.num_heads, C // self.num_heads)).permute(0, 2, 1, 3)  
        v = self.wv(x.reshape(B, N, self.num_heads, C // self.num_heads)).permute(0, 2, 1, 3) 
        v = v + self.position_embeddings
        attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 
        
        attn = attn.softmax(dim=-1)

        x = torch.einsum('bhij,bhjd->bhid', attn, v)
        x = ((x + v)[:, :, 0:1, ...]).transpose(1, 2)
        x = x.reshape(B, 1, C * self.num_heads)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class DWConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(DWConv, self).__init__()
        self.depthwise_conv = nn.Conv2d(512, 64, kernel_size = 3, groups= 64, padding=kernel_size//2)
        self.pointwise_conv = nn.Conv2d(512, 64, 1)
    
    def forward(self, x):
        return self.depthwise_conv(x) + self.pointwise_conv(x)

def position(H, W, is_cuda=True):
    if is_cuda:
        loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1)
        loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W)
    else:
        loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)
        loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)
    loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)
    return loc


def stride(x, stride):
    b, c, h, w = x.shape
    return x[:, :, ::stride, ::stride]

def init_rate_half(tensor):
    if tensor is not None:
        tensor.data.fill_(0.5)

def init_rate_0(tensor):
    if tensor is not None:
        tensor.data.fill_(0.)


class ACmix(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1):
        super(ACmix, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.head = head
        self.kernel_att = kernel_att
        self.kernel_conv = kernel_conv
        self.stride = stride
        self.dilation = dilation
        self.rate1 = torch.nn.Parameter(torch.Tensor(1))
        self.rate2 = torch.nn.Parameter(torch.Tensor(1))
        self.head_dim = self.out_planes // self.head

        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
        self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
        self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
        self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)

        self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2
        self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)
        self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)
        self.softmax = torch.nn.Softmax(dim=1)

        self.fc = nn.Conv2d(3*self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False)
        self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes, kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1, stride=stride)

        self.reset_parameters()
    
    def reset_parameters(self):
        init_rate_half(self.rate1)
        init_rate_half(self.rate2)
        kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)
        for i in range(self.kernel_conv * self.kernel_conv):
            kernel[i, i//self.kernel_conv, i%self.kernel_conv] = 1.
        kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)
        self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)
        self.dep_conv.bias = init_rate_0(self.dep_conv.bias)

    def forward(self, x):
        q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)
        scaling = float(self.head_dim) ** -0.5
        b, c, h, w = q.shape
        h_out, w_out = h//self.stride, w//self.stride

        pe = self.conv_p(position(h, w, x.is_cuda))

        q_att = q.view(b*self.head, self.head_dim, h, w) * scaling
        k_att = k.view(b*self.head, self.head_dim, h, w)
        v_att = v.view(b*self.head, self.head_dim, h, w)

        if self.stride > 1:
            q_att = stride(q_att, self.stride)
            q_pe = stride(pe, self.stride)
        else:
            q_pe = pe

        unfold_k = self.unfold(self.pad_att(k_att)).view(b*self.head, self.head_dim, self.kernel_att*self.kernel_att, h_out, w_out) 
        unfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att*self.kernel_att, h_out, w_out) 
        att = (q_att.unsqueeze(2)*(unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1)
        att = self.softmax(att)
        
        out_att = self.unfold(self.pad_att(v_att)).view(b*self.head, self.head_dim, self.kernel_att*self.kernel_att, h_out, w_out)
        out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out)

        f_all = self.fc(torch.cat([q.view(b, self.head, self.head_dim, h*w), k.view(b, self.head, self.head_dim, h*w), v.view(b, self.head, self.head_dim, h*w)], 1))
        f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])
        
        out_conv = self.dep_conv(f_conv)

        return self.rate1 * out_att + self.rate2 * out_conv
    
class Mlp(nn.Module):
    def __init__(self, dim):
        super(Mlp, self).__init__()
        self.fc1 = Linear(dim, 512)  
        self.fc2 = Linear(512, dim)  
        self.act_fn = nn.GELU()  
        self.dropout = Dropout(0.1)
        self.dwconv = DWConv(512, 64, 3)
        
        self._init_weights()
        
    def _init_weights(self):
        
        nn.init.xavier_uniform_(self.fc1.weight)   
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)
        
    def forward(self, x):       
        x = self.fc1(x)     
        x = self.act_fn(x)     
        x = self.dropout(x) 
        x1 = x.unsqueeze(1)
        x1 = x1.transpose(1, 3)
        x1 = self.dwconv(x1)
        x1 = x1.squeeze(-1)
        x1 = x1.transpose(1, 2)
        x1 = x1.cpu().detach().numpy()
        x1 = np.repeat(x1, 8, axis=2)
        x1 = torch.tensor(x1).to('cuda:0')
        x = x1 + x
        x = self.fc2(x)     
        x = self.dropout(x)     
        return x

class Block(nn.Module):
    def __init__(self, dim):
        super(Block, self).__init__()
        self.hidden_size = dim
        self.attention_norm = LayerNorm(dim, eps=1e-6)  

        self.ffn_norm = LayerNorm(dim, eps=1e-6)   
        self.ffn = Mlp(dim)    
        self.attn = MCrossAttention(dim = dim)   
        self.conv2 = nn.Sequential(
            nn.Conv2d(1,64,3,1,1),
            CBAM(64),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )

        self.depthwise = nn.Sequential(
            nn.Conv2d(64, 64, 3, 1, 1, groups= 64),
            CBAM(64),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.pointwise = nn.Sequential(
            nn.Conv2d(64, 1, 1),
            nn.BatchNorm2d(1),
            nn.ReLU()
        )
    def forward(self, x):
        h = x  
        x = self.attention_norm(x) 
        x = self.attn(x) 
        x = x + h
        
        h = x 
        x = self.ffn_norm(x) 
        x = self.ffn(x)
        x = x + h  
        return x

class TransformerEncoder(nn.Module):

    def __init__(self, dim, num_heads= 8, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0.1, attn_drop=0.1,
                 drop_path=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=False):
        super().__init__()    
        self.layer = nn.ModuleList()   
        self.encoder_norm = LayerNorm(dim, eps=1e-6)
        
        for _ in range(4):   
            layer = Block(dim)   
            self.layer.append(copy.deepcopy(layer))   
        self.skipcat = nn.ModuleList([])
        for _ in range(2):
            self.skipcat.append(nn.Conv2d(9, 9, [1, 2], 1, 0))

    def forward(self, x, mask = None):
        last_output = []
        nl = 0
        for layer_block in self.layer:
            last_output.append(x)
            if nl > 1:             
                 x = self.skipcat[nl-2](torch.cat([x.unsqueeze(3), last_output[nl-2].unsqueeze(3)], dim=3)).squeeze(3)
            x = layer_block(x)
            nl += 1
        encoded = self.encoder_norm(x)  
        return encoded[:,0]   

class NonLocalBlock(nn.Module):
    def __init__(self, in_channels, inter_channels=None):
        super(NonLocalBlock, self).__init__()

        if inter_channels is None:
            inter_channels = in_channels // 2

        self.theta = nn.Conv2d(in_channels, inter_channels, kernel_size=1, stride=1, padding=0)
        self.phi = nn.Conv2d(in_channels, inter_channels, kernel_size=1, stride=1, padding=0)
        self.g = nn.Conv2d(in_channels, inter_channels, kernel_size=1, stride=1, padding=0)

        self.operation_function = self._operation

        self.out_conv = nn.Conv2d(inter_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU(inplace=True)

    def _operation(self, theta, phi, g):
        b, c, h, w = theta.size() 
        theta = theta.view(b, c, h * w).permute(0, 2, 1) 
        phi = phi.view(b, c, h * w)              
        t = torch.matmul(theta, phi)  
        t = torch.softmax(t, dim=-1) 
        g = g.view(b, c, h * w).permute(0, 2, 1) 
        y = torch.matmul(t, g)        
        y = y.permute(0, 2, 1).contiguous().view(b, c, h, w) 
        y = self.relu(self.out_conv(y))
        return y  

    def forward(self, x):
        theta = self.theta(x)
        phi = self.phi(x) 
        g = self.g(x)
        y = self.operation_function(theta, phi, g)
        z = y + x
        return z

class ChannelAttention3D(nn.Module):
    def __init__(self, in_planes, ratio):
        super(ChannelAttention3D, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.max_pool = nn.AdaptiveMaxPool3d(1)
        self.fc = nn.Sequential(
            nn.Conv3d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv3d(in_planes // ratio, in_planes, 1, bias=False),
        )
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)
    
class SpatialAttention3D(nn.Module):
    def __init__(self, kernel_size):
        super(SpatialAttention3D, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv = nn.Conv3d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)

class CBAM3D(nn.Module):
    def __init__(self, in_planes, ratio, kernel_size):
        super(CBAM3D, self).__init__()
        self.channel_att = ChannelAttention3D(in_planes, ratio)
        self.spatial_att = SpatialAttention3D(kernel_size)

    def forward(self, x):
        out = self.channel_att(x) * x
        out = self.spatial_att(out) * out
        return out


class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = self.sigmoid(avg_out + max_out).view(x.size(0), x.size(1), 1, 1)
        return out

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=3):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(out)
        out = self.sigmoid(out)
        return out

class CBAM(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16, kernel_size=3):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(in_channels, reduction_ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        channel_out = self.channel_attention(x) * x
        spatial_out = self.spatial_attention(channel_out) * channel_out
        return spatial_out

class MFT(nn.Module):
    def __init__(self, FM, NC, NCLidar, Classes, HSIOnly):
        super(MFT, self).__init__()
        self.HSIOnly = HSIOnly
        self.conv5 = nn.Sequential(
            nn.Conv3d(1, 8, (9, 3, 3), padding=(0, 1, 1), stride=1),   
            CBAM3D(8, 4, 7),
            nn.BatchNorm3d(8),     
            nn.ReLU(),
        )   
        self.conv6 = nn.Sequential(
            HetConv(8 * (NC-8), FM*4,
                p = 1,
                g = (FM*4)//4 if (8 * (NC-8))%FM == 0 else (FM*4)//8,
                   ),     
            nn.BatchNorm2d(FM*4),
            nn.ReLU()
        )
        self.last_BandSize = NC//2//2//2
        self.lidarConv = nn.Sequential(
            nn.Conv2d(NCLidar,16,3,1,1),
            nn.BatchNorm2d(16),
            nn.GELU(),
            nn.Conv2d(16,32,3,1,1),
            nn.BatchNorm2d(32),
            nn.GELU(),
            nn.Conv2d(32,FM*4,3,1,1),
            nn.BatchNorm2d(64),
            nn.GELU(),
        )
        self.ca = TransformerEncoder(FM*4)        
        self.out3 = nn.Linear(FM*4 , Classes)     
        self.out2 = nn.Linear(FM*4 , 32)
        self.nlocal = NonLocalBlock(FM*4)
        self.dropout = nn.Dropout(0.1)      
        torch.nn.init.xavier_uniform_(self.out3.weight)    
        torch.nn.init.normal_(self.out3.bias, std=1e-6)    
        self.token_wA = nn.Parameter(torch.empty(1, 8, 64),     
                                     requires_grad=True) 
        torch.nn.init.xavier_normal_(self.token_wA)
       
        self.token_wV = nn.Parameter(torch.empty(1, 64, 64),
                                     requires_grad=True)  
        torch.nn.init.xavier_normal_(self.token_wV)       
        
        self.token_wA_L = nn.Parameter(torch.empty(1, 1, 64),
                                     requires_grad=True)  
        torch.nn.init.xavier_normal_(self.token_wA_L)  
        self.token_wV_L = nn.Parameter(torch.empty(1, 64, 64),
                                     requires_grad=True)  
        torch.nn.init.xavier_normal_(self.token_wV_L)
    
    def forward(self, x1, x2): 
        x1 = x1.reshape(x1.shape[0],-1,patchsize,patchsize)  
        x2 = x2.reshape(x2.shape[0],-1,patchsize,patchsize)  
        
        x1 = x1.unsqueeze(1)    
        x1 = self.conv5(x1)     
        x1 = x1.reshape(x1.shape[0],-1,patchsize,patchsize)  
        
        x1 = self.conv6(x1)   

        x2 = self.lidarConv(x2)
        x2 = self.nlocal(x2)
        x2 = x2.reshape(x2.shape[0],-1,patchsize**2)  

        x2 = x2.transpose(-1, -2)   
        wa_L = self.token_wA_L.expand(x1.shape[0],-1,-1)  
        wa_L = rearrange(wa_L, 'b h w -> b w h')  
        A_L = torch.einsum('bij,bjk->bik', x2, wa_L) 
        A_L = rearrange(A_L, 'b h w -> b w h')  
        A_L = A_L.softmax(dim=-1) 
        wv_L = self.token_wV_L.expand(x2.shape[0],-1,-1)  
        VV_L = torch.einsum('bij,bjk->bik', x2, wv_L) 
        x2 = torch.einsum('bij,bjk->bik', A_L, VV_L)  
        
        x1 = x1.flatten(2) 
        x1 = x1.transpose(-1, -2) 
              
        wa = self.token_wA.expand(x1.shape[0],-1,-1)  
        wa = rearrange(wa, 'b h w -> b w h')  
        A = torch.einsum('bij,bjk->bik', x1, wa)  
        A = rearrange(A, 'b h w -> b w h')  
        A = A.softmax(dim=-1)
        wv = self.token_wV.expand(x1.shape[0],-1,-1) 
        VV = torch.einsum('bij,bjk->bik', x1, wv) 
        T = torch.einsum('bij,bjk->bik', A, VV) 
        x = torch.cat((x2, T), dim = 1) 
        embeddings = x 
        x = self.ca(embeddings)  
        x = x.reshape(x.shape[0],-1)   
        out2 = x
        out3 = self.out3(x)
        return out3, out2

batchsize = 64 
patchsize = 11 
model = MFT(16, 144, 1, 15, False).to("cuda") 
summary(model, [(144,121),(1,121)], device = 'cuda')


In [None]:
from sklearn.model_selection import GridSearchCV

import warnings
warnings.filterwarnings("ignore", message="The least populated class in y has only .* members, which is less than n_splits=5.", category=UserWarning)

DATA2_List = []
os.environ["CUDA_VISIBLE_DEVICES"]="1"

datasetNames = ['Houston']

patchsize = 11   
batchsize = 64   
testSizeNumber = 500  
EPOCH = 300
BandSize = 1   
LR = 5e-4  
FM = 16     
HSIOnly = False
FileName = 'MFT'

w1 = 0.10

param_grid = {'C': [0.1, 1, 10, 100, 1000, 10000, 100000, 1000000]}
def AA_andEachClassAccuracy(confusion_matrix):
    counter = confusion_matrix.shape[0]     
    list_diag = np.diag(confusion_matrix)   
    list_raw_sum = np.sum(confusion_matrix, axis=1)     
    each_acc = np.nan_to_num(truediv(list_diag, list_raw_sum))     
    average_acc = np.mean(each_acc)    
    return each_acc, average_acc    

def reports (xtest,xtest2,ytest,name,model,iterNum):
    pred_y = np.empty((len(ytest)), dtype=np.float32)      
    number = len(ytest) // testSizeNumber

    for i in range(number):
        temp = xtest[i * testSizeNumber:(i + 1) * testSizeNumber, :, :] 
        temp = temp.cuda()  
        temp1 = xtest2[i * testSizeNumber:(i + 1) * testSizeNumber, :, :]
        temp1 = temp1.cuda()

        temp2, temp3 = model(temp,temp1)

        svm_classifier = pickle.load(open(name+'/best_model_HSIAMS_all_'+str(iterNum)+'.h5', 'rb'))
        pred_y[i * testSizeNumber:(i + 1) * testSizeNumber] = svm_classifier.predict(temp3.cpu().detach().numpy())
        del temp, temp1, temp2, temp3
    
    if (i + 1) * testSizeNumber < len(ytest):
        temp = xtest[(i + 1) * testSizeNumber:len(ytest), :, :]
        temp = temp.cuda()
        temp1 = xtest2[(i + 1) * testSizeNumber:len(ytest), :, :]
        temp1 = temp1.cuda()

        temp2, temp3 = model(temp,temp1)
        svm_classifier = pickle.load(open(name+'/best_model_HSIAMS_all_'+str(iterNum)+'.h5', 'rb'))
        pred_y[(i + 1) * testSizeNumber:len(ytest)] = svm_classifier.predict(temp3.cpu().detach().numpy())
        del temp, temp1, temp2, temp3
    pred_y = torch.from_numpy(pred_y).long()
    
    oa = accuracy_score(ytest, pred_y)     
    confusion = confusion_matrix(ytest, pred_y)     
    each_acc, aa = AA_andEachClassAccuracy(confusion)  
    kappa = cohen_kappa_score(ytest, pred_y) 
    return confusion, oa*100, each_acc*100, aa*100, kappa*100

def set_seed(seed):     
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)    
    np.random.seed(seed)   

sample_counts = [198, 190, 227, 188, 186, 196, 196, 191, 193, 191, 234, 192, 246, 216, 227]

class_weights = {i: 1.0 / count for i, count in enumerate(sample_counts)}

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, output, target):
        weights = [1 / count for count in sample_counts]
        class_weights = torch.tensor(weights).to('cuda:0')
        cross_entropy_loss = nn.CrossEntropyLoss(weight=class_weights)
        output = output.to('cuda:0')
        target = target.to('cuda:0')
        cross_entropy_loss = cross_entropy_loss(output, target)
        return cross_entropy_loss

def train():
    for BandSize in [1]:
        for datasetName in datasetNames:
                print("----------------------------------Training for ",datasetName," ---------------------------------------------")
                try:
                    os.makedirs(datasetName)   
                except FileExistsError:
                    pass
                data1Name = ''
                data2Name = ''
                if datasetName in ["Houston","MUUFL"]:
                    data1Name = datasetName
                    data2Name = "LIDAR"
                else:
                    for dataName in DATA2_List:
                        dataNameToCheck = re.compile(dataName)
                        matchObj = dataNameToCheck.search(datasetName)
                        if matchObj:
                            data1Name = datasetName.replace(dataName,"")
                            data2Name = dataName
                
                HSI = io.loadmat('./../'+data1Name+'11x11/HSI_Train.mat')      
                TrainPatch = HSI['block']       
                TrainPatch = TrainPatch.astype(np.float32)      
                NC = TrainPatch.shape[3] 
                
                LIDAR = io.loadmat('./../'+data1Name+'11x11/'+data2Name+'_Train.mat')
                TrainPatch2 = LIDAR['block2']
                TrainPatch2 = TrainPatch2.astype(np.float32)
                TrainPatch2 = TrainPatch2.reshape(-1, 11, 11, 1)
                NCLIDAR = TrainPatch2.shape[3] 

                label = io.loadmat('./../'+data1Name+'11x11/HSI_Train_label.mat')
                TrLabel = label['Tr_label']
                
                HSI = io.loadmat('./../'+data1Name+'11x11/HSI_Test.mat')
                TestPatch = HSI['block1']
                TestPatch = TestPatch.astype(np.float32)

                LIDAR = io.loadmat('./../'+data1Name+'11x11/'+data2Name+'_Test.mat')
                TestPatch2 = LIDAR['block3']
                TestPatch2 = TestPatch2.astype(np.float32)
                TestPatch2 = TestPatch2.reshape(-1, 11, 11, 1)

                label = io.loadmat('./../'+data1Name+'11x11/HSI_Test_label.mat')
                TsLabel = label['Te_label']
                
                
                TrainPatch1 = torch.from_numpy(TrainPatch).to(torch.float32)  
                TrainPatch1 = TrainPatch1.permute(0,3,1,2)  
                TrainPatch1 = TrainPatch1.reshape(TrainPatch1.shape[0],TrainPatch1.shape[1],-1).to(torch.float32)
                TrainPatch2 = torch.from_numpy(TrainPatch2).to(torch.float32)
                TrainPatch2 = TrainPatch2.permute(0,3,1,2)
                TrainPatch2 = TrainPatch2.reshape(TrainPatch2.shape[0],TrainPatch2.shape[1],-1).to(torch.float32)
                TrainLabel1 = torch.from_numpy(TrLabel)-1
                TrainLabel1 = TrainLabel1.long()
                TrainLabel1 = TrainLabel1.reshape(-1) 

                TestPatch1 = torch.from_numpy(TestPatch).to(torch.float32)
                TestPatch1 = TestPatch1.permute(0,3,1,2)
                TestPatch1 = TestPatch1.reshape(TestPatch1.shape[0],TestPatch1.shape[1],-1).to(torch.float32)
                TestPatch2 = torch.from_numpy(TestPatch2).to(torch.float32)
                TestPatch2 = TestPatch2.permute(0,3,1,2)
                TestPatch2 = TestPatch2.reshape(TestPatch2.shape[0],TestPatch2.shape[1],-1).to(torch.float32)
                TestLabel1 = torch.from_numpy(TsLabel)-1
                TestLabel1 = TestLabel1.long()
                TestLabel1 = TestLabel1.reshape(-1)
                
                Classes = len(np.unique(TrainLabel1)) 
                dataset = dataf.TensorDataset(TrainPatch1, TrainPatch2, TrainLabel1)
                if data1Name in ['Berlin']:
                    train_loader = dataf.DataLoader(dataset, batch_size=batchsize, shuffle=True, num_workers= 0)
                else:
                    train_loader = dataf.DataLoader(dataset, batch_size=batchsize, shuffle=True, num_workers= 4)  #  数据加载器    
                print("HSI Train data shape = ", TrainPatch1.shape)
                print(data2Name + " Train data shape = ", TrainPatch2.shape)
                print("Train label shape = ", TrainLabel1.shape)

                print("HSI Test data shape = ", TestPatch1.shape)
                print(data2Name + " Test data shape = ", TestPatch2.shape)
                print("Test label shape = ", TestLabel1.shape)

                print("Number of Classes = ", Classes)
                KAPPA = []
                OA = []
                AA = []
                ELEMENT_ACC = np.zeros((3, Classes))
                
                set_seed(42)
                for iterNum in range(3):
                    model = MFT(FM, NC, NCLIDAR, Classes, HSIOnly).cuda()
                    summary(model, [(NC, patchsize**2),(NCLIDAR,patchsize**2)])
                    
                    optimizer = torch.optim.Adam([
                        {'params': model.parameters(), 'lr':LR, 'weight_decay':5e-3},  
                    ])
                    criterion = CustomLoss()
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.9)
                    BestAcc = 0
                         
                    torch.cuda.synchronize()  
                    start = time.time()
                    for epoch in range(EPOCH):    
        
                        for step, (b_x1, b_x2, b_y) in enumerate(train_loader):
                            
                            b_x1 = b_x1.cuda()
                            b_y = b_y.cuda()
                                      
                            b_x2 = b_x2.cuda()
                            
                            out, intermediate_output = model(b_x1, b_x2)
                            
                            svm_classifier = svm.SVC(kernel='linear', tol=1e-9, class_weight=class_weights)
                            grid_search = GridSearchCV(svm_classifier, param_grid, cv=5)
                            grid_search.fit(intermediate_output.cpu().detach().numpy(), b_y.cpu().detach().numpy())
                            best_C = grid_search.best_params_['C']
                            svm_classifier = svm.SVC(kernel='linear', C=best_C, tol=1e-9, class_weight=class_weights)
                            svm_classifier.fit(intermediate_output.cpu().detach().numpy(), b_y.cpu().detach().numpy())
                            weights = svm_classifier.coef_
                            bias = svm_classifier.intercept_
                            weights = torch.tensor(weights, dtype=torch.float32).to('cuda')
                            bias = torch.tensor(bias, dtype=torch.float32).to('cuda')    
                            loss = criterion(out, b_y)
                            total_hinge_loss = 0
                            for i in range(64):
                                svm_loss = torch.sum(torch.abs(torch.matmul(weights, intermediate_output[i:i+1, :].transpose(0,1)) + bias.unsqueeze(1)) / torch.norm(weights))
                                total_hinge_loss += svm_loss
                            average_hinge_loss = torch.mean(torch.max(torch.tensor(0), 1 - total_hinge_loss))
                            loss = loss + w1 * average_hinge_loss
                            
                            optimizer.zero_grad()  
                            loss.backward()  
                            optimizer.step()  
                            
                            if step % 50 == 0:    
                                model.eval()  
                                pred_y = np.empty((len(TestLabel1)), dtype='float32')
                                number = len(TestLabel1) // testSizeNumber
                                for i in range(number):
                                    temp = TestPatch1[i * testSizeNumber:(i + 1) * testSizeNumber, :, :]
                                    temp = temp.cuda()
                                    temp1 = TestPatch2[i * testSizeNumber:(i + 1) * testSizeNumber, :, :]
                                    temp1 = temp1.cuda()
                                    if HSIOnly:
                                        temp2 = model(temp, temp1)
                                        temp3 = torch.max(temp2, 1)[1].squeeze()
                                        pred_y[i * testSizeNumber:(i + 1) * testSizeNumber] = temp3.cpu()
                                        del temp, temp2, temp3
                                    else:
                                        temp2, temp3 = model(temp, temp1)
                                        pred_y[i * testSizeNumber:(i + 1) * testSizeNumber] = svm_classifier.predict(temp3.cpu().detach().numpy())
                                        del temp, temp1, temp2, temp3
                                
                                if (i + 1) * testSizeNumber < len(TestLabel1):
                                    temp = TestPatch1[(i + 1) * testSizeNumber:len(TestLabel1), :, :]
                                    temp = temp.cuda()
                                    temp1 = TestPatch2[(i + 1) * testSizeNumber:len(TestLabel1), :, :]
                                    temp1 = temp1.cuda()
                                    if HSIOnly:
                                        temp2 = model(temp, temp1)
                                        temp3 = torch.max(temp2, 1)[1].squeeze()
                                        pred_y[(i + 1) * testSizeNumber:len(TestLabel1)] = temp3.cpu()
                                        del temp, temp2, temp3
                                    else:
                                        temp2, temp3 = model(temp, temp1)
                                        pred_y[(i + 1) * testSizeNumber:len(TestLabel1)] = svm_classifier.predict(temp3.cpu().detach().numpy())
                                        del temp, temp1, temp2, temp3
                                         
                                pred_y = torch.from_numpy(pred_y).long()
                                accuracy = torch.sum(pred_y == TestLabel1).type(torch.FloatTensor) / TestLabel1.size(0)
                                
                                print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.cpu().numpy(), '| test accuracy: %.4f' % (accuracy*100))
                                if accuracy > BestAcc:

                                    BestAcc = accuracy
                                    best_epoch = epoch
                                    
                                    torch.save(model.state_dict(), datasetName+'/net_params_AMS_all'+FileName+'.pkl')
                                    pickle.dump(svm_classifier, open(datasetName+'/best_model_HSIAMS_all_'+str(iterNum)+'.h5', 'wb'))
                                model.train()
                        scheduler.step()
                    print(f"The best epoch is {best_epoch+1} epoch , The accuracy is {BestAcc}")
                    
                    torch.cuda.synchronize()
                    end = time.time()
                    print(end - start)
                    Train_time = end - start
                    
                    model.load_state_dict(torch.load(datasetName+'/net_params_AMS_all'+FileName+'.pkl'))
                    
                    model.eval()
                    confusion, oa, each_acc, aa, kappa = reports(TestPatch1,TestPatch2,TestLabel1,datasetName,model,iterNum)
                    KAPPA.append(kappa)
                    OA.append(oa)
                    AA.append(aa)
                    ELEMENT_ACC[iterNum, :] = each_acc
                    torch.save(model, datasetName+'/best_model_AMS_all'+FileName+'_BandSize'+str(BandSize)+'_Iter'+str(iterNum)+'.pt')
                    pickle.dump(svm_classifier, open(datasetName+'/best_model_HSIAMS_all_'+str(iterNum)+'.h5', 'wb'))
                    
                    print("OA = ", oa)
                    print("AA = ", aa)
                    print("KAPPA = ", kappa)
                    
                print("----------" + datasetName + " Training Finished -----------")
                record.record_output(OA, AA, KAPPA, ELEMENT_ACC,'./' + datasetName +'/'+FileName+'_BandSize'+str(BandSize)+'_Report_AMS_all' + datasetName +'.txt')

train()
