In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
class CrossModalityEncoderBlock(nn.Module):
    def __init__(self, feature_dim, num_heads):
        super(CrossModalityEncoderBlock, self).__init__()
        self.cross_attention = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, feature_dim)
        )
        self.layer_norm1 = nn.LayerNorm(feature_dim)
        self.layer_norm2 = nn.LayerNorm(feature_dim)
    
    def forward(self, x1, x2):
        # Cross-Attention: x1 attends to x2
        attn_output, _ = self.cross_attention(x1, x2, x2)
        x1 = self.layer_norm1(x1 + attn_output)  # Residual connection
        
        # Feed-Forward Network
        ffn_output = self.ffn(x1)
        x1 = self.layer_norm2(x1 + ffn_output)  # Residual connection
        return x1


In [3]:
class QuestionCaptionAlignment(nn.Module):
    def __init__(self, feature_dim, num_heads, num_layers):
        super(QuestionCaptionAlignment, self).__init__()
        self.layers = nn.ModuleList([CrossModalityEncoderBlock(feature_dim, num_heads) for _ in range(num_layers)])
    
    def forward(self, Fq, Fc):
        for layer in self.layers:
            Fc = layer(Fc, Fq)  # Align caption with question feature
            Fq = layer(Fq, Fc)  # Align question with enhanced caption feature
        return Fq, Fc


In [4]:
Fc = torch.load('D:/Project_phase_1/text modality/caption_features.pt')  # Caption feature encoding
Fq = torch.load('D:/Project_phase_1/text modality/question_features.pt')  # Question feature encoding


  Fc = torch.load('D:/Project_phase_1/text modality/caption_features.pt')  # Caption feature encoding
  Fq = torch.load('D:/Project_phase_1/text modality/question_features.pt')  # Question feature encoding


In [None]:
# Hyperparameters
feature_dim = 768  # Adjust based on your encoding dimensions
num_heads = 8
num_layers_QC = 5  # NQC layers for Question-Caption Alignment

# Instantiate alignment modules
qc_alignment = QuestionCaptionAlignment(feature_dim, num_heads, num_layers_QC)


# Forward pass for Question-Caption alignment
Fq_enhanced, Fc_q = qc_alignment(Fq, Fc)


In [6]:
Fc.shape

torch.Size([15, 1, 768])

In [7]:
Fq.shape

torch.Size([20, 8, 768])

In [9]:
# Ensure matching batch size
Fc = Fc.repeat(1, 8, 1)  # Repeat the caption feature tensor for 8 batches

# Ensure matching sequence length
Fc = Fc[:, :20, :]  # Truncate the caption feature tensor if needed

# Forward pass
Fq_enhanced, Fc_q = qc_alignment(Fq, Fc)


In [10]:
Fq_enhanced

tensor([[[-0.7723, -1.3168, -0.3231,  ...,  0.2139, -0.3807, -0.3354],
         [-0.8277, -1.0029, -0.4432,  ...,  0.3182, -0.5649, -0.3708],
         [-0.8892, -1.1870, -0.2554,  ...,  0.3752, -1.3767, -0.5385],
         ...,
         [-1.0942, -1.2227, -0.6341,  ...,  0.6075, -0.9951, -0.5390],
         [-1.0568, -0.9494, -0.7056,  ...,  0.6367, -0.9933,  0.1635],
         [-1.0929, -1.3234, -0.4280,  ...,  0.5968, -0.4450,  0.0582]],

        [[ 0.0341, -1.5676, -1.3659,  ...,  1.7105, -1.5064, -0.2091],
         [-0.8501, -0.8165, -1.0338,  ...,  1.1966, -1.1694, -0.3077],
         [-0.6272, -0.4620, -0.0289,  ...,  0.7102, -1.3024,  0.0893],
         ...,
         [-0.8615, -0.6658, -0.8772,  ...,  1.4900, -1.7007, -0.2039],
         [-0.5709, -0.5890, -0.7153,  ...,  1.3896, -1.3399,  0.3017],
         [ 0.0382, -1.8895, -0.5701,  ..., -0.2208, -0.3327,  0.8098]],

        [[-0.3329, -1.8132, -1.5388,  ...,  1.5561, -0.4964, -0.3360],
         [-0.8636, -1.2953, -0.7040,  ...,  1

In [11]:
Fc_q

tensor([[[-1.0801,  0.2052, -0.6069,  ...,  2.0831,  0.3651, -0.9781],
         [-1.0933,  0.1673, -0.5936,  ...,  2.0562,  0.0903, -0.7492],
         [-1.1618,  0.4331, -0.3505,  ...,  2.0747,  0.4261, -0.7308],
         ...,
         [-0.9664,  0.5298, -0.3527,  ...,  2.1870,  0.1597, -0.9395],
         [-1.0288,  0.4768, -0.4282,  ...,  2.2127,  0.1822, -0.8281],
         [-1.1626,  0.2946, -0.5607,  ...,  2.1647,  0.3652, -0.9521]],

        [[ 0.0869,  0.6148,  0.1197,  ...,  0.4189,  0.2110,  0.3955],
         [ 0.1331,  0.5334,  0.0828,  ...,  0.4501, -0.0303,  0.7086],
         [ 0.0352,  0.8132,  0.3428,  ...,  0.5198,  0.3541,  0.6548],
         ...,
         [ 0.2056,  0.8969,  0.3434,  ...,  0.5657,  0.0664,  0.5001],
         [ 0.1231,  0.8877,  0.2921,  ...,  0.4595,  0.1375,  0.6870],
         [ 0.0416,  0.7131,  0.2431,  ...,  0.4048,  0.3038,  0.4431]],

        [[-0.4620, -0.3657, -0.0812,  ...,  0.3024,  0.4879, -1.1299],
         [-0.3730, -0.3311, -0.0841,  ...,  0