In [3]:

import torch.nn as nn

class key_embedd(nn.Module):

  def __init__(self, in_channels, out_channels, kernel_size, stride=(1,2,2), padding=(1,1,1), bias=True):

    super(key_embedd,self).__init__()

    self.key_dim = out_channels

    self.spatial_conv = nn.Conv3d(in_channels, self.key_dim, kernel_size,stride=(1,1,1), padding=padding,bias=bias)

    self.bn = nn.BatchNorm3d(self.key_dim)

    self.relu = nn.ReLU()

    self.spatial_conv1 = nn.Conv3d(self.key_dim, self.key_dim+32, kernel_size,stride=stride, padding=padding,bias=bias)

    self.bn1 = nn.BatchNorm3d(self.key_dim+32)

    self.relu1 = nn.ReLU()

  def forward(self, x):
    # X = (B,W,N,H,W,C)

    # x = x.view(x.shape[0],x.shape[1],x.shape[2]*x.shape[3],x.shape[4],x.shape[5])
    # print(x.shape)

    x = self.relu(self.bn(self.spatial_conv(x)))

    x = self.relu1(self.bn1(self.spatial_conv1(x)))

    shape = x.shape
    x = x.view(x.shape[0],x.shape[1],-1)
    return x,shape

In [4]:

import torch.nn as nn

class value_embedd(nn.Module):

  def __init__(self, in_channels, out_channels, kernel_size, stride=(1,2,2), padding=(1,1,1), bias=True):

    super(value_embedd,self).__init__()

    self.key_dim = out_channels

    self.spatial_conv = nn.Conv3d(in_channels, self.key_dim, kernel_size,stride=(1,1,1), padding=padding,bias=bias)

    self.bn = nn.BatchNorm3d(self.key_dim)

    self.relu = nn.ReLU()

    self.spatial_conv1 = nn.Conv3d(self.key_dim, self.key_dim+32, kernel_size,stride=stride, padding=padding,bias=bias)

    self.bn1 = nn.BatchNorm3d(self.key_dim+32)

    self.relu1 = nn.ReLU()

  def forward(self, x):

    # x = x.view(x.shape[0],x.shape[1],x.shape[2]*x.shape[3],x.shape[4],x.shape[5])
    # print(x.shape)


    x = self.relu(self.bn(self.spatial_conv(x)))

    x = self.relu1(self.bn1(self.spatial_conv1(x)))

    
    shape = x.shape
    x = x.view(x.shape[0],x.shape[1],-1)
    return x,shape

In [5]:
import torch.nn as nn

class query_embedd(nn.Module):

  def __init__(self, in_channels, out_channels, kernel_size, stride=(1,2,2), padding=(1,1,1), bias=True):

    super(query_embedd,self).__init__()

    self.key_dim = out_channels

    self.spatial_conv = nn.Conv3d(in_channels, self.key_dim, kernel_size,stride=(1,1,1), padding=padding,bias=bias)

    self.bn = nn.BatchNorm3d(self.key_dim)

    self.relu = nn.ReLU()

    self.spatial_conv1 = nn.Conv3d(self.key_dim, self.key_dim+32, kernel_size,stride=stride, padding=padding,bias=bias)

    self.bn1 = nn.BatchNorm3d(self.key_dim+32)

    self.relu1 = nn.ReLU()

  def forward(self, x):

    # x = x.view(x.shape[0],x.shape[1],x.shape[2]*x.shape[3],x.shape[4],x.shape[5])


    x = self.relu(self.bn(self.spatial_conv(x)))

    x = self.relu1(self.bn1(self.spatial_conv1(x)))

    shape = x.shape
    
    x = x.view(x.shape[0],x.shape[1],-1)

    # print(x.shape,shape)

    return x,shape

In [6]:
#x =  torch.ones((16,12,3,8,64,64))


import torch.nn as nn

class key_embedd(nn.Module):

  def __init__(self, in_channels, out_channels, kernel_size, stride=(1,2,2), padding=(1,1,1), bias=True):

    super(key_embedd,self).__init__()

    self.key_dim = out_channels

    self.spatial_conv = nn.Conv3d(in_channels, self.key_dim, kernel_size,stride=(1,1,1), padding=padding,bias=bias)

    self.bn = nn.BatchNorm3d(self.key_dim)

    self.relu = nn.ReLU()

    self.spatial_conv1 = nn.Conv3d(self.key_dim, self.key_dim+32, kernel_size,stride=stride, padding=padding,bias=bias)

    self.bn1 = nn.BatchNorm3d(self.key_dim+32)

    self.relu1 = nn.ReLU()

  def forward(self, x):
    # X = (B,W,N,H,W,C)

    # x = x.view(x.shape[0],x.shape[1],x.shape[2]*x.shape[3],x.shape[4],x.shape[5])
    # print(x.shape)

    x = self.relu(self.bn(self.spatial_conv(x)))

    x = self.relu1(self.bn1(self.spatial_conv1(x)))

    shape = x.shape
    x = x.view(x.shape[0],x.shape[1],-1)
    return x,shape

class Attention_head(nn.Module):

  def __init__(self,in_channels, out_channels, kernel_size, stride=(1,2,2), padding=(1,1,1), bias=True):
    super(Attention_head,self).__init__()

    self.key_embed = key_embedd(in_channels,out_channels,kernel_size)
    self.query_embed = key_embedd(in_channels,out_channels,kernel_size)
    self.value_embed = key_embedd(in_channels,out_channels//2,kernel_size)


  def forward(self,x):

    k,_ = self.key_embed(x)

    v,shape1 = self.value_embed(x)

    q,_ = self.query_embed(x)

    score = torch.matmul(k.permute(0,2,1),q)

    s = nn.Softmax(dim=2)(score)

    # print(s.shape,v.shape)
    v = torch.matmul(s,v.permute(0,2,1))

    v = v.permute(0,2,1)

    # print(v.shape,shape1)

    v = v.view(shape1[0],shape1[1],shape1[2],shape1[3],shape1[4])

    return v


class MHAttention(nn.Module):

  def __init__(self,in_channels, out_channels, kernel_size, stride=(1,2,2), padding=(1,1,1), bias=True):

    super(MHAttention,self).__init__()
    
    self.attention1 = Attention_head(3, 32, 3, stride=(1,2,2), padding=(1,1,1), bias=True)
    self.attention2 = Attention_head(32//2 + 32 , 64, 3, stride=(1,2,2), padding=(1,1,1), bias=True)
    self.attention3 = Attention_head(64//2 + 32  , 128, 3, stride=(1,2,2), padding=(1,1,1), bias=True)
    self.attention4 = Attention_head(128//2 + 32  , 256, 3, stride=(1,2,2), padding=(1,1,1), bias=True)

    # self.attention = [Attention_head(in_channels, out_channels, kernel_size, stride=(1,2,2), padding=(1,1,1), bias=True)]

  def forward(self,x):

    x = self.attention1(x)
    # print(x.shape,"1")
    x = self.attention2(x)
    # print(x.shape,"2")

    x = self.attention3(x)
    # print(x.shape,"3")

    x = self.attention4(x)
    # print(x.shape,"4")


    return x




  


In [7]:
import torch
(B,N,C,T,H,W) = (8,8,3,4,32,32)
x =  torch.ones((8,8,3,4,32,32))
x = x.view(B*N,3,4,32,32)
f = MHAttention(3,16,3)(x)
# f = Attention_head(3,24,3)(x)

f.shape

torch.Size([64, 160, 4, 2, 2])