In [1]:
import datetime
import os
import random
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

import torchvision.models as models_vision

import core.config as conf
from torchvision.transforms import transforms

from fn_networks import BackboneAud, SubnetAud, SubnetVid, show_feature_map

network path /Users/umar_m/Projects/MSc-project/AV-spatial-coherence/results/checkpoints/MC_full_AVOL_vid


In [72]:
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)


class AVE_Net(nn.Module):


    def __init__(self, 
    heatmap=conf.dnn_arch['heatmap'], 
    inference = conf.training_param['inference'],
    cam_size=11
    ):
        
        self.inference = inference

        super().__init__()
        self.heatmap = heatmap
        
        self.VideoNet = models_vision.vgg11(pretrained=True).features[0:20]

        for layer in self.VideoNet:
                if isinstance(layer, nn.Conv2d):
                    layer.requires_grad_ = False

        self.AudioNet = BackboneAud()
        self.AudioMerge = SubnetAud(BN=False)

        self.FC1_aud = nn.Linear(512, 128).apply(init_weights)
        self.FC2_aud = nn.Linear(128+cam_size, 128).apply(init_weights)

        self.FC1_img = nn.Linear(512, 128).apply(init_weights)
        self.FC2_img = nn.Linear(128, 128).apply(init_weights)

        self.FC3 = nn.Linear(1,2)

    def forward(self, x_in, y_in, cam=None, BS_pos=None):

        x1 = self.VideoNet(x_in)                                                # 512 x 14 x 14
        y1 = self.AudioNet(y_in)                                                # 512 x H x W 

        h_img, w_img = x1.shape[-2], x1.shape[-1]
        h_aud, w_aud = y1.shape[-2], y1.shape[-1]

        x1 = torch.max_pool2d(x1, (h_img, w_img), (h_img, w_img)).squeeze(-1).squeeze(-1)               # 512
        y1 = torch.max_pool2d(y1, (h_aud, w_aud), (h_aud, w_aud)).squeeze(-1).squeeze(-1)               # 512

        x1 = self.FC1_img(x1)                   # 128
        x1 = self.FC2_img(x1)                   # 128

        y1 = self.FC1_aud(y1)                   # 128
        y1 = torch.concat((y1, cam), dim=-1)    # 139
        y1 = self.FC2_aud(y1)                   # 128

        # L2 norm

        x1_norm = torch.linalg.vector_norm(x1, dim=-1, keepdim=True)
        y1_norm = torch.linalg.vector_norm(y1, dim=-1, keepdim=True)

        x1 = x1 / x1_norm
        y1 = y1 / y1_norm

        x = (x1 - y1).pow(2).sum(-1).sqrt().unsqueeze(-1)

        x = self.FC3(x)

        return x

In [73]:
net = AVE_Net()



Custom for Audio


In [74]:
img = torch.randn((2, 3, 224, 224))
aud = torch.randn((2, 16, 960, 64))
cam = torch.randn((2, 11))

In [75]:
out = net(img, aud, cam)

torch.Size([2, 2])


In [76]:
out

tensor([[0.3272, 0.0439],
        [0.3286, 0.0488]], grad_fn=<AddmmBackward0>)

In [16]:
for o in out:
    print(o.shape)

torch.Size([1])
torch.Size([1, 14, 14])


In [4]:
word = 'abcde-cam08'

word = word[:-word.find('cam')]

word

'abcde'