## Covariate-dependent learning framework

In [1]:
import torch
import torchvision
import torch.nn as nn
from torch.nn import functional as F
from torchmetrics.functional import accuracy
import os
import matplotlib.pylab as plt
%matplotlib inline
import numpy as np
import pandas as pd
import pickle
import seaborn as sns
import io_

In [2]:
atlas = 'BNA'
# data_dir = 'D:/ShareFolder/BNA/Proc'
# out_dir = 'D:/ShareFolder/BNA/Result'
data_dir = '/media/shuoz/MyDrive/HCP/BNA/Proc'
out_dir = '/media/shuoz/MyDrive/HCP/BNA/Results'


# atlas = 'AICHA'
# data_dir = 'D:/ShareFolder/AICHA_VolFC/Proc'
# out_dir = 'D:/ShareFolder/AICHA_VolFC/Result'

sessions = ['REST1', 'REST2']  

runs = ['RL', 'LR']
# connection_type = 'both'  # inter, intra, or both
connection_type = 'intra'
random_state = 144
clf = 'SVC'

info = dict()
data = dict()

session = 'REST1'
run_ = 'LR'
half = 'Left'

info_fname = 'HCP_%s_half_brain_%s_%s.csv' % (atlas, session, run_)
info[run_] = io_.read_table(os.path.join(data_dir, info_fname), index_col='ID')
data[run_] = io_.load_half_brain(data_dir, atlas, session, run_, connection_type)

In [3]:
from _base import _pick_half
from sklearn.preprocessing import label_binarize

x, y = _pick_half(data[run_])
y = label_binarize(y, classes=[1, -1])
# y = y.reshape((-1, 1))
genders = info[run_]['gender'].values

idx_male = np.where(genders==0)[0]
idx_female = np.where(genders==1)[0]


x = torch.from_numpy(x)
x = x.float()
y = torch.from_numpy(y)
y = y.long()
genders = torch.from_numpy(genders.reshape((-1, 1)))
genders = genders.float()

In [2]:
class LR(nn.Module):
    def __init__(self, n_features, n_classes, l1_hparam=0.0, l2_hparam=1.0,):
        super().__init__()
        self.l1_hparam = l1_hparam
        self.l2_hparam = l2_hparam
        self.linear = nn.Linear(n_features, n_classes)

    def forward(self, x):
        pred = torch.sigmoid(self.linear(x))
        return pred

    def training_step(self, batch):
        images, labels = batch
        out = self(images)  # Generate predictions
        pred_loss = F.cross_entropy(out, labels)  # Calculate loss
        
        # L1 regularizer
        if self.l1_hparam > 0:
            l1_reg = sum(param.abs().sum() for param in self.parameters())
            loss += self.l1_hparam * l1_reg

        # L2 regularizer
        if self.l2_hparam > 0:
            l2_reg = sum(param.pow(2).sum() for param in self.parameters())
            loss += self.l2_hparam * l2_reg
        
        return loss

    def validation_step(self, batch):
        images, labels = batch
        out = self(images)  # Generate predictions
        loss = F.cross_entropy(out, labels)  # Calculate loss
        acc = accuracy(out, labels)  # Calculate accuracy
        return {'val_loss': loss, 'val_acc': acc}

    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()  # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()  # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}

    def epoch_end(self, epoch, result):
        print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['val_loss'], result['val_acc']))


In [None]:
l1_hparam = 0.0
l2_hparam = 1.0


# L1 regularizer
# if l1_hparam > 0:
#     l1_reg = sum(param.abs().sum() for param in self.parameters())
#     loss += self.hparams.l1_strength * l1_reg

# # L2 regularizer
# if l2_hparam > 0:
#     l2_reg = sum(param.pow(2).sum() for param in self.parameters())
#     loss += self.hparams.l2_strength * l2_reg

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=1.0)

In [9]:
# from torch.linalg import multi_dot

def hsic(x, y):
    
    kx = torch.mm(x, x.T)
    ky = torch.mm(y, y.T)
    
    n = x.shape[0]
    ctr_mat = torch.eye(n) - torch.ones((n, n)) / n
    
    return torch.trace(torch.mm(torch.mm(torch.mm(kx, ctr_mat), ky), ctr_mat)) / (n ** 2)

In [6]:
genders = genders.long()

In [13]:
x[idx_male[0]].shape

torch.Size([470, 7503])

In [7]:
torch.manual_seed(144)

<torch._C.Generator at 0x7f7384061bd0>

In [30]:
# train_idx = idx_male
# test_idx = idx_female
train_idx = idx_female
test_idx = idx_male
n_train = train_idx.shape[0]
n_hold = int(0.2 * n_train)
n_test = test_idx.shape[0]

In [32]:
train_idx.shape[0]

565

In [31]:
num_epochs = 500
batch_size = 100
learning_rate = 0.001

model = nn.Linear(x.shape[1], 2)

# Loss and optimizer
# nn.CrossEntropyLoss() computes softmax internally
# criterion = F.cross_entropy() 
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.1)  
lambda_ = 8.0

# Train the model
# total_step = len(train_loader)
for epoch in range(num_epochs):
#     for i, (images, labels) in enumerate(train_loader):
#         # Reshape images to (batch_size, input_size)
#         images = images.reshape(-1, input_size)
        
    # Forward pass
    y_pred = torch.sigmoid(model(x[train_idx][n_hold:]))
#     loss = F.cross_entropy(y_pred, y[idx_male].view(-1)) + nn.MSELoss(model(x)[], genders)
#     loss = F.cross_entropy(y_pred, y[idx_male].view(-1)) + F.cross_entropy(torch.sigmoid(model(x)), genders.view(-1))
    loss = F.cross_entropy(y_pred, y[train_idx][n_hold:].view(-1)) + lambda_ * (1 - torch.sigmoid(hsic(model(x), genders.float())))

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % 10 == 0:
        print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

# Test the model
# In test phase, we don't need to compute gradients (for memory efficiency)
# with torch.no_grad():
#     correct = 0
#     total = 0
#     for images, labels in test_loader:
#         images = images.reshape(-1, input_size)
#         outputs = model(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum()

Epoch [10/500], Loss: 4.4146
Epoch [20/500], Loss: 3.9380
Epoch [30/500], Loss: 2.6179
Epoch [40/500], Loss: 1.1742
Epoch [50/500], Loss: 0.7223
Epoch [60/500], Loss: 0.6666
Epoch [70/500], Loss: 0.6982
Epoch [80/500], Loss: 0.7414
Epoch [90/500], Loss: 0.7505
Epoch [100/500], Loss: 0.7323
Epoch [110/500], Loss: 0.7145
Epoch [120/500], Loss: 0.7069
Epoch [130/500], Loss: 0.7049
Epoch [140/500], Loss: 0.7031
Epoch [150/500], Loss: 0.6999
Epoch [160/500], Loss: 0.6966
Epoch [170/500], Loss: 0.6941
Epoch [180/500], Loss: 0.6923
Epoch [190/500], Loss: 0.6907
Epoch [200/500], Loss: 0.6893
Epoch [210/500], Loss: 0.6880
Epoch [220/500], Loss: 0.6869
Epoch [230/500], Loss: 0.6860
Epoch [240/500], Loss: 0.6852
Epoch [250/500], Loss: 0.6845
Epoch [260/500], Loss: 0.6839
Epoch [270/500], Loss: 0.6833
Epoch [280/500], Loss: 0.6829
Epoch [290/500], Loss: 0.6825
Epoch [300/500], Loss: 0.6821
Epoch [310/500], Loss: 0.6818
Epoch [320/500], Loss: 0.6815
Epoch [330/500], Loss: 0.6813
Epoch [340/500], Lo

In [32]:
F.cross_entropy(y_pred, y[train_idx][n_hold:].view(-1))

tensor(0.4589, grad_fn=<NllLossBackward>)

In [33]:
hsic(model(x), genders.float())

tensor(3.5701, grad_fn=<DivBackward0>)

In [34]:
pred = torch.sigmoid(model(x[train_idx][:n_hold]))
_, target = torch.max(pred, 1)

In [35]:
target = target.view((-1, 1))

In [36]:
accuracy(y[train_idx][:n_hold], target)

tensor(0.8053)

In [37]:
_, pred_f = torch.max(torch.sigmoid(model(x[test_idx])), 1)

In [38]:
accuracy(y[test_idx], pred_f)

tensor(0.6723)

In [39]:
num_epochs = 500
batch_size = 100
learning_rate = 0.001

log_reg = nn.Linear(x.shape[1], 2)

# Loss and optimizer
# nn.CrossEntropyLoss() computes softmax internally
# criterion = F.cross_entropy() 
optimizer = torch.optim.Adam(log_reg.parameters(), lr=learning_rate, weight_decay=1.0)  

# Train the model
# total_step = len(train_loader)
for epoch in range(num_epochs):
#     for i, (images, labels) in enumerate(train_loader):
#         # Reshape images to (batch_size, input_size)
#         images = images.reshape(-1, input_size)
        
    # Forward pass
    y_pred = torch.sigmoid(log_reg(x[train_idx][n_hold:]))
    loss = F.cross_entropy(y_pred, y[train_idx][n_hold:].view(-1))

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % 10 == 0:
        print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

# Test the model
# In test phase, we don't need to compute gradients (for memory efficiency)
# with torch.no_grad():
#     correct = 0
#     total = 0
#     for images, labels in test_loader:
#         images = images.reshape(-1, input_size)
#         outputs = model(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum()

Epoch [10/500], Loss: 0.5544
Epoch [20/500], Loss: 0.5187
Epoch [30/500], Loss: 0.5266
Epoch [40/500], Loss: 0.5339
Epoch [50/500], Loss: 0.5342
Epoch [60/500], Loss: 0.5333
Epoch [70/500], Loss: 0.5332
Epoch [80/500], Loss: 0.5333
Epoch [90/500], Loss: 0.5332
Epoch [100/500], Loss: 0.5332
Epoch [110/500], Loss: 0.5332
Epoch [120/500], Loss: 0.5332
Epoch [130/500], Loss: 0.5332
Epoch [140/500], Loss: 0.5332
Epoch [150/500], Loss: 0.5332
Epoch [160/500], Loss: 0.5332
Epoch [170/500], Loss: 0.5332
Epoch [180/500], Loss: 0.5332
Epoch [190/500], Loss: 0.5332
Epoch [200/500], Loss: 0.5332
Epoch [210/500], Loss: 0.5332
Epoch [220/500], Loss: 0.5332
Epoch [230/500], Loss: 0.5332
Epoch [240/500], Loss: 0.5332
Epoch [250/500], Loss: 0.5332
Epoch [260/500], Loss: 0.5332
Epoch [270/500], Loss: 0.5332
Epoch [280/500], Loss: 0.5332
Epoch [290/500], Loss: 0.5332
Epoch [300/500], Loss: 0.5332
Epoch [310/500], Loss: 0.5332
Epoch [320/500], Loss: 0.5332
Epoch [330/500], Loss: 0.5332
Epoch [340/500], Lo

In [91]:
hsic(log_reg(x), genders)

tensor(1.1462e-06, grad_fn=<DivBackward0>)

In [13]:
hsic(log_reg(x), genders.float())

tensor(0.0005, grad_fn=<DivBackward0>)

In [15]:
proba_ = torch.sigmoid(log_reg(x[train_idx][:n_hold]))
_, pred_lr_f = torch.max(proba_, 1)

In [16]:
accuracy(y[train_idx][:n_hold], pred_lr_f)

tensor(0.9787)

In [17]:
proba_test = torch.sigmoid(log_reg(x[test_idx]))
_, pred_lr_test = torch.max(proba_test, 1)

In [18]:
accuracy(y[test_idx], pred_lr_test)

tensor(0.9947)

In [62]:
model.weight.data.numpy().T

array([[ 0.02604132, -0.02598074],
       [-0.00755683,  0.00788315],
       [-0.02889398,  0.02911591],
       ...,
       [-0.01741947,  0.01761098],
       [-0.01815327,  0.01827596],
       [-0.01183975,  0.01192712]], dtype=float32)

In [40]:
from scipy.stats import pearsonr
corr, _ = pearsonr(model.weight.data.numpy().T[:, 0], log_reg.weight.data.numpy().T[:, 0])

In [41]:
corr

0.39115288207116156