In [1]:
import numpy as np
import scipy.io
import matplotlib.pyplot as plt
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torch.utils.data import Dataset, random_split
from torchvision import transforms
from torch.nn import init as init
import math

In [2]:
device2 = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
num_step = 20 #five years
num_window = 4
num_predict_window = 16
img_n = 16
img_h = 8

In [4]:
class CustomDataset_plus(Dataset):
    def __init__(self, label_dir):
        self.label_dir = os.path.join(label_dir, "S")
        self.refine_well_dir = os.path.join(label_dir, "Well")
        self.K_dir = os.path.join(label_dir, "K")
        self.Pore_dir = os.path.join(label_dir, "Pore")
        self.Pressure_dir = os.path.join(label_dir, "Pressure")
        

        self.refine_file_list = os.listdir(self.refine_well_dir)
        self.K_file_list = os.listdir(self.K_dir)
        self.Pore_file_list = os.listdir(self.Pore_dir)
        self.Pressure_file_list = os.listdir(self.Pressure_dir)
        
        self.label_file_list = os.listdir(self.label_dir )
        

    def __len__(self):
        return len(self.label_file_list)* 4

    def __getitem__(self, index):
        file_index = index // 4
        rotation_index = index % 4
        refine_file_name = self.refine_file_list[file_index]
        K_file_name = self.K_file_list[file_index]
        
        Pore_file_name = self.Pore_file_list[file_index]
        Pressure_file_name = self.Pressure_file_list[file_index]
        
        label_file_name = self.label_file_list[file_index]
        
        refine_file_path = os.path.join(self.refine_well_dir, refine_file_name)
        K_file_path = os.path.join(self.K_dir, K_file_name)
        Pore_file_path = os.path.join(self.Pore_dir, Pore_file_name)
        Pressure_file_path = os.path.join(self.Pressure_dir, Pressure_file_name)
        
        label_file_path = os.path.join(self.label_dir, label_file_name)
        
        K_data = scipy.io.loadmat(K_file_path)
        K_array = (np.log(K_data['K']))
        K_tensor = torch.from_numpy(K_array).reshape(1,img_n,img_n,img_h).to(device2).float()
        
        Pore_data = scipy.io.loadmat(Pore_file_path)
        Pore_array = Pore_data['p']
        Pore_tensor = torch.from_numpy(Pore_array).reshape(1,img_n,img_n,img_h).to(device2).float()
        
        Pressure_data = scipy.io.loadmat(Pressure_file_path)
        Pressure_array = (Pressure_data['pressure'])
        Pressure_tensor = torch.from_numpy(Pressure_array.transpose(3,0,1,2))[:num_step].to(device2).float()* 1e-7  / 2. # (38,16,16,8)
        Pressure_tensor = torch.rot90(Pressure_tensor, k=rotation_index, dims=(1, 2))
      
        label_data = scipy.io.loadmat(label_file_path)
        label_array = label_data['Ssmatrix']
        
        label_refine_well = torch.from_numpy(scipy.io.loadmat(refine_file_path)["wc_globale_refine"]).to(device2).float() #用于loss使用
        
        input_refine_well = label_refine_well.reshape(1,img_n,img_n,img_h).to(device2).float()
        
        intermediate_condition = torch.cat([F.normalize(K_tensor,dim=0),F.normalize(Pore_tensor,dim=0),input_refine_well],axis=0)
        intermediate_condition = torch.rot90(intermediate_condition, k=rotation_index, dims=(1, 2))
        
        label_tensor = torch.from_numpy(label_array.transpose(3,4,0,1,2))[:,:num_step].to(device2)
        label_tensor = torch.rot90(label_tensor, k=rotation_index, dims=(2, 3))
        return  (label_tensor).float(),(intermediate_condition).float(),Pressure_tensor.float()

label_dir = '../CO2Datasets/refine/' 
dataset = CustomDataset_plus(label_dir)

# train_size = int(0.9 * len(dataset))
# valid_size = len(dataset) - train_size
# # print(valid_size)
# train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])
# torch.save(train_dataset,'./pt/train_dataset.pt')
# torch.save(valid_dataset,'./pt/valid_dataset.pt')

train_dataset = torch.load('./pt/train_dataset.pt')
valid_dataset = torch.load('./pt/valid_dataset.pt')

batch_size = 512
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [7]:
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
    """Initialize network weights.

    Args:
        module_list (list[nn.Module] | nn.Module): Modules to be initialized.
        scale (float): Scale initialized weights, especially for residual
            blocks. Default: 1.
        bias_fill (float): The value to fill bias. Default: 0
        kwargs (dict): Other arguments for initialization function.
    """
    if not isinstance(module_list, list):
        module_list = [module_list]
    for module in module_list:
        for m in module.modules():
            if isinstance(m, nn.Conv3d):
                init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, _BatchNorm):
                init.constant_(m.weight, 1)
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)

In [8]:
class DoubleConv(nn.Module):
    def __init__(self,in_c,out_c,mid_c=None,residual=False):
        super().__init__()
        self.residual = residual
        
        if not mid_c:
            mid_c = out_c
            
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_c,mid_c,kernel_size=3,padding=1),
            nn.GroupNorm(1,mid_c),
            nn.ReLU(inplace=True),
            nn.Conv3d(mid_c,out_c,kernel_size=3,padding=1),
            nn.GroupNorm(1,out_c)
        )
        self.relu = nn.ReLU(inplace=True)
    def forward(self,x):
        if self.residual:
            return self.relu(x + self.double_conv(x))
        else:
            return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels , out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
        nn.MaxPool3d(2),
        DoubleConv(in_channels , in_channels, residual=True),
        DoubleConv(in_channels, out_channels),
        )
        
        
    def forward(self,x):
        x = self.maxpool_conv(x)
        return x

class Up(nn.Module):
    def __init__ (self,in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            DoubleConv(in_channels, in_channels,residual=True),
            DoubleConv(in_channels, out_channels,in_channels//2),
        )
        
    def forward(self,x, skip_x):
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = torch.cat([skip_x,x],dim=1)
        x= self.conv(x)
        return x

In [9]:
class UNet(nn.Module):
    def __init__(self,c_in=8+4+3,c_out_s=2*num_predict_window,c_out_p = num_predict_window):
        super().__init__()
        
        self.inc = DoubleConv(c_in,8)
        self.down1 = Down(8,16)
        self.down2 = Down(16,32)
        self.down3 = Down(32,64)
        
        self.bot1 = DoubleConv(64,128)
        self.bot2 = DoubleConv(128,128)
        self.bot3 = DoubleConv(128,64)
        
        self.up1 = Up(64+32,32)
        self.up2 = Up(32+16,16)
        
        self.up3_s = Up(16+8,16)
        self.up3_p = Up(16+8,8)
        
        self.out_s = nn.Conv3d(16,c_out_s,kernel_size=1,stride=1,padding=0)
        self.out_p = nn.Conv3d(8,c_out_p,kernel_size=1,stride=1,padding=0)
    def forward(self,x):
        x1 = self.inc(x)
#         x1 = self.sa0(x1)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        
        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)
        
        x = self.up1(x4,x3)
        x = self.up2(x,x2)
        x_s = self.up3_s(x,x1)
        x_p = self.up3_p(x,x1)
        x_s = self.out_s(x_s)
        x_p = self.out_p(x_p)
        return torch.cat([x_s,x_p],dim=1) 

In [12]:
model = UNet().to(device2)
from torch.nn import DataParallel
model = nn.DataParallel(model, device_ids=[])

In [15]:
lr = 2 * 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr)
loss_fn = nn.L1Loss()

In [16]:
def trainUNet(model, train_loader, val_loader, optimizer,  num_epochs=1000):
    train_losses = [] 
    train_saturation_losses =[]
    train_pressure_losses = []
    val_losses = [] 
    val_saturation_losses =[]
    val_pressure_losses = []
    num_iters = (num_step-num_window)/num_predict_window
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_saturation_loss = 0.0
        train_pressure_loss = 0.0
        step = [np.ones((img_n,img_n,img_h)) * i for i in range(num_step-num_predict_window)]
        step_emb = torch.from_numpy(np.concatenate(step, axis=0).reshape(num_step-num_predict_window,img_n,img_n,img_h)/ num_step).float()
        for batch in train_loader:
            targets ,intermidates,Pressure= batch[0].to(device2), batch[1].to(device2),batch[2].to(device2)
            sequence_saturation = torch.zeros_like(targets)
            sequence_pressure = torch.zeros_like(Pressure)
            sequence_saturation[:,:,:num_window,:,:,:] = targets[:,:,:num_window,:,:,:]
            sequence_pressure[:,:num_window,:,:,:] = Pressure[:,:num_window,:,:,:]
            for s in range(int(num_iters)):
                optimizer.zero_grad()
                begin = s*num_predict_window
                tar_saturation = targets[:,:,(num_window+begin):(num_predict_window+num_window+begin),:,:,:].reshape(-1,2*num_predict_window,img_n,img_n,img_h)
                tar_pressure = Pressure[:,(num_window+begin):(num_predict_window+num_window+begin),:,:,:].reshape(-1,num_predict_window,img_n,img_n,img_h)
                tar = torch.cat([tar_saturation,tar_pressure],dim=1)
                saturation = targets[:,:,begin:(num_window+begin),:,:,:].reshape(-1,2*num_window,img_n,img_n,img_h)
                pressure = Pressure[:,begin:(num_window+begin),:,:,:].reshape(-1,num_window,img_n,img_n,img_h)
                
                inputs = torch.cat([saturation,pressure,intermidates],dim=1)
                outputs = model(inputs)
                step_total_loss = loss_fn(tar,outputs)
                step_total_loss.backward()
                optimizer.step()
                
                output_saturation = outputs[:,:2*num_predict_window,:,:,:].detach()
                output_pressure = outputs[:,2*num_predict_window:,:,:,:].detach()
#               

                train_loss += step_total_loss.item()
                step_saturation_loss = loss_fn(tar_saturation,output_saturation)
                step_pressure_loss = loss_fn(tar_pressure,output_pressure)
                train_saturation_loss += step_saturation_loss.item()
                train_pressure_loss += step_pressure_loss.item()
                sequence_saturation[:,:,(num_window+begin):(num_predict_window+num_window+begin),:,:,:] = output_saturation.reshape(-1,2,num_predict_window,img_n,img_n,img_h)
                sequence_pressure[:,(num_window+begin):(num_predict_window+num_window+begin),:,:,:] = output_pressure
  
        model.eval()
        val_loss = 0.0
        val_saturation_loss = 0.0
        val_pressure_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                targets ,intermidates,Pressure= batch[0].to(device2), batch[1].to(device2),batch[2].to(device2)
                sequence_saturation = torch.zeros_like(targets)
                sequence_pressure = torch.zeros_like(Pressure)
                sequence_saturation[:,:,:num_window,:,:,:] = targets[:,:,:num_window,:,:,:]
                sequence_pressure[:,:num_window,:,:,:] = Pressure[:,:num_window,:,:,:]
                for s in range(int(num_iters)):
                    begin = s*num_predict_window
                    
                    tar_saturation = targets[:,:,(num_window+begin):(num_predict_window+num_window+begin),:,:,:].reshape(-1,2*num_predict_window,img_n,img_n,img_h)
                    tar_pressure = Pressure[:,(num_window+begin):(num_predict_window+num_window+begin),:,:,:].reshape(-1,num_predict_window,img_n,img_n,img_h)
                    tar = torch.cat([tar_saturation,tar_pressure],dim=1)
                    saturation = sequence_saturation[:,:,begin:(num_window+begin),:,:,:].reshape(-1,2*num_window,img_n,img_n,img_h)
                    pressure = sequence_pressure[:,begin:(num_window+begin),:,:,:].reshape(-1,num_window,img_n,img_n,img_h)
                    inputs = torch.cat([saturation,pressure,intermidates],dim=1)
                    outputs = model(inputs)
                    output_saturation = outputs[:,:2*num_predict_window,:,:,:].detach()
                    output_pressure = outputs[:,2*num_predict_window:,:,:,:].detach()
                    step_saturation_loss = loss_fn(tar_saturation,output_saturation)
                    step_pressure_loss = loss_fn(tar_pressure,output_pressure)
                    step_total_loss = loss_fn(tar,outputs)
                    val_loss += step_total_loss.item()
                    val_saturation_loss += step_saturation_loss.item()
                    val_pressure_loss += step_pressure_loss.item()
                    sequence_saturation[:,:,(num_window+begin):(num_predict_window+num_window+begin),:,:,:] = output_saturation.reshape(-1,2,num_predict_window,img_n,img_n,img_h)
                    sequence_pressure[:,(num_window+begin):(num_predict_window+num_window+begin),:,:,:] = output_pressure

        train_loss /= (len(train_loader)*num_iters)
        train_pressure_loss /= (len(train_loader)*num_iters)
        train_saturation_loss /= (len(train_loader)*num_iters)
        val_loss /= (len(val_loader)*num_iters)
        val_pressure_loss /= (len(val_loader)*num_iters)
        val_saturation_loss /= (len(val_loader)*num_iters)

        train_losses.append(train_loss)  
        train_pressure_losses.append(train_pressure_loss)
        train_saturation_losses.append(train_saturation_loss)
        val_losses.append(val_loss) 
        val_pressure_losses.append(val_pressure_loss)
        val_saturation_losses.append(val_saturation_loss)
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, S: {train_saturation_loss:.4f}, P: {train_pressure_loss:.4f},"
                                            f" Val Loss: {val_loss:.4f}, S: {val_saturation_loss:.4f}, P: {val_pressure_loss:.4f} ")
        if (epoch+1) % 5 == 0 :
            torch.save(model.state_dict()," "%(epoch+1)) #save pth file

In [None]:
trainUNet(model, train_loader, valid_loader, optimizer,num_epochs=1000)