### Import python modules 

In [None]:
import os
import math
import random
import numpy as np
import torch

root_path = os.path.abspath('.')

### Raw data processing

In [None]:
num_sample,num_inline,num_crossline = 128,256,256
data_name = "0108-128x256x256"
dataset_name = "0108-128x256x256"
# 设置路径

raw_data_path = os.path.join(os.path.abspath('..'), "DATA", data_name)
dataset_path = os.path.join(root_path, "datasets", data_name)

seis_path = os.path.join(raw_data_path, "seis")
rgt_path = os.path.join(raw_data_path, "rgt")
fault_path = os.path.join(raw_data_path, "fault")

if not os.path.exists(dataset_path):
    os.makedirs(dataset_path)

train_dataset_path = os.path.join(dataset_path, "data")
if not os.path.exists(train_dataset_path):
    os.makedirs(train_dataset_path)

### Generating training and validation datasets

In [None]:
data_file_list = os.listdir(seis_path)
data_file_index = [i[0] for i in sorted(enumerate(data_file_list), key=lambda x:int(x[1].split('.')[0]))]
data_file_list = [data_file_list[i] for i in data_file_index]
num_data_file = len(data_file_list)
size = (num_sample,num_inline,num_crossline)

In [None]:
sample_count = 0
train_sample_list = []
for index, data_file_name in enumerate(data_file_list):
    print(f"Processing raw data file {index+1}/{num_data_file} ...")       
    
    seis_file_path = os.path.join(seis_path, data_file_name)
    seis_cube = utils.read_data(size,seis_file_path)

    rgt_file_path = os.path.join(rgt_path,str(index)+".dat")
    rgt_cube = utils.read_data(size,rgt_file_path)

    fault_file_path = os.path.join(fault_path,str(index)+".dat")
    fault_cube = utils.read_data(size,fault_file_path)

    sample = {}
    sample["seis"] = seis_cube
    sample["rgt"] = rgt_cube
    sample["fault"] = fault_cube
    
    sample_file_name = f'sample_{sample_count}.npy'
    sample["fname"] = sample_file_name
    
    np.save(os.path.join(train_dataset_path, sample_file_name), sample)  
    train_sample_list.append(sample_file_name)
    sample_count += 1

In [None]:
#训练集比例
train_set_ratio = 0.9

num_train_sample = len(train_sample_list)
print(f"样本库总数: {num_train_sample}")

# 混乱数据集
random.shuffle(train_sample_list)

# 训练集/验证集划分
valid_num = int(num_train_sample * (1-train_set_ratio))
valid_sample_list = random.sample(train_sample_list, valid_num)

samples_train,samples_valid = [],[]

for i_sample in train_sample_list[:num_train_sample]:
    if i_sample not in valid_sample_list:
        samples_train.append(i_sample)
    else:
        samples_valid.append(i_sample)

print(f'训练样本数量：{len(samples_train)}')
print(f'验证样本数量：{len(samples_valid)}')
np.save(os.path.join(dataset_path, 'samples_train.npy'), samples_train)
np.save(os.path.join(dataset_path, 'samples_valid.npy'), samples_valid)