### Import python modules 

In [None]:
import os
import math
import random
import numpy as np
import torch

import utils
import models
import draw
import tools

random_state = 12314
torch.manual_seed(random_state) # cpu
np.random.seed(random_state)    # numpy
random.seed(random_state)       # random and transforms

os.environ['CUDA_VISIBLE_DEVICES']= '0,1,2,3'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    num_GPU = torch.cuda.device_count()
    print(f"GPU 数量: {num_GPU}")
else:
    num_GPU = 1
    
print(f"运行平台: {device}") 
root_path = os.path.abspath('.')

### Load synthetic geologic models

In [None]:
# 读取数据集
dataset_name = "0108-128x256x256"
dataset_path = os.path.join(root_path, "datasets", dataset_name)
                            
samples_train = np.load(os.path.join(dataset_path, 'samples_train.npy'), allow_pickle=True)
samples_valid = np.load(os.path.join(dataset_path, 'samples_valid.npy'), allow_pickle=True)

print(f"训练样本数量: {len(samples_train)}")
print(f"验证样本数量: {len(samples_valid)}")

train_sample_path = os.path.join(dataset_path, "data")

### Parameters of data simulation 

In [None]:
inshape = (128,128,128)
num_hrzs_list=[2]
bit=256
mask_grp_sel=[4,None]
bit_rate=2
sample_rate_list=[50, 100]
fault_range=1
bit_mute = 85
norm=utils.min_max_norm
use_normal=False
if use_normal:
    input_attr_list = ["scalar", "normal", "fault"]
else:
    input_attr_list = ["scalar", "fault"] 
output_attr_list = ["rgt"]

print(f"输入属性:{input_attr_list}")
print(f"目标属性:{output_attr_list}")

### Automatic structural data generator

In [None]:
train_data =  utils.build_dataset(inshape, samples_train, train_sample_path, 'Valid',
                            num_hrzs_list=num_hrzs_list,
                            bit=bit,
                            mask_grp_sel=mask_grp_sel,
                            bit_rate=bit_rate,
                            sample_rate_list=sample_rate_list,
                            fault_range=fault_range, 
                            bit_mute = bit_mute,
                            norm=norm, 
                            point_set=True, use_normal=use_normal)

In [None]:
batch_samples = [train_data[i] for i in range(30,50)]
k = 15
gt = batch_samples[k]['rgt'][0]
fl = batch_samples[k]['fault'][0]
ps = batch_samples[k]['point_set_scalar']
fm = batch_samples[k]['mask'][0]
nps = batch_samples[k]['point_set_scalar']
fps = batch_samples[k]['point_set_fault']
cvs = draw.get_horizon_scalar(nps, gt)

In [None]:
draw.draw_slice_line_surf(fl, 
                     x_slices=[30], y_slices=[30], z_slices=[120], 
                     points=nps, points2=fps, 
                     smap='jet', smin=np.min(gt), smax=np.max(gt),
                     cmap='fault')

In [None]:
train_data =  utils.build_dataset(inshape, samples_train, train_sample_path, 'Train',
                            num_hrzs_list=num_hrzs_list,
                            bit=bit,
                            mask_grp_sel=mask_grp_sel,
                            bit_rate=bit_rate,
                            sample_rate_list=sample_rate_list,
                            fault_range=fault_range, 
                            bit_mute = bit_mute, 
                            norm=norm, use_normal=use_normal)

valid_data =  utils.build_dataset(inshape, samples_valid, train_sample_path, 'Valid',
                            num_hrzs_list=num_hrzs_list,
                            bit=bit,
                            mask_grp_sel=mask_grp_sel,
                            bit_rate=bit_rate,
                            sample_rate_list=sample_rate_list,
                            fault_range=fault_range, 
                            bit_mute = bit_mute, 
                            norm=norm, use_normal=use_normal)

### Training CNN

In [None]:
# 定义网络
param_model = {}
param_model['network'] = "ISMNet"
param_model['input_channels'] = 1+1
if use_normal:
    param_model['input_channels'] += 3
param_model['output_channels'] = 1
param_model['inshape'] = inshape

model = getattr(models, param_model['network'])(param_model)
loss_type = {"mae":0.24, "ms-ssim":0.84}
loss_name = '+'.join([f"{'{:.2f}'.format(value)}*{key}" for key, value in loss_type.items()])

session_name = '-'.join([param_model['network'], "dataset_"+dataset_name, loss_name])
if use_normal:
    session_name = '-'.join([session_name, "orientation"])  
    
# 并行模式
if num_GPU > 1:
    print(f"多核模式")
    model = torch.nn.DataParallel(model, device_ids=range(num_GPU)).to(device)
else:
    print(f"单核模式")
    model = model.to(device)
    
# 模型保存路径
checkpoint_path = os.path.join('checkpoints', session_name)
if not os.path.exists(checkpoint_path):
    os.makedirs(checkpoint_path)
print(f"模型读取路径: {checkpoint_path}")

In [None]:
# 定义训练参数
param = {}
param['model'] = param_model
param['epochs'] = 161 # 训练轮数  
param['batch_size'] = 2*num_GPU # 批大小
param['lr'] = 1e-3 # 学习率         
param['optimizer_type'] = 'Adam' # 优化器类型 
param['weight_decay'] = 1e-4 # 权重衰减
param['decay_type'] = 'ReduceLROnPlateau' # 学习率衰减策略 
param['gamma'] = 0.5 # 学习率衰减系数    
param['lr_decay'] = 2 # 学习率衰减周期
param['loss_type'] = loss_type
param['checkpoint_path'] = checkpoint_path
param['disp_inter'] = 2 # 显示间隔 
param['save_inter'] = 10 # 保存间隔

In [None]:
# 训练网络
model = utils.train_valid_net(param, model, train_data, valid_data, plot=True, 
                              device=device, use_normal=use_normal)

### Inference

In [None]:
load_checkpoint_path = os.path.join("checkpoints", "ISMNet-dataset_0108-128x256x256")
model.load_state_dict(torch.load(os.path.join(load_checkpoint_path, 'checkpoint-best.pth'))['state_dict'])
print(f"模型读取路径: {checkpoint_path}")

In [None]:
inshape = (128,128,128)
num_hrzs_list=[2]
bit=256
mask_grp_sel=[2,None]
bit_rate=2
sample_rate_list=[100]
fault_range=1
bit_mute=80
norm=utils.min_max_norm

In [None]:
test_data =  utils.build_dataset(inshape, samples_valid[3:4], train_sample_path, 'Valid',
                            num_hrzs_list=num_hrzs_list,
                            bit=bit,
                            mask_grp_sel=mask_grp_sel,
                            bit_rate=bit_rate,
                            sample_rate_list=sample_rate_list,
                            fault_range=fault_range, 
                            bit_mute = bit_mute, 
                            norm=norm, use_normal=use_normal,
                            point_set=True)

In [None]:
output_pred_samples = utils.pred(model, test_data, use_normal=use_normal, device=device)

In [None]:
k  = 0
gt = output_pred_samples[k]['rgt'][0]
nps= output_pred_samples[k]['point_set_scalar']
fps= output_pred_samples[k]['point_set_fault']
mk = output_pred_samples[k]['mask'][0]
fl = output_pred_samples[k]['fault'][0]
pd = output_pred_samples[k]['pred'][0]

In [None]:
draw_slice_line_surf(pd, 
                     x_slices=[20], y_slices=[20], z_slices=[120], 
                     points=nps, 
                     points2=fps,
                     cmap='model',
                     isovol=pd, 
                     isofs=draw.get_horizon_scalar(nps, pd),
                     mute_edge=3,
                    )