In [None]:
import os
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch import optim
from utils import *
from nn_for_fwi_80_8_attention_3 import UNet_conditional,EMA  #depth=80 channel=8

import logging
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import math
from torch.utils import data
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda:0")

logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def gaussian_kernel(size, sigma, num_channels, device):
    coords = torch.arange(size, dtype=torch.float32, device=device)
    coords -= (size - 1) / 2.0
    g = -(coords ** 2) / (2 * sigma ** 2)
    g = torch.exp(g)
    g /= g.sum()

    g_2d = g.unsqueeze(0) * g.unsqueeze(1)
    g_2d /= g_2d.sum() 

    gaussian_filter = g_2d.view(1, 1, size, size).repeat(num_channels, 1, 1, 1)
    return gaussian_filter

def _ssim(img1, img2, window, data_range, C1, C2):
    channels = img1.shape[1]
    window = window.to(img1.device)

    mu1 = F.conv2d(img1, window, padding='valid', groups=channels)
    mu2 = F.conv2d(img2, window, padding='valid', groups=channels)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding='valid', groups=channels) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding='valid', groups=channels) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding='valid', groups=channels) - mu1_mu2

    luminance_map = (2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1).clamp(min=1e-8)
    contrast_structure_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2).clamp(min=1e-8)
    
    return luminance_map, contrast_structure_map

def msssim(img1, img2, window_size=11, data_range=1.0, size_average=True,
           scales=5, weights=None, sigma=1.5):
    if not img1.shape == img2.shape:
        raise ValueError("Input images must have the same dimensions.")
    if img1.ndim != 4:
        raise ValueError("Input images must be 4D tensors (N, C, H, W).")

    num_channels = img1.shape[1]
    device = img1.device

    if weights is None:
        # 默认权重，根据 scales 数量调整
        if scales == 5:
            weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
        elif scales == 3:
            weights = [0.5, 0.3, 0.2] # 示例权重，总和为 1
        elif scales == 2: # 为 scales=2 提供一个示例权重
            weights = [0.7, 0.3] # 示例：给更精细的尺度更大的权重
        elif scales == 1: # 单尺度，权重自然是 1.0
            weights = [1.0]
        else:
            weights = [1.0 / scales] * scales
            print(f"Warning: Using uniform weights for {scales} scales as no specific weights are defined.")

    if len(weights) != scales:
        raise ValueError(f"Number of weights ({len(weights)}) must match number of scales ({scales}).")
    weights = torch.tensor(weights, dtype=torch.float32, device=device).view(1, scales)

    C1 = (0.01 * data_range) ** 2
    C2 = (0.03 * data_range) ** 2

    pad = (window_size - 1) // 2
    window = gaussian_kernel(window_size, sigma, num_channels, device)

    # 重要的尺寸检查：确保在所有尺度下图像维度都足够大
    current_h = img1.shape[2]
    current_w = img1.shape[3]
    
    for i in range(scales):
        # reflect padding 要求维度 > pad
        if current_h <= pad or current_w <= pad:
            raise ValueError(
                f"Image dimension ({current_h}x{current_w}) is too small for "
                f"reflect padding ({pad}) with window_size={window_size} at scale {i+1}/{scales}. "
                f"Consider reducing 'scales' or 'window_size', or increasing initial image size."
            )
        current_h //= 2
        current_w //= 2
            
    ssim_components_per_scale = []
    
    for i in range(scales):
        padded_img1 = F.pad(img1, (pad, pad, pad, pad), mode='reflect')
        padded_img2 = F.pad(img2, (pad, pad, pad, pad), mode='reflect')

        L_map, CS_map = _ssim(padded_img1, padded_img2, window, data_range, C1, C2)
        
        if i < scales - 1: # 非最粗尺度：使用 CS_map
            current_scale_component = CS_map.mean(dim=(-1, -2)).mean(dim=-1)
            ssim_components_per_scale.append(current_scale_component)
            
            img1 = F.avg_pool2d(img1, kernel_size=2, stride=2)
            img2 = F.avg_pool2d(img2, kernel_size=2, stride=2)
        else: # 最粗尺度：使用 L_map
            current_scale_component = L_map.mean(dim=(-1, -2)).mean(dim=-1)
            ssim_components_per_scale.append(current_scale_component)

    ms_ssim_val = ssim_components_per_scale[-1].pow(weights[0, -1])

    for i in range(scales - 1):
        ms_ssim_val = ms_ssim_val * ssim_components_per_scale[i].pow(weights[0, i])
        
    if size_average:
        return ms_ssim_val.mean()
    else:
        return ms_ssim_val

class MSSSIMLoss(nn.Module):
    def __init__(self, window_size=11, data_range=1.0, size_average=True,
                 scales=5, weights=None, sigma=1.5):
        super(MSSSIMLoss, self).__init__()
        self.window_size = window_size
        self.data_range = data_range
        self.size_average = size_average
        self.scales = scales
        self.weights = weights
        self.sigma = sigma

    def forward(self, img1, img2):
        ms_ssim_val = msssim(img1, img2,
                             window_size=self.window_size,
                             data_range=self.data_range,
                             size_average=self.size_average,
                             scales=self.scales,
                             weights=self.weights,
                             sigma=self.sigma)
        return 1.0 - ms_ssim_val

In [None]:
def chang_np_array_size(raw_data,resize_x=224,resize_z=224):
    resize = transforms.Resize((resize_x, resize_z))

    raw_data_img=Image.fromarray(raw_data)

    data_resize=resize(raw_data_img)
    data_resize=np.array(data_resize)
    return data_resize
def get_positional_encoding(seq_len, d_model):
    positional_encoding = np.zeros((seq_len, d_model))
    for pos in range(seq_len):
            for i in range(0, d_model, 2):
                        positional_encoding[pos, i] = np.sin(pos / (10000 ** (i / d_model)))
                        positional_encoding[pos, i + 1] = np.cos(pos / (10000 ** ((i + 1) / d_model)))
    return positional_encoding
def make_vp_net(data_raw,n_channel=16):
    xn,zn=data_raw.shape
    data_net=np.zeros([xn,n_channel,zn])
    consecutive_trace=n_channel
    for i in range(xn):
        try:
            if i < int(consecutive_trace/2):
        #         for j in range(i):
        #             vp_initial[i,j,:]=vp[0,:]
                data_net[i,:i,:]=data_raw[:i,:]
                data_net[i,i:consecutive_trace-i,:]=data_raw[i:consecutive_trace-i,:]

            elif i <= xn -int(consecutive_trace/2) :
                data_net[i,:,:]=data_raw[i-int(consecutive_trace/2):i+int(consecutive_trace/2),:]
            else:
                data_net[i,:,:]=data_raw[xn-consecutive_trace:,:]
                data_net[i,consecutive_trace-(xn-i):,:]=data_raw[i:,:]
        except:
            print(i)
    return data_net


import numpy as np
from scipy.interpolate import interp1d

def generate_migrated_data(resize_model, factor=4, fc=0.1):
    """
    Generate migrated seismic data from impedance model using convolution model.
    
    Parameters:
        resize_model (ndarray): 2D impedance model [traces, time_samples]
        factor (int): Upsampling factor (default=4)
        fc (float): Normalized wavelet frequency [cycles/sample] (default=0.1)
    
    Returns:
        migrated (ndarray): Synthetic seismic section [traces, time_samples]
    """
    # --- 1. Upsample impedance model ---
    orig_axis = np.arange(resize_model.shape[1])
    new_len = (resize_model.shape[1] - 1) * factor + 1
    new_axis = np.linspace(0, resize_model.shape[1]-1, new_len)
    
    model_up = np.zeros((resize_model.shape[0], new_len))
    for i in range(resize_model.shape[0]):
        f = interp1d(orig_axis, resize_model[i], kind='linear', 
                     bounds_error=False, fill_value="extrapolate")
        model_up[i] = f(new_axis)

    # --- 2. Calculate reflectivities ---
    num = model_up[:, 1:] - model_up[:, :-1]
    den = model_up[:, 1:] + model_up[:, :-1]
    eps = 1e-9
    den = np.where(np.abs(den) < eps, eps, den)
    refl = num / den  # [traces, new_len-1]

    # --- 3. Build reflectivity series ---
    refl_series = np.zeros_like(model_up)
    refl_series[:, 1:] = refl

    # --- 4. Generate Ricker wavelet ---
    # Original domain wavelet (51 samples)
    wv_len_orig = 51
    center = (wv_len_orig-1)//2
    t_orig = np.arange(wv_len_orig) - center
    wavelet_orig = (1 - 2*(np.pi*fc*t_orig)**2) * np.exp(-(np.pi*fc*t_orig)**2)
    
    # Upsampled wavelet
    wv_up_axis = np.linspace(0, wv_len_orig-1, (wv_len_orig-1)*factor + 1)
    f_wavelet = interp1d(np.arange(wv_len_orig), wavelet_orig, 
                         kind='linear', bounds_error=False, fill_value=0)
    wavelet_up = f_wavelet(wv_up_axis)
    wavelet_up /= np.max(np.abs(wavelet_up))  # Normalize

    # --- 5. Convolve with wavelet ---
    synthetic = np.zeros_like(model_up)
    for i in range(model_up.shape[0]):
        synthetic[i] = np.convolve(refl_series[i], wavelet_up, mode='same')

    # --- 6. Downsample to original resolution ---
    migrated = np.zeros_like(resize_model)
    for i in range(synthetic.shape[0]):
        f = interp1d(new_axis, synthetic[i], kind='linear', 
                     bounds_error=False, fill_value="extrapolate")
        migrated[i] = f(orig_axis)
    
    return migrated
import torch
import torch.nn.functional as F


def multi_migrated_data_cuda(impedance_model_3d, factor=4, fc=0.1, device='cuda'):
    image_3d_cuda = torch.zeros_like(impedance_model_3d, device=device)
    for i in range(impedance_model_3d.shape[0]):
        image_3d_cuda[i]=generate_migrated_data_cuda(impedance_model_3d[i], factor=factor, fc=fc)
    return image_3d_cuda
        


def generate_migrated_data_cuda(impedance_model, factor=4, fc=0.1, device='cuda'):

    
    """
    Generate migrated seismic data using a convolution model and a Ricker wavelet,
    with upsampling via `torch.nn.functional.interpolate`.
    
    Parameters:
        impedance_model (ndarray or tensor): 2D impedance model [traces, time_samples]
        factor (int): Upsampling factor (default=4)
        fc (float): Normalized central frequency of the Ricker wavelet [cycles/sample] (default=0.1)
        device (str): 'cuda' for GPU or 'cpu' for CPU (default='cuda')
    
    Returns:
        migrated (tensor): Synthetic seismic section [traces, time_samples]
    """
    
    # Convert input impedance model to torch tensor
    impedance_model = torch.tensor(impedance_model, dtype=torch.float32, device=device)

    # --- 1. Upsample impedance model ---
    orig_axis = torch.arange(impedance_model.shape[1], device=device).float().unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, original_length)
    new_len = (impedance_model.shape[1] - 1) * factor + 1
    new_axis = torch.linspace(0, impedance_model.shape[1] - 1, new_len, device=device).unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, new_length)

    # Reshape impedance model to (batch_size, channels, length) for interpolation
    impedance_model_reshaped = impedance_model.unsqueeze(1)  # Shape: (traces, 1, time_samples)

    # Interpolate (upsample) the impedance model
    model_up = F.interpolate(impedance_model_reshaped, size=new_len, mode='linear', align_corners=True)

    # --- 2. Calculate reflectivities ---
    num = model_up[:, :, 1:] - model_up[:, :, :-1]
    den = model_up[:, :, 1:] + model_up[:, :, :-1]
    eps = 1e-9
    den = torch.where(torch.abs(den) < eps, eps, den)  # Avoid division by zero
    refl = num / den  # [traces, new_len-1]

    # --- 3. Build reflectivity series ---
    refl_series = torch.zeros_like(model_up)
    refl_series[:, :, 1:] = refl

    # --- 4. Generate Ricker wavelet ---
    wv_len_orig = 51
    center = (wv_len_orig - 1) // 2
    t_orig = torch.arange(wv_len_orig, device=device) - center
    wavelet_orig = (1 - 2 * (torch.pi * fc * t_orig) ** 2) * torch.exp(-(torch.pi * fc * t_orig) ** 2)

    # Upsampled wavelet using interpolation
    # Ensure wavelet is a 3D tensor: (1, 1, length)
    wavelet_orig = wavelet_orig.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, length)
#     print(wavelet_orig.shape,'wavelet_orig')
    wv_up_axis = torch.linspace(0, wv_len_orig - 1, (wv_len_orig - 1) * factor + 1, device=device).unsqueeze(0).unsqueeze(0)
    wavelet_up = F.interpolate(wavelet_orig, size=wv_up_axis.shape[-1], mode='linear', align_corners=True).squeeze(0).squeeze(0)

    wavelet_up = wavelet_up / torch.max(torch.abs(wavelet_up))  # Normalize wavelet

    # --- 5. Convolve with wavelet ---
    synthetic = torch.zeros_like(model_up, device=device)
    
    # Make sure the wavelet is in the correct shape for conv1d
    wavelet_up = wavelet_up.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, kernel_size)

    for i in range(model_up.shape[0]):
        synthetic[i] = F.conv1d(refl_series[i].unsqueeze(0), wavelet_up, padding='same').squeeze(0)

    # --- 6. Downsample to original resolution ---
    migrated = torch.zeros_like(impedance_model, device=device)
    for i in range(synthetic.shape[0]):
        # Ensure the correct shape for interpolation (3D tensor: (1, 1, length))
        synthetic_i = synthetic[i].unsqueeze(0)  # Shape: (1, 1, length)
        # Now we can use F.interpolate
#         print(synthetic_i.shape,'synthetic_i')

        migrated[i] = F.interpolate(synthetic_i, size=impedance_model.shape[1], mode='linear', align_corners=True).squeeze(0).squeeze(0)

    return migrated


# fc is cycle per sample  e.g. 0.1 means the full cycle of the wavelet need 10 grid, and that means the wavenumber of  the wavelet is 10 grid *20m grid size=200m, so the freq =velocity 2000m/s/ wavenumber 200m  = 10 hz
# true freq = velocity/(grid_size/fc)

In [None]:

class RFlow:
    def __init__(self, step=20,img_size=160, device=device):
        self.step = step

        self.img_size = img_size
        self.device = device


    def euler(self, x_t, v, dt):
        """ 使用欧拉方法计算下一个时间步长的值
            
        Args:
            x_t: 当前的值，维度为 [B, C, H, W]
            v: 当前的速度，维度为 [B, C, H, W]
            dt: 时间步长
        """
        x_t = x_t + v * dt

        return x_t

    # 路线
    # v1.2: reflow增加x_0的输入
    def create_flow(self, x_1, t, x_0=None):
        """ 使用x_t = t * x_1 + (1 - t) * x_0公式构建x_0到x_1的流

            X_1是原始图像 X_0是噪声图像（服从标准高斯分布）
            
        Args:
            x_1: 原始图像，维度为 [B, C, H, W]
            t: 一个标量，表示时间，时间范围为 [0, 1]，维度为 [B]
            x_0: 噪声图像，维度为 [B, C, H, W]，默认值为None
            
        Returns:
            x_t: 在时间t的图像，维度为 [B, C, H, W]
            x_0: 噪声图像，维度为 [B, C, H, W]
        
        """

        # 需要一个x0，x0服从高斯噪声
        if x_0 is None:
            x_0 = torch.randn_like(x_1).to(self.device)

        t = t[:, None, None, None].to(self.device)  # [B, 1, 1, 1]

        # 获得xt的值
        x_t = t * x_1 + (1 - t) * x_0

        return x_t, x_0

    # 司机
    def mse_loss(self, v, x_1, x_0):
        """ 计算RectifiedFlow的损失函数
        L = MSE(x_1 - x_0 - v(t))  匀速直线运动

        Args:
            v: 速度，维度为 [B, C, H, W]
            x_1: 原始图像，维度为 [B, C, H, W]
            x_0: 噪声图像，维度为 [B, C, H, W]
        """

        # 求loss函数，是一个MSE，最后维度是[B]

        loss = F.mse_loss(x_1 - x_0, v)
        # loss = torch.mean((x_1 - x_0 - v)**2)

        return loss
    
    
    
    
    def sample_for_gen(self,model=None,n=5,
        y=None,z=None,w=None,n_channel=None,
        cfg_scale=7.0,
        save_path='./results',
        save_noise_path=None,
        device='cuda'):
        
        """flow matching模型推理

        Args:
            checkpoint_path (str): 模型路径
            base_channels (int, optional): MiniUnet的基础通道数，默认值为16。
            step (int, optional): 采样步数（Euler方法的迭代次数），默认值为50。
            num_imgs (int, optional): 推理一次生成图片数量，默认值为5。
            y (torch.Tensor, optional): 条件生成中的条件，可以为数据标签（每一个标签是一个类别int型）或text文本（下一版本支持）,维度为[B]或[B, L]，其中B要么与num_imgs相等，要么为1（所有图像依照同一个条件生成）。 
            cfg_scale (float, optional): Classifier-free Guidance的缩放因子，默认值为7.0，y如果是None，无论这个值是几都是无条件生成。这个值越大，多样性下降，但生成图像更符合条件要求。这个值越小，多样性增加，但生成图像可能不符合条件要求。
            save_path (str, optional): 保存路径，默认值为'./results'。
            save_noise_path (str, optional): 保存噪声路径，默认值为None。
            device (str, optional): 推理设备，默认值为'cuda'。
        """

        os.makedirs(save_path, exist_ok=True)
        if save_noise_path is not None:
            os.makedirs(save_noise_path, exist_ok=True)



        with torch.no_grad():
        # 无条件或有条件生成图片
#                 print(f'Generating {i}th image...')
                # Euler法间隔
                dt = 1.0 / self.step

                # 初始的x_t就是x_0，标准高斯噪声
                x_t = torch.randn((n, 2, n_channel, self.img_size)).to(self.device)
                noise = x_t.detach().cpu().numpy()

                # 提取第i个图像的标签条件y_i
#                 if y is not None:
#                     y_i = y.unsqueeze(0)
                y_i = y.to(device)
                z_i = z.to(device)
                w_i = w.to(device)

                for j in range(self.step):
#                     if j % 10 == 0:
#                         print(f'Generating {i}th image, step {j}...')
                    t = torch.ones(n)*j * dt
                    t = t.to(device)

                    
                    v_pred = model(x=x_t, t=t, y=y_i,z=z_i,w=w_i)
        
            
                    # 使用Euler法计算下一个时间的x_t
                    x_t = rf.euler(x_t, v_pred, dt)

                # 最后一步的x_t就是生成的图片
                # 先去掉batch维度
    #             x_t = x_t[0]
        return x_t



In [None]:
NZ=200
NX=70
n_cmp_raw = 2
NR=200
Tn = 1500  # 时间长度
# n_channel=16
# n_channel_raw = 16

n_channel=8
n_channel_raw = 8  #{}_cmp_2

n_cmp=2
DX=20
# wavelet_freq=30

pao_interval=5

dt=0.002


# smooth_sigma=3
smooth_sigma=5

wavelet_freq=30
migrate_factor=4


nz=NZ
img_size=64

root_dir='/data/wsl/model_openfwi/'

import scipy.ndimage
from scipy.ndimage import gaussian_filter
iter_i=200



def make_model_big(model_0,resize_x=100,resize_z=140):
    vp_i_reshape2 = ndimage.zoom(model_0, 3, order=4)  # order=3为双三次插值
    vp_i_reshape=chang_np_array_size(vp_i_reshape2,resize_x=resize_x,resize_z=resize_z)
    smoothed = anisotropic_diffusion(vp_i_reshape, kappa=30, iterations=15)
    return smoothed

In [None]:


vp_arr=[]
rtm_arr=[]
cmp_arr=[]
rtm0_arr=[]
vp0_arr=[]
    

root_dir='/data/wsl/model_openfwi/'
model_dir=root_dir+'flatfault/model_200_70/'
for iter_i in range(0,800,1):
    cmp_dir=root_dir+'flatfault/cmp_data_200_70_add_layer/{}_cmp/'.format(iter_i)
    vp_real_i0 = np.fromfile(model_dir+"model{}.bin".format(iter_i), dtype=np.float32).reshape([NZ, NX])
    
    vp_real_i0=chang_np_array_size(vp_real_i0,resize_x=NZ,resize_z=img_size)

    vp_real_i=scipy.ndimage.filters.gaussian_filter(vp_real_i0, sigma=smooth_sigma)
    vp_net_i0=make_vp_net(vp_real_i0,n_channel=n_channel)
#     rtm_i= 2 * (vp_real_i0 - vp_real_i) / vp_real_i
    vp_net_i=make_vp_net(vp_real_i,n_channel=n_channel)


    cmp_i=np.fromfile(cmp_dir+"cmp_{}_{}_{}.bin".format(NZ,n_channel_raw * n_cmp_raw,Tn), dtype=np.float32).reshape(NZ, n_channel_raw * n_cmp_raw, Tn )
    

        
    """============================================"""
    max_vp_top_layer=vp_real_i0[int(NZ/2),0]
    mean_vp_all_layer=np.mean(vp_real_i0[:,:])   
       
    dx_2=2*DX*20*3  
    t_path_2=dx_2/max_vp_top_layer
    last_T=(4.3*(NX+20)*DX)/mean_vp_all_layer
    time_cut_arrival=int(t_path_2/dt)
    last_time=int(1/dt+last_T/dt)
    """============================================"""

    
    
    cmp_net_i=np.zeros([NZ,n_channel_raw * n_cmp_raw,500])


    for channel_i in range(n_channel_raw * n_cmp_raw):
        cmp_net_i[:,channel_i,:]=chang_np_array_size(cmp_i[:,channel_i,time_cut_arrival:time_cut_arrival+last_time],resize_x=NZ,resize_z=500)

    
    dt=6e-3
    f_low=2
    N_filter=2
    import scipy.signal
    aa, bb = scipy.signal.butter(N_filter, 2 * f_low / (1 / dt), btype='highpass',analog=False, output='ba',fs=None)  # low_freq/ Nyquist freq == low_freq/（dt/2）
    cmp_i_resize_filter_t = scipy.signal.filtfilt(aa, bb, cmp_net_i)
    cmp_net_i=cmp_i_resize_filter_t
    
    


    resize_model=vp_real_i0


    rtm_i = generate_migrated_data(resize_model, factor=migrate_factor, fc=(DX/(3000/wavelet_freq)))# true freq = velocity/(grid_size/fc) 调节最后一个数字为 freq 目前是30hz

    rtm_i=chang_np_array_size(rtm_i,resize_x=NZ,resize_z=img_size)

    rtm_net_i=make_vp_net(rtm_i,n_channel=n_channel)
    
    
    rtm_i0 = generate_migrated_data(vp_real_i, factor=migrate_factor, fc=(DX/(3000/wavelet_freq)))# true freq = velocity/(grid_size/fc) 调节最后一个数字为 freq 目前是30hz

    rtm_i0=chang_np_array_size(rtm_i0,resize_x=NZ,resize_z=img_size)

    rtm_net_i0=make_vp_net(rtm_i0,n_channel=n_channel)
    rtm0_arr.append(rtm_net_i0[int(n_channel/2):-int(n_channel/2)])

    vp_arr.append(vp_net_i[int(n_channel/2):-int(n_channel/2)])
    rtm_arr.append(rtm_net_i[int(n_channel/2):-int(n_channel/2)])
    cmp_arr.append(cmp_net_i[int(n_channel/2):-int(n_channel/2)])
    vp0_arr.append(vp_net_i0[int(n_channel/2):-int(n_channel/2)])
    
model_dir=root_dir+'curvevelB/model_200_70/'
for iter_i in range(0,400,1):
    cmp_dir=root_dir+'curvevelB/cmp_data_200_70_add_layer/{}_cmp/'.format(iter_i)
    vp_real_i0 = np.fromfile(model_dir+"model{}.bin".format(iter_i), dtype=np.float32).reshape([NZ, NX])
    vp_real_i0=chang_np_array_size(vp_real_i0,resize_x=NZ,resize_z=img_size)

    vp_real_i=scipy.ndimage.filters.gaussian_filter(vp_real_i0, sigma=smooth_sigma)
    vp_net_i0=make_vp_net(vp_real_i0,n_channel=n_channel)
#     rtm_i= 2 * (vp_real_i0 - vp_real_i) / vp_real_i
    vp_net_i=make_vp_net(vp_real_i,n_channel=n_channel)


    cmp_i=np.fromfile(cmp_dir+"cmp_{}_{}_{}.bin".format(NZ,n_channel_raw * n_cmp_raw,Tn), dtype=np.float32).reshape(NZ, n_channel_raw * n_cmp_raw, Tn )

    
        
    """============================================"""
    max_vp_top_layer=vp_real_i0[int(NZ/2),0]
    mean_vp_all_layer=np.mean(vp_real_i0[:,:])   # 看最小频率 这里就看 主频的 4
       
    dx_2=2*DX*20*3  #20是顶层厚度  共覆盖点为2的时候 斜路径 2 是 双倍程
    t_path_2=dx_2/max_vp_top_layer
    last_T=(4.3*(NX+20)*DX)/mean_vp_all_layer
    time_cut_arrival=int(t_path_2/dt)
    last_time=int(1/dt+last_T/dt)
    """============================================"""
    
    
    
    cmp_net_i=np.zeros([NZ,n_channel_raw * n_cmp_raw,500])


    for channel_i in range(n_channel_raw * n_cmp_raw):
        cmp_net_i[:,channel_i,:]=chang_np_array_size(cmp_i[:,channel_i,time_cut_arrival:time_cut_arrival+last_time],resize_x=NZ,resize_z=500)

    
    dt=6e-3
    f_low=2
    N_filter=2
    import scipy.signal
    aa, bb = scipy.signal.butter(N_filter, 2 * f_low / (1 / dt), btype='highpass',analog=False, output='ba',fs=None)  # low_freq/ Nyquist freq == low_freq/（dt/2）
    cmp_i_resize_filter_t = scipy.signal.filtfilt(aa, bb, cmp_net_i)
    cmp_net_i=cmp_i_resize_filter_t
    
    


    resize_model=vp_real_i0


    rtm_i = generate_migrated_data(resize_model, factor=migrate_factor, fc=(DX/(3000/wavelet_freq)))# true freq = velocity/(grid_size/fc) 调节最后一个数字为 freq 目前是30hz
    rtm_i=chang_np_array_size(rtm_i,resize_x=NZ,resize_z=img_size)
    rtm_net_i=make_vp_net(rtm_i,n_channel=n_channel)
    
    
    rtm_i0 = generate_migrated_data(vp_real_i, factor=migrate_factor, fc=(DX/(3000/wavelet_freq)))# true freq = velocity/(grid_size/fc) 调节最后一个数字为 freq 目前是30hz
    rtm_i0=chang_np_array_size(rtm_i0,resize_x=NZ,resize_z=img_size)
    rtm_net_i0=make_vp_net(rtm_i0,n_channel=n_channel)
    
    rtm0_arr.append(rtm_net_i0[int(n_channel/2):-int(n_channel/2)])

    vp_arr.append(vp_net_i[int(n_channel/2):-int(n_channel/2)])
    rtm_arr.append(rtm_net_i[int(n_channel/2):-int(n_channel/2)])
    cmp_arr.append(cmp_net_i[int(n_channel/2):-int(n_channel/2)])
    vp0_arr.append(vp_net_i0[int(n_channel/2):-int(n_channel/2)])

    
model_dir=root_dir+'flat_b/model_200_70/'
for iter_i in range(0,200,1):
    cmp_dir=root_dir+'flat_b/cmp_data_200_70_add_layer/{}_cmp/'.format(iter_i)
    vp_real_i0 = np.fromfile(model_dir+"model{}.bin".format(iter_i), dtype=np.float32).reshape([NZ, NX])
    vp_real_i0=chang_np_array_size(vp_real_i0,resize_x=NZ,resize_z=img_size)

    vp_real_i=scipy.ndimage.filters.gaussian_filter(vp_real_i0, sigma=smooth_sigma)
    



    vp_net_i0=make_vp_net(vp_real_i0,n_channel=n_channel)

    vp_net_i=make_vp_net(vp_real_i,n_channel=n_channel)


    cmp_i=np.fromfile(cmp_dir+"cmp_{}_{}_{}.bin".format(NZ,n_channel_raw * n_cmp_raw,Tn), dtype=np.float32).reshape(NZ, n_channel_raw * n_cmp_raw, Tn )

    
    
        
    """============================================"""
    max_vp_top_layer=vp_real_i0[int(NZ/2),0]
    mean_vp_all_layer=np.mean(vp_real_i0[:,:])   # 看最小频率 这里就看 主频的 4
       
    dx_2=2*DX*20*3  #20是顶层厚度  共覆盖点为2的时候 斜路径 2 是 双倍程
    t_path_2=dx_2/max_vp_top_layer
    last_T=(4.3*(NX+20)*DX)/mean_vp_all_layer
    time_cut_arrival=int(t_path_2/dt)
    last_time=int(1/dt+last_T/dt)
    """============================================"""
    
    cmp_net_i=np.zeros([NZ,n_channel_raw * n_cmp_raw,500])


    for channel_i in range(n_channel_raw * n_cmp_raw):
        cmp_net_i[:,channel_i,:]=chang_np_array_size(cmp_i[:,channel_i,time_cut_arrival:time_cut_arrival+last_time],resize_x=NZ,resize_z=500)

    dt=6e-3
    f_low=2
    N_filter=2
    import scipy.signal
    aa, bb = scipy.signal.butter(N_filter, 2 * f_low / (1 / dt), btype='highpass',analog=False, output='ba',fs=None)  # low_freq/ Nyquist freq == low_freq/（dt/2）
    cmp_i_resize_filter_t = scipy.signal.filtfilt(aa, bb, cmp_net_i)
    cmp_net_i=cmp_i_resize_filter_t
    
    


    resize_model=vp_real_i0
    rtm_i = generate_migrated_data(resize_model, factor=migrate_factor, fc=(DX/(3000/wavelet_freq)))# true freq = velocity/(grid_size/fc) 调节最后一个数字为 freq 目前是30hz
    rtm_i=chang_np_array_size(rtm_i,resize_x=NZ,resize_z=img_size)
    rtm_net_i=make_vp_net(rtm_i,n_channel=n_channel)
    
    
    rtm_i0 = generate_migrated_data(vp_real_i, factor=migrate_factor, fc=(DX/(3000/wavelet_freq)))# true freq = velocity/(grid_size/fc) 调节最后一个数字为 freq 目前是30hz
    rtm_i0=chang_np_array_size(rtm_i0,resize_x=NZ,resize_z=img_size)
    rtm_net_i0=make_vp_net(rtm_i0,n_channel=n_channel)
    rtm0_arr.append(rtm_net_i0[int(n_channel/2):-int(n_channel/2)])
    

    vp_arr.append(vp_net_i[int(n_channel/2):-int(n_channel/2)])
    rtm_arr.append(rtm_net_i[int(n_channel/2):-int(n_channel/2)])
    cmp_arr.append(cmp_net_i[int(n_channel/2):-int(n_channel/2)])
    vp0_arr.append(vp_net_i0[int(n_channel/2):-int(n_channel/2)])


In [None]:
vp_arr=np.array(vp_arr)
rtm_arr=np.array(rtm_arr)
cmp_arr=np.array(cmp_arr)
rtm0_arr=np.array(rtm0_arr)
vp0_arr=np.array(vp0_arr)

vp0_arr=np.concatenate(vp0_arr,axis=0)
vp_arr=np.concatenate(vp_arr,axis=0)
rtm_arr=np.concatenate(rtm_arr,axis=0)
cmp_arr=np.concatenate(cmp_arr,axis=0)
rtm0_arr=np.concatenate(rtm0_arr,axis=0)



vp_mean = np.mean(vp_arr,axis=(0,-1), keepdims=True)
vp_std = np.std(vp_arr,axis=(0,-1), keepdims=True)
vp_arr = (vp_arr - vp_mean) / (vp_std+1e-12)

vp0_mean = np.mean(vp0_arr,axis=(0,-1), keepdims=True)
vp0_std = np.std(vp0_arr,axis=(0,-1), keepdims=True)
vp0_arr = (vp0_arr - vp0_mean) / (vp0_std+1e-12)






cmp_mean = np.mean(cmp_arr,axis=(0,-1), keepdims=True)
cmp_std = np.std(cmp_arr,axis=(0,-1), keepdims=True)
cmp_arr = (cmp_arr - cmp_mean) / (cmp_std+1e-12)




rtm_mean = np.mean(rtm_arr,axis=(0,-1), keepdims=True)
rtm_std = np.std(rtm_arr,axis=(0,-1), keepdims=True)
rtm_net = (rtm_arr - rtm_mean) / (rtm_std+1e-12)

rtm0_mean = np.mean(rtm0_arr,axis=(0,-1), keepdims=True)
rtm0_std = np.std(rtm0_arr,axis=(0,-1), keepdims=True)
rtm0_net = (rtm0_arr - rtm0_mean) / (rtm0_std+1e-12)

# np.save(root_dir+"rtm0_mean{}.npy".format(smooth_sigma),rtm0_mean)
# np.save(root_dir+"rtm0_std{}.npy".format(smooth_sigma),rtm0_std)
# np.save(root_dir+"rtm_mean{}.npy".format(smooth_sigma),rtm_mean)
# np.save(root_dir+"rtm_std{}.npy".format(smooth_sigma),rtm_std)
# np.save(root_dir+"vp0_mean{}.npy".format(smooth_sigma),vp0_mean)
# np.save(root_dir+"vp0_std{}.npy".format(smooth_sigma),vp0_std)
# np.save(root_dir+"vp_mean{}.npy".format(smooth_sigma),vp_mean)
# np.save(root_dir+"vp_std{}.npy".format(smooth_sigma),vp_std)
# np.save(root_dir+"cmp_mean{}.npy".format(smooth_sigma),cmp_mean)
# np.save(root_dir+"cmp_std{}.npy".format(smooth_sigma),cmp_std)



In [None]:


NR=vp_arr.shape[0]


resize_x=n_channel*n_cmp
resize_z=256


    
resized_seismic_data=np.zeros([NR,3,resize_x,resize_z]) # 0:cmp 1:vel_smooth
for i in range(NR):
    resized_seismic_data[i,0]=chang_np_array_size(cmp_arr[i],resize_x=int(resize_x),resize_z=resize_z)
    resized_seismic_data[i,1]=chang_np_array_size(vp_arr[i],resize_x=int(resize_x),resize_z=resize_z)    
    resized_seismic_data[i,2]=chang_np_array_size(rtm0_net[i],resize_x=int(resize_x),resize_z=resize_z)  



nx=NR
print(nx)



input_data=np.zeros([nx,2,n_channel,img_size])
input_data[:,0]=rtm_net.reshape([nx,n_channel,img_size])
input_data[:,1]=vp0_arr.reshape([nx,n_channel,img_size])



np.save(root_dir+"cmp_vp0_rtm0_sigma{}.npy".format(smooth_sigma),resized_seismic_data)



In [None]:


print(input_data.shape)
print(resized_seismic_data.shape)


NR=input_data.shape[0]
nx=NR

resize_x=n_channel*n_cmp
resize_z=256

In [None]:
vp0_mean_cuda=torch.tensor(vp0_mean).float().cuda()
vp0_std_cuda=torch.tensor(vp0_std).float().cuda()

rtm_mean_cuda=torch.tensor(rtm_mean).float().cuda()
rtm_std_cuda=torch.tensor(rtm_std).float().cuda()


In [None]:

batch_size = 400
image_size = img_size
device = "cuda"
# args.device = "cpu"

lr = 3e-4
run_name = "SAVE_DIRS".format(wavelet_freq,smooth_sigma) 


# 
setup_logging(run_name)
device = device







model = UNet_conditional(c_in=2, c_out=2,number_traces=resize_x).to(device)




optimizer = optim.AdamW(model.parameters(), lr=lr)
mse = nn.MSELoss()

rtm_min=np.mean(np.abs(input_data[:,0]))
vp_min=np.mean(np.abs(input_data[:,1]))
# 
ms_ssim_loss_fn_rtm = MSSSIMLoss(data_range=rtm_min, scales=1).to(device) # 图像在 [0, 1] 范围内

ms_ssim_loss_fn_vp = MSSSIMLoss(data_range=vp_min, scales=1).to(device) # 图像在 [0, 1] 范围内



rf = RFlow(step=20,img_size=image_size, device=device)



logger = SummaryWriter(os.path.join("runs", run_name))
l = nx
ema = EMA(0.995)
ema_model = copy.deepcopy(model).eval().requires_grad_(False)

from torch.optim.lr_scheduler import StepLR
lr_adjust_epoch=500
scheduler = StepLR(optimizer, step_size=lr_adjust_epoch, gamma=0.1)




epochs = 300



train_rtm_data=torch.tensor(input_data).float()
train_seismic_data=torch.tensor(resized_seismic_data).float()
loss_list=[]

In [None]:

batch_size = 300
dataloader = data.DataLoader(
    data.TensorDataset(train_rtm_data,train_seismic_data[:,0],train_seismic_data[:,1],train_seismic_data[:,2]),
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True       # 启用锁页内存加速传输
)
###条件也可以逐层➕ 比如t<5时 只加浅层的，深度随时间慢慢增加

In [None]:
# resized_seismic_data=torch.tensor(resized_seismic_data).float()

early_stop_patience = 100  # 容忍的连续无改善epoch数
early_stop_delta = 0.000000001  # 视为有改善的最小变化阈值
best_loss = float('inf')
early_stop_counter = 0
torch.backends.cudnn.benchmark = True

accumulation_steps = 10
scaler = torch.cuda.amp.GradScaler()
for epoch in tqdm(range(epochs)):
    model.train()
    epoch_loss = 0.0
    num_batches = 0
    for i,(images_raw,labels1,labels2,labels3) in enumerate(dataloader):


        images = images_raw.to(device, non_blocking=True)
        labels1=labels1.to(device, non_blocking=True)
        labels2=labels2.to(device, non_blocking=True)
        labels3=labels3.to(device, non_blocking=True)

        t = torch.rand(images.size(0)).to(device, non_blocking=True)
        x_t, x_0 = rf.create_flow(x_1=images, t=t)

        if np.random.random() < 0.2:
            labels1 = None
            labels2 = None
            labels3 = None

        v_pred = model(x_t, t, labels1,labels2,labels3)
        
#         print(v_pred.shape)
        indices = torch.nonzero(t >= 0.4).squeeze()
#         print(indices.shape)
        if indices is not None:
            rtm_from_pred_cuda=multi_migrated_data_cuda(v_pred[indices,1]*(vp0_std_cuda+1e-12)+vp0_mean_cuda, factor=migrate_factor, fc=(DX/(3000/wavelet_freq)))
    #         print(rtm_from_pred_cuda.shape)
            rtm_from_pred_cuda=(rtm_from_pred_cuda-rtm_mean_cuda)/(rtm_std_cuda+1e-12)

            loss2_1=mse(rtm_from_pred_cuda,v_pred[indices,0])+mse(rtm_from_pred_cuda,images[indices,0])
            loss2_2=ms_ssim_loss_fn_rtm(rtm_from_pred_cuda.reshape(-1,1,n_channel,img_size),v_pred[indices,0].reshape(-1,1,n_channel,img_size))+ms_ssim_loss_fn_rtm(rtm_from_pred_cuda.reshape(-1,1,n_channel,img_size),images[indices,0].reshape(-1,1,n_channel,img_size))


#             loss2_1=mse(rtm_from_pred_cuda,v_pred[indices,0])+ms_ssim_loss_fn_rtm(rtm_from_pred_cuda.reshape(-1,1,n_channel,img_size),v_pred[indices,0].reshape(-1,1,n_channel,img_size))
#             loss2_2=mse(rtm_from_pred_cuda,images[indices,0])+ms_ssim_loss_fn_rtm(rtm_from_pred_cuda.reshape(-1,1,n_channel,img_size),images[indices,0].reshape(-1,1,n_channel,img_size))

#             loss1 = rf.mse_loss(v_pred, images, x_0)#尝试将他进一步拆开~~~ loss_1_1 只关于vp  loss_1_2 只关于rtm  loss_1_1可以和loss_2组合（因为他们都是关于vp的）加起来的权重跟rtm一样
#             loss=loss1/loss1.detach().clone()+0.1*loss2/ loss2.detach().clone()
            
            loss1_1 = rf.mse_loss(v_pred[:,0], images[:,0], x_0[:,0])
                                                      
            loss1_2 = rf.mse_loss(v_pred[:,1], images[:,1], x_0[:,1])
            
            # 计算每个损失的梯度
            loss1_1.backward(retain_graph=True)
            grad1_1 = [p.grad.norm() for p in model.parameters() if p.grad is not None]

            loss1_2.backward(retain_graph=True)
            grad1_2 = [p.grad.norm() for p in model.parameters() if p.grad is not None]


            # 计算每个任务的梯度范数
            grad_norm1_1 = sum(grad1_1)/len(grad1_1)  # 或者用平均值/最大值等
            grad_norm1_2 = sum(grad1_2)/len(grad1_2) 
#             grad_norm2_1 = sum(grad2_1)
#             grad_norm2_2 = sum(grad2_2)

            # 计算梯度范数的倒数作为权重
            weight1_1 = 1.0 / (grad_norm1_1 + 1e-12)  # 加上小值避免除0
            weight1_2 = 1.0 / (grad_norm1_2 + 1e-12)
            weight2_1 = 5.0 / (grad_norm1_1 + 1e-12)*indices.shape[0]/batch_size
            weight2_2 = 5.0 / (grad_norm1_1 + 1e-12)*indices.shape[0]/batch_size

            # 对权重进行标准化
            total_weight = weight1_1 + weight1_2 + weight2_1 + weight2_2
            weight1_1 /= total_weight
            weight1_2 /= total_weight
            weight2_1 /= total_weight
            weight2_2 /= total_weight
#             print(weight1_1,weight1_2,weight2_1,weight2_2)
            # 最终加权的损失函数
            loss = weight1_1 * loss1_1 + weight1_2 * loss1_2/(loss1_2/loss1_1).detach().clone()+weight2_1 * loss2_1/(loss2_1/loss1_1).detach().clone() +0.01*weight2_2 * loss2_2/(loss2_2/loss1_1).detach().clone()
            
            
            ## loss1_1 标签image损失 loss1_2  标签vel损失  loss2_1 预测vel合成的image与预测image的差 loss2_2表示 预测vel合成的image与真实image的差
            ##所以 loss1_1 与 loss1_2是首先最需要优化的值（因为有真实标签） 
            ##loss2_2 是第二需要优化的（因为有真实标签，但关系更为间接） loss2_1是第三需要优化的（只是让网络内部保持一致）
#             loss=loss1_1/loss1_1.detach().clone()+1*loss1_2/loss1_2.detach().clone()+0.5*loss2_1/loss2_1.detach().clone()+1*loss2_2/loss2_2.detach().clone()
            epoch_loss +=(loss1_1.item()+loss1_2.item())
        else:
            loss1_1 = rf.mse_loss(v_pred[:,0], images[:,0], x_0[:,0])
            loss1_2 = rf.mse_loss(v_pred[:,1], images[:,1], x_0[:,1])
            loss=loss1_1/loss1_1.detach().clone()+1*loss1_2/loss1_2.detach().clone()
            epoch_loss +=(loss1_1.item()+loss1_2.item())
        num_batches += 1
#         epoch_loss +=(loss1.item()+loss2.item())

#         torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)



        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ema.step_ema(ema_model, model)
    avg_epoch_loss = epoch_loss / num_batches
    loss_list.append(avg_epoch_loss)
    
    

    # 调整学习率
    scheduler.step()     
    if epoch%100==0:

            torch.save(model.state_dict(), os.path.join("models", run_name, "{}_ckpt.pt".format(epoch)))
            torch.save(model.state_dict(), os.path.join("models", run_name, "ckpt.pt"))




In [None]:
torch.save(model.state_dict(), os.path.join("models", run_name, f"ckpt.pt"))
torch.save(ema_model.state_dict(), os.path.join("models", run_name, "ckpt_ema.pt".format(epoch)))


In [None]:
np.savetxt("models/"+ run_name+"_loss_{}.txt".format(epoch),np.array(loss_list))