In [None]:
!pip install einops
!pip install ipdb

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import numpy as np
from PIL import Image
import skimage
from skimage import io, measure
import random
import scipy.io as sio
import matplotlib
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import cv2
from collections import  Counter
from __future__ import print_function
import glob
from itertools import chain
import os
import zipfile
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

import math
import numbers
from torch.nn import init
from einops import rearrange
from functools import partial
from ipdb import set_trace as st

torch.backends.cudnn.enable =True

In [None]:
def acc_score(pre, lab):
  num = 0
  for i in range(pre.shape[0]):
    if(pre[i] != lab[i]):
      num+=1
  return (pre.shape[0] - num) / pre.shape[0]

#  Inputs:  gtImg  = ground truth image
#           tstImg = change map
#  Outputs: FA  = False alarms
#           MA  = Missed alarms
#           OE  = Overall error
#           PCC = Overall accuracy
def evaluate(gtImg, tstImg):
  ylen = gtImg.shape[0]
  FA = 0
  MA = 0
  label_0 = np.sum(gtImg==0)
  label_1 = np.sum(gtImg==1)
  print("label_0:", label_0)
  print("label_1:", label_1)

  for j in range(ylen):
    if gtImg[j]==0 and tstImg[j]!=0 :
        FA = FA+1
    if gtImg[j]!=0 and tstImg[j]==0 :
        MA = MA+1

  OE = FA+MA
  PCC = 1-OE/(ylen)
  PRE=((label_1+FA-MA)*label_1+(label_0+MA-FA)*label_0)/((ylen)*(ylen))
  KC=(PCC-PRE)/(1-PRE)
  print(' Change detection results ==>')
  print(' ... ... FP:  ', FA)
  print(' ... ... FN:  ', MA)
  print(' ... ... OE:  ', OE)
  print(' ... ... PCC: ', format(PCC*100, '.2f'))
  print(' ... ... KC: ', format(KC*100, '.2f'))

In [None]:
# Training settings
windowSize = 6
epochs = 60
# lr = 3e-5
lr = 3e-5
gamma = 0.7
seed = 11
batch_size = 32
def seed_everything(seed):
  random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  torch.backends.cudnn.deterministic = True
seed_everything(seed)
device = 'cuda'
# device = 'cpu'

In [None]:
# # 获取index保存为.mat  格式为(x,y)
# # im_gt = sio.loadmat('/content/drive/MyDrive/wzy/datasets/GTSa.mat')['GT']
# im_gt = sio.loadmat('/content/drive/MyDrive/wzy/datasets/GTFarm.mat')['label']
# im_gt = sio.loadmat('/content/drive/MyDrive/wzy/datasets/GTRiver.mat')['gt']

# def getIndexMat(im_gt):

#   ele_num1 = np.sum(im_gt==1)
#   ele_num2 = np.sum(im_gt==2)

#   index_1 = []
#   index_2 = []
#   index_all = []
#   index_num1 = 0
#   index_num2 = 0
#   for i in range(im_gt.shape[0]):
#     for j in range(im_gt.shape[1]):
#       if(im_gt[i][j] == 1):
#         index_1.append([i, j])
#         index_num1 += 1
#       elif im_gt[i][j] == 2:
#         index_2.append([i, j])
#         index_num2 += 1
#       index_all.append([i,j])
#   mat_path1 = '/content/drive/MyDrive/wzy/datasets/River_index_1.mat'
#   mat_path2 = '/content/drive/MyDrive/wzy/datasets/River_index_2.mat'
#   mat_path3 = '/content/drive/MyDrive/wzy/datasets/River_index_all.mat'
#   io.savemat(mat_path1, {'index_1': index_1})
#   io.savemat(mat_path2, {'index_2': index_2})
#   io.savemat(mat_path3, {'index_all': index_all})
# getIndexMat(im_gt)

In [None]:
""" Training dataset"""
from torchvision.transforms import ToTensor
class TrainTestDS(torch.utils.data.Dataset):
  def __init__(self, hsi_1, hsi_2, gt, pos, windowSize):
    self.pad = windowSize // 2
    self.windowSize = windowSize
    self.pos = pos
    self.im1 = np.pad(hsi_1, ((self.pad, self.pad), (self.pad, self.pad), (0, 0)), 'constant',constant_values=0)
    self.im2 = np.pad(hsi_2, ((self.pad, self.pad), (self.pad, self.pad), (0, 0)), 'constant',constant_values=0)
    self.gt = gt
  def __getitem__(self, index):
    # 根据索引返回数据和对应的标签
    h, w = self.pos[index, :]
    im1 = self.im1[h: h + self.windowSize, w: w + self.windowSize]
    im2 = self.im2[h: h + self.windowSize, w: w + self.windowSize]
    im1 = ToTensor()(im1).float()
    im2 = ToTensor()(im2).float()
    gt = torch.tensor(self.gt[h, w] - 1).long()

    return im1, im2, gt
  def __len__(self):
    # 返回文件数据的数目
    return self.pos.shape[0]

In [None]:
def getData(dataName, hsi_path, windowSize, keys):
  hsi_path1 = hsi_path + '/' + dataName + '1.mat'
  hsi_path2 = hsi_path + '/' + dataName + '2.mat'
  gt_path = hsi_path + '/GT' + dataName + '.mat'
  index_path1 = hsi_path + dataName +'_index_1.mat'
  index_path2 = hsi_path + dataName +'_index_2.mat'
  index_path3 = hsi_path + dataName +'_index_all.mat'
  im1 = sio.loadmat(hsi_path1)[keys[0]]
  im2 = sio.loadmat(hsi_path2)[keys[1]]
  im_gt = sio.loadmat(gt_path)[keys[2]]
  im_index_1 = sio.loadmat(index_path1)['index_1']
  im_index_2 = sio.loadmat(index_path2)['index_2']
  im_index_all = sio.loadmat(index_path3)['index_all']
  h = im_gt.shape[0]
  w = im_gt.shape[1]
  c = im1.shape[2]

  test_len = im_gt.shape[0] * im_gt.shape[1]   #设置测试index样本数
  train_len=9000  #设置训练集index样本数

  pdata=np.zeros((train_len, 2))
  tdata=np.zeros((test_len, 2))

  tdata[:,:] = im_index_all[:,:]
  print(test_len)
  np.random.shuffle(im_index_1)
  np.random.shuffle(im_index_2)

  pdata[0:6000,:]=im_index_1[0:6000,:]
  pdata[6000:,:]=im_index_2[0:3000,:]

  np.random.shuffle(pdata)

  pdata = pdata.astype(np.int64)
  tdata = tdata.astype(np.int64)

  trainset = TrainTestDS(im1, im2, im_gt, pdata, windowSize)
  train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True, num_workers=2)
  testset = TrainTestDS(im1, im2, im_gt, tdata, windowSize)
  test_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size, shuffle=False, num_workers=2)

  return train_loader, test_loader, h, w, c

In [None]:
keysFarm = ['imgh', 'imghl', 'label']
keysRiver = ['river_before', 'river_after', 'gt']
keysSa = ['T1', 'T2', 'GT']

hsi_path = '/content/drive/MyDrive/wzy/HSI_CD/model/datasets/'

# dataKey = keysFarm
# dataName = 'Farm'

dataKey = keysSa
dataName = 'SSa'

# dataKey = keysRiver
# dataName = 'River'

train_loader, test_loader, H, W, C = getData(dataName, hsi_path, windowSize, dataKey)

In [None]:
class PixelEmbedding(nn.Module):
    def __init__(self, num_bands, embed_dim):
        super(PixelEmbedding, self).__init__()
        self.fc = nn.Linear(num_bands, embed_dim)
        self.positional_encoding = nn.Parameter(torch.randn(1,windowSize*windowSize,embed_dim))
        nn.init.trunc_normal_(self.positional_encoding, std=0.02)

    def forward(self, x):
        # (batch_size, seq_len(9*9), num_bands)
        x = self.fc(x)
        x += self.positional_encoding
        return x # [B , 81 ,384]

In [None]:
def drop_path(x, drop_prob: float = 0., training: bool = False):

    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):

    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

In [None]:
class GLA(nn.Module):

    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=3, alpha=0.5):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
        head_dim = int(dim/num_heads)
        self.dim = dim

        self.l_heads = int(num_heads * alpha)
        self.l_dim = self.l_heads * head_dim

        self.h_heads = num_heads - self.l_heads
        self.h_dim = self.h_heads * head_dim

        self.ws = window_size

        if self.ws == 1:
            self.h_heads = 0
            self.h_dim = 0
            self.l_heads = num_heads
            self.l_dim = dim

        self.scale = qk_scale or head_dim ** -0.5

        if self.l_heads > 0:
            if self.ws != 1:
                self.sr = nn.AvgPool2d(kernel_size=window_size, stride=window_size)
            self.l_q = nn.Linear(self.dim, self.l_dim, bias=qkv_bias)
            self.l_kv = nn.Linear(self.dim, self.l_dim * 2, bias=qkv_bias)
            self.l_proj = nn.Linear(self.l_dim, self.l_dim)

        if self.h_heads > 0:
            self.h_qkv = nn.Linear(self.dim, self.h_dim * 3, bias=qkv_bias)
            self.h_proj = nn.Linear(self.h_dim, self.h_dim)

    def loc(self, x):
        B, H, W, C = x.shape
        h_group, w_group = H // self.ws, W // self.ws

        total_groups = h_group * w_group

        x = x.reshape(B, h_group, self.ws, w_group, self.ws, C).transpose(2, 3)

        qkv = self.h_qkv(x).reshape(B, total_groups, -1, 3, self.h_heads, self.h_dim // self.h_heads).permute(3, 0, 1, 4, 2, 5)
        q, k, v = qkv[0], qkv[1], qkv[2]  # B, hw, n_head, ws*ws, head_dim
        attn = (q @ k.transpose(-2, -1)) * self.scale  # B, hw, n_head, ws*ws, ws*ws
        attn = attn.softmax(dim=-1)
        attn = (attn @ v).transpose(2, 3).reshape(B, h_group, w_group, self.ws, self.ws, self.h_dim)
        x = attn.transpose(2, 3).reshape(B, h_group * self.ws, w_group * self.ws, self.h_dim)

        x = self.h_proj(x)
        return x

    def glo(self, x):
        B, H, W, C = x.shape

        q = self.l_q(x).reshape(B, H * W, self.l_heads, self.l_dim // self.l_heads).permute(0, 2, 1, 3)

        if self.ws > 1:
            x_ = x.permute(0, 3, 1, 2)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            kv = self.l_kv(x_).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.l_kv(x).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.l_dim)
        x = self.l_proj(x)
        return x

    def forward(self, x, H, W):
        B, N, C = x.shape

        x = x.reshape(B, H, W, C)

        if self.h_heads == 0:
            x = self.glo(x)
            return x.reshape(B, N, C)

        if self.l_heads == 0:
            x = self.loc(x)
            return x.reshape(B, N, C)

        loc_out = self.loc(x)
        glo_out = self.glo(x)

        x = torch.cat((loc_out, glo_out), dim=-1)
        x = x.reshape(B, N, C)

        return x

In [None]:
# Cross Gated Feed-Forward Network
class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor=2.66, bias=True):
        super(FeedForward, self).__init__()
        hidden_features = int(dim*ffn_expansion_factor)
        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        B,_ , C = x.shape
        # x = torch.permute(x.view(B,9,9,C),(0,3,1,2)) # [B,C,9,9]
        x = rearrange(x, 'b (h w) c -> b c h w',h=windowSize,w=windowSize)

        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x2)*x1 + F.gelu(x1)*x2
        x = self.project_out(x)

        # x = torch.permute(x,(0,2,3,1)).view(B,81,C) # [B,9*9,C]
        x = rearrange(x, 'b c h w -> b (h w) c')
        return x

In [None]:
class Block(nn.Module):
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=2.66,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        super(Block, self).__init__()
        self.norm1 = norm_layer(dim)
        # self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
        #                       attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)

        self.gla = GLA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop=attn_drop_ratio, proj_drop=drop_ratio, window_size = 3,alpha=0.5)

        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.ffn = FeedForward(dim,ffn_expansion_factor=mlp_ratio)
        # mlp_hidden_dim = int(dim * mlp_ratio)
        # self.ffn = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)

    def forward(self, x):
        x = x + self.drop_path(self.gla(self.norm1(x),windowSize, windowSize))  # norm1-->attn-->drop_path
        x = x + self.drop_path(self.ffn(self.norm2(x)))   # norm2-->MLP(FFN)-->drop_path
        return x

In [None]:
def _init_vit_weights(m):
    """
    ViT weight initialization
    :param m: module
    """
    if isinstance(m, nn.Linear):
        nn.init.trunc_normal_(m.weight, std=.01)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out")
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.zeros_(m.bias)
        nn.init.ones_(m.weight)

In [None]:

class GLAFormer(nn.Module):
    def __init__(self,  img_c, num_classes=2,
                 embed_dim=128, depth=4, num_heads=8, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PixelEmbedding, norm_layer=None,
                 act_layer=None):

        super(GLAFormer, self).__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 1
        img1_norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        img1_act_layer = act_layer or nn.GELU
        img2_norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        img2_act_layer = act_layer or nn.GELU

        self.img1_patch_embed = embed_layer(num_bands = img_c, embed_dim = embed_dim)
        self.img2_patch_embed = embed_layer(num_bands = img_c, embed_dim = embed_dim)

        img1_dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule
        self.img1_blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=img1_dpr[i],
                  norm_layer=img1_norm_layer, act_layer=img1_act_layer)
            for i in range(depth)
        ])
        self.img1_norm = img1_norm_layer(embed_dim)

        img2_dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule
        self.img2_blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=img1_dpr[i],
                  norm_layer=img2_norm_layer, act_layer=img2_act_layer)
            for i in range(depth)
        ])
        self.img2_norm = img2_norm_layer(embed_dim)

        # Classifier head(s)
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=embed_dim, out_channels=16, kernel_size=3, padding=1),nn.ReLU())
        self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, padding=1))
        self.fc = nn.Sequential(
            nn.Linear(windowSize*windowSize, 20),
            nn.ReLU(),
            nn.Linear(20, 2)
        )

        self.apply(_init_vit_weights)

    def forward_features(self, img):
        x = self.patch_embed(x , y)
        x = self.blocks(x)
        x = self.norm(x)
        return x

    def forward(self, img1, img2):
        img1 = self.img1_patch_embed(img1)
        img1 = self.img1_blocks(img1)
        img1 = self.img1_norm(img1)


        img2 = self.img1_patch_embed(img2)
        img2 = self.img1_blocks(img2)
        img2 = self.img1_norm(img2)

        fuse_feature = torch.abs(torch.sub(img1, img2))
        # print(fuse_feature.shape)
        B,_ , C = fuse_feature.shape
        x = torch.torch.permute(fuse_feature.view(B,windowSize,windowSize,C),(0,3,1,2))
        # print(x.shape)
        x = self.conv1(x)
        x = self.conv2(x)

        x = x.view(B,-1)
        x = self.fc(x)

        return x

In [None]:
model = GLAFormer(img_c=C,
                embed_dim=256,
                depth=6,
                num_heads=8,
                num_classes=2)

In [None]:
def save():
  model.eval()
  count = 0
  max_acc = 0.
  for im1_loader, im2_loader, label in tqdm(test_loader):
    im1_loader = im1_loader.to(device)
    im2_loader = im2_loader.to(device)
    im1_loader = torch.transpose(torch.flatten(im1_loader, start_dim=2,end_dim=3),1,2)
    im2_loader = torch.transpose(torch.flatten(im2_loader, start_dim=2,end_dim=3),1,2)
    predict = model(im1_loader,im2_loader)
    predict = np.argmax(predict.detach().cpu().numpy(), axis=1)

    if count == 0:
      y_pred_test =  predict
      gty = label
      count = 1
    else:
      y_pred_test = np.concatenate( (y_pred_test, predict) )
      gty = np.concatenate( (gty, label) )
  acc1 = acc_score(y_pred_test, gty)
  if acc1 > max_acc:
    torch.save(model, '/content/drive/MyDrive/wzy/HSI_CD/model/Sa/H_D/model_H_D_W6{:.4f}.pth'.format(acc1))
    max_acc = acc1

  evaluate(y_pred_test, gty)

  outputs = np.zeros((H, W))
  for i in range(H):
    outputs[i,:] = y_pred_test[i*W:(i+1)*W]
  res = outputs*255
  res = res.astype(np.int64)
  plt.imshow(res, 'gray')

In [None]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

model.train()
for epoch in range(epochs):
  epoch_loss = 0
  epoch_accuracy = 0
  for im1_loader, im2_loader, label in tqdm(train_loader):
    im1_loader = im1_loader.to(device)
    im2_loader = im2_loader.to(device)

    # [P,81,C]
    im1_loader = torch.transpose(torch.flatten(im1_loader, start_dim=2,end_dim=3),1,2)
    im2_loader = torch.transpose(torch.flatten(im2_loader, start_dim=2,end_dim=3),1,2)
    label = label.to(device)
    output = model(im1_loader, im2_loader)

    loss = criterion(output, label)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    acc = (output.argmax(dim=1) == label).float().mean()
    epoch_accuracy += acc / len(train_loader)
    epoch_loss += loss / len(train_loader)
  print(
      f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f}\n"
  )
  if (epoch+1)%20==0:
    save()