In [30]:
import torch
import torch.nn as nn
import numpy as np
import os
import sys
import cv2
from PIL import Image
import easydict
sys.path.append('../Whatiswrong')
import Extract
import utils
import torch.nn.functional as F

In [5]:
import torchvision.models as models

In [None]:
class Basemodel(nn.Module):
    def __init__(self, opt):
        super(Basemodel, self).__init__()
        self.encoder = Resnet_EFIFSTR(with_lstm=True)
        self.decoder = 
        
    def forward(self, img)

In [92]:
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class ResnetBlock(nn.Module):

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(ResnetBlock, self).__init__()
        self.conv1 = conv1x1(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

class Resnet_encoder(nn.Module):
    def __init__(self, n_group=1):
        super(Resnet_encoder, self).__init__()
        self.n_group= n_group
        
        in_channels=3
        self.layer0 = nn.Sequential(nn.Conv2d(in_channels, 32, kernel_size=(3,3), stride=1, padding=1, bias=False),
                                   nn.BatchNorm2d(32),
                                   nn.ReLU(inplace=True))
        self.inplanes = 32
        self.layer1 = self._make_layer(32, 3, [2,2])
        self.layer2 = self._make_layer(64, 4, [2,2])
        self.layer3 = self._make_layer(128, 6, [2,1])
        self.layer4 = self._make_layer(256, 6, [1,1])
        self.layer5 = self._make_layer(512, 3, [1,1])
        
        self.rnn = nn.LSTM(512, 256, bidirectional=True, num_layers=2, batch_first=True)
        self.out_planes = 2 * 256

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def _make_layer(self, planes, blocks, stride):
        downsample = None
        if stride !=[1,1] or self.inplanes != planes:
            downsample = nn.Sequential(conv1x1(self.inplanes, planes, stride),
                                      nn.BatchNorm2d(planes))
            
        layers = []
        layers.append(ResnetBlock(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(ResnetBlock(self.inplanes, planes))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x0 = self.layer0(x)
        x1 = self.layer1(x0)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        x5 = self.layer5(x4)
        feature_map = x5
        
        batch_size, channels, feature_h, feature_w = feature_map.shape
        cnn_feat = F.max_pool2d(feature_map, (feature_h, 1))
        cnn_feat = cnn_feat.permute(0, 3, 1, 2).squeeze(3)

        _, (holistic_feature, _) = self.rnn(cnn_feat)
        return feature_map, holistic_feature
                   

In [93]:
extract = Resnet_EFIFSTR()

In [100]:
inputt = torch.FloatTensor(1,3, 48, 160)

In [102]:
x5, rnn_feat = extract(inputt)

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, num_classes, enc_dim, dec_dim, att_dim, opt):
        super(Decoder, self).__init__()
        self.attention_unit = Attention_unit()
        self.input_embedding = nn.Embedding(num_classes+1, att_dim) # including <BOS>
        self.lstm = nn.LSTM(enc_dim, dec_dim , batch_first=True)
        self.num_classes = num_classes
        self.fc = nn.Linear(dec_dim + att_dim, num_classes +1 ) # including <EOS>
        self.opt = opt
        
    def forward(self, feature_map, holistic_feature, Input, is_train):
        x, target, length = Input
        batch_size = x.size(0)
        _, h0 = self.lstm(holistic_feature)
        
        logits = []
        if is_train:
            for i in range(length):
                if i == 0:
                    input_label = torch.zeros(batch_size, dtype= torch.long).fill_(num_classes) # the last one is used as the <BOS>
                    input_vector = self.input_embedding(input_label)
                    _, hidden_state = self.lstm(input_vector, h0)
                else: 
                    input_label = target[:,i-1]
                    input_vector = self.input_embedding(input_label)
                    _, hidden_state = self.lstm(input_vector, hidden_state)
                glimpse_vector = self.attention_unit(feature_map, hidden_state[0])
                logit = self.fc(torch.cat([hidden_state[0], glimpse_vector], axis=2))
                logits.append(logit)
                
        else:
            for i in range(self.opt.max_length):
                if i == 0:
                    input_label = torch.zeros(batch_size, dtype= torch.long).fill_(num_classes) # the last one is used as the <BOS>
                    input_vector = self.input_embedding(input_label)
                    _, hidden_state = self.lstm(input_vector, h0)
                else:
                    input_vector = self.input_embedding(target)
                    _, hidden_state = self.lstm(input_vector, hidden_state)
                    
                glimpse_vector = self.attention_unit(feature_map, hidden_state[0])
                logit = self.fc(torch.cat([hidden_state[0], glimpse_vector], axis=2))
                logits.append(logit)
                y_pred = torch.argmax(torch.softmax(logit, axis=1), 1)
                target = y_pred
        
        return logits
        
        
        
        
class Attention_unit(nn.Module):
    
    def __init__(self, fmap_dim, lstm_dim, attn_dim):
        super(Attention_unit, self).__init__()
        self.fmap_dim = fmap_dim
        self.lstm_dim = lstm_dim
        self.Nin = nn.Conv2d(lstm_dim, attn_dim, kernel_size=1)
        self.Fmap_conv = nn.Conv2d(fmap_dim, attn_dim, kernel_size=3, padding=1)
        
    def forward(self, fmap, hidden_state):
        batch_size, channel, height, width = fmap.shape
        nin_res =  Nin(lstm_dim)
        tiled = nin_res.repeat(channel, height, width)
        
        
        

In [116]:
torch.Tensor([[1,2,3], [4,5,6]]).repeat(1, 2)

tensor([[1., 2., 3., 1., 2., 3.],
        [4., 5., 6., 4., 5., 6.]])

In [109]:
inputt = torch.FloatTensor(1, 3, 65, 65)

In [110]:
nin = nn.Conv2d(3, 256, kernel_size=1)

In [111]:
nin_res = nin(inputt)

In [112]:
nin_res.shape

torch.Size([1, 256, 65, 65])

In [87]:
output.shape

torch.Size([16, 40, 512])

In [88]:
h_0, c_0 = sub

In [89]:
h_0.shape

torch.Size([1, 16, 512])

In [90]:
c_0.shape

torch.Size([1, 16, 512])

In [67]:
input_embedding = nn.Embedding(120+2, 512)

In [70]:
input_embedding()

TypeError: embedding(): argument 'indices' (position 2) must be Tensor, not int

In [85]:
input = torch.zeros(16,1, dtype= torch.long)

torch.Size([16])

In [81]:
input_embedding(input)

torch.Size([16, 1, 512])