In [44]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy


In [61]:
class Conv_cell(nn.Module):
    def __init__(self, input_channels, output_channels, stride=(1,1), padding=0, conv_kernel_size=(3,3)):
        super(Conv_cell, self).__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=conv_kernel_size,padding=padding)
        self.bn = nn.BatchNorm2d(output_channels)
        self.pool= nn.MaxPool2d(kernel_size=(2,2), stride=stride, padding=padding)
        
    def forward(self,input_tensor):
        output = self.conv(input_tensor)
        output = self.bn(output)
        output = self.pool(F.relu(output))
        return output

In [62]:
class BCN(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(BCN, self).__init__()
        n_channels = [int(output_channels/4) , int(output_channels/2), output_channels]
        self.conv1 = Conv_cell(input_channels, n_channels[0], stride=(2,2))
        self.conv2 = Conv_cell(n_channels[0],  n_channels[1], stride=(2,2))
        self.conv3 = Conv_cell(n_channels[1],  n_channels[2])
    
    def forward(self, tensor):
        output = self.conv1(tensor)
        output = self.conv2(output)
        output = self.conv3(output)
        
        return output

In [63]:
test = torch.FloatTensor(192, 3, 100, 100)

In [64]:
bcn_layer = BCN(3, 256)

In [65]:
hi = bcn_layer(test)

In [66]:
hi.shape

torch.Size([192, 256, 20, 20])

In [135]:
class MCN(nn.Module):
    def __init__(self, bcn_output_channels):
        super(MCN, self).__init__()
        self.channels = bcn_output_channels
        self.conv1 = Conv_cell(self.channels, self.channels, stride=(2,1), padding = (1,0))
        self.conv2 = Conv_cell(self.channels, self.channels, stride=(2,1), padding = (1,1))
        self.conv3 = Conv_cell(self.channels, self.channels, stride=(2,1), padding = (1,0))
        self.conv4 = Conv_cell(self.channels, self.channels)
        self.char_possibility = nn.Conv2d(self.channels, 1, kernel_size = (1,1))
        
        
    def forward(self, bcn_output):
        rotated_features = []
        P = []
        for degree in range(4):
            rotated = bcn_output.rot90(k = degree, dims = [2,3])
            output = self.conv1(rotated)
#             print('conv1 : ',output.shape)
            output = self.conv2(output)
#             print('conv2 : ', output.shape)
            output = self.conv3(output)
#             print('conv3 : ', output.shape)
            rotated_feature = self.conv4(output)
#             print('rotated _feature : ', rotated_feature.shape)
            rotated_char = self.char_possibility(rotated_feature).squeeze_(2)
#             print('rotated_char', rotated_char.shape)
            
            rotated_features.append(rotated_feature)
            P.append(rotated_char)
            character_placement_possibility = torch.cat(P, dim=1).softmax(1)
            
        return rotated_features, character_placement_possibility
           

In [136]:
mcn_layer = MCN(256)

In [137]:
rotated_features, cpp = mcn_layer(hi)

In [159]:
cpp.shape

torch.Size([192, 4, 12])

In [180]:
class Encoder(nn.Module):
    def __init__(self, bcn_output_channels):
        super(Encoder, self).__init__()
        self.BiLSTM = nn.LSTM(bcn_output_channels, bcn_output_channels, num_layers=2, bidirectional=True)
        
    def forward(self, rotated_features, character_place_possibility):
        feature_codes = 0
        for idx, rotated_feature in enumerate(rotated_features):
            feature_vector, _ = self.BiLSTM(rotated_feature.permute(0, 3, 1, 2).squeeze_(3))
            p_i = character_place_possibility[:, idx, :].unsqueeze(2)
            feature_code = feature_vector * p_i
            feature_codes += feature_code
        
        return feature_codes

In [181]:
encoder = Encoder(256)

In [182]:
feature_codes = encoder(rotated_features, cpp)

In [None]:
class Decoder()