In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, accuracy_score, roc_auc_score, f1_score, precision_recall_curve, roc_curve, RocCurveDisplay, PrecisionRecallDisplay
from scipy.special import softmax

from imblearn.ensemble import BalancedRandomForestClassifier
import xgboost as xgb
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from pytorch_metric_learning import losses
import copy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

pd.options.display.max_rows = 999

In [None]:
dataset = "AD_patients.csv"
dataset1 = "AD_patients_subset.csv"

In [None]:
patient_df = pd.read_csv(dataset, index_col=0)
patient_df.shape

In [None]:
patient_df = pd.read_csv(dataset1, index_col=0)
missingness = patient_df.isna().mean(axis=0)
patient_df = patient_df.drop(columns=missingness[missingness > 0.3].index)
patient_df = patient_df.apply(lambda x: x.fillna(x.mean()),axis=0)

patient_df["family history"] = patient_df["family history"].astype(int)

label = patient_df["AD"].to_numpy()
demographic_df = patient_df[patient_df.columns[1:-40]].copy()
demographic_df["age"] = 2021 - demographic_df["year of birth"]
demographic_df1 = demographic_df.drop(columns=["year of birth"])
colnames = demographic_df.columns
demographic_df = demographic_df1.to_numpy()

genotype_df = patient_df[patient_df.columns[-40:]].to_numpy()

print(label.shape, demographic_df.shape, genotype_df.shape)

demo_train, demo_test, geno_train, geno_test, y_train, y_test = train_test_split(
    demographic_df, genotype_df, label, test_size=0.2, random_state=42, stratify=label)
print(y_train.mean(), y_test.mean())

demo_scaler = StandardScaler()
demo_train = demo_scaler.fit_transform(demo_train)
demo_test = demo_scaler.transform(demo_test)

geno_scaler = StandardScaler()
geno_train = geno_scaler.fit_transform(geno_train)
geno_test = geno_scaler.transform(geno_test)

[a.shape for a in [demo_train, demo_test, geno_train, geno_test, y_train, y_test]]

In [None]:
demo_train1, demo_val, geno_train1, geno_val, y_train1, y_val = train_test_split(
    demo_train, geno_train, y_train, test_size=0.125, random_state=42, stratify=y_train)

In [None]:
comb_train = np.concatenate([demo_train, geno_train], axis=1)
comb_test = np.concatenate([demo_test, geno_test], axis=1)

In [None]:
def get_performance(y_test, y_pred, plot=False):
    acc = accuracy_score(y_test, y_pred.argmax(axis=1))
    auc = roc_auc_score(y_test, y_pred[:,1])
    f1 = f1_score(y_test, y_pred.argmax(axis=1))
    class_acc = confusion_matrix(y_test, y_pred.argmax(axis=1), normalize="true").diagonal()
    print(f"Accuracy: {acc}, AUC: {auc}, f1:{f1}, non-AD accuracy: {class_acc[0]}, AD accuracy: {class_acc[1]}")
    if plot:
        fpr, tpr, _ = roc_curve(y_test, y_pred[:,1])
        roc_display = RocCurveDisplay(fpr=fpr, tpr=tpr)
        prec, recall, _ = precision_recall_curve(y_test, y_pred[:,1])
        pr_display = PrecisionRecallDisplay(precision=prec, recall=recall)
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
        roc_display.plot(ax=ax1)
        pr_display.plot(ax=ax2)
        plt.show()

## Baseline Majority class

In [None]:
maj_class = np.stack([np.ones(len(y_test)), np.zeros(len(y_test))])
get_performance(y_test, np.stack([np.ones(len(y_test)), np.zeros(len(y_test))], axis=1))

## Logistic regression

In [None]:
demo_lr = LogisticRegressionCV(max_iter=3000, class_weight="balanced",  penalty="elasticnet", solver="saga", l1_ratios=[0.5, 0.8])
demo_lr.fit(demo_train, y_train)
y_pred = demo_lr.predict_proba(demo_test)
demo_logreg = y_pred
get_performance(y_test, y_pred, True)

In [None]:
coefs = demo_lr.coef_.flatten()
weights = pd.Series(coefs, index=colnames)
plt.figure(figsize=(6,10))
weights.plot.barh()
plt.show()

In [None]:
comb_lr = LogisticRegressionCV(max_iter=3000, class_weight="balanced",  penalty="elasticnet", solver="saga", l1_ratios=[0.5, 0.8])
comb_lr.fit(comb_train, y_train)
y_pred = comb_lr.predict_proba(comb_test)
logreg_pred = y_pred
get_performance(y_test, y_pred, True)

In [None]:
coefs = comb_lr.coef_.flatten()
weights = pd.Series(coefs, index=patient_df.columns[1:])
plt.figure(figsize=(12,8))
weights.plot.barh()
plt.show()

In [None]:
demo_lr = LogisticRegressionCV(max_iter=3000, class_weight="balanced",  penalty="l1", solver="saga")
demo_lr.fit(demo_train, y_train)
y_pred = demo_lr.predict_proba(demo_test)
get_performance(y_test, y_pred, True)

In [None]:
coefs = demo_lr.coef_.flatten()
weights = pd.Series(coefs, index=colnames)
plt.figure(figsize=(12,8))
weights.plot.barh()
plt.show()

In [None]:
comb_lr = LogisticRegression(max_iter=1000, class_weight="balanced")
comb_lr.fit(comb_train, y_train)
y_pred = comb_lr.predict_proba(comb_test)
get_performance(y_test, y_pred, True)

In [None]:
geno_lr = LogisticRegression(max_iter=1000, class_weight="balanced")
geno_lr.fit(geno_train, y_train)
y_pred = geno_lr.predict_proba(geno_test)
get_performance(y_test, y_pred, True)

## Random Forest

In [None]:
demo_lr = BalancedRandomForestClassifier(class_weight="balanced_subsample", n_estimators=500)
demo_lr.fit(demo_train, y_train)
y_pred = demo_lr.predict_proba(demo_test)
randomforest_pred_demo = y_pred
get_performance(y_test, y_pred, True)

In [None]:
importances = demo_lr.feature_importances_
std = np.std([demo_lr.feature_importances_ for tree in demo_lr.estimators_], axis=0)
forest_importances = pd.Series(importances, index=colnames)
plt.figure(figsize=(12,8))
forest_importances.plot.barh(yerr=std)
plt.show()

In [None]:
comb_lr = BalancedRandomForestClassifier(class_weight="balanced_subsample", n_estimators=500)
comb_lr.fit(comb_train, y_train)
y_pred = comb_lr.predict_proba(comb_test)
randomforest_pred_comb = y_pred
get_performance(y_test, y_pred, True)

## XGBoost

In [None]:
demo_lr = xgb.XGBClassifier(n_estimators=200, scale_pos_weight=1000, objective="binary:logistic")
demo_lr.fit(demo_train, y_train)
y_pred = demo_lr.predict_proba(demo_test)
xgb_pred_demo = y_pred
get_performance(y_test, y_pred, True)

In [None]:
comb_lr = xgb.XGBClassifier(scale_pos_weight=(y_train==0).sum() / y_train.sum()*10, objective="binary:logistic")
comb_lr.fit(comb_train, y_train)
y_pred = comb_lr.predict_proba(comb_test)
xgb_pred_comb = y_pred
get_performance(y_test, y_pred, True)

In [None]:
geno_lr = xgb.XGBClassifier(scale_pos_weight=(y_train==0).sum() / y_train.sum()*10, objective="binary:logistic")
geno_lr.fit(geno_train, y_train)
y_pred = geno_lr.predict_proba(geno_test)
get_performance(y_test, y_pred, True)

In [None]:
class ADDataset(Dataset):
    def __init__(self, data_x, data_y):
        self.data_x = torch.from_numpy(data_x)
        label = np.eye(2)[data_y.astype(int)]
        self.data_y = torch.from_numpy(label)

    def __len__(self):
        return len(self.data_x)

    def __getitem__(self, idx):
        x = self.data_x[idx]
        y = self.data_y[idx]
        return x.to(device, dtype=torch.float32), y.to(device, dtype=torch.int64)

class ADModel(nn.Module):
    def __init__(self, input_size, hidden_1, hidden_2, hidden_3):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, hidden_1),
            nn.ReLU(),
            nn.Linear(hidden_1, hidden_2),
            nn.ReLU(),
            nn.Linear(hidden_2, hidden_3),
            nn.ReLU(),
            nn.Linear(hidden_3, 2),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        return self.linear_relu_stack(x)


def categorical_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    y = y.argmax(dim=1)
    correct = max_preds.squeeze(1).eq(y)
    return correct.sum().cpu() / torch.FloatTensor([y.shape[0]])

def train(model, iterator, optimizer, criterion):
    epoch_loss = 0
    epoch_acc = 0
    model.train()
    for batchID, (input, target) in enumerate(iterator):
        optimizer.zero_grad()
        predictions = model(input).float()
        loss = criterion(predictions, target.float())
        acc = categorical_accuracy(predictions,target)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_acc += acc.item()
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def evaluate(model, iterator, criterion):
    preds = []
    truth = []
    model.eval()
    with torch.no_grad():
        for batchID, (input, target) in enumerate(iterator):
            predictions = model(input.float()).cpu()
            # target_onehot = nn.functional.one_hot(target, num_classes=2).cpu()
            preds.append(predictions.numpy())
            truth.append(target.cpu().numpy())
    truth = np.concatenate(truth, axis=0)
    preds = np.concatenate(preds, axis=0) 
    
    return truth, preds

In [None]:
demo = [demo_train1, demo_val, demo_test]
geno = [geno_train1, geno_val, geno_test]
comb = [
    np.concatenate([demo_train1, geno_train1], axis=1), 
    np.concatenate([demo_val, geno_val], axis=1), 
    np.concatenate([demo_test, geno_test], axis=1)
]

In [None]:
x_train, x_val, x_test = demo
BATCH_SIZE = 512
dataset_Train = ADDataset(x_train, y_train1)
dataset_Val = ADDataset(x_val, y_val)
dataset_Test = ADDataset(x_test, y_test)
dataloader_train = DataLoader(dataset = dataset_Train, batch_size = BATCH_SIZE, shuffle = True)
dataloader_val = DataLoader(dataset = dataset_Val, batch_size = BATCH_SIZE, shuffle = True)
dataloader_test = DataLoader(dataset = dataset_Test, batch_size = BATCH_SIZE, shuffle = False)

model = ADModel(x_train.shape[1], 50, 20, 10).to(device)
weight = torch.tensor([sum(y_train1==1)/len(y_train1), sum(y_train1==0)/len(y_train1)]).to(device)
# weight = torch.tensor([0.05, 0.95]).to(device)
criterion = nn.CrossEntropyLoss(weight=weight) 
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

N_EPOCHS = 20
best_auc = 0
for epoch in range(N_EPOCHS):

    train_loss, train_acc = train(model, dataloader_train, optimizer, criterion)
    truth, preds = evaluate(model, dataloader_val, criterion)
    val_auc = roc_auc_score(truth[:,1], preds[:,1])
    print(f'Epoch {epoch+1}  Train Loss: {train_loss:.3f}  Train accuracy {train_acc:.4f}  Validation metrics:', end=" ") 
    get_performance(truth[:,1], preds)
    if val_auc > best_auc:
        torch.save(model.state_dict(), 'best-model-parameters.pt')
        best_auc = val_auc

In [None]:
model.load_state_dict(torch.load('best-model-parameters.pt'))
truth, preds = evaluate(model, dataloader_test, criterion)
mlp_preds_demo = preds
get_performance(truth[:,1], preds)

In [None]:
x_train, x_val, x_test = comb

dataset_Train = ADDataset(x_train, y_train1)
dataset_Val = ADDataset(x_val, y_val)
dataset_Test = ADDataset(x_test, y_test)
dataloader_train = DataLoader(dataset = dataset_Train, batch_size = BATCH_SIZE, shuffle = True)
dataloader_val = DataLoader(dataset = dataset_Val, batch_size = BATCH_SIZE, shuffle = True)
dataloader_test = DataLoader(dataset = dataset_Test, batch_size = BATCH_SIZE, shuffle = False)


BATCH_SIZE = 512
model = ADModel(x_train.shape[1], 50, 25, 10).to(device)
weight = torch.tensor([sum(y_train1==1)/len(y_train1), sum(y_train1==0)/len(y_train1)]).to(device)
# weight = torch.tensor([0.05, 0.95]).to(device)
criterion = nn.CrossEntropyLoss(weight=weight) 
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

N_EPOCHS = 20
best_auc = 0
for epoch in range(N_EPOCHS):

    train_loss, train_acc = train(model, dataloader_train, optimizer, criterion)
    truth, preds = evaluate(model, dataloader_val, criterion)
    val_auc = roc_auc_score(truth[:,1], preds[:,1])
    print(f'Epoch {epoch+1}  Train Loss: {train_loss:.3f}  Train accuracy {train_acc:.4f}  Validation metrics:', end=" ") 
    get_performance(truth[:,1], preds)
    if val_auc > best_auc:
        torch.save(model.state_dict(), 'best-model-parameters.pt')
        best_auc = val_auc
        
model.load_state_dict(torch.load('best-model-parameters.pt'))
truth, preds = evaluate(model, dataloader_test, criterion)
mlp_preds_comb = preds
get_performance(truth[:,1], preds)

In [None]:
class ImageGenoDataset(Dataset):
    def __init__(self, images, modes, targets, transform):
        """
        images -- Batch size # of images
        snps -- SNP genotype matrix (0, 1, 2); (batch_size, num_genes)
        """
        self.images = images.type(torch.FloatTensor)
        self.modes = [mode.type(torch.FloatTensor) for mode in modes]
        self.targets = targets.type(torch.FloatTensor)
        self.transform = transform

    def __getitem__(self, index): 
        sample = self.images[index]
        modes = [mode[index] for mode in self.modes]
        target = self.targets[index]
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, modes, target
    
    def __len__(self):
        return len(self.images)

class SNPEncoder(nn.Module):
    """
    MLP
    """
    def __init__(self, x_dim, out_dim, fc_dims, use_bn = True):
        """
            fc_dims -- A list of fully connected layers
            p -- Dropout
        """
        super(SNPEncoder, self).__init__()
        self.x_dim = x_dim
        self.use_bn = use_bn
        self.enc_shape = fc_dims
        self.enc = []
        for i in range(len(self.enc_shape)):
            if i == 0: # set intial data fc from x -> i
                self.enc.append(nn.Linear(self.x_dim, self.enc_shape[i]))
            else: # fc from i-1 -> i
                self.enc.append(nn.Linear(self.enc_shape[i-1], self.enc_shape[i]))

            if self.use_bn:
                self.enc.append(nn.BatchNorm1d(self.enc_shape[i]))
        self.enc = nn.ModuleList(self.enc)
        self.out = nn.Linear(fc_dims[-1], out_dim)

    def forward(self, x):
        # encode x's until shared layer
        for l in self.enc:
            x = l(x)
            if isinstance(l, nn.BatchNorm1d):
                x = F.relu(x) 
        z = self.out(x)
        z = F.normalize(z, dim = 1)
        return z

class MultimodalNet(nn.Module):
    def __init__(self, img_enc, img_kwargs, mode_encs, mode_kwargs, latent_dim = 20, p = 0.3, device = 'cpu'):
        """
        Multimodal network combining image encoder with other modalities

            img_enc -- Image encoder
            mode_encs -- List of models encoding each modality
            mode_kwargs -- List of parameter dictionaries for each modality encoder
                mode_kwargs should not have an out_dim key, but every model should have 
                an out_dim argument
            latent_dim -- Dimension of shared latent space
        """
        super(MultimodalNet, self).__init__()
        self.im_model = img_enc(**img_kwargs, out_dim = latent_dim).to(device)
        self.model = []
        for i in range(len(mode_encs)):
            self.model.append(mode_encs[i](**mode_kwargs[i], out_dim = latent_dim).to(device))

    def forward(self, im, modes):
        """
            im -- image tensor (batch_size x ...)
            modes -- list of mode batch tensors
        """
        z_im = self.im_model(im)
        z_modes = []
        for i in range(len(modes)):
            z_modes.append(self.model[i](modes[i]))
        
        return z_im, z_modes

In [None]:
LARGE_NUM = 1e9

class NTXentLoss(torch.nn.Module):
    """
    Cross-modal self-supervised contrastive loss
        https://github.com/HealthML/ContIG/blob/main/models/cross_modal_loss.py
    """
    def __init__(self, device, batch_size, temperature, alpha_weight):
        """Compute loss for model.
        temperature: a `floating` number for temperature scaling.
        weights: a weighting number or vector.
        """
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.alpha_weight = alpha_weight
        self.device = device
        self.softmax = torch.nn.Softmax(dim=-1)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")

    def softXEnt(self, target, logits):
        """
        From the pytorch discussion Forum:
            https://discuss.pytorch.org/t/soft-cross-entropy-loss-tf-has-it-does-pytorch-have-it/69501
        """
        logprobs = torch.nn.functional.log_softmax(logits, dim=1)
        loss = -(target * logprobs).sum() / logits.shape[0]
        return loss

    def forward(self, zis, zjs, norm=True):
        temperature = self.temperature
        alpha = self.alpha_weight

        # Get (normalized) hidden1 and hidden2.
        if norm:
            zis = F.normalize(zis, p=2, dim=1)
            zjs = F.normalize(zjs, p=2, dim=1)

        hidden1, hidden2 = zis, zjs
        batch_size = hidden1.shape[0]

        hidden1_large = hidden1
        hidden2_large = hidden2
        labels = F.one_hot(
            torch.arange(start=0, end=batch_size, dtype=torch.int64),
            num_classes=batch_size,
        ).float()
        labels = labels.to(self.device)

        # Different from Image-Image contrastive learning
        # In the case of Image-Gen contrastive learning we do not compute the intra-modal similarity
        # masks = F.one_hot(
        #     torch.arange(start=0, end=batch_size, dtype=torch.int64),
        #     num_classes=batch_size,
        # )
        # logits_aa = torch.matmul(hidden1, torch.transpose(hidden1_large,0, 1)) / temperature
        # logits_aa = logits_aa - masks * LARGE_NUM
        # logits_bb = torch.matmul(hidden2,  torch.transpose(hidden2_large,0, 1)) / temperature
        # logits_bb = logits_bb - masks * LARGE_NUM

        logits_ab = (
            torch.matmul(hidden1, torch.transpose(hidden2_large, 0, 1)) / temperature
        )
        logits_ba = (
            torch.matmul(hidden2, torch.transpose(hidden1_large, 0, 1)) / temperature
        )

        loss_a = self.softXEnt(labels, logits_ab)
        loss_b = self.softXEnt(labels, logits_ba)

        return alpha * loss_a + (1 - alpha) * loss_b + F.mse_loss(logits_ab, logits_ba) # add that the logits should be the same?

In [None]:
def epoch_contrastive(model, criterion, loader, epoch, 
                       w = 1, optimizer = None, device = 'cpu'):
    """
    standard contrastive epoch
    """
    # print(criterion)
    if optimizer:
        model.train()
        mode = 'Train'
    else:
        model.eval()
        mode = 'Val'

    train_loss = []
    batches = tqdm(enumerate(loader), total=len(loader))
    batches.set_description("Epoch NA: Loss (NA)")

    for batch_idx, (im, modes, y) in batches:
        im, modes, y = im.to(device), [mode.to(device) for mode in modes], y.to(device)
        z_im, z_modes = model(im, modes)
        loss = 0
        for z_mode in z_modes:
            loss = loss + w * criterion(z_im, z_mode)#, y.to(torch.int64))
            
        if optimizer:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        train_loss.append(loss.item())
        batches.set_description(
            "Epoch {:d}: {:s} Loss ({:.2e})".format(
                epoch, mode, loss.item()
            )
        )

    return np.mean(train_loss)

def epoch_standard(model, criterion, loader, epoch, optimizer = None, device = 'cpu'):
    """
    standard epoch
    """
    if optimizer:
        model.train()
        mode = 'Train'
    else:
        model.eval()
        mode = 'Val'

    train_loss = []
    batches = tqdm(enumerate(loader), total=len(loader))
    batches.set_description("Epoch NA: Loss (NA) ACC (NA)")

    count = 0
    correct = 0
    weight = torch.tensor([sum(y_train1==1)/len(y_train1), sum(y_train1==0)/len(y_train1)]).float().to(device)
    for batch_idx, (x, _, y) in batches:
        x, y = x.to(device), y.to(device)
        z = model(x)
        # print(z.shape, y.shape, type(z), type(y))
        loss = F.cross_entropy(z, y.to(torch.int64), weight = weight) # criterion isnt working??
        if optimizer:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        correct += (z.max(axis = 1).indices == y).float().sum()
        count += y.shape[0]

        train_loss.append(loss.item())
        batches.set_description(
            "Epoch {:d}: {:s} Loss ({:.2e}) ACC ({:.2e})".format(
                epoch, mode, loss.item(), 100 * correct / count
            )
        )

    return np.mean(train_loss), (100 * correct/count).detach().cpu().numpy()


In [None]:
train_dataset = ImageGenoDataset(torch.from_numpy(demo_train1), [torch.from_numpy(geno_train1)], torch.from_numpy(y_train1), transform = None)
val_dataset = ImageGenoDataset(torch.from_numpy(demo_val), [torch.from_numpy(geno_val)], torch.from_numpy(y_val), transform = None)
test_dataset = ImageGenoDataset(torch.from_numpy(demo_test), [torch.from_numpy(geno_test)], torch.from_numpy(y_test), transform = None)

BATCH_SIZE = 128
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE, shuffle = False)
test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = False)

for im, geno, y in train_loader:
    print(im.shape, geno[0].shape, y.shape)
    break

In [None]:
Z_DIM = 30
LR = 1e-3
demo_kwargs = {
    'x_dim' : 54, 
    'fc_dims' : [200, 50],
    'use_bn' : True
}
geno_kwargs = {
    'x_dim' : 40, 
    'fc_dims' : [200, 50],
    'use_bn' : True
}
multi_model = MultimodalNet(SNPEncoder, demo_kwargs, 
                            [SNPEncoder], [geno_kwargs],
                            latent_dim = Z_DIM,
                            device = device).to(device)
optimizer = optim.AdamW(multi_model.parameters(), lr = LR)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma = 0.95)
criterion_cl = NTXentLoss(device, 
                          BATCH_SIZE, 
                          temperature = 0.1, 
                          alpha_weight = 0.5) # 0.75 weight towards images as in ContIG paper

In [None]:
# fit model
NUM_EPOCHS = 25
###### Train Model ######
train_losses = []
val_losses = []
best_val = float("inf")
best_epoch = 0

for epoch in tqdm(range(NUM_EPOCHS)):

    # train 
    train_loss = epoch_contrastive(multi_model, criterion_cl, train_loader, epoch, w = 1, optimizer = optimizer, device = device)
    train_losses.append(train_loss)

    # eval 
    val_loss = epoch_contrastive(multi_model, criterion_cl, val_loader, epoch, w = 1, optimizer = None, device = device)
    val_losses.append(val_loss)

    scheduler.step()

    # retain best val
    if best_val > val_losses[-1]:
        print(f"Updating at epoch {epoch}")
        best_val = val_losses[-1]
        best_epoch = epoch
        # save model parameter/state dictionary
        best_model = copy.deepcopy(multi_model.state_dict())

# load best weights
print(f"Best epoch at {best_epoch} with {'NTXent'} loss: {best_val}")

In [None]:
plt.plot(list(range(NUM_EPOCHS)), train_losses, label='Training', c = 'b')
plt.plot(list(range(NUM_EPOCHS)), val_losses, label='Validation', c = 'm')

In [None]:
# multi_model.load_state_dict(best_model)

LR = 3e-5
gamma = 0.95
classifier = nn.Sequential(multi_model.im_model, 
                            ADModel(Z_DIM, 50, 20, 10)).to(device)
weight = torch.tensor([sum(y_train1==1)/len(y_train1), sum(y_train1==0)/len(y_train1)]).to(device)
criterion_ce = nn.CrossEntropyLoss(weight=weight)
optimizer = optim.AdamW(classifier.parameters(), lr = LR)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma = gamma)

In [None]:
# fit model
NUM_EPOCHS = 20
###### Train Model ######
train_losses = []
train_accs = []
val_losses = []
val_accs = []
best_val = float("inf")
best_epoch = 0

for epoch in tqdm(range(NUM_EPOCHS)):

    # train 
    train_loss, train_acc = epoch_standard(classifier, criterion_ce, train_loader, epoch, optimizer = optimizer, device = device)
    train_losses.append(train_loss)
    train_accs.append(train_acc)

    # eval 
    val_loss, val_acc = epoch_standard(classifier, criterion_ce, val_loader, epoch, optimizer = None, device = device)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    scheduler.step()
    # retain best val
    if best_val > val_losses[-1]:
        print(f"Updating at epoch {epoch}")
        best_val = val_losses[-1]
        best_epoch = epoch
        # save model parameter/state dictionary
        best_model = copy.deepcopy(classifier.state_dict())
classifier.load_state_dict(best_model)
# load best weights
print(f"Best epoch at {best_epoch} with {'CrossEntropy'} loss: {best_val}")

In [None]:
plt.plot(list(range(NUM_EPOCHS)), train_losses, label='Training', c = 'b')
plt.plot(list(range(NUM_EPOCHS)), val_losses, label='Validation', c = 'm')
plt.title('Loss')
plt.legend()
plt.show()

plt.plot(list(range(NUM_EPOCHS)), train_accs, label='Training', c = 'b')
plt.plot(list(range(NUM_EPOCHS)), val_accs, label='Validation', c = 'm')
plt.title('Accuracy')
plt.legend()
plt.show()

In [None]:
y_pred = []
y_true = []
for im, geno, y in test_loader:
    y = y.type(torch.int64)
    im = im.to(device)
    y_pred += list(classifier(im).detach().cpu().numpy())
    y_true += list(y.detach().cpu().numpy())
    
get_performance(y_true, np.stack(y_pred, axis=0))

In [None]:
contrast_pred = np.stack(y_pred, axis=0)

In [None]:
predictions = [
    np.stack([np.ones(len(y_test)), np.zeros(len(y_test))], axis=1), 
    demo_logreg, logreg_pred, 
    randomforest_pred_demo, randomforest_pred_comb, 
    xgb_pred_demo, xgb_pred_comb, 
    mlp_preds_demo, mlp_preds_comb, 
    contrast_pred
]
names = [
    "MajorityClass", 
    "LogisticReg (lab)", "LogisticReg (lab+genotype)", 
    "BalancedRF (lab)", "BalancedRF (lab+genotype)", 
    "XGBoost (lab)", "XGBoost (lab+genotype)", 
    "MLP (lab)", "MLP (lab+genotype)",
    "ContIG"
]

In [None]:
plt.figure(figsize=(8,6))
for i, p in enumerate(predictions):
    print(names[i])
    get_performance(y_true, p)
    auc = roc_auc_score(y_true, p[:,1])
    fpr, tpr, _ = roc_curve(y_test, p[:,1])
    plt.plot(fpr, tpr, label=names[i] + f" AUC: {auc:.3f}")
    plt.ylabel("True Positive Rate")
    plt.xlabel("False Positive Rate")
plt.legend()
plt.savefig("model_performance.png", dpi=400)
plt.show()

In [None]:
[len(p) for p in predictions]

In [None]:
from statsmodels.miscmodels.ordinal_model import OrderedModel

demographic_df1["AD"] = label

mod_log = OrderedModel(
    demographic_df1["AD"], demographic_df1[demographic_df1.columns[:-1]], distr="logit"
)
res_log = mod_log.fit(method="bfgs", disp=False)
res_log.summary()


In [None]:
demographic_df1.columns[:-1]