<a href="https://colab.research.google.com/github/pengmy001/VideoSwin/blob/main/vslr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
!pip install einops

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [None]:
import tensorflow as tf
import keras
from keras import layers
from einops import rearrange
from keras.layers import Input
from typing import List, Union
import os
from osgeo import gdal
import numpy as np
import math
from keras.callbacks import ReduceLROnPlateau, EarlyStopping
from keras.utils import plot_model
import datetime
import cv2
import random


EPSILON = 1e-5


CONFIG = {
    "SMALL": {
        "stem_chs": [32,16,32],
        "depths": [1,2,2],
        "drop_path": 0.1,
    },
    "BASE": {
        "stem_chs": [64, 32, 64],
        "depths": [3, 4, 20, 3],
        "drop_path": 0.1,
    },
    "LARGE": {
        "stem_chs": [64, 32, 64],
        "depths": [3, 4, 30, 3],
        "drop_path": 0.1,
    },
}
#"SMALL": {
#        "stem_chs": [64, 32, 64],
#        "depths": [3, 4, 10, 3],
#        "drop_path": 0.1,
#    },
# https://github.com/huggingface/transformers/blob/60d51ef5123d949fd8c59cd4d3254e711541d278/src/transformers/tf_utils.py#L26
def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
    if isinstance(tensor, np.ndarray):
        return list(tensor.shape)
    dynamic = tf.shape(tensor)
    if tensor.shape == tf.TensorShape(None):
        return dynamic
    static = tensor.shape.as_list()
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]


class E_MHSA(layers.Layer):
    """
    Efficient Multi-Head Self Attention
    """

    def __init__(
        self,
        dim,
        out_dim=None,
        head_dim=2,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0,
        proj_drop=0.0,
        sr_ratio=1,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.dim = dim
        self.out_dim = out_dim if out_dim is not None else dim
        self.num_heads = self.dim // head_dim
        self.scale = qk_scale or head_dim**-0.5
        self.q = tf.keras.layers.Dense(dim, use_bias=qkv_bias)
        self.k = tf.keras.layers.Dense(dim, use_bias=qkv_bias)
        self.v = tf.keras.layers.Dense(dim, use_bias=qkv_bias)
        self.proj = tf.keras.layers.Dense(self.out_dim)
        self.attn_drop = tf.keras.layers.Dropout(attn_drop)
        self.proj_drop = tf.keras.layers.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        self.N_ratio = sr_ratio**2
        if sr_ratio > 1:
            self.sr = tf.keras.layers.AveragePooling1D(
                pool_size=self.N_ratio, strides=self.N_ratio
            )
            self.norm = tf.keras.layers.BatchNormalization(epsilon=1e-5)

    def call(self, x):
        B = shape_list(x)[0]
        N = shape_list(x)[1]
        C = shape_list(x)[2]
        q = self.q(x)
        q = tf.reshape(q, (B, N, self.num_heads, int(C // self.num_heads)))
        q = tf.transpose(q, perm=[0, 2, 1, 3])

        if self.sr_ratio > 1:
            x_ = tf.transpose(x, perm=[0, 2, 1])
            #x_ = self.sr(x_)
            x_ = tf.transpose(x_, perm=[0, 2, 1])
            k = self.k(x_)
            k = tf.reshape(k, (B, -1, self.num_heads, C // self.num_heads))
            k = tf.transpose(k, perm=[0, 2, 3, 1])
            v = self.v(x_)
            v = tf.reshape(v, (B, -1, self.num_heads, C // self.num_heads))
            v = tf.transpose(v, perm=[0, 2, 1, 3])
        else:
            k = self.k(x)
            k = tf.reshape(k, (B, -1, self.num_heads, C // self.num_heads))
            k = tf.transpose(k, perm=[0, 2, 3, 1])
            v = self.v(x)
            v = tf.reshape(v, (B, -1, self.num_heads, C // self.num_heads))
            v = tf.transpose(v, perm=[0, 2, 1, 3])
        attn = tf.matmul(q, k) * self.scale

        #attn = tf.nn.softmax(attn, axis=-1)
        attn = self.attn_drop(attn)

        x = tf.matmul(attn, v)
        x = tf.transpose(x, perm=[0, 2, 1, 3])
        x = tf.reshape(x, (B, N, C))
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


def NextViT_2b(
    input_shape1, input_shape2, stem_chs, depths, path_dropout,out_chans,
) -> keras.Model:
    strides = [1, 1, 1] #[1, 2, 2, 2]
    sr_ratios = [4, 2, 1]
    input_layer1 = layers.Input(input_shape1)
    input_layer2 = layers.Input(input_shape2)
    #input_layer1 = layers.Input((256,256,32))
    #input_layer2 = layers.Input((256,256,14))

    stem1 = tf.keras.Sequential(
        [
            ConvBNReLU(stem_chs[0], kernel_size=3, strides=1),#2
            ConvBNReLU(stem_chs[1], kernel_size=3, strides=1),
            #ConvBNReLU(stem_chs[2], kernel_size=3, strides=1),
            ConvBNReLU(stem_chs[2], kernel_size=3, strides=1),#2
        ]
    )
    x1 = stem1(input_layer1)

    '''stage_out_channels = [
        [4] * (depths[0]),
        [8] * (depths[1] - 1) + [8],
        [16, 32] * (depths[2] // 1),
        #[16] * (depths[2] - 1) + [16],
    ]
    print(stage_out_channels)

    stage_block_types = [
        [NTB] * depths[0],
        [NCB] * (depths[1] - 1) + [NTB],
        [NCB, NTB] * (depths[2] // 1),
        #[NCB] * (depths[2] - 1) + [NTB],
    ]
    print(stage_block_types)'''

    stage_out_channels = [
        [4] * (depths[0]),
        [6] * (depths[1] - 1) + [6],
        [8, 16] * (depths[2] // 1),
        #[16] * (depths[2] - 1) + [16],
    ]
    print(stage_out_channels)

    stage_block_types = [
        [NCB] * depths[0],
        [NCB] * (depths[1] - 1) + [NCB],
        [NCB, NTB] * (depths[2] // 1),
        #[NCB] * (depths[2] - 1) + [NTB],
    ]
    print(stage_block_types)

    input_channel1 = stem_chs[-1]
    features1 = []
    idx = 0
    dpr = [x for x in tf.linspace(0.0, path_dropout, sum(depths))]
    for stage_id in range(len(depths)):
        numrepeat = depths[stage_id]
        output_channels = stage_out_channels[stage_id]
        block_types = stage_block_types[stage_id]
        for block_id in range(numrepeat):
            if strides[stage_id] == 2 and block_id == 0:
                stride = 1 #2
            else:
                stride = 1
            output_channel = output_channels[block_id]
            print(output_channel)
            block_type = block_types[block_id]
            print(block_type)
            if block_type is NCB:
                layer = NCB(
                    output_channel,
                    strides=stride,
                    path_dropout=dpr[idx + block_id],
                    drop=0,
                    head_dim=2,
                )#head_dim=32
                features1.append(layer)
            elif block_type is NTB:
                layer = NTB(
                    output_channel,
                    path_dropout=dpr[idx + block_id],
                    strides=stride,
                    sr_ratio=sr_ratios[stage_id],
                    head_dim=2,
                    mix_block_ratio=0.75,
                    attn_drop=0,
                    drop=0,
                )#head_dim=32
                features1.append(layer)
        idx += numrepeat
    for layer in features1:
        x1 = layer(x1)
    x1 = layers.BatchNormalization(epsilon=1e-5)(x1)

    stem2 = tf.keras.Sequential(
        [
            ConvBNReLU(stem_chs[0], kernel_size=3, strides=1),  # 2
            ConvBNReLU(stem_chs[1], kernel_size=3, strides=1),
            # ConvBNReLU(stem_chs[2], kernel_size=3, strides=1),
            ConvBNReLU(stem_chs[2], kernel_size=3, strides=1),  # 2
        ]
    )
    x2 = stem2(input_layer2)

    input_channel2 = stem_chs[-1]
    features2 = []
    idx = 0
    dpr = [x for x in tf.linspace(0.0, path_dropout, sum(depths))]
    for stage_id in range(len(depths)):
        numrepeat = depths[stage_id]
        output_channels = stage_out_channels[stage_id]
        block_types = stage_block_types[stage_id]
        for block_id in range(numrepeat):
            if strides[stage_id] == 2 and block_id == 0:
                stride = 1  # 2
            else:
                stride = 1
            output_channel = output_channels[block_id]
            block_type = block_types[block_id]
            if block_type is NCB:
                layer = NCB(
                    output_channel,
                    strides=stride,
                    path_dropout=dpr[idx + block_id],
                    drop=0,
                    head_dim=2,
                )  # head_dim=32
                features2.append(layer)
            elif block_type is NTB:
                layer = NTB(
                    output_channel,
                    path_dropout=dpr[idx + block_id],
                    strides=stride,
                    sr_ratio=sr_ratios[stage_id],
                    head_dim=2,
                    mix_block_ratio=0.75,
                    attn_drop=0,
                    drop=0,
                )  # head_dim=32
                features2.append(layer)
        idx += numrepeat
    for layer in features2:
        x2 = layer(x2)
    x2 = layers.BatchNormalization(epsilon=1e-5)(x2)

    x = layers.Concatenate(axis=-1)([x1,x2])

    stem = tf.keras.Sequential(
        [
            ConvBNReLU(stem_chs[0], kernel_size=3, strides=1),  # 2
            ConvBNReLU(stem_chs[1], kernel_size=3, strides=1),
            # ConvBNReLU(stem_chs[2], kernel_size=3, strides=1),
            ConvBNReLU(stem_chs[2], kernel_size=3, strides=1),  # 2
        ]
    )
    x = stem(x)

    input_channel = stem_chs[-1]
    features = []
    idx = 0
    dpr = [x for x in tf.linspace(0.0, path_dropout, sum(depths))]
    for stage_id in range(len(depths)):
        numrepeat = depths[stage_id]
        output_channels = stage_out_channels[stage_id]
        block_types = stage_block_types[stage_id]
        for block_id in range(numrepeat):
            if strides[stage_id] == 2 and block_id == 0:
                stride = 1  # 2
            else:
                stride = 1
            output_channel = output_channels[block_id]
            block_type = block_types[block_id]
            if block_type is NCB:
                layer = NCB(
                    output_channel,
                    strides=stride,
                    path_dropout=dpr[idx + block_id],
                    drop=0,
                    head_dim=2,
                )  # head_dim=32
                features.append(layer)
            elif block_type is NTB:
                layer = NTB(
                    output_channel,
                    path_dropout=dpr[idx + block_id],
                    strides=stride,
                    sr_ratio=sr_ratios[stage_id],
                    head_dim=2,
                    mix_block_ratio=0.75,
                    attn_drop=0,
                    drop=0,
                )  # head_dim=32
                features.append(layer)
        idx += numrepeat
    for layer in features:
        x = layer(x)
    x = layers.BatchNormalization(epsilon=1e-5)(x)

    x = layers.Conv2D(filters=out_chans, kernel_size=3, strides=1,padding='same')(x)
    x = layers.BatchNormalization(epsilon=1e-5)(x)


    return keras.Model([input_layer1,input_layer2], x)



def nextvit_small_2b(input_shape1=(None, None, 3),input_shape2=(None, None, 3),out_chans=103):
    model = NextViT_2b(
        input_shape1=input_shape1,
        input_shape2=input_shape2,
        stem_chs=CONFIG["SMALL"]["stem_chs"],
        depths=CONFIG["SMALL"]["depths"],
        path_dropout=CONFIG["SMALL"]["drop_path"],
        out_chans=out_chans,
    )
    return model

class StochasticDepth(layers.Layer):
    """Stochastic Depth module.
    It is also referred to as Drop Path in `timm`.
    References:
        (1) github.com:rwightman/pytorch-image-models
    """

    def __init__(self, drop_path, **kwargs):
        super(StochasticDepth, self).__init__(**kwargs)
        self.drop_path = drop_path

    def call(self, x, training=None):
        if training:
            keep_prob = 1 - self.drop_path
            shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
            random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
            random_tensor = tf.floor(random_tensor)
            return (x / keep_prob) * random_tensor
        return x


# https://github.com/bytedance/Next-ViT/blob/main/classification/nextvit.py
def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class ConvBNReLU(tf.keras.Model):
    def __init__(self, filters, kernel_size, strides, groups=1, **kwargs):
        super(ConvBNReLU, self).__init__(**kwargs)
        self.conv = layers.Conv2D(
            filters,
            kernel_size=kernel_size,
            strides=strides,
            groups=groups,
            padding="SAME",
            use_bias=False,
        )
        self.norm = layers.BatchNormalization(epsilon=EPSILON)
        self.act = tf.keras.layers.ELU()

    def call(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x


class MHCA(tf.keras.Model):
    def __init__(self, filters, head_dim, **kwargs):
        super(MHCA, self).__init__(**kwargs)
        self.new = filters // head_dim
        self.group_conv3x3 = layers.Conv2D(
            filters=filters,
            kernel_size=3,
            strides=1,
            padding="same",
            groups=filters // head_dim,
            use_bias=False,
        )
        self.norm = layers.BatchNormalization(epsilon=EPSILON)
        self.act = layers.Activation("elu")
        self.projection = layers.Conv2D(filters, kernel_size=1, padding="same",use_bias=False)#padding 后加

    def call(self, x):
        x = self.group_conv3x3(x)
        x = self.norm(x)
        x = self.act(x)
        x = self.projection(x)
        return x


class mlp(tf.keras.Model):
    def __init__(self, filters, mlp_ratio=1, drop=0.0, **kwargs):
        super(mlp, self).__init__(**kwargs)
        hidden_dim = _make_divisible(filters * mlp_ratio, 32)
        self.conv1 = layers.Conv2D(hidden_dim, kernel_size=1, padding="SAME")#原padding="VALID"
        self.act = layers.Activation("elu")
        self.drop1 = layers.Dropout(drop)
        self.conv2 = layers.Conv2D(filters, kernel_size=1, padding="SAME")#原padding="VALID"
        self.drop2 = layers.Dropout(drop)

    def call(self, x):
        x = self.conv1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.conv2(x)
        x = self.drop2(x)
        return x


class PatchEmbed(tf.keras.Model):
    def __init__(self, filters, strides):
        super(PatchEmbed, self).__init__()
        self.filters = filters
        self.strides = strides

        '''if self.strides == 2:
            self.avg_pool = tf.keras.layers.AvgPool2D(
                pool_size=(2, 2), strides=2, padding="same"
            )
        else:
            self.avg_pool = None'''
        self.avg_pool = None

        if self.filters is not None:
            print(self.filters)
            self.conv = tf.keras.layers.Conv2D(
                filters=self.filters, kernel_size=1, strides=1, use_bias=False,padding="same"
            )#wu padding
            self.bn = tf.keras.layers.BatchNormalization(epsilon=EPSILON)

    def call(self, inputs):
        x = inputs
        #if self.avg_pool is not None:
        #    x = self.avg_pool(x)
        if self.filters is not None:
            if x.shape[-1] != self.filters:
                #x = tf.keras.layers.Lambda(lambda x: x)(x)
                x = x
                x = self.conv(x)
                x = self.bn(x)
            else:
                x = x
                x = x
                x = x
                #x = tf.keras.layers.Lambda(lambda x: x)(x)
                #x = tf.keras.layers.Lambda(lambda x: x)(x)
                #x = tf.keras.layers.Lambda(lambda x: x)(x)
        return x


class NCB(tf.keras.Model):
    def __init__(
        self,
        filters,
        strides=1,
        path_dropout=0,
        drop=0,
        head_dim=2,
        mlp_ratio=3,
        **kwargs
    ):
        super(NCB, self).__init__(**kwargs)
        self.filters = filters
        self.strides = strides
        self.patch_embed = PatchEmbed(filters, strides)
        self.mhca = MHCA(filters, head_dim)
        self.attention_path_dropout = StochasticDepth(path_dropout)
        self.norm = layers.BatchNormalization(epsilon=EPSILON)
        self.mlp = mlp(filters, mlp_ratio=mlp_ratio, drop=drop)
        self.mlp_path_dropout = StochasticDepth(path_dropout)

    def call(self, x):
        x = self.patch_embed(x)
        x = x + self.attention_path_dropout(self.mhca(x))
        x = self.norm(x)
        x = x + self.mlp_path_dropout(self.mlp(x))
        return x


class NTB(tf.keras.Model):
    def __init__(
        self,
        filters,
        path_dropout=0,
        strides=1,
        sr_ratio=1,
        mlp_ratio=2,
        head_dim=2,
        mix_block_ratio=0.75,
        attn_drop=0,
        drop=0,
        **kwargs
    ):
        super(NTB, self).__init__(**kwargs)
        self.filters = filters
        self.strides = strides
        self.mix_block_ratio = mix_block_ratio

        self.mhsa_out_channels = _make_divisible(
            int(filters * mix_block_ratio), 2
        )#,32
        print(self.mhsa_out_channels)

        self.mhca_out_channels = filters - self.mhsa_out_channels
        self.patch_embed = PatchEmbed(self.mhsa_out_channels, strides)
        self.norm1 = layers.BatchNormalization(epsilon=EPSILON)
        self.e_mhsa = E_MHSA(
            self.mhsa_out_channels,
            head_dim=head_dim,
            sr_ratio=sr_ratio,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.mhsa_path_dropout = StochasticDepth(path_dropout * mix_block_ratio)
        self.projection = PatchEmbed(self.mhca_out_channels, strides=1)
        self.mhca = MHCA(self.mhca_out_channels, head_dim=head_dim)
        self.mhca_path_dropout = StochasticDepth(
            path_dropout * (1 - mix_block_ratio)
        )
        self.norm2 = layers.BatchNormalization(epsilon=EPSILON)
        self.mlp = mlp(filters, mlp_ratio=mlp_ratio, drop=drop)
        self.mlp_path_dropout = StochasticDepth(path_dropout)

    def call(self, x):
        x = self.patch_embed(x)
        B, C, H, W = x.shape
        out = self.norm1(x)
        out = rearrange(out, "b h w c -> b (h w) c")
        out = self.mhsa_path_dropout(self.e_mhsa(out))
        x = x + rearrange(out, "b (h w) c -> b h w c", h=H)
        out = self.projection(x)
        out = out + self.mhca_path_dropout(self.mhca(out))
        x = tf.concat([x, out], axis=-1)
        out = self.norm2(x)
        x = x + self.mlp_path_dropout(self.mlp(out))
        return x

'''input = Input((256,256,24))
test = NTB(128,
           path_dropout=0.036363635,
           strides=4,
           sr_ratio=5,
           head_dim=4,
           mix_block_ratio=0.75,
           attn_drop=0,
           drop=0,
           )
x = test(input)
a=1'''



class GRID():
        # 读图像文件
        def read_img(self, filename):
            dataset = gdal.Open(filename)
            # 打开文件
            im_width = dataset.RasterXSize  # 栅格矩阵的列数
            im_height = dataset.RasterYSize  # 栅格矩阵的行数

            im_geotrans = dataset.GetGeoTransform()  # 仿射矩阵
            im_proj = dataset.GetProjection()  # 地图投影信息
            im_data = dataset.ReadAsArray(0, 0, im_width, im_height)  # 将数据写成数组，对应栅格矩阵
            im_data = np.array(im_data)
            sp = im_data.shape
            if im_data.ndim == 2:
                im_data2 = im_data[:, :, np.newaxis]
            else:
                im_data2 = np.zeros((im_height, im_width, sp[0]))
                for bands in range(0, sp[0]):
                    im_data2[:, :, bands] = im_data[bands, :, :]

            del dataset
            return im_height, im_width, im_data2

        def write_img(self, filename, im_proj, im_geotrans, im_data2):
            # gdal数据类型包括
            # gdal.GDT_Byte,
            # gdal .GDT_UInt16, gdal.GDT_Int16, gdal.GDT_UInt32, gdal.GDT_Int32,
            # gdal.GDT_Float32, gdal.GDT_Float64
            sp = im_data2.shape
            if len(sp) > 2:
                im_data = np.zeros((sp[2], sp[0], sp[1]))
                for bands in range(0, sp[2]):
                    im_data[bands, :, :] = im_data2[:, :, bands]
                print(im_data.shape)
            else:
                im_data = im_data2

            # 判断栅格数据的数据类型
            if 'int8' in im_data.dtype.name:
                datatype = gdal.GDT_Byte
            elif 'int16' in im_data.dtype.name:
                datatype = gdal.GDT_UInt16
            else:
                datatype = gdal.GDT_Float32

            # 判读数组维数
            if len(im_data.shape) == 3:
                im_bands, im_height, im_width = im_data.shape
            else:
                im_bands, (im_height, im_width) = 1, im_data.shape

            # 创建文件
            driver = gdal.GetDriverByName("ENVI")  # 数据类型必须有，因为要计算需要多大内存空间
            dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

            #dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
            #dataset.SetProjection(im_proj)  # 写入投影

            if im_bands == 1:
                dataset.GetRasterBand(1).WriteArray(im_data)  # 写入数组数据
            else:
                for i in range(im_bands):
                    dataset.GetRasterBand(i + 1).WriteArray(im_data[i])

            del dataset

        def restore4d_over(self, data4d, patchsize, sp, jg):

            spp = data4d.shape
            sm = math.floor((sp[1] - 2 * jg) / (patchsize - 2 * jg))
            rex = np.zeros((math.floor(spp[0] / sm) * (patchsize - 2 * jg), sm * (patchsize - 2 * jg), sp[2]))
            for ii in range(0, spp[0]):
                mm = math.floor(ii / sm)
                nn = ii % sm
                rex[mm * (patchsize - 2 * jg):(mm + 1) * (patchsize - 2 * jg),
                nn * (patchsize - 2 * jg):(nn + 1) * (patchsize - 2 * jg), :] = data4d[ii, jg:patchsize - jg,
                                                                                jg:patchsize - jg, :]
            return rex

        def prepare4d_over(self, img, patchsize, jg):
            sp = img.shape
            mnum = math.floor((sp[0] - 2 * jg) / (patchsize - 2 * jg))
            nnum = math.floor((sp[1] - 2 * jg) / (patchsize - 2 * jg))
            result = np.zeros((mnum * nnum, patchsize, patchsize, sp[2]))
            count = -1
            for i in range(0, mnum):
                for j in range(0, nnum):
                    count += 1
                    result[count, :, :, :] = img[(patchsize - 2 * jg) * i:(patchsize - 2 * jg) * i + patchsize,
                                                (patchsize - 2 * jg) * j: (patchsize - 2 * jg) * j + patchsize, :]
            return result, sp

class ETmodel:
    def __init__(self,dirr,patchsize,jg,epochs,resname):
        self.patchsize = patchsize
        self.jg = jg
        self.grid = GRID()
        lr_train, vs_train, lr_pred = self.read_data(dirr,patchsize,jg)
        print('data read!')
        self.model = nextvit_small_2b(input_shape1=(lr_train.shape[1], lr_train.shape[2], 1),
                                      input_shape2=(vs_train.shape[1], vs_train.shape[2], 1),
                                      out_chans=1
                                      )
        print('model built!')
        plot_model(self.model, to_file='modelvs_auth.png',show_shapes=True)
        self.train_model(lr_train, vs_train, lr_pred, epochs)
        self.pred_model(dirr,patchsize,jg)
        #res = self.grid.restore4d_over(res4d, patchsize, self.sp, jg)
        #self.grid.write_img(resname,[],[],res)

    def gaussian_down_sample(self,data,w,mask=0):
      # masking mode
      if np.isscalar(mask):
        masking = 0
      else:
        masking = 1

      xdata = data.shape[0]
      ydata = data.shape[1]
      band = data.shape[2]
      hx = int(np.floor(xdata/w))
      hy = int(np.floor(ydata/w))
      HSI = np.zeros((hx, hy, band))
      sig = w/2.35482

      if masking == 0: # without mask
        if np.mod(w,2)==0:
            H1 = self.gaussian_filter2d((w,w),sig).reshape(w,w,1)
            H2 = self.gaussian_filter2d((w*2,w*2),sig).reshape(w*2,w*2,1)
            for x in range(hx):
                for y in range(hy):
                    if x==0 or x==hx-1 or y==0 or y==hy-1:
                        HSI[x,y,:] = (np.double( data[x*w:(x+1)*w,y*w:(y+1)*w,:] ) * np.tile(H1,(1,1,band))).sum(axis=0).sum(axis=0).reshape(1,1,band)
                    else:
                        HSI[x,y,:] = (np.double( data[x*w-int(w/2):(x+1)*w+int(w/2),y*w-int(w/2):(y+1)*w+int(w/2),:] ) * np.tile(H2,(1,1,band))).sum(axis=0).sum(axis=0).reshape(1,1,band)
        else:
            H1 = self.gaussian_filter2d((w,w),sig).reshape(w,w,1)
            H2 = self.gaussian_filter2d((w*2-1,w*2-1),sig).reshape(w*2-1,w*2-1,1)
            for x in range(hx):
                for y in range(hy):
                    if x==0 or x==hx-1 or y==0 or y==hy-1:
                        HSI[x,y,:] = (np.double( data[x*w:(x+1)*w,y*w:(y+1)*w,:] ) * np.tile(H1,(1,1,band)) ).sum(axis=0).sum(axis=0).reshape(1,1,band)
                    else:
                        HSI[x,y,:] = (np.double( data[x*w-int((w-1)/2):(x+1)*w+int((w-1)/2),y*w-int((w-1)/2):(y+1)*w+int((w-1)/2),:] ) * np.tile(H2,(1,1,band))).sum(axis=0).sum(axis=0).reshape(1,1,band)
      else: # with mask
        if np.mod(w,2)==0:
            H1 = self.gaussian_filter2d((w,w),sig).reshape(w,w,1)
            H2 = self.gaussian_filter2d((w*2,w*2),sig).reshape(w*2,w*2,1)
            for x in range(hx):
                for y in range(hy):
                    mask_tmp = mask[x*w:(x+1)*w,y*w:(y+1)*w]
                    if mask_tmp.sum() == w**2:
                        if x==0 or x==hx-1 or y==0 or y==hy-1:
                            HSI[x,y,:] = (np.double( data[x*w:(x+1)*w,y*w:(y+1)*w,:] ) * np.tile(H1,(1,1,band))).sum(axis=0).sum(axis=0).reshape(1,1,band)
                        else:
                            HSI[x,y,:] = (np.double( data[x*w-int(w/2):(x+1)*w+int(w/2),y*w-int(w/2):(y+1)*w+int(w/2),:] ) * np.tile(H2,(1,1,band))).sum(axis=0).sum(axis=0).reshape(1,1,band)
        else:
            H1 = self.gaussian_filter2d((w,w),sig).reshape(w,w,1)
            H2 = self.gaussian_filter2d((w*2-1,w*2-1),sig).reshape(w*2-1,w*2-1,1)
            for x in range(hx):
                for y in range(hy):
                    mask_tmp = mask[x*w:(x+1)*w,y*w:(y+1)*w]
                    if mask_tmp.sum() == w**2:
                        if x==0 or x==hx-1 or y==0 or y==hy-1:
                            HSI[x,y,:] = (np.double( data[x*w:(x+1)*w,y*w:(y+1)*w,:] ) * np.tile(H1,(1,1,band)) ).sum(axis=0).sum(axis=0).reshape(1,1,band)
                        else:
                            HSI[x,y,:] = (np.double( data[x*w-int((w-1)/2):(x+1)*w+int((w-1)/2),y*w-int((w-1)/2):(y+1)*w+int((w-1)/2),:] ) * np.tile(H2,(1,1,band))).sum(axis=0).sum(axis=0).reshape(1,1,band)

      return HSI

    def gaussian_filter2d(self,shape=(3,3),sigma=1):
      m,n = [(ss-1.)/2. for ss in shape]
      y,x = np.ogrid[-m:m+1,-n:n+1]
      h = np.exp( -(x**2 + y**2) / (2.*sigma**2) )
      h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
      sumh = h.sum()
      if sumh != 0:
        h /= sumh
      return h

    def read_data(self,dirr,patchsize,jg):
        lrdata = []
        vsdata = []
        zdlr = {}
        zdvs = {}
        lr_train = []
        vs_train = []
        lr_pred = []
        lr_tar = []
        dirl = dirr + r'/lr'
        dirv = dirr + r'/vis2pan'
        listl = os.listdir(dirl)
        listv = os.listdir(dirv)
        nums = len(dirl)
        indd = np.arange(nums)[0:60]
        random.shuffle(indd)
        listl = [listl[i] for i in indd]
        listv = [listv[i] for i in indd]

        for ii in range(0,len(listl)):
          lfile = listl[ii]
          lfile = dirl + r'/' + lfile
          imgr = cv2.imread(lfile,0)
          gauss = np.random.normal(0,25,(imgr.shape[0],imgr.shape[1]))
          imgr = imgr + gauss

          vfile = listv[ii]
          vfile = dirv + r'/' + vfile
          imgv = cv2.imread(vfile,0)


          if imgr.shape != imgv.shape:
            sp = [min(imgr.shape[0],imgv.shape[0]),min(imgr.shape[1],imgv.shape[1])]
            imgr = imgr[:sp[0],:sp[1]]
            imgv = imgv[:sp[0],:sp[1]]

          imgr = imgr[:,:,np.newaxis]
          imgv = imgv[:,:,np.newaxis]

          #lr_train
          imgr1 = cv2.resize(cv2.resize(imgr,
           (int(imgr.shape[0]/9*4),int(imgr.shape[1]/9*4))),(int(imgr.shape[0]/3*2),int(imgr.shape[1]/3*2)))[:,:,np.newaxis]
          #lr_pred
          imgr2 = cv2.resize(imgr,(int(imgr.shape[0]/3*2),int(imgr.shape[1]/3*2)))[:,:,np.newaxis]


          #vs_train
          imgv1 = cv2.resize(imgv,(int(imgv.shape[0]/3*2),int(imgv.shape[1]/3*2)))[:,:,np.newaxis]


          if lr_train == []:
            lr_train,sp = self.grid.prepare4d_over(imgr1,patchsize,jg)
          else:
            temp,sp = self.grid.prepare4d_over(imgr1,patchsize,jg)
            lr_train = np.concatenate((lr_train,temp),axis=0)

          if lr_pred == []:
            lr_pred,sp = self.grid.prepare4d_over(imgr2,patchsize,jg)
          else:
            temp,sp = self.grid.prepare4d_over(imgr2,patchsize,jg)
            lr_pred = np.concatenate((lr_pred,temp),axis=0)


          if vs_train == []:
            vs_train,sp = self.grid.prepare4d_over(imgv1,patchsize,jg)
          else:
            temp,sp = self.grid.prepare4d_over(imgv1,patchsize,jg)
            vs_train = np.concatenate((vs_train,temp),axis=0)

        print(lr_train.shape)
        print(vs_train.shape)
        print(lr_pred.shape)

        indd = np.arange(0,lr_train.shape[0])
        np.random.shuffle(indd)
        lr_train = lr_train[indd,:,:,:]
        vs_train = vs_train[indd,:,:,:]
        lr_pred = lr_pred[indd,:,:,:]

        return lr_train, vs_train, lr_pred

    def train_model(self,x1,x2,y,epochs):
        print(x1.shape)
        print(x2.shape)
        print(y.shape)
        minn = min(x1.shape[0],x2.shape[0])
        optim = tf.keras.optimizers.Adam(lr=0.001)
        self.model.compile(loss='mse', optimizer=optim, metrics=['accuracy'])
        if os.path.isfile('pres_model.h5'):
            self.model.load_weights('pres_model.h5')
            self.status = 1
            # return
        else:
            callbacks1 = EarlyStopping(monitor='loss', min_delta=0, patience=15, verbose=0, mode='auto',
                                       baseline=None, restore_best_weights=True)
            reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.2,
                                          patience=7, min_lr=0.00001)
            # history = self.p_model.fit(x=self.hd5mu[:,np.linspace(0,8,5,dtype='int8'),:,:,:], y=self.hd5l, batch_size=32, callbacks = [callbacks1,reduce_lr],epochs=self.epochs, validation_split=0.1)
            history = self.model.fit(x=[x1[:minn,:,:,:]/100,x2[:minn,:,:,:]/100], y=y/100, batch_size=4,
                                       callbacks=[callbacks1, reduce_lr],
                                       epochs=epochs, validation_split=0.1)
            #history = self.model.fit_generator(self.datagen.flow([x1[:minn,:,:,:]/1000,x2[:minn,:,:,:]/1000],
            #                                           y=y[:minn,:,:,:]/1000),
            #                                           epochs=epochs,
            #                                           steps_per_epoch=minn/4)
            self.model.save_weights('pres_model.h5')
            loss_history = history.history["loss"]
            val_loss_history = history.history['val_loss']
            numpy_loss_history = np.array(loss_history)
            numpy_val_loss = np.array(val_loss_history)
            np.savetxt('pres_loss.txt', numpy_loss_history,delimiter=",")
            np.savetxt('pres_val_loss.txt', numpy_val_loss,delimiter=",")

    def pred_model(self,dirr,patchsize,jg):
        dirl = dirr + r'/lr'
        dirv = dirr + r'/vis2pan'
        dirres = dirr + r'/res/'
        listl = os.listdir(dirl)
        listv = os.listdir(dirv)
        for ii in range(0,len(listl)):
          lfile = listl[ii]
          lfile = dirl + r'/' + lfile
          imgr = cv2.imread(lfile,0)

          vfile = listv[ii]
          vfile = dirv + r'/' + vfile
          imgv = cv2.imread(vfile,0)

          if imgr.shape != imgv.shape:
            sp = [min(imgr.shape[0],imgv.shape[0]),min(imgr.shape[1],imgv.shape[1])]
            imgr = imgr[:sp[0],:sp[1]]
            imgv = imgv[:sp[0],:sp[1]]

          imgr = imgr[:,:,np.newaxis]
          imgv = imgv[:,:,np.newaxis]

          imgr2 = cv2.resize(imgr,(int(imgv.shape[1]/3*2),int(imgv.shape[0]/3*2)))[:,:,np.newaxis]
          #lr_tar
          imgr3 = cv2.resize(imgr2,(imgv.shape[1],imgv.shape[0]))[:,:,np.newaxis]

          lr_tar,sp = self.grid.prepare4d_over(imgr3,patchsize,5)
          vsdata,sp = self.grid.prepare4d_over(imgv,patchsize,5)

          print(lr_tar.shape)
          print(vsdata.shape)

          res4d = self.model.predict([lr_tar/100,vsdata/100], batch_size=8)*100
          res = self.grid.restore4d_over(res4d,patchsize,sp,5)
          cv2.imwrite(dirres + listl[ii][:-4]+'_fus.png', np.array(np.squeeze(res),dtype=int))


if __name__ == '__main__':
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    dirr = r'/content/drive/MyDrive/new'
    os.chdir(dirr)
    print(os.getcwd())
    print(os.listdir())
    patchsize = 64
    jg = 0
    epochs = 1000

    starttime = datetime.datetime.now()
    ET = ETmodel(dirr,patchsize,jg,epochs,'')
    endtime = datetime.datetime.now()
    print('total time is ' + str(endtime - starttime) + 's')
    np.savetxt('time_cst.txt', np.array([(endtime - starttime).seconds]), delimiter=",")

/content/drive/.shortcut-targets-by-id/1NGc_AhpYbtXwsERLnGHUwsUHrAPuyqR5/new
['vis2pan', 'res', 'lr', 'pres_loss.txt', 'pres_val_loss.txt', 'modelvs_auth.png']


  if lr_train == []:
  if lr_pred == []:
  if vs_train == []:


(2320, 64, 64, 1)
(2320, 64, 64, 1)
(2320, 64, 64, 1)
data read!
[[4], [6, 6], [8, 16, 8, 16]]
[[<class '__main__.NCB'>], [<class '__main__.NCB'>, <class '__main__.NCB'>], [<class '__main__.NCB'>, <class '__main__.NTB'>, <class '__main__.NCB'>, <class '__main__.NTB'>]]
4
<class '__main__.NCB'>
4
6
<class '__main__.NCB'>
6
6
<class '__main__.NCB'>
6
8
<class '__main__.NCB'>
8
16
<class '__main__.NTB'>
12
12
4
4
6
6
8
12
12
4
4
6
6
8
12
12
4
model built!




(2320, 64, 64, 1)
(2320, 64, 64, 1)
(2320, 64, 64, 1)
Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000

In [None]:
!apt install psmisc
!sudo fuser /dev/nvidia*

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
psmisc is already the newest version (23.4-2build3).
0 upgraded, 0 newly installed, 0 to remove and 24 not upgraded.
/dev/nvidia0:          228m
/dev/nvidiactl:        228m
/dev/nvidia-uvm:       228m


In [None]:
!kill -9 228

vit 2loss 思路

In [None]:
import tensorflow as tf
import keras
from keras import layers
import keras.backend as K
from einops import rearrange
from keras.layers import Input
from typing import List, Union
import os
from osgeo import gdal
import numpy as np
import math
from keras.callbacks import ReduceLROnPlateau, EarlyStopping
from keras.utils import plot_model
import datetime
import cv2
import random


EPSILON = 1e-5


CONFIG = {
    "SMALL": {
        "stem_chs": [32,16,32],
        "depths": [1,2],
        "drop_path": 0.1,
    },
    "BASE": {
        "stem_chs": [64, 32, 64],
        "depths": [3, 4, 20, 3],
        "drop_path": 0.1,
    },
    "LARGE": {
        "stem_chs": [64, 32, 64],
        "depths": [3, 4, 30, 3],
        "drop_path": 0.1,
    },
}
#"SMALL": {
#        "stem_chs": [64, 32, 64],
#        "depths": [3, 4, 10, 3],
#        "drop_path": 0.1,
#    },
# https://github.com/huggingface/transformers/blob/60d51ef5123d949fd8c59cd4d3254e711541d278/src/transformers/tf_utils.py#L26
def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
    if isinstance(tensor, np.ndarray):
        return list(tensor.shape)
    dynamic = tf.shape(tensor)
    if tensor.shape == tf.TensorShape(None):
        return dynamic
    static = tensor.shape.as_list()
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]


class E_MHSA(layers.Layer):
    """
    Efficient Multi-Head Self Attention
    """

    def __init__(
        self,
        dim,
        out_dim=None,
        head_dim=2,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0,
        proj_drop=0.0,
        sr_ratio=1,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.dim = dim
        self.out_dim = out_dim if out_dim is not None else dim
        self.num_heads = self.dim // head_dim
        self.scale = qk_scale or head_dim**-0.5
        self.q = tf.keras.layers.Dense(dim, use_bias=qkv_bias)
        self.k = tf.keras.layers.Dense(dim, use_bias=qkv_bias)
        self.v = tf.keras.layers.Dense(dim, use_bias=qkv_bias)
        self.proj = tf.keras.layers.Dense(self.out_dim)
        self.attn_drop = tf.keras.layers.Dropout(attn_drop)
        self.proj_drop = tf.keras.layers.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        self.N_ratio = sr_ratio**2
        if sr_ratio > 1:
            self.sr = tf.keras.layers.AveragePooling1D(
                pool_size=self.N_ratio, strides=self.N_ratio
            )
            self.norm = tf.keras.layers.BatchNormalization(epsilon=1e-5)

    def call(self, x):
        B = shape_list(x)[0]
        N = shape_list(x)[1]
        C = shape_list(x)[2]
        q = self.q(x)
        q = tf.reshape(q, (B, N, self.num_heads, int(C // self.num_heads)))
        q = tf.transpose(q, perm=[0, 2, 1, 3])

        if self.sr_ratio > 1:
            x_ = tf.transpose(x, perm=[0, 2, 1])
            #x_ = self.sr(x_)
            x_ = tf.transpose(x_, perm=[0, 2, 1])
            k = self.k(x_)
            k = tf.reshape(k, (B, -1, self.num_heads, C // self.num_heads))
            k = tf.transpose(k, perm=[0, 2, 3, 1])
            v = self.v(x_)
            v = tf.reshape(v, (B, -1, self.num_heads, C // self.num_heads))
            v = tf.transpose(v, perm=[0, 2, 1, 3])
        else:
            k = self.k(x)
            k = tf.reshape(k, (B, -1, self.num_heads, C // self.num_heads))
            k = tf.transpose(k, perm=[0, 2, 3, 1])
            v = self.v(x)
            v = tf.reshape(v, (B, -1, self.num_heads, C // self.num_heads))
            v = tf.transpose(v, perm=[0, 2, 1, 3])
        attn = tf.matmul(q, k) * self.scale

        #attn = tf.nn.softmax(attn, axis=-1)
        attn = self.attn_drop(attn)

        x = tf.matmul(attn, v)
        x = tf.transpose(x, perm=[0, 2, 1, 3])
        x = tf.reshape(x, (B, N, C))
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


def Encoder(
    input_shape1, stem_chs, depths, path_dropout,out_chans,
) -> keras.Model:
    strides = [1, 1, 1] #[1, 2, 2, 2]
    sr_ratios = [4, 2, 1]
    input_layer1 = layers.Input(input_shape1)
    #input_layer2 = layers.Input(input_shape2)
    #input_layer1 = layers.Input((256,256,32))
    #input_layer2 = layers.Input((256,256,14))

    stem1 = tf.keras.Sequential(
        [
            ConvBNReLU(stem_chs[0], kernel_size=3, strides=1),#2
            ConvBNReLU(stem_chs[1], kernel_size=3, strides=1),
            #ConvBNReLU(stem_chs[2], kernel_size=3, strides=1),
            ConvBNReLU(stem_chs[2], kernel_size=3, strides=1),#2
        ]
    )
    x1 = stem1(input_layer1)

    '''stage_out_channels = [
        [4] * (depths[0]),
        [8] * (depths[1] - 1) + [8],
        [16, 32] * (depths[2] // 1),
        #[16] * (depths[2] - 1) + [16],
    ]
    print(stage_out_channels)

    stage_block_types = [
        [NTB] * depths[0],
        [NCB] * (depths[1] - 1) + [NTB],
        [NCB, NTB] * (depths[2] // 1),
        #[NCB] * (depths[2] - 1) + [NTB],
    ]
    print(stage_block_types)'''

    stage_out_channels = [
        [4] * (depths[0]),
        #[6] * (depths[1] - 1) + [6],
        [8, 16] * (depths[1] // 1),
        #[16] * (depths[2] - 1) + [16],
    ]

    stage_block_types = [
        [NCB] * depths[0],
        #[NCB] * (depths[1] - 1) + [NCB],
        [NCB, NTB] * (depths[1] // 1),
        #[NCB] * (depths[2] - 1) + [NTB],
    ]

    input_channel1 = stem_chs[-1]
    features1 = []
    idx = 0
    dpr = [x for x in tf.linspace(0.0, path_dropout, sum(depths))]
    for stage_id in range(len(depths)):
        numrepeat = depths[stage_id]
        output_channels = stage_out_channels[stage_id]
        block_types = stage_block_types[stage_id]
        for block_id in range(numrepeat):
            if strides[stage_id] == 2 and block_id == 0:
                stride = 1 #2
            else:
                stride = 1
            output_channel = output_channels[block_id]
            block_type = block_types[block_id]
            if block_type is NCB:
                layer = NCB(
                    output_channel,
                    strides=stride,
                    path_dropout=dpr[idx + block_id],
                    drop=0,
                    head_dim=2,
                )#head_dim=32
                features1.append(layer)
            elif block_type is NTB:
                layer = NTB(
                    output_channel,
                    path_dropout=dpr[idx + block_id],
                    strides=stride,
                    sr_ratio=sr_ratios[stage_id],
                    head_dim=2,
                    mix_block_ratio=0.75,
                    attn_drop=0,
                    drop=0,
                )#head_dim=32
                features1.append(layer)
        idx += numrepeat
    for layer in features1:
        x1 = layer(x1)
    x1 = layers.BatchNormalization(epsilon=1e-5)(x1)
    sp = x1.shape
    x1 = layers.Flatten()(x1)
    z_mean = layers.Dense(out_chans, name="z_mean")(x1)
    z_log_var = layers.Dense(out_chans, name="z_log_var")(x1)
    return keras.Model(input_layer1, [z_mean,z_log_var])

def Decoder(
    input_shape1,stem_chs, depths, path_dropout,out_chans,patchsize,
) -> keras.Model:

    strides = [1, 1, 1] #[1, 2, 2, 2]
    sr_ratios = [1, 2, 4]
    d1,d2 = depths
    depths = [d2,d1]
    input_layer1 = layers.Input(input_shape1,name='decoder_input')
    #input_layer2 = layers.Input(input_shape2)
    #input_layer1 = layers.Input((256,256,32))
    #input_layer2 = layers.Input((256,256,14))

    x1 = layers.Dense(patchsize*patchsize*64)(input_layer1)

    x1 = layers.Reshape((patchsize, patchsize, 64))(x1)

    stage_out_channels = [[8, 16] * (depths[0] // 1),
        [4] * (depths[1])]

    stage_block_types = [
        [NTB, NCB] * (depths[0] // 1),
        [NCB] * depths[1]]

    input_channel1 = stem_chs[-1]
    features1 = []
    idx = 0
    dpr = [x for x in tf.linspace(0.0, path_dropout, sum(depths))]
    for stage_id in range(len(depths)):
        numrepeat = depths[stage_id]
        output_channels = stage_out_channels[stage_id]
        block_types = stage_block_types[stage_id]
        for block_id in range(numrepeat):
            if strides[stage_id] == 2 and block_id == 0:
                stride = 1 #2
            else:
                stride = 1
            output_channel = output_channels[block_id]
            block_type = block_types[block_id]
            if block_type is NCB:
                layer = NCB(
                    output_channel,
                    strides=stride,
                    path_dropout=dpr[idx + block_id],
                    drop=0,
                    head_dim=2)#head_dim=32
                features1.append(layer)
            elif block_type is NTB:
                layer = NTB(
                    output_channel,
                    path_dropout=dpr[idx + block_id],
                    strides=stride,
                    sr_ratio=sr_ratios[stage_id],
                    head_dim=2,
                    mix_block_ratio=0.75,
                    attn_drop=0,
                    drop=0)#head_dim=32
                features1.append(layer)
        idx += numrepeat
    for layer in features1:
        x1 = layer(x1)

    stem1 = tf.keras.Sequential(
        [
            ConvBNReLU(stem_chs[2], kernel_size=3, strides=1),#2
            ConvBNReLU(stem_chs[1], kernel_size=3, strides=1),
            #ConvBNReLU(stem_chs[2], kernel_size=3, strides=1),
            ConvBNReLU(out_chans, kernel_size=3, strides=1)])
    x1 = stem1(x1)


    x1 = layers.BatchNormalization(epsilon=1e-5,name='decoder_output')(x1)


    return keras.Model(input_layer1, x1)


def autoencoder(input_shape1=(None, None, 3),latent_dim=1024,out_chans=1,patchsize=64):
    en_model = Encoder(
        input_shape1=input_shape1,
        stem_chs=CONFIG["SMALL"]["stem_chs"],
        depths=CONFIG["SMALL"]["depths"],
        path_dropout=CONFIG["SMALL"]["drop_path"],
        out_chans=latent_dim)
    de_model = Decoder(
        input_shape1=(None,latent_dim),
        stem_chs=CONFIG["SMALL"]["stem_chs"],
        depths=CONFIG["SMALL"]["depths"],
        path_dropout=CONFIG["SMALL"]["drop_path"],
        out_chans=out_chans,patchsize=patchsize)
    return en_model, de_model


def nextvit_small_2b(input_shape1=(None, None, 3),input_shape2=(None, None, 3),out_chans=103,patchsize=64):
    vau_e, vau_d = autoencoder(input_shape1=input_shape1,latent_dim=64,out_chans=1,patchsize=64)
    iau_e, iau_d = autoencoder(input_shape1=input_shape2,latent_dim=64,out_chans=1,patchsize=64)
    return vau_e, vau_d, iau_e, iau_d

class StochasticDepth(layers.Layer):
    """Stochastic Depth module.
    It is also referred to as Drop Path in `timm`.
    References:
        (1) github.com:rwightman/pytorch-image-models
    """

    def __init__(self, drop_path, **kwargs):
        super(StochasticDepth, self).__init__(**kwargs)
        self.drop_path = drop_path

    def call(self, x, training=None):
        if training:
            keep_prob = 1 - self.drop_path
            shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
            random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
            random_tensor = tf.floor(random_tensor)
            return (x / keep_prob) * random_tensor
        return x


# https://github.com/bytedance/Next-ViT/blob/main/classification/nextvit.py
def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class ConvBNReLU(tf.keras.Model):
    def __init__(self, filters, kernel_size, strides, groups=1, **kwargs):
        super(ConvBNReLU, self).__init__(**kwargs)
        self.conv = layers.Conv2D(
            filters,
            kernel_size=kernel_size,
            strides=strides,
            groups=groups,
            padding="SAME",
            use_bias=False,
        )
        self.norm = layers.BatchNormalization(epsilon=EPSILON)
        self.act = tf.keras.layers.ELU()

    def call(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x


class MHCA(tf.keras.Model):
    def __init__(self, filters, head_dim, **kwargs):
        super(MHCA, self).__init__(**kwargs)
        self.new = filters // head_dim
        self.group_conv3x3 = layers.Conv2D(
            filters=filters,
            kernel_size=3,
            strides=1,
            padding="same",
            groups=filters // head_dim,
            use_bias=False,
        )
        self.norm = layers.BatchNormalization(epsilon=EPSILON)
        self.act = layers.Activation("elu")
        self.projection = layers.Conv2D(filters, kernel_size=1, padding="same",use_bias=False)#padding 后加

    def call(self, x):
        x = self.group_conv3x3(x)
        x = self.norm(x)
        x = self.act(x)
        x = self.projection(x)
        return x


class mlp(tf.keras.Model):
    def __init__(self, filters, mlp_ratio=1, drop=0.0, **kwargs):
        super(mlp, self).__init__(**kwargs)
        hidden_dim = _make_divisible(filters * mlp_ratio, 32)
        self.conv1 = layers.Conv2D(hidden_dim, kernel_size=1, padding="SAME")#原padding="VALID"
        self.act = layers.Activation("elu")
        self.drop1 = layers.Dropout(drop)
        self.conv2 = layers.Conv2D(filters, kernel_size=1, padding="SAME")#原padding="VALID"
        self.drop2 = layers.Dropout(drop)

    def call(self, x):
        x = self.conv1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.conv2(x)
        x = self.drop2(x)
        return x


class PatchEmbed(tf.keras.Model):
    def __init__(self, filters, strides):
        super(PatchEmbed, self).__init__()
        self.filters = filters
        self.strides = strides

        '''if self.strides == 2:
            self.avg_pool = tf.keras.layers.AvgPool2D(
                pool_size=(2, 2), strides=2, padding="same"
            )
        else:
            self.avg_pool = None'''
        self.avg_pool = None

        if self.filters is not None:
            self.conv = tf.keras.layers.Conv2D(
                filters=self.filters, kernel_size=1, strides=1, use_bias=False,padding="same")#wu padding
            self.bn = tf.keras.layers.BatchNormalization(epsilon=EPSILON)

    def call(self, inputs):
        x = inputs
        #if self.avg_pool is not None:
        #    x = self.avg_pool(x)
        if self.filters is not None:
            if x.shape[-1] != self.filters:
                #x = tf.keras.layers.Lambda(lambda x: x)(x)
                x = x
                x = self.conv(x)
                x = self.bn(x)
            else:
                x = x
                x = x
                x = x
                #x = tf.keras.layers.Lambda(lambda x: x)(x)
                #x = tf.keras.layers.Lambda(lambda x: x)(x)
                #x = tf.keras.layers.Lambda(lambda x: x)(x)
        return x


class NCB(tf.keras.Model):
    def __init__(
        self,
        filters,
        strides=1,
        path_dropout=0,
        drop=0,
        head_dim=2,
        mlp_ratio=3,
        **kwargs):
        super(NCB, self).__init__(**kwargs)
        self.filters = filters
        self.strides = strides
        self.patch_embed = PatchEmbed(filters, strides)
        self.mhca = MHCA(filters, head_dim)
        self.attention_path_dropout = StochasticDepth(path_dropout)
        self.norm = layers.BatchNormalization(epsilon=EPSILON)
        self.mlp = mlp(filters, mlp_ratio=mlp_ratio, drop=drop)
        self.mlp_path_dropout = StochasticDepth(path_dropout)

    def call(self, x):
        x = self.patch_embed(x)
        x = x + self.attention_path_dropout(self.mhca(x))
        x = self.norm(x)
        x = x + self.mlp_path_dropout(self.mlp(x))
        return x


class NTB(tf.keras.Model):
    def __init__(
        self,
        filters,
        path_dropout=0,
        strides=1,
        sr_ratio=1,
        mlp_ratio=2,
        head_dim=2,
        mix_block_ratio=0.75,
        attn_drop=0,
        drop=0,
        **kwargs):
        super(NTB, self).__init__(**kwargs)
        self.filters = filters
        self.strides = strides
        self.mix_block_ratio = mix_block_ratio

        self.mhsa_out_channels = _make_divisible(
            int(filters * mix_block_ratio), 2)#,32

        self.mhca_out_channels = filters - self.mhsa_out_channels
        self.patch_embed = PatchEmbed(self.mhsa_out_channels, strides)
        self.norm1 = layers.BatchNormalization(epsilon=EPSILON)
        self.e_mhsa = E_MHSA(
            self.mhsa_out_channels,
            head_dim=head_dim,
            sr_ratio=sr_ratio,
            attn_drop=attn_drop,
            proj_drop=drop)
        self.mhsa_path_dropout = StochasticDepth(path_dropout * mix_block_ratio)
        self.projection = PatchEmbed(self.mhca_out_channels, strides=1)
        self.mhca = MHCA(self.mhca_out_channels, head_dim=head_dim)
        self.mhca_path_dropout = StochasticDepth(
            path_dropout * (1 - mix_block_ratio))
        self.norm2 = layers.BatchNormalization(epsilon=EPSILON)
        self.mlp = mlp(filters, mlp_ratio=mlp_ratio, drop=drop)
        self.mlp_path_dropout = StochasticDepth(path_dropout)

    def call(self, x):
        x = self.patch_embed(x)
        B, C, H, W = x.shape
        out = self.norm1(x)
        out = rearrange(out, "b h w c -> b (h w) c")
        out = self.mhsa_path_dropout(self.e_mhsa(out))
        x = x + rearrange(out, "b (h w) c -> b h w c", h=H)
        out = self.projection(x)
        out = out + self.mhca_path_dropout(self.mhca(out))
        x = tf.concat([x, out], axis=-1)
        out = self.norm2(x)
        x = x + self.mlp_path_dropout(self.mlp(out))
        return x

'''input = Input((256,256,24))
test = NTB(128,
           path_dropout=0.036363635,
           strides=4,
           sr_ratio=5,
           head_dim=4,
           mix_block_ratio=0.75,
           attn_drop=0,
           drop=0,
           )
x = test(input)
a=1'''



class GRID():
        # 读图像文件
        def read_img(self, filename):
            dataset = gdal.Open(filename)
            # 打开文件
            im_width = dataset.RasterXSize  # 栅格矩阵的列数
            im_height = dataset.RasterYSize  # 栅格矩阵的行数

            im_geotrans = dataset.GetGeoTransform()  # 仿射矩阵
            im_proj = dataset.GetProjection()  # 地图投影信息
            im_data = dataset.ReadAsArray(0, 0, im_width, im_height)  # 将数据写成数组，对应栅格矩阵
            im_data = np.array(im_data)
            sp = im_data.shape
            if im_data.ndim == 2:
                im_data2 = im_data[:, :, np.newaxis]
            else:
                im_data2 = np.zeros((im_height, im_width, sp[0]))
                for bands in range(0, sp[0]):
                    im_data2[:, :, bands] = im_data[bands, :, :]

            del dataset
            return im_height, im_width, im_data2

        def write_img(self, filename, im_proj, im_geotrans, im_data2):
            # gdal数据类型包括
            # gdal.GDT_Byte,
            # gdal .GDT_UInt16, gdal.GDT_Int16, gdal.GDT_UInt32, gdal.GDT_Int32,
            # gdal.GDT_Float32, gdal.GDT_Float64
            sp = im_data2.shape
            if len(sp) > 2:
                im_data = np.zeros((sp[2], sp[0], sp[1]))
                for bands in range(0, sp[2]):
                    im_data[bands, :, :] = im_data2[:, :, bands]
                print(im_data.shape)
            else:
                im_data = im_data2

            # 判断栅格数据的数据类型
            if 'int8' in im_data.dtype.name:
                datatype = gdal.GDT_Byte
            elif 'int16' in im_data.dtype.name:
                datatype = gdal.GDT_UInt16
            else:
                datatype = gdal.GDT_Float32

            # 判读数组维数
            if len(im_data.shape) == 3:
                im_bands, im_height, im_width = im_data.shape
            else:
                im_bands, (im_height, im_width) = 1, im_data.shape

            # 创建文件
            driver = gdal.GetDriverByName("ENVI")  # 数据类型必须有，因为要计算需要多大内存空间
            dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

            #dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
            #dataset.SetProjection(im_proj)  # 写入投影

            if im_bands == 1:
                dataset.GetRasterBand(1).WriteArray(im_data)  # 写入数组数据
            else:
                for i in range(im_bands):
                    dataset.GetRasterBand(i + 1).WriteArray(im_data[i])

            del dataset

        def restore4d_over(self, data4d, patchsize, sp, jg):

            spp = data4d.shape
            sm = math.floor((sp[1] - 2 * jg) / (patchsize - 2 * jg))
            rex = np.zeros((math.floor(spp[0] / sm) * (patchsize - 2 * jg), sm * (patchsize - 2 * jg), sp[2]))
            for ii in range(0, spp[0]):
                mm = math.floor(ii / sm)
                nn = ii % sm
                rex[mm * (patchsize - 2 * jg):(mm + 1) * (patchsize - 2 * jg),
                nn * (patchsize - 2 * jg):(nn + 1) * (patchsize - 2 * jg), :] = data4d[ii, jg:patchsize - jg,
                                                                                jg:patchsize - jg, :]
            return rex

        def prepare4d_over(self, img, patchsize, jg):
            sp = img.shape
            mnum = math.floor((sp[0] - 2 * jg) / (patchsize - 2 * jg))
            nnum = math.floor((sp[1] - 2 * jg) / (patchsize - 2 * jg))
            result = np.zeros((mnum * nnum, patchsize, patchsize, sp[2]))
            count = -1
            for i in range(0, mnum):
                for j in range(0, nnum):
                    count += 1
                    result[count, :, :, :] = img[(patchsize - 2 * jg) * i:(patchsize - 2 * jg) * i + patchsize,(patchsize - 2 * jg) * j: (patchsize - 2 * jg) * j + patchsize, :]
            return result, sp

class Scaler(layers.Layer):
    """特殊的scale层
    """
    def __init__(self, tau=0.5, mode='positive', **kwargs):
        super(Scaler, self).__init__(**kwargs)
        self.tau = tau

    def build(self, input_shape):
        super(Scaler, self).build(input_shape)
        self.scale = self.add_weight(
            name='scale', shape=(input_shape[-1],), initializer='zeros'
        )

    def call(self, inputs, mode='positive'):
        if mode == 'positive':
            scale = self.tau + (1 - self.tau) * K.sigmoid(self.scale)
        else:
            scale = (1 - self.tau) * K.sigmoid(-self.scale)
        return inputs * K.sqrt(scale)

    def get_config(self):
        config = {'tau': self.tau}
        base_config = super(Scaler, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

class ETmodel:
    def __init__(self,dirr,patchsize,jg,epochs,resname):
        self.patchsize = patchsize
        self.jg = jg
        self.grid = GRID()
        lr_train, vs_train = self.read_data(dirr,patchsize,jg)
        print('data read!')
        model = self.buildmodel(input_shape1=(lr_train.shape[1], lr_train.shape[2], lr_train.shape[3]),input_shape2=(vs_train.shape[1], vs_train.shape[2],vs_train.shape[3]),out_chans=vs_train.shape[-1],patchsize=64)
        print('model built!')
        self.train_model(model, lr_train, vs_train, epochs,'vs_model')
        nmodel = keras.Model(model.input,model.get_layer(name='output_layer').output)
        #plot_model(model, to_file='modell.png',show_shapes=True)
        self.pred_model(nmodel,dirr,patchsize,jg)

    def buildmodel(self,input_shape1=(None,None,1),input_shape2=(None,None,3),out_chans=1,patchsize=64):
        #vau_e, vau_d = self.autoencoder(input_shape1=input_shape1,latent_dim=1024,out_chans=1,patchsize=64)
        #iau_e, iau_d = self.autoencoder(input_shape1=input_shape2,latent_dim=1024,out_chans=1,patchsize=64)
        #return iau_e, iau_d, vau_e, vau_d
        input1 = layers.Input(shape=input_shape1,name='i_input')
        input2 = layers.Input(shape=input_shape2,name='v_input')
        x = layers.Conv2D(filters=64, kernel_size=1, data_format='channels_last', padding='same')(input1)
        x = layers.BatchNormalization(axis=1)(x)
        x = layers.LeakyReLU(alpha=0.3)(x)

        for i in range(0, 3):
            x = self._convolutional_block(x, [64, 32, 16], 'p', str(i))

        x = layers.Conv2D(filters=16, kernel_size=1, data_format='channels_last', padding='same')(x)
        x = layers.BatchNormalization(axis=-1)(x)
        x = layers.LeakyReLU(alpha=0.3)(x)

        y = layers.Conv2D(filters=64, kernel_size=1, data_format='channels_last', padding='same')(input2)
        y = layers.BatchNormalization(axis=1)(y)
        y = layers.LeakyReLU(alpha=0.3)(y)

        for i in range(0, 3):
            y = self._convolutional_block(y, [64, 32, 16], 'p', str(i))

        y = layers.Conv2D(filters=16, kernel_size=1, data_format='channels_last', padding='same')(y)
        y = layers.BatchNormalization(axis=-1)(y)
        y = layers.LeakyReLU(alpha=0.3)(y)

        z = layers.Concatenate(axis=-1)([x,y])
        z = layers.Conv2D(filters=64, kernel_size=1, data_format='channels_last', padding='same')(z)
        z = layers.BatchNormalization(axis=1)(z)
        z = layers.LeakyReLU(alpha=0.3)(z)

        for i in range(0, 3):
            z = self._convolutional_block(z, [64, 32, 16], 'p', str(i))

        z = layers.Conv2D(filters=out_chans, kernel_size=1, data_format='channels_last', padding='same')(z)
        z = layers.BatchNormalization(axis=-1)(z)
        z = layers.LeakyReLU(alpha=0.3, name='output_layer')(z)

        res1 = Lossi()([input1,z])
        res2 = Losst()([input2,z])
        res = res1 + res2

        return keras.Model([input1,input2],res)


    def _convolutional_block(self, X, filters, stage, block):
        conv_name_base = 'res' + stage + block + '_branch'
        bn_name_base = 'bn' + stage + block + '_branch'
        lrelu_name_base = 'lrelu' + stage + block + '_branch'
        # Retrieve Filters
        F1, F2, F3 = filters
        # Save the input value
        X_shortcut = X
        # ##### MAIN PATH #####
        # First component of main path
        X = layers.Conv2D(filters=F1, kernel_size=3, strides=1, padding='same',
                   data_format='channels_last',
                   kernel_initializer='glorot_uniform')(X)
        X = layers.BatchNormalization(axis=-1)(X)
        X = layers.LeakyReLU(alpha=0.3)(X)
        ### START CODE HERE ##
        # Second component of main path (≈3 lines)
        X = layers.Conv2D(F2, 3, strides=1, padding='same', data_format='channels_last',
                   kernel_initializer='glorot_uniform')(X)
        X = layers.BatchNormalization(axis=-1, )(X)
        X = layers.LeakyReLU(alpha=0.3)(X)
        # Third component of main path (≈2 lines)
        X = layers.Conv2D(F1, 3, strides=1, padding='same', data_format='channels_last',
                   kernel_initializer='glorot_uniform')(X)
        X = layers.BatchNormalization(axis=-1)(X)
        X = layers.LeakyReLU(alpha=0.3)(X)
        # Final step: Add shortcut value to main path, and pass it through a RELU activation (≈2 lines)
        X = layers.Add()([X, X_shortcut])
        X = layers.Activation('relu')(X)
        return X

    def _convolutional_block_reverse(self, X, filters, stage, block):
        conv_name_base = 'res' + stage + block + '_branch'
        bn_name_base = 'bn' + stage + block + '_branch'
        lrelu_name_base = 'lrelu' + stage + block + '_branch'
        # Retrieve Filters
        F1, F2, F3 = filters
        # Save the input value
        X_shortcut = X
        # ##### MAIN PATH #####
        # First component of main path
        X = layers.Conv2DTranspose(filters=F1, kernel_size=3, strides=1, padding='same',
                   data_format='channels_last',
                   kernel_initializer='glorot_uniform')(X)
        X = layers.BatchNormalization(axis=-1)(X)
        X = layers.LeakyReLU(alpha=0.3)(X)
        ### START CODE HERE ##
        # Second component of main path (≈3 lines)
        X = layers.Conv2DTranspose(F2, 3, strides=1,padding='same', data_format='channels_last',
                   kernel_initializer='glorot_uniform')(X)
        X = layers.BatchNormalization(axis=-1)(X)
        X = layers.LeakyReLU(alpha=0.3)(X)
        # Third component of main path (≈2 lines)
        X = layers.Conv2DTranspose(F1, 3, strides=1, padding='same', data_format='channels_last',
                   kernel_initializer='glorot_uniform')(X)
        X = layers.BatchNormalization(axis=-1)(X)
        X = layers.LeakyReLU(alpha=0.3)(X)
        # Final step: Add shortcut value to main path, and pass it through a RELU activation (≈2 lines)
        X = layers.Add()([X, X_shortcut])
        X = layers.Activation('relu')(X)
        return X


    def gaussian_down_sample(self,data,w,mask=0):
      # masking mode
      if np.isscalar(mask):
        masking = 0
      else:
        masking = 1

      xdata = data.shape[0]
      ydata = data.shape[1]
      band = data.shape[2]
      hx = int(np.floor(xdata/w))
      hy = int(np.floor(ydata/w))
      HSI = np.zeros((hx, hy, band))
      sig = w/2.35482

      if masking == 0: # without mask
        if np.mod(w,2)==0:
            H1 = self.gaussian_filter2d((w,w),sig).reshape(w,w,1)
            H2 = self.gaussian_filter2d((w*2,w*2),sig).reshape(w*2,w*2,1)
            for x in range(hx):
                for y in range(hy):
                    if x==0 or x==hx-1 or y==0 or y==hy-1:
                        HSI[x,y,:] = (np.double( data[x*w:(x+1)*w,y*w:(y+1)*w,:] ) * np.tile(H1,(1,1,band))).sum(axis=0).sum(axis=0).reshape(1,1,band)
                    else:
                        HSI[x,y,:] = (np.double( data[x*w-int(w/2):(x+1)*w+int(w/2),y*w-int(w/2):(y+1)*w+int(w/2),:] ) * np.tile(H2,(1,1,band))).sum(axis=0).sum(axis=0).reshape(1,1,band)
        else:
            H1 = self.gaussian_filter2d((w,w),sig).reshape(w,w,1)
            H2 = self.gaussian_filter2d((w*2-1,w*2-1),sig).reshape(w*2-1,w*2-1,1)
            for x in range(hx):
                for y in range(hy):
                    if x==0 or x==hx-1 or y==0 or y==hy-1:
                        HSI[x,y,:] = (np.double( data[x*w:(x+1)*w,y*w:(y+1)*w,:] ) * np.tile(H1,(1,1,band)) ).sum(axis=0).sum(axis=0).reshape(1,1,band)
                    else:
                        HSI[x,y,:] = (np.double( data[x*w-int((w-1)/2):(x+1)*w+int((w-1)/2),y*w-int((w-1)/2):(y+1)*w+int((w-1)/2),:] ) * np.tile(H2,(1,1,band))).sum(axis=0).sum(axis=0).reshape(1,1,band)
      else: # with mask
        if np.mod(w,2)==0:
            H1 = self.gaussian_filter2d((w,w),sig).reshape(w,w,1)
            H2 = self.gaussian_filter2d((w*2,w*2),sig).reshape(w*2,w*2,1)
            for x in range(hx):
                for y in range(hy):
                    mask_tmp = mask[x*w:(x+1)*w,y*w:(y+1)*w]
                    if mask_tmp.sum() == w**2:
                        if x==0 or x==hx-1 or y==0 or y==hy-1:
                            HSI[x,y,:] = (np.double( data[x*w:(x+1)*w,y*w:(y+1)*w,:] ) * np.tile(H1,(1,1,band))).sum(axis=0).sum(axis=0).reshape(1,1,band)
                        else:
                            HSI[x,y,:] = (np.double( data[x*w-int(w/2):(x+1)*w+int(w/2),y*w-int(w/2):(y+1)*w+int(w/2),:] ) * np.tile(H2,(1,1,band))).sum(axis=0).sum(axis=0).reshape(1,1,band)
        else:
            H1 = self.gaussian_filter2d((w,w),sig).reshape(w,w,1)
            H2 = self.gaussian_filter2d((w*2-1,w*2-1),sig).reshape(w*2-1,w*2-1,1)
            for x in range(hx):
                for y in range(hy):
                    mask_tmp = mask[x*w:(x+1)*w,y*w:(y+1)*w]
                    if mask_tmp.sum() == w**2:
                        if x==0 or x==hx-1 or y==0 or y==hy-1:
                            HSI[x,y,:] = (np.double( data[x*w:(x+1)*w,y*w:(y+1)*w,:] ) * np.tile(H1,(1,1,band)) ).sum(axis=0).sum(axis=0).reshape(1,1,band)
                        else:
                            HSI[x,y,:] = (np.double( data[x*w-int((w-1)/2):(x+1)*w+int((w-1)/2),y*w-int((w-1)/2):(y+1)*w+int((w-1)/2),:] ) * np.tile(H2,(1,1,band))).sum(axis=0).sum(axis=0).reshape(1,1,band)

      return HSI

    def gaussian_filter2d(self,shape=(3,3),sigma=1):
      m,n = [(ss-1.)/2. for ss in shape]
      y,x = np.ogrid[-m:m+1,-n:n+1]
      h = np.exp( -(x**2 + y**2) / (2.*sigma**2) )
      h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
      sumh = h.sum()
      if sumh != 0:
        h /= sumh
      return h

    def read_data(self,dirr,patchsize,jg):
        lrdata = []
        vsdata = []
        zdlr = {}
        zdvs = {}
        lr_train = []
        vs_train = []
        lr_pred = []
        lr_tar = []
        dirl = dirr + r'/lr'
        dirv = dirr + r'/vis'
        listl = os.listdir(dirl)
        listv = os.listdir(dirv)
        nums = len(dirl)
        indd = np.arange(nums)[0:60]
        random.shuffle(indd)
        listl = [listl[i] for i in indd]
        listv = [listv[i] for i in indd]

        for ii in range(0,len(listl)):
          lfile = listl[ii]
          lfile = dirl + r'/' + lfile
          imgr = cv2.imread(lfile,0)
          #gauss = np.random.normal(0,25,(imgr.shape[0],imgr.shape[1]))
          #imgr = imgr + gauss

          vfile = listv[ii]
          vfile = dirv + r'/' + listl[ii]
          imgv = cv2.imread(vfile,cv2.IMREAD_COLOR)

          #vs_train
          imgv1 = cv2.resize(imgv,(int(imgv.shape[0]/3*2),int(imgv.shape[1]/3*2)))

          if imgr.shape != imgv1.shape:
            sp = [min(imgr.shape[0],imgv1.shape[0]),min(imgr.shape[1],imgv1.shape[1])]
            imgr = imgr[:sp[0],:sp[1]]
            imgv1 = imgv1[:sp[0],:sp[1],:]

          imgr = imgr[:,:,np.newaxis]

          if lr_train == []:
            lr_train,sp = self.grid.prepare4d_over(imgr,patchsize,jg)
          else:
            temp,sp = self.grid.prepare4d_over(imgr,patchsize,jg)
            lr_train = np.concatenate((lr_train,temp),axis=0)

          if vs_train == []:
            vs_train,sp = self.grid.prepare4d_over(imgv1,patchsize,jg)
          else:
            temp,sp = self.grid.prepare4d_over(imgv1,patchsize,jg)
            vs_train = np.concatenate((vs_train,temp),axis=0)

        print(lr_train.shape)
        print(vs_train.shape)

        indd = np.arange(0,lr_train.shape[0])
        np.random.shuffle(indd)
        lr_train = lr_train[indd,:,:,:]
        vs_train = vs_train[indd,:,:,:]
        return lr_train, vs_train

    def compute_loss(self, model1, model2, x, epoch):
      x = tf.cast(x,dtype=tf.float32)
      mean, logvar = model1(x)
      z = self.sampling(mean, logvar)
      x_logit = model2(z)

      #(1-1/epoch)*
      cross_ent = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x), axis=[1, 2, 3])
      KLD = -0.5 * tf.reduce_sum(1 + logvar- tf.pow(mean, 2) - tf.exp(logvar), axis=-1)
      return tf.reduce_mean(cross_ent + KLD)

    def compute_apply_gradients(self, model1, model2, model, x, optimizer):
      with tf.GradientTape() as tape:
        loss = self.compute_loss(model1, model2, x)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    def sampling(self, z_mean, z_log_var):
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

    def train_model(self,model,x1,x2,epochs,modelname):
        print(x1.shape)
        print(x2.shape)
        minn = min(x1.shape[0],x2.shape[0])
        optim = tf.keras.optimizers.Adam(lr=0.001)
        model.compile(loss='mse', optimizer=optim, metrics=['accuracy'])
        if os.path.isfile('pres_model.h5'):
            model.load_weights('pres_model.h5')
            self.status = 1
            # return
        else:
            callbacks1 = EarlyStopping(monitor='loss', min_delta=0, patience=15, verbose=0, mode='auto',
                                       baseline=None, restore_best_weights=True)
            reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.2,
                                          patience=7, min_lr=0.00001)
            # history = self.p_model.fit(x=self.hd5mu[:,np.linspace(0,8,5,dtype='int8'),:,:,:], y=self.hd5l, batch_size=32, callbacks = [callbacks1,reduce_lr],epochs=self.epochs, validation_split=0.1)
            history = model.fit(x=[x1[:minn,:,:,:]/100,x2[:minn,:,:,:]/100], y=np.zeros((minn,x2.shape[1],x2.shape[2],x2.shape[3])),
                                batch_size=8,
                                       callbacks=[callbacks1, reduce_lr],
                                       epochs=epochs, validation_split=0.1)
            #history = self.model.fit_generator(self.datagen.flow([x1[:minn,:,:,:]/1000,x2[:minn,:,:,:]/1000],
            #                                           y=y[:minn,:,:,:]/1000),
            #                                           epochs=epochs,
            #                                           steps_per_epoch=minn/4)
            model.save_weights('pres_model.h5')
            loss_history = history.history["loss"]
            val_loss_history = history.history['val_loss']
            numpy_loss_history = np.array(loss_history)
            numpy_val_loss = np.array(val_loss_history)
            np.savetxt('pres_loss.txt', numpy_loss_history,delimiter=",")
            np.savetxt('pres_val_loss.txt', numpy_val_loss,delimiter=",")




    def pred_model(self,model,dirr,patchsize,jg):
        dirl = dirr + r'/lr'
        dirv = dirr + r'/vis'
        dirres = dirr + r'/res/'
        listl = os.listdir(dirl)
        listv = os.listdir(dirv)
        for ii in range(0,len(listl)):
          lfile = listl[ii]
          lfile = dirl + r'/' + lfile
          imgr = cv2.imread(lfile,0)

          vfile = listv[ii]
          vfile = dirv + r'/' + listl[ii]
          imgv = cv2.imread(vfile,cv2.IMREAD_COLOR)

          imgr1 = cv2.resize(imgr,(int(imgr.shape[0]/2*3),int(imgr.shape[1]/2*3)))
          imgr1 = imgr1[:,:,np.newaxis]

          if imgr1.shape != imgv.shape:
            sp = [min(imgr1.shape[0],imgv.shape[0]),min(imgr1.shape[1],imgv.shape[1])]
            imgr1 = imgr1[:sp[0],:sp[1],:]
            imgv = imgv[:sp[0],:sp[1],:]

          lr_tar,sp = self.grid.prepare4d_over(imgr1,patchsize,7)
          vsdata,sp = self.grid.prepare4d_over(imgv,patchsize,7)

          res4d = model.predict([lr_tar/100,vsdata/100], batch_size=8)*100
          res = self.grid.restore4d_over(res4d,patchsize,sp,7)
          cv2.imwrite(dirres + listl[ii][:-4]+'_fus.png', np.array(np.squeeze(res),dtype=int))


class Lossi(layers.Layer):
    def __init__(self):
        super(Lossi, self).__init__()

    # 通过回调函数计算
    def call(self, inputs):
        ii, f = inputs
        fi = tf.reduce_mean(f,axis=-1,keepdims=True)
        return tf.abs(ii - fi)/ii.shape[1]/ii.shape[2]

class Losst(layers.Layer):
    def __init__(self):
        super(Losst, self).__init__()

    def call(self, inputs):
        v, f = inputs
        dxf, dyf = tf.image.image_gradients(f)
        dxv, dyv = tf.image.image_gradients(tf.reduce_mean(v,axis=-1)[...,tf.newaxis])
        dxx = tf.abs(dxv-tf.abs(dxf))
        dyy = tf.abs(dyv-tf.abs(dyf))
        return dxx/dxx.shape[1]/dxx.shape[2]+dyy/dyy.shape[1]/dyy.shape[2]


if __name__ == '__main__':
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    dirr = r'/content/drive/MyDrive/new'
    os.chdir(dirr)
    print(os.getcwd())
    print(os.listdir())
    patchsize = 64
    jg = 0
    epochs = 1000

    starttime = datetime.datetime.now()
    ET = ETmodel(dirr,patchsize,jg,epochs,'')
    endtime = datetime.datetime.now()
    print('total time is ' + str(endtime - starttime) + 's')
    np.savetxt('time_cst.txt', np.array([(endtime - starttime).seconds]), delimiter=",")

/content/drive/.shortcut-targets-by-id/1NGc_AhpYbtXwsERLnGHUwsUHrAPuyqR5/new
['vis2pan', 'lr', 'vis', 'res']


  if lr_train == []:
  if vs_train == []:


(2320, 64, 64, 1)
(2320, 64, 64, 3)
data read!




model built!
(2320, 64, 64, 1)
(2320, 64, 64, 3)
Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69