In [4]:
import torch
from torch import nn
from torch.nn import functional as F

from .bam import BAM
from .cbam import CBAM
from .conv import Conv2d



class SyncNet_color(nn.Module):
    def __init__(self):
        super(SyncNet_color, self).__init__()

        self.face_encoder = nn.Sequential(
            Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),

            Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
            Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)

        self.audio_encoder = nn.Sequential(
            Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
            Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
            Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)

    def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
        face_embedding = self.face_encoder(face_sequences)
        audio_embedding = self.audio_encoder(audio_sequences)

        audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
        face_embedding = face_embedding.view(face_embedding.size(0), -1)

        audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
        face_embedding = F.normalize(face_embedding, p=2, dim=1)


        return audio_embedding, face_embedding
    
    
class _Syncnet_color_Attention(SyncNet_color):
    attention_dict = {
        'None': None,
        'bam': BAM,
        'cbam': CBAM,
        'cbam_inter': CBAM
    }
    def __init__(self, face_attention_type="none", audio_attention_type="none"):
        self.initialize()
        self.face_attention_type = face_attention_type
        self.audio_attention_type = audio_attention_type
        
    def initialize(self):
        super(_Syncnet_color_Attention, self).__init__()
        self.built = False
        
    def from_pretrained(self, syncnet96):
        self.face_encdoer = syncnet96.face_encoder
        self.audio_encoder = syncnet96.audio_encoder
    
    def build_model(self):
        if not self.built:
            self._build_model()
            self.built=True
        else:
            print('model has already been built')
            
    def _build_model(self):
        pass
    
    def add_attention(self, encoder, attention_type, attention_info):
        add_len=0
        for l, in_channels in attention_info.items():
            add_len += 1
            encoder = nn.Sequential(
                *encoder[:l+add_len],
                self.attention_dict[attention_type](in_channels),
                *encoder[l+add_len:]
            )
            
        return encoder
        
#     def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
#         # Encoding
#         face_embedding = face_sequences
#         for i in range(len(self.face_encoder)):
#             face_embedding = self.face_encoder[i](face_embedding)
#             if self.face_attention_type is not "none" and i in self.face_attentions.keys():
#                 face_embedding = self.face_attentions[i](face_embedding)
        
#         audio_embedding = audio_sequences
#         for i in range(len(self.audio_encoder)):
#             audio_embedding = self.audio_encoder[i](audio_embedding)
#             if self.audio_attention_type is not "none" and i in self.audio_attentions.keys():
#                 audio_embedding = self.audio_attentions[i](audio_embedding)

#         # Reshape        
#         audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
#         face_embedding = face_embedding.view(face_embedding.size(0), -1)
        
#         # Normalize
#         audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
#         face_embedding = F.normalize(face_embedding, p=2, dim=1)

#         return audio_embedding, face_embedding
        
        
class Syncnet_color_96(_Syncnet_color_Attention):
    def __init__(self, face_attention_type="none", audio_attention_type="none", freeze=False):
        super(Syncnet_color_96, self).__init__(face_attention_type, audio_attention_type)
        self.freeze = freeze
        
    def _build_model(self):
        if self.face_attention_type is not "none":
            self.face_encoder = self.add_attention(self.face_encoder, self.face_attention_type,
                                                  {3: 64, 7: 128, 10: 256, 13: 512}) # layer sequence: in_channels
            
        if self.audio_attention_type is not "none":
            self.audio_encoder = self.add_attention(self.audio_encoder, self.audio_attention_type,
                                                   {2: 32, 5: 64, 8: 128, 11: 256}) # layer sequence, in_channels
        
    
class Syncnet_color_120(_Syncnet_color_Attention):
    def __init__(self, face_attention_type="none", audio_attention_type="none", freeze=False):
        super(Syncnet_color_120, self).__init__(face_attention_type, audio_attention_type)
        self.freeze = freeze
            
    def _build_model(self):
        # add 2x2 conv2d layer
        self.face_encoder.add_module(str(len(self.face_encoder)), Conv2d(512, 512, kernel_size=2, stride=1, padding=0)) # 2,2 -> 1,1
        
        # attention
        if self.face_attention_type is not "none":
            self.face_encoder = self.add_attention(self.face_encoder, self.face_attention_type,
                                                  {3: 64, 7: 128, 10: 256, 13: 512}) # layer sequence: in_channels
            
        if self.audio_attention_type is not "none":
            self.audio_encoder = self.add_attention(self.audio_encoder, self.audio_attention_type,
                                                   {3: 64, 6: 128, 9: 256, 12: 512}) # layer sequence, in_channels
        
        
class Syncnet_color_144(_Syncnet_color_Attention):
    def __init__(self, face_attention_type="none", audio_attention_type="none", freeze=False):
        super(Syncnet_color_144, self).__init__(face_attention_type, audio_attention_type)
        self.freeze = freeze
        
    def _build_model(self):
        # add 2x2 conv2d layer
        self.face_encoder.add_module(str(len(self.face_encoder)), Conv2d(512, 512, kernel_size=3, stride=2, padding=1)) # 3,3 -> 1,1
        self.face_encoder.add_module(str(len(self.face_encoder)), Conv2d(512, 512, kernel_size=2, stride=1, padding=0)) # 2,2 -> 1,1
        
        # attention
        if self.face_attention_type is not "none":
            self.face_encoder = self.add_attention(self.face_encoder, self.face_attention_type,
                                                  {3: 64, 7: 128, 10: 256, 13: 512}) # layer sequence: in_channels
        if self.audio_attention_type is not "none":
            self.audio_encoder = self.add_attention(self.audio_encoder, self.audio_attention_type,
                                                   {3: 64, 6: 128, 9: 256, 12: 512}) # layer sequence, in_channels
            
        
class Syncnet_color_192(_Syncnet_color_Attention):
    def __init__(self, face_attention_type="none", audio_attention_type="none", freeze=False):
        super(Syncnet_color_192, self).__init__(face_attention_type, audio_attention_type)
        self.freeze = freeze

    def _build_model(self):
        # add module
        self.face_encoder = nn.Sequential(
            *self.face_encoder[:14], # 96, 192 -> 12, 12
            Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 6, 6
            Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
            *self.face_encoder[14:] # 1, 1
        )
        
        # attention
        if self.face_attention_type is not "none":
            self.face_encoder = self.add_attention(self.face_encoder, self.face_attention_type,
                                                  {3: 64, 7: 128, 10: 256, 13: 512, 16: 512}) # layer sequence: in_channels
        if self.audio_attention_type is not "none":
            self.audio_encoder = self.add_attention(self.audio_encoder, self.audio_attention_type,
                                                   {3: 64, 6: 128, 9: 256, 12: 512}) # layer sequence, in_channels
            
            
class Syncnet_color_288(_Syncnet_color_Attention):
    def __init__(self, face_attention_type="none", audio_attention_type="none", freeze=False):
        super(Syncnet_color_288, self).__init__(face_attention_type, audio_attention_type)
        self.freeze = freeze

    def _build_model(self):
        # add module
        self.face_encoder = nn.Sequential(
            *self.face_encoder[:14], # 144, 288 -> 18, 18
            Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 9, 9
            Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
            
            Conv2d(512, 512, kernel_size=3, stride=2, padding=2), # 6, 6
            Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
            *self.face_encoder[14:] # 1, 1
        )
        
        # attention
        if self.face_attention_type is not "none":
            self.face_encoder = self.add_attention(self.face_encoder, self.face_attention_type,
                                                  {3: 64, 7: 128, 10: 256, 13: 512, 16: 512, 19: 512}) # layer sequence: in_channels
        if self.audio_attention_type is not "none":
            self.audio_encoder = self.add_attention(self.audio_encoder, self.audio_attention_type,
                                                   {3: 64, 6: 128, 9: 256, 12: 512}) # layer sequence, in_channels
            
            
def load_syncnet(img_size, face_attention_type, audio_attention_type, **kwargs):
    syncnet_dict = {
        96: Syncnet_color_96,
        120: Syncnet_color_120,
        144: Syncnet_color_144,
        192: Syncnet_color_192,
        288: Syncnet_color_288
    }
    
    syncnet = syncnet_dict.get(img_size, None)
    
    if syncnet is None:
        raise 'img_size {} is not valid.'.format(img_size)
        return
    
    return syncnet(face_attention_type=face_attention_type, audio_attention_type=audio_attention_type, **kwargs)

'여'