In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiModalRowTransformer(nn.Module):
    def __init__(
        self,
        num_feature_names,       # list[str]
        cat_feature_names,       # list[str]
        cat_vocab_sizes,         # list[int] matching order of cat_feature_names
        text_vocab_size,         # size of tokenizer vocab (e.g., 30522 for BERT)
        d_model: int = 128,
        nhead: int = 8,
        num_layers: int = 4,
        max_text_len: int = 128,
        dropout: float = 0.1,
    ):
        super().__init__()
        
        # 1. Feature Bookkeeping
        self.num_feature_names = num_feature_names
        self.cat_feature_names = cat_feature_names
        self.N_num = len(num_feature_names)
        self.N_cat = len(cat_feature_names)
        self.d_model = d_model
        self.max_text_len = max_text_len

        # 2. THE TRINITY EMBEDDINGS
        # Type Embeddings: 0=Numeric, 1=Categorical, 2=Text
        self.type_emb = nn.Embedding(3, d_model)
        
        # Feature ID Embeddings: Unique ID for every column name
        self.all_feat_names = num_feature_names + cat_feature_names
        self.feature_to_id = {name: i for i, name in enumerate(self.all_feat_names)}
        self.feature_emb = nn.Embedding(len(self.all_feat_names), d_model)

        # 3. VALUE PROJECTIONS
        # Numeric: Scalar -> Vector
        self.num_proj = nn.ModuleList([nn.Linear(1, d_model) for _ in range(self.N_num)])
        
        # Categorical: ID -> Vector
        self.cat_val_emb = nn.ModuleList([
            nn.Embedding(vsz, d_model) for vsz in cat_vocab_sizes
        ])

        # Text: Word ID -> Vector + Positional Embedding
        self.text_tok_emb = nn.Embedding(text_vocab_size, d_model)
        self.text_pos_emb = nn.Embedding(max_text_len, d_model)

        # 4. THE GLOBAL AGGREGATOR (CLS Token)
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))

        # 5. THE TRANSFORMER ENGINE
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model*4, 
            dropout=dropout, batch_first=True, activation="gelu"
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # 6. TASK HEADS
        # Regression Head (For 10% labeled data)
        self.reg_head = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Linear(d_model // 2, 1)
        )
        
        # Reconstruction Heads (For Self-Supervised pre-training)
        self.num_recon = nn.ModuleList([nn.Linear(d_model, 1) for _ in range(self.N_num)])
        self.cat_recon = nn.ModuleList([nn.Linear(d_model, vsz) for vsz in cat_vocab_sizes])

    def _build_sequence(self, num_x, cat_x, text_ids):
        B = num_x.shape[0]
        device = num_x.device
        all_tokens = []

        # Add CLS Token
        all_tokens.append(self.cls_token.expand(B, -1, -1))

        # Process Numeric
        type_num = self.type_emb(torch.tensor([0], device=device))
        for i, name in enumerate(self.num_feature_names):
            val = self.num_proj[i](num_x[:, i].unsqueeze(-1))
            feat = self.feature_emb(torch.tensor([self.feature_to_id[name]], device=device))
            all_tokens.append((val + feat + type_num).unsqueeze(1))

        # Process Categorical
        type_cat = self.type_emb(torch.tensor([1], device=device))
        for i, name in enumerate(self.cat_feature_names):
            val = self.cat_val_emb[i](cat_x[:, i])
            feat = self.feature_emb(torch.tensor([self.feature_to_id[name]], device=device))
            all_tokens.append((val + feat + type_cat).unsqueeze(1))

        # Process Text
        if text_ids is not None:
            type_txt = self.type_emb(torch.tensor([2], device=device))
            txt_val = self.text_tok_emb(text_ids)
            pos = self.text_pos_emb(torch.arange(text_ids.size(1), device=device))
            all_tokens.append(txt_val + pos + type_txt)

        return torch.cat(all_tokens, dim=1)

    def forward(self, num_x, cat_x, text_ids=None, y=None, 
                num_mask=None, num_targets=None, cat_mask=None, cat_targets=None, 
                mode="regress"):
        
        # 1. Tokenize and Run Transformer
        X = self._build_sequence(num_x, cat_x, text_ids)
        H = self.transformer(X) # (Batch, Seq_Len, d_model)

        if mode == "regress":
            # Pull the CLS token (index 0)
            cls_out = H[:, 0, :]
            y_hat = self.reg_head(cls_out).squeeze(-1)
            loss = F.mse_loss(y_hat, y) if y is not None else None
            return {"y_hat": y_hat, "loss": loss}

        elif mode == "pretrain":
            total_loss = 0
            # Numeric Reconstruction Loss
            H_num = H[:, 1 : 1+self.N_num, :]
            for i in range(self.N_num):
                mask = num_mask[:, i]
                if mask.any():
                    pred = self.num_recon[i](H_num[mask, i, :]).squeeze(-1)
                    total_loss += F.mse_loss(pred, num_targets[mask, i])

            # Categorical Reconstruction Loss
            H_cat = H[:, 1+self.N_num : 1+self.N_num+self.N_cat, :]
            for i in range(self.N_cat):
                mask = cat_mask[:, i]
                if mask.any():
                    logits = self.cat_recon[i](H_cat[mask, i, :])
                    total_loss += F.cross_entropy(logits, cat_targets[mask, i])
            
            return {"loss": total_loss}

In [None]:
import torch

def create_masks(num_x, cat_x, mask_prob=0.15):
    """
    num_x: (B, N_num) float tensor
    cat_x: (B, N_cat) int64 tensor
    """
    device = num_x.device
    B, N_num = num_x.shape
    _, N_cat = cat_x.shape

    # 1. Create Numeric Masks
    # Generate a random matrix, if value < prob, we mask it
    num_mask = torch.rand(num_x.shape, device=device) < mask_prob
    num_targets = num_x.clone()
    
    # We "zero out" the masked values in the input so the model can't see them
    # Since num_x is normalized (mean 0), 0.0 is a neutral mask value
    num_x_masked = num_x.clone()
    num_x_masked[num_mask] = 0.0

    # 2. Create Categorical Masks
    cat_mask = torch.rand(cat_x.shape, device=device) < mask_prob
    cat_targets = cat_x.clone()
    
    # We replace masked categories with ID 0 (the [MASK] token)
    cat_x_masked = cat_x.clone()
    cat_x_masked[cat_mask] = 0 

    return num_x_masked, num_mask, num_targets, cat_x_masked, cat_mask, cat_targets

In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 1. SETUP DATASETS
# Assume X_num, X_cat, X_text are tensors from your Preprocessing Notebook
# total_dataset: all 4000 rows (for pre-training)
# labeled_dataset: only the 400 rows + y_labels (for fine-tuning)
train_loader_pre = DataLoader(total_dataset, batch_size=64, shuffle=True)
train_loader_fine = DataLoader(labeled_dataset, batch_size=32, shuffle=True)

# 2. INITIALIZE MODEL
model = MultiModalRowTransformer(
    num_feature_names=metadata['num_features'],
    cat_feature_names=list(metadata['cat_features'].keys()),
    cat_vocab_sizes=list(metadata['cat_features'].values()),
    text_vocab_size=metadata['text_vocab_size'],
    d_model=128
).to('cuda')

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

# --- PHASE 1: SELF-SUPERVISED PRE-TRAINING (The "Masking" Phase) ---
print("Starting Pre-training...")
model.train()
for epoch in range(50):  # Pre-train for more epochs
    for batch in train_loader_pre:
        num_x, cat_x, text_ids = [b.to('cuda') for b in batch]
        
        # Here you would generate masks (e.g., 15% random features)
        # For simplicity, we assume you have a 'masking_function'
        num_mask, num_targets, cat_mask, cat_targets = create_masks(num_x, cat_x)

        optimizer.zero_grad()
        # Run in PRETRAIN mode
        out = model(
            num_x, cat_x, text_ids, 
            num_mask=num_mask, num_targets=num_targets,
            cat_mask=cat_mask, cat_targets=cat_targets,
            mode="pretrain"
        )
        
        loss = out["loss"]
        loss.backward()
        optimizer.step()

# --- PHASE 2: DOWNSTREAM FINE-TUNING (The "Label" Phase) ---
print("Starting Fine-tuning...")
# Optional: model.encoder.requires_grad_(False) # Freeze encoder if 400 samples are too noisy
for epoch in range(20):
    for batch in train_loader_fine:
        num_x, cat_x, text_ids, y = [b.to('cuda') for b in batch]

        optimizer.zero_grad()
        # Run in REGRESS mode
        out = model(num_x, cat_x, text_ids, y=y, mode="regress")
        
        loss = out["loss"] # This is the MSE Loss for the 10% labels
        loss.backward()
        optimizer.step()