In [None]:
import torch
import torch.nn as nn
import math

class dRoFE(nn.Module):
    def __init__(self, token_dim):
        super(dRoFE, self).__init__()
        self.token_dim = token_dim

        # Learnable projections for demographic injection
        self.age_proj = nn.Linear(1, token_dim)
        self.gender_proj = nn.Embedding(2, token_dim)  # Assuming gender is 0 or 1

        # Learnable embedding for each frequency band (9 bands)
        self.band_proj = nn.Embedding(9, token_dim)

    def forward(self, X, band_indices, age, gender):
        """
        X: [B, 9, D]         ← token sequence (1 per frequency band)
        band_indices: [9]   ← integers 0–8 indicating band index
        age: [B]             ← float tensor of ages
        gender: [B]          ← 0 or 1 tensor

        Returns:
            Q_rotated, K_rotated
        """

        B, N, D = X.shape  # B=batch, N=9 tokens, D=token_dim

        # Inject band encoding (same across batch)
        band_emb = self.band_proj(band_indices.to(X.device))  # [9, D]
        band_emb = band_emb.unsqueeze(0).expand(B, -1, -1)     # [B, 9, D]

        # Inject age (broadcast across tokens)
        age = age.view(B, 1, 1)  # [B, 1, 1]
        age_emb = self.age_proj(age)  # [B, 1, D]
        age_emb = age_emb.expand(-1, N, -1)  # [B, 9, D]

        # Inject gender
        gender_emb = self.gender_proj(gender)  # [B, D]
        gender_emb = gender_emb.unsqueeze(1).expand(-1, N, -1)  # [B, 9, D]

        # Rotate Q and K using the sum of all embeddings
        enriched = X + band_emb + age_emb + gender_emb

        Q_rotated = enriched
        K_rotated = enriched

        return Q_rotated, K_rotated




In [None]:
# Simulate 2 batches
X = torch.randn(2, 9, 128)
band_indices = torch.arange(9)
age = torch.tensor([25.0, 40.0])
gender = torch.tensor([1, 0])

# call the module
drofe_module = dRoFE(token_dim=128)
Q, K = drofe_module(X, band_indices, age, gender)

print("Q shape:", Q.shape)
print("K shape:", K.shape)


Q shape: torch.Size([2, 9, 128])
K shape: torch.Size([2, 9, 128])


What it does: Enriches the tokens from the ConnectomeTokenizer by injecting demographic information (age, gender) and frequency band embeddings.

This ensures that the tokens carry both graph-level information and demographic context.

Input:

Tokens from the ConnectomeTokenizer ([B, 9, D]).
Demographic data: age ([B]), gender ([B]).
Frequency band indices: [9].

Output: Enriched tokens of shape [B, 9, D].