In [1]:
import numpy as np
import scipy.io
import matplotlib.pyplot as plt
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torch.utils.data import Dataset, random_split
from torchvision import transforms
from torch.nn import init as init
import math

In [2]:
device2 = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
num_step = 20
num_window = 4
num_predict_window = 4; ##step can 1、2、4
img_n = 16
img_h = 8

In [4]:
class CustomDataset_plus(Dataset):
    def __init__(self, label_dir):
        self.label_dir = os.path.join(label_dir, "S")
        self.refine_well_dir = os.path.join(label_dir, "Well")
        self.K_dir = os.path.join(label_dir, "K")
        self.Pore_dir = os.path.join(label_dir, "Pore")
        self.Pressure_dir = os.path.join(label_dir, "Pressure")
        

        self.refine_file_list = os.listdir(self.refine_well_dir)
        self.K_file_list = os.listdir(self.K_dir)
        self.Pore_file_list = os.listdir(self.Pore_dir)
        self.Pressure_file_list = os.listdir(self.Pressure_dir)
        
        self.label_file_list = os.listdir(self.label_dir )
        

    def __len__(self):
        return len(self.label_file_list)* 4

    def __getitem__(self, index):
        file_index = index // 4
        rotation_index = index % 4
        refine_file_name = self.refine_file_list[file_index]
        K_file_name = self.K_file_list[file_index]
        
        Pore_file_name = self.Pore_file_list[file_index]
        Pressure_file_name = self.Pressure_file_list[file_index]
        
        label_file_name = self.label_file_list[file_index]
        
        refine_file_path = os.path.join(self.refine_well_dir, refine_file_name)
        K_file_path = os.path.join(self.K_dir, K_file_name)
        Pore_file_path = os.path.join(self.Pore_dir, Pore_file_name)
        Pressure_file_path = os.path.join(self.Pressure_dir, Pressure_file_name)
        
        label_file_path = os.path.join(self.label_dir, label_file_name)
        
        K_data = scipy.io.loadmat(K_file_path)
        K_array = (np.log(K_data['K']))
        K_tensor = torch.from_numpy(K_array).reshape(1,img_n,img_n,img_h).to(device2).float()
        K_RError = torch.rot90(torch.from_numpy(scipy.io.loadmat(K_file_path)['K']).reshape(1,img_n,img_n,img_h).to(device2).float(), k=rotation_index, dims=(1, 2))
        Pore_data = scipy.io.loadmat(Pore_file_path)
        Pore_array = Pore_data['p']
        Pore_tensor = torch.from_numpy(Pore_array).reshape(1,img_n,img_n,img_h).to(device2).float()
        Pore_RError = torch.rot90(Pore_tensor, k=rotation_index, dims=(1, 2))
        Pressure_data = scipy.io.loadmat(Pressure_file_path)
        Pressure_array = (Pressure_data['pressure'])
        Pressure_tensor = torch.from_numpy(Pressure_array.transpose(3,0,1,2))[:num_step].to(device2).float()* 1e-7  / 2.
        Pressure_tensor = torch.rot90(Pressure_tensor, k=rotation_index, dims=(1, 2))
      
        
        label_data = scipy.io.loadmat(label_file_path)
        
        label_array = label_data['Ssmatrix']
        
        label_refine_well = torch.from_numpy(scipy.io.loadmat(refine_file_path)["wc_globale_refine"]).to(device2).float()
        
        input_refine_well = label_refine_well.reshape(1,img_n,img_n,img_h).to(device2).float()
        
        intermediate_condition = torch.cat([F.normalize(K_tensor,dim=0),F.normalize(Pore_tensor,dim=0),input_refine_well],axis=0)
        intermediate_condition = torch.rot90(intermediate_condition, k=rotation_index, dims=(1, 2))
        
        label_tensor = torch.from_numpy(label_array.transpose(3,4,0,1,2))[:,:num_step].to(device2)
        label_tensor = torch.rot90(label_tensor, k=rotation_index, dims=(2, 3))
        return  (label_tensor).float(),(intermediate_condition).float(),Pressure_tensor.float(),K_RError,Pore_RError

label_dir = ''  # your dataset
dataset = CustomDataset_plus(label_dir)

train_size = int(0.9 * len(dataset))
valid_size = len(dataset) - train_size
train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])
torch.save(train_dataset,'./pt/train_dataset.pt')
torch.save(valid_dataset,'./pt/valid_dataset.pt')

batch_size = 512
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [7]:
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
    """Initialize network weights.

    Args:
        module_list (list[nn.Module] | nn.Module): Modules to be initialized.
        scale (float): Scale initialized weights, especially for residual
            blocks. Default: 1.
        bias_fill (float): The value to fill bias. Default: 0
        kwargs (dict): Other arguments for initialization function.
    """
    if not isinstance(module_list, list):
        module_list = [module_list]
    for module in module_list:
        for m in module.modules():
            if isinstance(m, nn.Conv3d):
                init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, _BatchNorm):
                init.constant_(m.weight, 1)
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)

In [8]:
class DoubleConv(nn.Module):
    def __init__(self,in_c,out_c,mid_c=None,residual=False):
        super().__init__()
        self.residual = residual
        
        if not mid_c:
            mid_c = out_c
            
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_c,mid_c,kernel_size=3,padding=1),
            nn.GroupNorm(1,mid_c),
            nn.ReLU(inplace=True),
            nn.Conv3d(mid_c,out_c,kernel_size=3,padding=1),
            nn.GroupNorm(1,out_c)
        )
        self.relu = nn.ReLU(inplace=True)
    def forward(self,x):
        if self.residual:
            return self.relu(x + self.double_conv(x))
        else:
            return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels , out_channels , down_t):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
        nn.MaxPool3d(2),
        DoubleConv(in_channels , in_channels, residual=True),
        DoubleConv(in_channels, out_channels),
        )
        
        self.maxpool_conv_t = nn.Sequential(
            nn.MaxPool3d(down_t),
            DoubleConv(num_window, out_channels),
        )
        
    def forward(self,x,t):
        x = self.maxpool_conv(x)
        t = self.maxpool_conv_t(t)
        return x + t


class Up(nn.Module):
    def __init__ (self,in_channels, out_channels, down_t):
        super().__init__()
        self.conv = nn.Sequential(
            DoubleConv(in_channels, in_channels,residual=True),
            DoubleConv(in_channels, out_channels,in_channels//2),
        )
        self.maxpool_conv_t = nn.Sequential(
            nn.MaxPool3d(down_t),
            DoubleConv(num_window, out_channels),
        )
    def forward(self,x, skip_x,t):
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = torch.cat([skip_x,x],dim=1)
        x= self.conv(x)
        t = self.maxpool_conv_t(t)
        return x + t
    

class ConvAttentionBlock_2C(nn.Module):
    def __init__(self, nf=64, bias=True):
        super(ConvAttentionBlock_2C,self).__init__()
        self.nf = nf
        self.convQ1 = nn.Conv3d(nf, nf, 3, 1,1)
        self.convK1 = nn.Conv3d(nf, nf, 3, 1,1)
        self.convV1 = nn.Conv3d(nf, nf, 3, 1,1)
        self.convQ2 = nn.Conv3d(nf, nf, 3,1,1)
        self.convV3 = nn.Conv3d(nf, nf, 3, 1,1)
        self.relu = nn.ReLU(inplace=True)
#         self.tan = nn.Tanh()
        self.instance_norm = nn.InstanceNorm3d(nf)
        self.convTurn1 = nn.Conv3d(nf, nf, 3, 1,1)
        self.convTurn2= nn.Conv3d(nf, nf, 3, 1,1)
        self.convK2 = nn.Conv3d(nf, nf, 3, 1,1)
        self.convV2 = nn.Conv3d(nf, nf, 3, 1,1)
        self.convQ3 = nn.Conv3d(nf, nf, 3, 1,1)
        self.convK3 = nn.Conv3d(nf, nf, 3,  1,1)
        self.convTurn3 = nn.Conv3d(nf, nf, 3, 1,1)
        self.convTurn = nn.Conv3d(nf, nf, 3, 1,1)
         # initialization
        default_init_weights([self.convTurn1,self.convTurn2,self.convTurn3,self.convTurn,self.convQ1,self.convK1,self.convV1 ,self.convQ2,self.convK2,self.convV2 ,self.convQ3,self.convK3,self.convV3,self.convTurn], 0.1)
    
    def forward(self, x):
        m = x.size(2)
        h = x.size(4)
        size = m*m*h
        Q1 = self.convQ1(x).permute((0, 2, 3, 4, 1)).view(-1,size,self.nf)
        K1 = self.convK1(x).permute((0, 2, 3, 4, 1)).view(-1,size,self.nf).permute(0, 2, 1)
        V1= self.convV1(x).permute((0, 2, 3, 4, 1)).view(-1,size,self.nf)
        Q2 = self.convQ2(x).permute((0, 2, 3, 4, 1)).view(-1,size,self.nf)
        K2 = self.convK2(x).permute((0, 2, 3, 4, 1)).view(-1,size,self.nf).permute(0, 2, 1)
        V2= self.convV2(x).permute((0, 2, 3, 4, 1)).view(-1,size,self.nf)
        Q3 = self.convQ3(x).permute((0, 2, 3, 4, 1)).view(-1,size,self.nf)
        K3 = self.convK3(x).permute((0, 2, 3, 4, 1)).view(-1,size,self.nf).permute(0, 2, 1)
        V3= self.convV3(x).permute((0, 2, 3, 4, 1)).view(-1,size,self.nf)
        
        B1 = (self.convTurn1(torch.bmm( V1 , self.relu((torch.bmm(K1, Q1)/math.sqrt(self.nf)).float()) ).permute(0, 2, 1).view(-1,self.nf,m,m,h)))
        B2 = (self.convTurn2(torch.bmm( V2 , self.relu((torch.bmm(K2, Q2)/math.sqrt(self.nf)).float()) ).permute(0, 2, 1).view(-1,self.nf,m,m,h)))
        B3 = (self.convTurn3(torch.bmm( V3 , self.relu((torch.bmm(K3, Q3)/math.sqrt(self.nf)).float()) ).permute(0, 2, 1).view(-1,self.nf,m,m,h)))
        
        x = self.instance_norm(self.convTurn(x + (0.3 * (B1+B2+B3))))
        
        return x


class ConvAttention(nn.Module):
    '''Residual in Residual Dense Block'''

    def __init__(self, nf):
        super(ConvAttention, self).__init__()
        self.CAB1 = ConvAttentionBlock_2C(nf)
        self.CAB2 = ConvAttentionBlock_2C(nf)
        self.conv1 = nn.Conv3d(nf, nf, 3, 1,1)
        self.conv2 = nn.Conv3d(nf, nf, 3, 1,1)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        
        out = self.CAB1(x)
        x = self.relu(self.conv1(x + out)) 
        out = self.CAB2(x)
        x = self.relu(self.conv2(x + out)) 
        return x

In [9]:
class FCUNet(nn.Module):
    def __init__(self,c_in=8+4+3,c_out_s=2*num_predict_window,c_out_p = num_predict_window):
        super().__init__()
        
        self.inc = DoubleConv(c_in,8)
        self.down1 = Down(8,16,2)
        self.sa1 = ConvAttention(16)
        self.down2 = Down(16,32,4)
        self.sa2 = ConvAttention(32)
        self.down3 = Down(32,64,8)
        self.sa3 = ConvAttention(64)
        
        self.bot1 = DoubleConv(64,128)
        self.bot2 = DoubleConv(128,128)
        self.bot3 = DoubleConv(128,64)
        
        self.up1 = Up(64+32,32,4)
        self.sa4 = ConvAttention(32)
        self.up2 = Up(32+16,16,2)
        self.sa5 = ConvAttention(16)
        
        self.up3_s = Up(16+8,16,1)
        self.up3_p = Up(16+8,8,1)
        self.out_s = nn.Conv3d(16,c_out_s,kernel_size=1,stride=1,padding=0)
        self.out_p = nn.Conv3d(8,c_out_p,kernel_size=1,stride=1,padding=0)
    def forward(self,x,t):
        #down sample
        x1 = self.inc(x)
        x2 = self.down1(x1,t)
        x2 = self.sa1(x2)
        x3 = self.down2(x2,t)
        x3 = self.sa2(x3)
        x4 = self.down3(x3,t)
        x4 = self.sa3(x4)
        
        
        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)
        
        #up sample
        x = self.up1(x4,x3,t)
        x = self.sa4(x)
        x = self.up2(x,x2,t)
        x = self.sa5(x)
        x_s = self.up3_s(x,x1,t)
        x_p = self.up3_p(x,x1,t)
        x_s = self.out_s(x_s)
        x_p = self.out_p(x_p)
        return torch.cat([x_s,x_p],dim=1) 

In [11]:
model = FCUNet().to(device2)
from torch.nn import DataParallel
model = nn.DataParallel(model, device_ids=[]) ##choice gpu id

In [14]:
lr = 2 * 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr)
loss_fn = nn.L1Loss()

In [None]:
def VDirect(Pressure, K, c):
    Pressure.requires_grad_(True)
    if c == 1:
        mu = 8e-4
    if c == 0:
        mu = 0.0141
    dx = torch.zeros_like(Pressure).to(device2).float()
    dx[:, :-1, :, :] = Pressure[:, 1:, :, :] - Pressure[:, :-1, :, :]
    
    dy = torch.zeros_like(Pressure).to(device2).float()
    dy[:, :, :-1, :] = Pressure[:, :, 1:, :] - Pressure[:, :, :-1, :]
    
    dz = torch.zeros_like(Pressure).to(device2).float()
    dz[:, :, :, :-1] = Pressure[:, :, :, 1:] - Pressure[:, :, :, :-1]
    
    return  (-K * dx/ mu).detach(),  (-K *dy/ mu).detach(), (-K *dz/ mu).detach()


In [None]:
def RErrorS(S_n_plus_one,S_n,Pressure,K,Pore): 
    S_n_plus_one = S_n_plus_one.reshape(-1,2*num_predict_window,16,16,8)[:,-2:,:,:,:]
    S_n = S_n.reshape(-1,2*num_predict_window,16,16,8)[:,-2:,:,:,:]
    Pressure = Pressure.reshape(-1,num_predict_window,16,16,8)[:,-1,:,:,:]
    K = K.reshape(-1,16,16,8)
    Pore = Pore.reshape(-1,16,16,8)

    R = torch.zeros_like(S_n_plus_one).to(device2).float()
    
    for c in range(2):
        vx,vy,vz = VDirect(Pressure,K,c)
        for i in range(7, 10):
            for j in range(7, 10):
                for k in range(4, 6):
                    time_diff = (S_n_plus_one[:,c,i,j,k] - S_n[:,c,i,j,k]) / 2. #/ dt

                    x_diff = vx[:,i+1,j,k] * (S_n[:,c,i+1,j,k] - S_n[:,c,i-1,j,k]) / 2.#(2*dx)
                    y_diff = vy[:,i,j+1,k] * (S_n[:,c,i,j+1,k] - S_n[:,c,i,j-1,k]) / 2.#(2*dy)
                    z_diff = vz[:,i,j,k+1] * (S_n[:,c,i,j,k+1] - S_n[:,c,i,j,k-1]) / 2.#(2*dz)

                    boundary_term = Pore[:,i,j,k]

                    R[:,c,i,j,k] = torch.pow((time_diff + x_diff + y_diff + z_diff - boundary_term),2)
    return R.sum()/(2*3*3*2*R.size(0))

In [15]:
def trainFCUNet(model, train_loader, val_loader, optimizer, p=0.9, num_epochs=1000):
    train_losses = []  
    train_saturation_losses =[]
    train_pressure_losses = []
    val_losses = []  
    val_saturation_losses =[]
    val_pressure_losses = []
    train_R_losses = []
    val_R_losses = []
    num_iters = (num_step-num_window)/num_predict_window
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_saturation_loss = 0.0
        train_pressure_loss = 0.0
        train_R_loss = 0.0
        step = [np.ones((img_n,img_n,img_h)) * i for i in range(num_step-num_predict_window)]
        step_emb = torch.from_numpy(np.concatenate(step, axis=0).reshape(num_step-num_predict_window,img_n,img_n,img_h)/ num_step).float()
        for batch in train_loader:
            targets ,intermidates,Pressure,K,Pore= batch[0].to(device2), batch[1].to(device2),batch[2].to(device2),batch[3].to(device2),batch[4].to(device2)
            sequence_saturation = torch.zeros_like(targets)
            sequence_pressure = torch.zeros_like(Pressure)
            sequence_saturation[:,:,:num_window,:,:,:] = targets[:,:,:num_window,:,:,:]
            sequence_pressure[:,:num_window,:,:,:] = Pressure[:,:num_window,:,:,:]
            for s in range(int(num_iters)):
                optimizer.zero_grad()
                begin = s*num_predict_window
                t = step_emb[begin:(num_window+begin),:,:,:].reshape(num_window,img_n,img_n,img_h).to(device2).repeat(targets.size(0),1,1,1,1)
                tar_saturation = targets[:,:,(num_window+begin):(num_predict_window+num_window+begin),:,:,:].reshape(-1,2*num_predict_window,img_n,img_n,img_h)
                tar_pressure = Pressure[:,(num_window+begin):(num_predict_window+num_window+begin),:,:,:].reshape(-1,num_predict_window,img_n,img_n,img_h)
                tar = torch.cat([tar_saturation,tar_pressure],dim=1)
                saturation = sequence_saturation[:,:,begin:(num_window+begin),:,:,:].reshape(-1,2*num_window,img_n,img_n,img_h)
                pressure = sequence_pressure[:,begin:(num_window+begin),:,:,:].reshape(-1,num_window,img_n,img_n,img_h)
                teacher_forcing = True if torch.rand(1) <= p else False
                if teacher_forcing or s == 0:
                    saturation = targets[:,:,begin:(num_window+begin),:,:,:].reshape(-1,2*num_window,img_n,img_n,img_h)
                    pressure = Pressure[:,begin:(num_window+begin),:,:,:].reshape(-1,num_window,img_n,img_n,img_h)
                inputs = torch.cat([saturation,pressure,intermidates],dim=1)
                outputs = model(inputs, t)
                Rloss = 0.1 * RErrorS(outputs[:,:2*num_predict_window,:,:,:],targets[:,:,(num_window+begin-1):(num_predict_window+num_window+begin-1),:,:,:],Pressure[:,(num_window+begin-1):(num_predict_window+num_window+begin-1),:,:,:]*2./1e-7,K,Pore)
                step_total_loss = loss_fn(tar,outputs) + Rloss
                step_total_loss.backward()
                optimizer.step()
                
                output_saturation = outputs[:,:2*num_predict_window,:,:,:].detach()
                output_pressure = outputs[:,2*num_predict_window:,:,:,:].detach()
#               

                train_loss += step_total_loss.item()
                step_saturation_loss = loss_fn(tar_saturation,output_saturation)
                step_pressure_loss = loss_fn(tar_pressure,output_pressure)
                train_saturation_loss += step_saturation_loss.item()
                train_pressure_loss += step_pressure_loss.item()
                train_R_loss += Rloss.item()
                sequence_saturation[:,:,(num_window+begin):(num_predict_window+num_window+begin),:,:,:] = output_saturation.reshape(-1,2,num_predict_window,img_n,img_n,img_h)
                sequence_pressure[:,(num_window+begin):(num_predict_window+num_window+begin),:,:,:] = output_pressure

        model.eval()
        val_loss = 0.0
        val_saturation_loss = 0.0
        val_pressure_loss = 0.0
        val_R_loss=0.0
        with torch.no_grad():
            for batch in val_loader:
                targets ,intermidates,Pressure,K,Pore= batch[0].to(device2), batch[1].to(device2),batch[2].to(device2),batch[3].to(device2),batch[4].to(device2)
                sequence_saturation = torch.zeros_like(targets)
                sequence_pressure = torch.zeros_like(Pressure)
                sequence_saturation[:,:,:num_window,:,:,:] = targets[:,:,:num_window,:,:,:]
                sequence_pressure[:,:num_window,:,:,:] = Pressure[:,:num_window,:,:,:]
                for s in range(int(num_iters)):
                    begin = s*num_predict_window
                    t = step_emb[begin:(num_window+begin),:,:,:].reshape(num_window,img_n,img_n,img_h).to(device2).repeat(targets.size(0),1,1,1,1)
                    tar_saturation = targets[:,:,(num_window+begin):(num_predict_window+num_window+begin),:,:,:].reshape(-1,2*num_predict_window,img_n,img_n,img_h)
                    tar_pressure = Pressure[:,(num_window+begin):(num_predict_window+num_window+begin),:,:,:].reshape(-1,num_predict_window,img_n,img_n,img_h)
                    tar = torch.cat([tar_saturation,tar_pressure],dim=1)
                    saturation = sequence_saturation[:,:,begin:(num_window+begin),:,:,:].reshape(-1,2*num_window,img_n,img_n,img_h)
                    pressure = sequence_pressure[:,begin:(num_window+begin),:,:,:].reshape(-1,num_window,img_n,img_n,img_h)
                    inputs = torch.cat([saturation,pressure,intermidates],dim=1)
                    outputs = model(inputs,t)
                    output_saturation = outputs[:,:2*num_predict_window,:,:,:].detach()
                    output_pressure = outputs[:,2*num_predict_window:,:,:,:].detach()
                    step_saturation_loss = loss_fn(tar_saturation,output_saturation)
                    step_pressure_loss = loss_fn(tar_pressure,output_pressure)
                    Rloss = 0.1 * RErrorS(outputs[:,:2*num_predict_window,:,:,:],targets[:,:,(num_window+begin-1):(num_predict_window+num_window+begin-1),:,:,:],Pressure[:,(num_window+begin-1):(num_predict_window+num_window+begin-1),:,:,:]*2./1e-7,K,Pore)
                    step_total_loss = loss_fn(tar,outputs) + Rloss
                    val_loss += step_total_loss.item()
                    val_saturation_loss += step_saturation_loss.item()
                    val_pressure_loss += step_pressure_loss.item()
                    val_R_loss += Rloss.item()
                    sequence_saturation[:,:,(num_window+begin):(num_predict_window+num_window+begin),:,:,:] = output_saturation.reshape(-1,2,num_predict_window,img_n,img_n,img_h)
                    sequence_pressure[:,(num_window+begin):(num_predict_window+num_window+begin),:,:,:] = output_pressure

        train_loss /= (len(train_loader)*num_iters)
        train_pressure_loss /= (len(train_loader)*num_iters)
        train_saturation_loss /= (len(train_loader)*num_iters)
        val_loss /= (len(val_loader)*num_iters)
        val_pressure_loss /= (len(val_loader)*num_iters)
        val_saturation_loss /= (len(val_loader)*num_iters)
        val_R_loss /= (len(val_loader)*num_iters)
        train_R_loss /= (len(train_loader)*num_iters)
        
        train_losses.append(train_loss) 
        train_pressure_losses.append(train_pressure_loss)
        train_saturation_losses.append(train_saturation_loss)
        val_losses.append(val_loss)  
        val_pressure_losses.append(val_pressure_loss)
        val_saturation_losses.append(val_saturation_loss)
        val_R_losses.append(val_R_loss)
        train_R_losses.append(train_R_loss)
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, S: {train_saturation_loss:.4f}, P: {train_pressure_loss:.4f}, R:{train_R_loss:.4f}"
                                            f" Val Loss: {val_loss:.4f}, S: {val_saturation_loss:.4f}, P: {val_pressure_loss:.4f}, R:{val_R_loss:.4f} ")
        if p > 0:
            p -= 0.008
        if (epoch+1) % 5 == 0 :
            torch.save(model.state_dict(),"../"%(epoch+1)) #save pth file

In [None]:
trainFCUNet(model, train_loader, valid_loader, optimizer, p=0.9,num_epochs=1000)