In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, backbone_name='resnet50', pretrained=True, hidden_dim=256):
        super(ImageEncoder, self).__init__()

        backbone = models.resnet50(pretrained=pretrained)
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])
        
        self.hidden_dim = hidden_dim
        self.conv1x1 = nn.Conv2d(backbone.fc.in_features, hidden_dim, kernel_size=1)
        
        self.positional_encoding = self._get_positional_encoding()

    
    def forward(self, x):
        features = self.backbone(x)
        features = self.conv1x1(features)
        
        features = features + self.positional_encoding
        
        return features

    def _get_positional_encoding(self, height=32, width=32):
        pe = torch.zeros(self.hidden_dim, height, width)
        y, x = torch.meshgrid(torch.arange(height), torch.arange(width), indexing='ij')
        div_term = torch.exp(torch.arange(0., self.hidden_dim, 2) * -(torch.log(torch.tensor(10000.0)) / self.hidden_dim))
        
        pe[0::2, :, :] = torch.sin(x.unsqueeze(0) * div_term.unsqueeze(1).unsqueeze(2))
        pe[1::2, :, :] = torch.cos(x.unsqueeze(0) * div_term.unsqueeze(1).unsqueeze(2))
        
        pe = pe.unsqueeze(0)
        
        return pe