In [None]:
'''audio encode'''
class AudioEncode(nn.Module):
    def __init__(self):
        super(AudioEncode,self).__init__()
        
        self.conv1 = nn.Sequential(nn.Conv1d(1,16,kernel_size=300,stride=20,padding=1),
                                   nn.BatchNorm1d(16),
                                   nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(nn.Conv1d(16,32,kernel_size=3,stride=2,padding=1),
                                   nn.BatchNorm1d(32),
                                   nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(nn.Conv1d(32,64,kernel_size=3,stride=2,padding=1),
                                   nn.BatchNorm1d(64),
                                   nn.ReLU(inplace=True))
        self.conv4 = nn.Sequential(nn.Conv1d(64,128,kernel_size=3,stride=2,padding=1),
                                   nn.BatchNorm1d(128),
                                   nn.ReLU(inplace=True))
        self.conv5 = nn.Sequential(nn.Conv1d(128,256,kernel_size=3,stride=2,padding=1),
                                   nn.BatchNorm1d(256),
                                   nn.ReLU(inplace=True))
        self.conv6 = nn.Sequential(nn.Conv1d(256,512,kernel_size=3,stride=2,padding=1),
                                   nn.BatchNorm1d(512),
                                    nn.ReLU(inplace=True))
        
        self.fc = nn.Sequential(nn.Linear(512*6,256),
                               nn.Tanh())
        self.gru = nn.GRU(256,256,2,batch_first=True)
    def forward(self,x):
        #x = self.conv1(x)
        x = self.conv2(x)
        #print(x.shape)
        x = self.conv3(x)
        #print(x.shape)
        x = self.conv4(x)
        #print(x.shape)
        x = self.conv5(x)
        #print(x.shape)
        x = self.conv6(x)
        #print(x.shape)
        x = self.conv7(x)
        #print(x.shape)
        x = x.view(x.shape[0],-1)
        x = self.fc(x)
        #print(x.shape)
        x = x.view(x.shape[0],1,-1)
        x1,h= self.gru(x)
        return x.squeeze(),x1.squeeze()
    def save(self, PATH):
        torch.save(self.state_dict(), PATH)

    def load(self, PATH):
        self.load_state_dict(torch.load(PATH))

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, padding=1):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn1 = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        return out

In [None]:
class Img_Encoder(nn.Module):
    def __init__(self, num_output_length, if_tanh=False):
        super(Img_Encoder, self).__init__()
        self.if_tanh = if_tanh
        self.block1 = ResBlock(3, 32, kernel_size=3, stride=1) 
        self.block2 = ResBlock(32, 64, kernel_size=4, stride=2) 
        self.block3 = ResBlock(64, 128, kernel_size=4, stride=2) 
        self.block4 = ResBlock(128, 256, kernel_size=4, stride=2) 
        self.block5 = ResBlock(256, 512, kernel_size=4, stride=2)
        self.block6 = ResBlock(512, 1024, kernel_size=3, stride=2) 
        self.fc1 = nn.Sequential(nn.Linear(36864, 512), nn.BatchNorm1d(512), nn.LeakyReLU(0.2, inplace=True))
        self.fc2 = nn.Linear(512, num_output_length)

    def forward(self, x):
        x_1 = self.block1(x)
        #print(x_1.shape)
        x_2 = self.block2(x_1)
        #print(x_2.shape)
        x_3 = self.block3(x_2)
        #print(x_3.shape)
        x_4 = self.block4(x_3)
        #print(x_4.shape)
        x_5 = self.block5(x_4)
        #print(x_5.shape)
        x_6 = self.block6(x_5)
        #print(x_6.shape)
        x_7 = x_6.contiguous().view(x_6.shape[0], -1)
        x_7 = self.fc1(x_7)
        x_7 = self.fc2(x_7)
        x_7 = torch.unsqueeze(x_7,2)
        x_7 = torch.unsqueeze(x_7,3)
        #print(x_7.shape)
        if self.if_tanh:
            x_7 = F.tanh(x_7)
        return x_2,x_3,x_4,x_5,x_6,x_7

    def save(self, PATH):
        torch.save(self.state_dict(), PATH)

    def load(self, PATH):
        self.load_state_dict(torch.load(PATH))

In [None]:
class Img_Decoder(nn.Module):
    def __init__(self):
        super(Img_Decoder,self).__init__()
        self.ct2d1 = nn.Sequential(nn.ConvTranspose2d(384,1024,kernel_size=(6,6),stride=(1,1), bias=False),
                                   nn.BatchNorm2d(1024),
                                   nn.ReLU(inplace=True)) 
        
        self.ct2d2 = nn.Sequential(nn.ConvTranspose2d(1024,512,kernel_size=(4,4),stride=(2,2),padding=(1,1), bias=False),
                                   nn.BatchNorm2d(512),
                                   nn.ReLU(inplace=True)) 
        self.cov2d2= nn.Sequential(nn.Conv2d(1024,512,kernel_size=(3,3),stride=(1,1),padding=(1,1), bias=False),
                                   nn.BatchNorm2d(512),
                                   nn.ReLU(inplace=True)) 
        self.ct2d3 = nn.Sequential(nn.ConvTranspose2d(512,256,kernel_size=(5,5),stride=(2,2),padding=(1,1), bias=False),
                                   nn.BatchNorm2d(256),
                                   nn.ReLU(inplace=True))  
        self.cov2d3= nn.Sequential(nn.Conv2d(512,256,kernel_size=(3,3),stride=(1,1),padding=(1,1), bias=False),
                                   nn.BatchNorm2d(256),
                                   nn.ReLU(inplace=True)) 
        self.ct2d4 = nn.Sequential(nn.ConvTranspose2d(256,128,kernel_size=(4,4),stride=(2,2),padding=(1,1), bias=False),
                                   nn.BatchNorm2d(128),
                                   nn.ReLU(inplace=True)) 
        self.cov2d4= nn.Sequential(nn.Conv2d(256,128,kernel_size=(3,3),stride=(1,1),padding=(1,1), bias=False),
                                   nn.BatchNorm2d(128),
                                   nn.ReLU(inplace=True))  
        self.ct2d5 = nn.Sequential(nn.ConvTranspose2d(128,64,kernel_size=(4,4),stride=(2,2),padding=(1,1), bias=False),
                                   nn.BatchNorm2d(64),
                                   nn.ReLU(inplace=True))  
        self.cov2d5= nn.Sequential(nn.Conv2d(128,64,kernel_size=(3,3),stride=(1,1),padding=(1,1), bias=False),
                                   nn.BatchNorm2d(64),
                                   nn.ReLU(inplace=True)) 
        self.ct2d6 = nn.Sequential(nn.ConvTranspose2d(64,3,kernel_size=(4,4),stride=(2,2),padding=(1,1), bias=False),
                                   nn.Tanh()) 
        
        
    def forward(self,x,x_1,x_2,x_3,x_4,x_5):
        
        x = self.ct2d1(x)
        
        x = self.ct2d2(x)
        x = torch.cat((x,x_4),1)
        x = self.cov2d2(x)
    
        x = self.ct2d3(x)
        x = torch.cat((x,x_3),1)
        x = self.cov2d3(x)
        
        x = self.ct2d4(x)
        x = torch.cat((x,x_2),1)
        x = self.cov2d4(x)
        
        x = self.ct2d5(x)
        x = torch.cat((x,x_1),1)
        x = self.cov2d5(x)
        
        x = self.ct2d6(x)
        #print(x.shape)
        
        return x
    def save(self, PATH):
        torch.save(self.state_dict(), PATH)

    def load(self, PATH):
        self.load_state_dict(torch.load(PATH))

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.img_encoder = Img_Encoder(128)
        self.audencoder = AudioEncode()
        self.img_decoder = Img_Decoder()
        self.fc = nn.Linear(384,128)
                            
    def forward(self,x,v):
        x_1,x_2,x_3,x_4,x_5,x = self.img_encoder(x)
        v,v1 = self.audencoder(v) 
        v = torch.unsqueeze(v,2)
        v = torch.unsqueeze(v,3)
        #print(v.shape)
        #print(x.shape)
        v1 = torch.unsqueeze(v1,2)
        v1 = torch.unsqueeze(v1,3) 
        #print(v1.shape)
        x = torch.cat((x,v),1)      
        x = x.squeeze()
        x = self.fc(x)
        x = torch.unsqueeze(x,2)
        x = torch.unsqueeze(x,3)
        x = torch.cat((x,v1),1)   
        
        x = self.img_decoder(x,x_1,x_2,x_3,x_4,x_5)
        return x
    def save(self, PATH):
        torch.save(self.state_dict(), PATH)

    def load(self, PATH):
        self.load_state_dict(torch.load(PATH))