In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

In [2]:
def project_simplex(v):
    if torch.all(v >= 0) and torch.isclose(torch.sum(v), torch.tensor(1.0)):
        return v

    u, indices = torch.sort(v, descending=True)
    cssv = torch.cumsum(u, dim=0)
    rho = torch.where(u - (cssv - 1) / (torch.arange(1, len(v)+1, device=v.device, dtype=v.dtype)) > 0)[0].max()
    tau = (cssv[rho] - 1) / (rho + 1)

    w = torch.maximum(v - tau, torch.tensor(0.0))
    return w

In [3]:
def sigmoid(x):
    return 1/(1+np.exp(-x))

def logistic(x):
    return np.log(1 + np.exp(x))

def logit(p):
    return np.log(p / (1 - p))

In [4]:
n = 1000 #sample size
p = 3 #covariate length
k = 3 #size size
beta_1 = np.array([0.9,0.4,0.4])
beta_2 = np.array([0.4,0.9,0.4])
beta_3 = np.array([0.4,0.4,0.9])

In [5]:
criterion = nn.BCELoss()

In [6]:
for i in range(1):
    X_1 = np.random.uniform(-2, 2, (n, p))
    X_2 = np.random.uniform(-2, 2, (n, p))
    X_3 = np.random.uniform(-2, 2, (n, p))
    
    logits_1 = X_1.dot(beta_1)
    probs_1 = sigmoid(logits_1)
    y_1 = np.random.binomial(1, probs_1)
    
    logits_2 = X_2.dot(beta_2)
    probs_2 = sigmoid(logits_2)
    y_2 = np.random.binomial(1, probs_2)

    logits_3 = X_3.dot(beta_3)
    probs_3 = sigmoid(logits_3)
    y_3 = np.random.binomial(1, probs_3)

    p1_bar = np.mean(y_1)
    p2_bar = np.mean(y_2)
    p3_bar = np.mean(y_3)

    null_1 = np.mean(np.log(p1_bar)*y_1 + np.log(1-p1_bar)*(1-y_1))
    null_2 = np.mean(np.log(p2_bar)*y_2 + np.log(1-p2_bar)*(1-y_2))
    null_3 = np.mean(np.log(p3_bar)*y_3 + np.log(1-p3_bar)*(1-y_3))
    
    X1_tensor = torch.tensor(X_1, dtype=torch.float32)
    X2_tensor = torch.tensor(X_2, dtype=torch.float32)
    X3_tensor = torch.tensor(X_3, dtype=torch.float32)

    y1_tensor = torch.tensor(y_1, dtype=torch.float32)
    y2_tensor = torch.tensor(y_2, dtype=torch.float32)
    y3_tensor = torch.tensor(y_3, dtype=torch.float32)

    # Initialize the parameter vector and weight vector
    a = torch.randn((p, 1), requires_grad=True)
    q = nn.Parameter(torch.tensor([0.233, 0.433, 0.433], requires_grad=True))
    criterion = nn.BCELoss()

    optimizer_mlp = optim.SGD([a], lr=0.1, momentum=0.9)
    optimizer_q = optim.SGD([q], lr=0.2)

    for epoch in range(2000):
        optimizer_mlp.zero_grad()

        # Compute logits manually
        logits1 = X1_tensor @ a
        logits2 = X2_tensor @ a
        logits3 = X3_tensor @ a

        # Apply sigmoid function
        outputs1 = torch.sigmoid(logits1).squeeze(1)
        outputs2 = torch.sigmoid(logits2).squeeze(1)
        outputs3 = torch.sigmoid(logits3).squeeze(1)

        # Compute the losses for all datasets
        loss1 = criterion(outputs1, y1_tensor) + null_1
        loss2 = criterion(outputs2, y2_tensor) + null_2
        loss3 = criterion(outputs3, y3_tensor) + null_3

        # MLP update
        weighted_loss_mlp = q[0] * loss1 + q[1] * loss2 + q[2] * loss3# + torch.norm(a,1)/n if you need regularization
        weighted_loss_mlp.backward()
        optimizer_mlp.step()

        # q-update
        optimizer_q.zero_grad()
        # Detach the losses from the graph used for MLP parameters update
        loss1_detached = loss1.detach()
        loss2_detached = loss2.detach()
        loss3_detached = loss3.detach()
        weighted_loss_q = q[0] * loss1_detached + q[1] * loss2_detached + q[2] * loss3_detached# + torch.norm(a,1)/n
        (-weighted_loss_q).backward()  # Gradient ascent for q
        optimizer_q.step()

        with torch.no_grad():
            q[:] = project_simplex(q)
            
        if epoch % 100 == 0:
            print(epoch ,q.data, loss1.item(),loss2.item(),loss3.item())

0 tensor([0.2001, 0.3939, 0.4060]) -0.08286803960800171 -0.11398577690124512 -0.05364882946014404
100 tensor([0.2962, 0.3749, 0.3289]) -0.12488830089569092 -0.12698894739151 -0.13000208139419556
200 tensor([0.3230, 0.3770, 0.3000]) -0.12663298845291138 -0.12717992067337036 -0.12809354066848755
300 tensor([0.3306, 0.3789, 0.2905]) -0.1270904541015625 -0.12721580266952515 -0.1275426149368286
400 tensor([0.3328, 0.3798, 0.2874]) -0.12721657752990723 -0.12724590301513672 -0.12735843658447266
500 tensor([0.3335, 0.3801, 0.2864]) -0.12725287675857544 -0.12725985050201416 -0.12729781866073608
600 tensor([0.3337, 0.3803, 0.2860]) -0.127263605594635 -0.12726527452468872 -0.12727802991867065
700 tensor([0.3338, 0.3803, 0.2859]) -0.12726688385009766 -0.12726730108261108 -0.12727147340774536
800 tensor([0.3338, 0.3803, 0.2859]) -0.1272679567337036 -0.1272679567337036 -0.12726932764053345
900 tensor([0.3338, 0.3803, 0.2859]) -0.12726819515228271 -0.1272682547569275 -0.12726867198944092
1000 tensor(