In [None]:
import numpy as np
from PIL import Image
import skimage
from skimage import io, measure
import random
import scipy.io as sio
import matplotlib
import matplotlib.pyplot as plt
from preclassify import del2, srad, dicomp, FCM, hcluster
import torch
import torchvision
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import cv2
from collections import  Counter

im1_path  = 'im1.bmp'
im2_path  = 'im2.bmp'
imgt_path = 'im3.bmp'

patch_size = 9

In [None]:
def image_normalize(data):  # 对图像进行标准化
  import math
  _mean = np.mean(data)
  _std = np.std(data)
  npixel = np.size(data) * 1.0
  min_stddev = 1.0 / math.sqrt(npixel)
  return (data - _mean) / max(_std, min_stddev)

def image_padding(data,r):
  if len(data.shape)==3:
    data_new=np.lib.pad(data,((r,r),(r,r),(0,0)),'constant',constant_values=0)
    return data_new
  if len(data.shape)==2:
    data_new=np.lib.pad(data,r,'constant',constant_values=0)
    return data_new

def arr(length):
  arr=np.arange(length-1)
  #print(arr)
  random.shuffle(arr)
  #print(arr)
  return arr


def createTrainingCubes(X, y, patch_size):   # 创建训练数据集
  # 给 X 做 padding
  margin = int((patch_size - 1) / 2)
  zeroPaddedX = image_padding(X, margin)
  ele_num1 = np.sum(y==1)
  ele_num2 = np.sum(y==2)
  patchesData_1 = np.zeros((ele_num1, patch_size, patch_size, X.shape[2]))
  patchesLabels_1 = np.zeros(ele_num1)

  patchesData_2 = np.zeros((ele_num2, patch_size, patch_size, X.shape[2]))
  patchesLabels_2 = np.zeros(ele_num2)

  patchIndex_1 = 0
  patchIndex_2 = 0
  for r in range(margin, zeroPaddedX.shape[0] - margin):
    for c in range(margin, zeroPaddedX.shape[1] - margin):
      # remove uncertainty pixels
      if y[r-margin, c-margin] == 1 :
        patch_1 = zeroPaddedX[r - margin:r + margin + 1, c - margin:c + margin + 1]
        patchesData_1[patchIndex_1, :, :, :] = patch_1
        patchesLabels_1[patchIndex_1] = y[r-margin, c-margin]
        patchIndex_1 = patchIndex_1 + 1
      elif y[r-margin, c-margin] == 2 :
        patch_2 = zeroPaddedX[r - margin:r + margin + 1, c - margin:c + margin + 1]
        patchesData_2[patchIndex_2, :, :, :] = patch_2
        patchesLabels_2[patchIndex_2] = y[r-margin, c-margin]
        patchIndex_2 = patchIndex_2 + 1
  patchesLabels_1 = patchesLabels_1-1
  patchesLabels_2 = patchesLabels_2-1

  arr_1=arr(len(patchesData_1))
  arr_2=arr(len(patchesData_2))

  train_len=10000
  pdata=np.zeros((train_len, patch_size, patch_size, X.shape[2]))
  plabels = np.zeros(train_len)

  for i in range(7000):
    pdata[i,:,:,:]=patchesData_1[arr_1[i],:,:,:]
    plabels[i]=patchesLabels_1[arr_1[i]]
  for j in range(7000,train_len):
    pdata[j,:,:,:]=patchesData_2[arr_2[j-7000],:,:,:]
    plabels[j]=patchesLabels_2[arr_2[j-7000]]

  return pdata, plabels

def createTestingCubes(X,patch_size):   # 创建测试数据集
  # 给 X 做 padding
  margin = int((patch_size - 1) / 2)
  zeroPaddedX = image_padding(X, margin)
  patchesData = np.zeros((X.shape[0]*X.shape[1], patch_size, patch_size, X.shape[2]))
  patchIndex = 0
  for r in range(margin, zeroPaddedX.shape[0] - margin):
    for c in range(margin, zeroPaddedX.shape[1] - margin):
      patch = zeroPaddedX[r - margin:r + margin + 1, c - margin:c + margin + 1]
      patchesData[patchIndex, :, :, :] = patch
      patchIndex = patchIndex + 1
  return patchesData


#  Inputs:  gtImg  = ground truth image
#           tstImg = change map
#  Outputs: FA  = False alarms
#           MA  = Missed alarms
#           OE  = Overall error
#           PCC = Overall accuracy
def evaluate(gtImg, tstImg):
  if gtImg.ndim != 2:
    raise ValueError("gtImg should be a 2D array")
  if tstImg.ndim != 2:
    raise ValueError("tstImg should be a 2D array")
  gtImg[np.where(gtImg>128)] = 255
  gtImg[np.where(gtImg<128)] = 0
  tstImg[np.where(tstImg>128)] = 255
  tstImg[np.where(tstImg<128)] = 0
  [ylen, xlen] = gtImg.shape
  FA = 0
  MA = 0
  label_0 = np.sum(gtImg==0)
  label_1 = np.sum(gtImg==255)
  print(label_0)
  print(label_1)

  for j in range(0,ylen):
    for i in range(0,xlen):
      if gtImg[j,i]==0 and tstImg[j,i]!=0 :
        FA = FA+1
      if gtImg[j,i]!=0 and tstImg[j,i]==0 :
        MA = MA+1

  OE = FA+MA
  PCC = 1-OE/(ylen*xlen)
  PRE = (np.float64(label_1) + FA - MA) * label_1 + (np.float64(label_0) + MA - FA) * label_0
  PRE = PRE / (np.float64(ylen * xlen) * (ylen * xlen))
  KC=(PCC-PRE)/(1-PRE)
  print(' Change detection results ==>')
  print(' ... ... FP:  ', FA)
  print(' ... ... FN:  ', MA)
  print(' ... ... OE:  ', OE)
  print(' ... ... PCC: ', format(PCC*100, '.2f'))
  print(' ... ... KC: ', format(KC*100, '.2f'))

def postprocess1(res):
  res_new = res
  res = measure.label(res, connectivity=2)
  #print(res)
  num = res.max()
  #print(num)
  for i in range(1, num+1):
    idy, idx = np.where(res==i)
    if len(idy) <= 20:
      res_new[idy, idx] = 0
  return res_new

def postprocess(res):
  res_new = res
  res = measure.label(res, connectivity=2)
  #print(res)
  num = res.max()
  #print(num)
  for i in range(1, num+1):
    idy, idx = np.where(res==i)
    if len(idy) <= 20:
      res_new[idy, idx] = 0.5
  return res_new

In [None]:
# read image, and then tranform to float32
from skimage.color import rgb2gray
import numpy as np

im1 = io.imread(im1_path).astype(np.float32)
im2 = io.imread(im2_path).astype(np.float32)

print(im1.shape)
print(im2.shape)
print(type(im1))


im_gt = io.imread(imgt_path).astype(np.float32)

if im1.ndim == 3 and im1.shape[2] == 3:
    im1 = rgb2gray(im1)
if im2.ndim == 3 and im2.shape[2] == 3:
    im2 = rgb2gray(im2)
if im_gt.ndim == 3:
    im_gt = rgb2gray(im_gt).astype(np.uint8)

im_di = dicomp(im1, im2)
ylen, xlen = im_di.shape
pix_vec = im_di.reshape([ylen*xlen, 1])

preclassify_lab = hcluster(pix_vec, im_di)
print('... ... hiearchical clustering finished !!!')


mdata = np.zeros([im1.shape[0], im1.shape[1], 3], dtype=np.float32)
mdata[:,:,0] = im1
mdata[:,:,1] = im2
mdata[:,:,2] = im_di
mlabel = preclassify_lab

x_train, y_train = createTrainingCubes(mdata, mlabel, patch_size)
x_train = x_train.transpose(0, 3, 1, 2)
print('... x train shape: ', x_train.shape)
print('... y train shape: ', y_train.shape)


x_test = createTestingCubes(mdata, patch_size)
x_test = x_test.transpose(0, 3, 1, 2)
print('... x test shape: ', x_test.shape)

In [None]:
""" 训练数据集"""
class TrainDS(torch.utils.data.Dataset):
  def __init__(self):
    self.len = x_train.shape[0]
    self.x_data = torch.FloatTensor(x_train)
    self.y_data = torch.LongTensor(y_train)
  def __getitem__(self, index):
    return self.x_data[index], self.y_data[index]
  def __len__(self):
    return self.len

trainset = TrainDS()
train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=128, shuffle=True, num_workers=0)

In [None]:
from __future__ import print_function

import glob
from itertools import chain
import os
import random
import zipfile
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

epochs = 10
lr = 1e-3
gamma = 0.7
seed = 42
def seed_everything(seed):
  random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  torch.backends.cudnn.deterministic = True

seed_everything(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
def pair(t):
  return t if isinstance(t, tuple) else (t, t)

class Attention(nn.Module):
  def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
    super().__init__()
    inner_dim = dim_head *  heads
    project_out = not (heads == 1 and dim_head == dim)
    self.heads = heads
    self.scale = dim_head ** -0.5
    self.attend = nn.Softmax(dim = -1)
    self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

    self.to_out = nn.Sequential(
        nn.Linear(inner_dim, dim),
        nn.Dropout(dropout)
    ) if project_out else nn.Identity()
  def forward(self, x):
    qkv = self.to_qkv(x).chunk(3, dim = -1)
    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
    dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
    attn = self.attend(dots)
    out = torch.matmul(attn, v)
    out = rearrange(out, 'b h n d -> b n (h d)')
    return self.to_out(out)

class ViT(nn.Module):
  def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 1, dim_head = 64, dropout = 0., emb_dropout = 0.):
    super().__init__()
    image_height, image_width = pair(image_size)
    patch_height, patch_width = pair(patch_size)
    num_patches = (image_height // patch_height) * (image_width // patch_width)
    patch_dim = channels * patch_height * patch_width
    assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

    self.to_patch_embedding = nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
        nn.Linear(patch_dim, patch_dim),
    )
    self.transformer = Attention(patch_dim, heads=heads, dim_head=dim_head, dropout=dropout)
    '''
    self.reshape = nn.Sequential(
        Rearrange('b (h w) c -> b c h w', h = image_height, w = image_width, c = channels),
    )
    '''
  def forward(self, img):
    x = self.to_patch_embedding(img)
    x = self.transformer(x)
    x = x.unsqueeze(1)
    return x
vit = ViT(
  image_size = 9,
  patch_size = 3,
  num_classes = 2,
  dim = 27,
  depth = 6,
  heads = 4,
  mlp_dim = 32,
).to(device)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BiAttention(nn.Module):
    def __init__(self, feature_dim=1):
        super(BiAttention, self).__init__()
        self.feature_dim = feature_dim
        self.proj1 = nn.Conv2d(1, self.feature_dim, kernel_size=3, padding=1)
        self.proj2 = nn.Conv2d(1, self.feature_dim, kernel_size=3, padding=1)
        self.bilinear_pool = nn.AdaptiveAvgPool2d((7, 7))
        self.proj_back = nn.Conv2d(self.feature_dim, 1, kernel_size=1)

    def forward(self, img1, img2):
        print("input:",img1.shape)
        img1_proj = self.proj1(img1)
        img2_proj = self.proj2(img2)

        B, C, H, W = img1_proj.size()
        img1_flat = img1_proj.view(B*C, H, W)
        img2_flat = img2_proj.view(B*C, H, W)
        img2_flat_T = img2_flat.transpose(1, 2)

        bilinear_features = torch.bmm(img1_flat, img2_flat_T)
        print("after bmm:",bilinear_features.shape)
        bilinear_features = bilinear_features.view(B, 1, 9, 9)
        print("before:",bilinear_features)
        bilinear_features = self.bilinear_pool(bilinear_features)
        print("after:",bilinear_features)
        print("after pooling:",bilinear_features.shape)

        enhanced_feature = self.proj_back(bilinear_features)

        return enhanced_feature

In [None]:
class attention2d(nn.Module):
    def __init__(self, in_planes, ratios, K, temperature, init_weight=True):
        super(attention2d, self).__init__()
        assert temperature%3==1
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        if in_planes!=3:
            hidden_planes = 2
        else:
            hidden_planes = 16
        self.fc1 = nn.Conv2d(in_planes, hidden_planes, 1, bias=False)
        self.fc2 = nn.Conv2d(hidden_planes, K, 1, bias=True)
        self.temperature = temperature
        if init_weight:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            if isinstance(m ,nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def updata_temperature(self):
        if self.temperature!=1:
            self.temperature -=3
            print('Change temperature to:', str(self.temperature))

    def forward(self, x):
        x = self.avgpool(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x).view(x.size(0), -1)
        return F.softmax(x/self.temperature, 1)



class Dynamic_conv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4,temperature=22, init_weight=True):
        super(Dynamic_conv2d, self).__init__()
        assert in_planes % groups==0
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.bias = bias
        self.K = K
        self.attention = attention2d(in_planes, ratio, K, temperature)

        self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size, kernel_size), requires_grad=True)
        if bias:
            self.bias = nn.Parameter(torch.zeros(K, out_planes))
        else:
            self.bias = None
        if init_weight:
            self._initialize_weights()

    def _initialize_weights(self):
        for i in range(self.K):
            nn.init.kaiming_uniform_(self.weight[i])

    def update_temperature(self):
        self.attention.updata_temperature()

    def forward(self, x):
        softmax_attention = self.attention(x)

        batch_size, in_planes, height, width = x.size()
        x = x.view(1, -1, height, width)
        weight = self.weight.view(self.K, -1)

        aggregate_weight = torch.mm(softmax_attention, weight).view(batch_size*self.out_planes, self.in_planes//self.groups, self.kernel_size, self.kernel_size)
        if self.bias is not None:
            aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1)
            output = F.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding,
                              dilation=self.dilation, groups=self.groups*batch_size)
        else:
            output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
                              dilation=self.dilation, groups=self.groups * batch_size)

        output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
        return output

In [None]:
import math
def shift(y, n = 6):
  B, C, H, W = y.shape
  num = C // n
  out = torch.zeros_like(y)
  out[:, num * 0:num * 1, 1:, :] = y[:, num * 0:num * 1, :-1, :]  # shift down
  out[:, num * 1:num * 2, :-1, :] = y[:, num * 1:num * 2, 1:, :]  # shift up
  out[:, num * 2:num * 3, :, :-1] = y[:, num * 2:num * 3, :, 1:]  # shift left
  out[:, num * 3:num * 4, :, 1:] = y[:, num * 3:num * 4, :, :-1]  # shift right
  out[:, num * 4:, :, :] = y[:, num * 4:, :, :]  # no shift
  return out
class CSC(nn.Module):
  def __init__(self):
    super(CSC, self).__init__()
    self.conv1 = nn.Conv2d(1,12,1,1)
    self.shift = shift
    self.conv2 = nn.Conv2d(12,1,1,1)
    self.dconv1 = Dynamic_conv2d(in_planes=1, out_planes=12, kernel_size=1, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4,temperature=37, init_weight=True)
    self.dconv2 = Dynamic_conv2d(in_planes=12, out_planes=1, kernel_size=1, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4,temperature=37, init_weight=True)
  def forward(self, x):
    #x = self.conv1(x)
    x = self.dconv1(x)
    x = shift(x)
    #x = self.conv2(x)
    x = self.dconv2(x)
    return x
class SGU(nn.Module):
  def __init__(self, dim = 3):
    super(SGU, self).__init__()
    self.catConv =  nn.Conv2d(1, 1, kernel_size=1)
    self.norm1 = nn.LayerNorm([1, patch_size, patch_size])
    self.conv = nn.Conv2d(1,8,1,1)
    self.project_in = nn.Conv2d(1, 16, kernel_size=1)
    self.dwconv = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, groups=16, bias=False)
    self.project_out = nn.Conv2d(8, 1, kernel_size=1)
  def forward(self, x):
    catOut = self.catConv(x)
    x = catOut
    x = self.norm1(x)
    x = self.project_in(x)
    x1, x2 = self.dwconv(x).chunk(2, dim=1)
    x = F.gelu(x1) * x2
    catOut = self.conv(catOut)
    x = x + catOut
    x = self.project_out(x)
    return x

class CAMixer(nn.Module):
  def __init__(self):
    super(CAMixer, self).__init__()
    self.vit = vit
    self.csc = CSC()
    feature_dimension = 4
    self.BiAttention = BiAttention(feature_dim=feature_dimension)
    self.sgu = SGU()
    self.conv1x1 = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=1)
    self.linear1=nn.Linear(patch_size * patch_size * 1, 20)
    self.linear2=nn.Linear(20, 2)
    self.conv1x1 = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=1)
  def forward(self, img):
    first_channel = img[:, 0, :, :]
    first_channel = first_channel.unsqueeze(1)
    second_channel = img[:, 1, :, :]
    second_channel = second_channel.unsqueeze(1)
    in_x = img.reshape(img.shape[0],-1)

    vitOut11 = self.vit(first_channel)
    cscOut11 = self.csc(first_channel)
    catOut11 = torch.cat((vitOut11, cscOut11), 1)
    catOut11 = self.conv1x1(catOut11)
    sguOut11 = self.sgu(first_channel)
    vitOut21 = self.vit(second_channel)
    cscOut21 = self.csc(second_channel)
    catOut21 = torch.cat((vitOut21, cscOut21), 1)
    catOut21 = self.conv1x1(catOut21)
    sguOut21 = self.sgu(second_channel)
    BiAttentionout11 = self.BiAttention(catOut11,sguOut11)
    BiAttentionout21 = self.BiAttention(catOut21,sguOut21)
    x = self.BiAttention(BiAttentionout11,BiAttentionout21)
    out = x.view(x.size(0), -1)
    out1 = self.linear1(out)
    out1 = self.linear2(out1)
    return in_x,out1

In [None]:
from ptflops import get_model_complexity_info
criterion = nn.CrossEntropyLoss()
model = CAMixer().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr) #
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from tqdm import tqdm
import time

# 网络训练循环
outputs = np.zeros((ylen, xlen))
for epoch in range(20):
    epoch_loss = 0
    epoch_accuracy = 0
    all_in_x = []
    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        in_x, output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

        # 存储 in_x
        all_in_x.append(in_x.detach().cpu().numpy())
    print(f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f}\n")

In [None]:
istrain=False
model.eval()
with torch.no_grad():
  outputs = np.zeros((ylen, xlen))
  for i in range(ylen):
    for j in range(xlen):
      if preclassify_lab[i, j] != 1.5 :
        outputs[i, j] = preclassify_lab[i, j]
      else:
        img_patch = x_test[i*xlen+j, :, :, :]
        img_patch = img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2])
        img_patch = torch.FloatTensor(img_patch).to(device)
        in_x, prediction = model(img_patch)
        prediction = np.argmax(prediction.detach().cpu().numpy(), axis=1)
        outputs[i, j] = prediction+1

  outputs = outputs-1


plt.imshow(outputs, 'gray')


res = outputs*255
res = postprocess(res)
evaluate(im_gt, res)
plt.imshow(res, 'gray')
plt.imsave('result.png', res, cmap='gray')