In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class modelConv1(nn.Module):
    def __init__(self, in_channels, out_channels, is_batchnorm):
        super(modelConv1, self).__init__()
        # Kernel size: 3, Stride: 1, Padding: 1
        if is_batchnorm:
            # 31, 1, 15
            self.conv1 = nn.Sequential(nn.Conv1d(in_channels, out_channels, 3, 1, 1),
                                       nn.BatchNorm1d(out_channels),
                                       nn.ReLU(inplace=True),)

        else:
            
            self.conv1 = nn.Sequential(nn.Conv1d(in_channels, out_channels, 3, 1, 1),
                                       nn.ReLU(inplace=True),)
            
    def forward(self, inputs):
        
        outputs = self.conv1(inputs)

        return outputs

In [3]:
class modelResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, is_batchnorm):
        super(modelResBlock, self).__init__()
        # Kernel size: 3*3, Stride: 1, Padding: 1
        if is_batchnorm:
            # 31, 1, 15
            self.conv1 = nn.Sequential(nn.Conv1d(in_channels, out_channels, 3, 1, 1),
                                       nn.BatchNorm1d(out_channels),
                                       nn.ReLU(inplace=True),)
            self.conv2 = nn.Sequential(nn.Conv1d(out_channels, out_channels, 3, 1, 1),
                                       nn.BatchNorm1d(out_channels),)

        else:
            
            # 31, 1, 15
            self.conv1 = nn.Sequential(nn.Conv1d(in_channels, out_channels, 3, 1, 1),
                                       nn.ReLU(inplace=True),)
            self.conv2 = nn.Sequential(nn.Conv1d(in_channels, out_channels, 3, 1, 1),)
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, inputs):
        
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs) + inputs

        return self.relu(outputs)

In [4]:
class  CLFCRN_Encoder(nn.Module):
    def __init__(self, in_channels, out_channels, is_batchnorm):
        super(CLFCRN_Encoder, self).__init__()
        self.in_channels   = in_channels
        self.is_batchnorm  = is_batchnorm
        self.out_channels  = out_channels
        
        filters = [16, 16, 16, 16]
        
        self.conv = nn.Sequential(
                                  modelConv1(self.in_channels, filters[0], self.is_batchnorm),
                                  modelResBlock(filters[0], filters[1], self.is_batchnorm),
                                  modelResBlock(filters[1], filters[2], self.is_batchnorm),
                                  modelResBlock(filters[2], filters[3], self.is_batchnorm),
                                  nn.Conv1d(filters[3], self.out_channels, 3, 1, 1),
                                  nn.ReLU(inplace=True)
        )
        
    def forward(self, inputs):
            
        return self.conv(inputs)

In [5]:
class  CLFCRN_Decoder(nn.Module):
    def __init__(self, in_channels, out_channels, is_batchnorm):
        super(CLFCRN_Decoder, self).__init__()
        self.in_channels   = in_channels
        self.is_batchnorm  = is_batchnorm
        self.out_channels  = out_channels
        
        filters = [16, 16, 16, 16]
        
        self.deconv = nn.Sequential(modelConv1(self.in_channels, filters[0], self.is_batchnorm),
                                    modelResBlock(filters[0], filters[1], self.is_batchnorm),
                                    modelResBlock(filters[1], filters[2], self.is_batchnorm),
                                    modelResBlock(filters[2], filters[3], self.is_batchnorm),
                                    nn.Conv1d(filters[3], self.out_channels, 3, 1, 1),
                                    nn.ReLU(inplace=True))
        
    def forward(self, inputs):
             
        return self.deconv(inputs)

In [10]:
class  Network_clFCRN(nn.Module):
    def __init__(self, in_channels, out_channels, is_batchnorm):
        super(Network_clFCRN, self).__init__()
        self.in_channels   = in_channels
        self.is_batchnorm  = is_batchnorm
        self.out_channels  = out_channels
        
        self.encoder = CLFCRN_Encoder(in_channels, out_channels, is_batchnorm)

        self.decoder = CLFCRN_Decoder(in_channels, out_channels, is_batchnorm)
    
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, z):
        return self.decoder(z)
        
    def forward(self, seismic_data, label_velocity):
        
        pre_velocity = self.encode(seismic_data)
        recon_seismic_data = self.decode(pre_velocity)
        
        pre_seismic_data = self.decode(label_velocity)
        recon_velocity = self.encode(pre_seismic_data)
        
        return pre_velocity, recon_seismic_data, pre_seismic_data, recon_velocity
        

In [11]:
net = Network_clFCRN(1, 1, True)

In [12]:
s_input = torch.randn(20, 1, 93)
v_input = torch.randn(5, 1, 93)

In [13]:
pre_velocity, recon_seismic_data, pre_seismic_data, recon_velocity = net(s_input, v_input)
pre_velocity.shape, recon_seismic_data.shape, pre_seismic_data.shape, recon_velocity.shape

(torch.Size([20, 1, 93]),
 torch.Size([20, 1, 93]),
 torch.Size([5, 1, 93]),
 torch.Size([5, 1, 93]))