<a href="https://colab.research.google.com/github/tsanoop887-hash/AIF360/blob/main/cross_attention_metric_head.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [19]:
class CrossAttentionMetricHead(nn.Module):


 def __init__(self,d_model=256 ,nhead=4,dim_feedforward=512):
  super().__init__()
  self.cross_attn=nn.MultiheadAttention(d_model,nhead,batch_first=True)
  self.ff=nn.Sequential(
      nn.LayerNorm(d_model),
      nn.Linear(d_model,dim_feedforward),
      nn.GELU(),
      nn.Linear(dim_feedforward,d_model)
  )
  self.score_head=nn.Sequential(
      nn.LayerNorm(d_model),
      nn.Linear(d_model,d_model//2),
      nn.ReLU(),
      nn.Linear(d_model//2,1)
  )

 def forward(self,prompt,response,prompt_mask=None,response_mask=None):
  attn_output, _=self.cross_attn(response,prompt,prompt,
                                 key_padding_mask=~prompt_mask if prompt_mask is not None else None)
  x=response+attn_output
  x=self.ff(x)

  if response_mask is not None:
    mask=response_mask.unsqueeze(-1).float()
    pooled=(x*mask).sum(dim=1)/mask.sum(dim=1).clamp_min(1e-6)
  else:
    pooled=x.mean(dim=1)

  score = self.score_head(pooled).squeeze(-1)
  return score

In [15]:
class CAMT(nn.Module):


  def __init__(self, d_model=256, nhead=8, num_layers=2, num_metrics=3):
        super().__init__()
        encoder_layer=nn.TransformerEncoderLayer(d_model, nhead, 512, batch_first=True)
        self.encoder=nn.TransformerEncoder(encoder_layer,num_layers=num_layers)
        self.pos_enc=nn.Parameter(torch.randn(1,512,d_model))

        self.metric_heads=nn.ModuleList([CrossAttentionMetricHead(d_model,nhead//2,512)for _ in range(num_metrics)])

  def forward(self, prompt_embeds, response_embeds, prompt_mask=None, response_mask=None):

    prompt=prompt_embeds+self.pos_enc[:,:prompt_embeds.size(1),:]
    response=response_embeds+self.pos_enc[:,:response_embeds.size(1),:]

    prompt_encoded = self.encoder(prompt, src_key_padding_mask=~prompt_mask if prompt_mask is not None else None)
    response_encoded = self.encoder(response, src_key_padding_mask=~response_mask if response_mask is not None else None)

    scores= []
    for head in self.metric_heads:
      s=head(prompt_encoded, response_encoded, prompt_mask, response_mask)
      scores.append(s)

    scores = torch.stack(scores, dim=1);
    probs = torch.sigmoid(scores)
    return {
              "raw": scores,
              "prob": probs,
          }

In [23]:
if __name__ == "__main__" :
  torch.manual_seed(0)
  B, P, R, D = 2, 10, 15, 256
  model = CAMT(d_model=D, nhead=8, num_layers=2, num_metrics=3)
  prompt_embeds = torch.randn(B, P, D)
  response_embeds = torch.randn(B, R, D)
  prompt_mask = torch.ones(B, P, dtype=torch.bool)
  response_mask = torch.ones(B, R, dtype=torch.bool)

  out = model(prompt_embeds, response_embeds, prompt_mask, response_mask)
  print("Raw metric scores:", out["raw"])
  print("Prob metric scores (0–1):", out["prob"])
  print("AGI Index (example):", (out["prob"] @ torch.tensor([0.5, 0.3, 0.2])).detach().cpu())



Raw metric scores: tensor([[-0.0744,  0.2996, -0.0325],
        [-0.0189,  0.3758,  0.0651]], grad_fn=<StackBackward0>)
Prob metric scores (0–1): tensor([[0.4814, 0.5744, 0.4919],
        [0.4953, 0.5929, 0.5163]], grad_fn=<SigmoidBackward0>)
AGI Index (example): tensor([0.5114, 0.5287])
