In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import confusion_matrix,roc_curve,auc,recall_score, precision_score,accuracy_score,matthews_corrcoef, f1_score
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import warnings
warnings.filterwarnings("ignore")
from sklearn import svm
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.decomposition import PCA
from tqdm import tqdm
from torch.utils.data import random_split
from scipy.interpolate import make_interp_spline
from scipy.signal import savgol_filter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.manual_seed(3407)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')
torch.cuda.get_device_name()

'NVIDIA GeForce RTX 3090 Ti'

In [3]:
cli_data_path = '/data/minwenwen/lixiaoyu/TCGA/data/BRCA/BRCA-300/BRCA_Clinical.csv'
cna_data_path = '/data/minwenwen/lixiaoyu/TCGA/data/BRCA/BRCA-300/BRCA_CNA.csv'
rna_data_path = '/data/minwenwen/lixiaoyu/TCGA/data/BRCA/BRCA-300/BRCA_RNA_V2_EXP.csv'
mic_data_path = '/data/minwenwen/lixiaoyu/TCGA/data/BRCA/BRCA-300/BRCA_microbiome.csv'
label_data_path = '/data/minwenwen/lixiaoyu/TCGA/data/BRCA/BRCA-300/BRCA_label.csv'
class MyDataset(Dataset):
    def __init__(self, data_path, label_path):
        super(MyDataset, self).__init__()
        self.data = pd.read_csv(data_path, index_col=0)
        self.label = pd.read_csv(label_path)
        self.x_data = torch.tensor(self.data.values, dtype=torch.float32)
        self.y_data = torch.tensor(self.label.values, dtype=torch.float32)
    def __getitem__(self, idx):
        return self.x_data[idx], self.y_data[idx]
    def __len__(self):
        return len(self.x_data)

cli_data = MyDataset(cli_data_path, label_data_path)
cna_data = MyDataset(cna_data_path, label_data_path)
rna_data = MyDataset(rna_data_path, label_data_path)
mic_data = MyDataset(mic_data_path, label_data_path)

train_size = int(0.8 * len(cli_data))
test_size = len(cli_data) - train_size
cli_train_dataset, cli_test_dataset = random_split(cli_data, [train_size, test_size], generator=torch.Generator().manual_seed(3407))
valid_size = int(0.25 * len(cli_train_dataset))
train_size = len(cli_train_dataset) - valid_size
cli_train_dataset, cli_valid_dataset = random_split(cli_train_dataset, [train_size, valid_size], generator=torch.Generator().manual_seed(3407))

cli_train_dataloader = DataLoader(cli_train_dataset, batch_size=len(cli_train_dataset))
cli_valid_dataloader = DataLoader(cli_valid_dataset, batch_size=len(cli_valid_dataset))
cli_test_dataloader = DataLoader(cli_test_dataset, batch_size=len(cli_test_dataset))

train_size = int(0.8 * len(cna_data))
test_size = len(cna_data) - train_size
cna_train_dataset, cna_test_dataset = random_split(cna_data, [train_size, test_size], generator=torch.Generator().manual_seed(3407))
valid_size = int(0.25 * len(cna_train_dataset))
train_size = len(cna_train_dataset) - valid_size
cna_train_dataset, cna_valid_dataset = random_split(cna_train_dataset, [train_size, valid_size], generator=torch.Generator().manual_seed(3407))
cna_train_dataloader = DataLoader(cna_train_dataset, batch_size=len(cna_train_dataset))
cna_valid_dataloader = DataLoader(cna_valid_dataset, batch_size=len(cna_valid_dataset))
cna_test_dataloader = DataLoader(cna_test_dataset, batch_size=len(cna_test_dataset))

train_size = int(0.8 * len(rna_data))
test_size = len(rna_data) - train_size
rna_train_dataset, rna_test_dataset = random_split(rna_data, [train_size, test_size], generator=torch.Generator().manual_seed(3407))
valid_size = int(0.25 * len(rna_train_dataset))
train_size = len(rna_train_dataset) - valid_size
rna_train_dataset, rna_valid_dataset = random_split(rna_train_dataset, [train_size, valid_size], generator=torch.Generator().manual_seed(3407))
rna_train_dataloader = DataLoader(rna_train_dataset, batch_size=len(rna_train_dataset))
rna_valid_dataloader = DataLoader(rna_valid_dataset, batch_size=len(rna_valid_dataset))
rna_test_dataloader = DataLoader(rna_test_dataset, batch_size=len(rna_test_dataset))

train_size = int(0.8 * len(mic_data))
test_size = len(mic_data) - train_size
mic_train_dataset, mic_test_dataset = random_split(mic_data, [train_size, test_size], generator=torch.Generator().manual_seed(3407))
valid_size = int(0.25 * len(mic_train_dataset))
train_size = len(mic_train_dataset) - valid_size
mic_train_dataset, mic_valid_dataset = random_split(mic_train_dataset, [train_size, valid_size], generator=torch.Generator().manual_seed(3407))
mic_train_dataloader = DataLoader(mic_train_dataset, batch_size=len(mic_train_dataset))
mic_valid_dataloader = DataLoader(mic_valid_dataset, batch_size=len(mic_valid_dataset))
mic_test_dataloader = DataLoader(mic_test_dataset, batch_size=len(mic_test_dataset))

In [4]:
x = pd.read_csv('/data/minwenwen/lixiaoyu/TCGA/data/BRCA/BRCA-300/BRCA_CNA.csv', index_col=0)
x = torch.tensor(x.values, dtype=torch.float32)
x.shape[1]

300

In [5]:
#MultiHeadAttention
class Attention(nn.Module):
    def __init__(self, in_size, hidden_size=16):
        super(Attention, self).__init__()
        self.project = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1, bias=False)
        )
    def forward(self, z):
        w = self.project(z)
        beta = torch.softmax(w, dim=1)
        return (beta * z).sum(1), beta

class ScaledDotProductAttention(nn.Module):
    """ Scaled Dot-Product Attention """
    def __init__(self, scale):
        super().__init__()
        self.scale = scale
        self.softmax = nn.Softmax(dim=1)
    def forward(self, q, k, v, mask=None):
        u = torch.matmul(q, k.transpose(-2, -1))
        u = u / self.scale
        if mask is not None:
            u = u.masked_fill(mask, -np.inf)
        attn = self.softmax(u)
        output = torch.matmul(attn, v)
        return attn, output

class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, d_k_, d_v_, d_k, d_v, d_o):
        super().__init__()
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v
        self.fc_q = nn.Linear(d_k_, n_head * d_k)
        self.fc_k = nn.Linear(d_k_, n_head * d_k)
        self.fc_v = nn.Linear(d_v_, n_head * d_v)
        self.attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))
        self.fc_o = nn.Linear(n_head * d_v, d_o)
    def forward(self, q, k, v, mask=None):
        n_head, d_q, d_k, d_v = self.n_head, self.d_k, self.d_k, self.d_v
        n_q, d_q_ = q.size()
        n_k, d_k_ = k.size()
        n_v, d_v_ = v.size()
        q = self.fc_q(q)
        k = self.fc_k(k)
        v = self.fc_v(v)
        q = q.view(n_q, n_head, d_q).permute(1, 0, 2).contiguous().view(-1, n_q, d_q)
        k = k.view(n_k, n_head, d_k).permute(1, 0, 2).contiguous().view(-1, n_k, d_k)
        v = v.view(n_v, n_head, d_v).permute(1, 0, 2).contiguous().view(-1, n_v, d_v)
        if mask is not None:
            mask = mask.repeat(n_head, 1, 1)
        attn, output = self.attention(q, k, v, mask=mask)
        output = output.view(n_head, n_q, d_v).permute(0, 1, 2).contiguous().view(n_q, -1)
        output = self.fc_o(output)
        return attn, output

class MultiSelfAttention(nn.Module):
    def __init__(self, n_head, d_k, d_v, d_x, d_o):
        super().__init__()
        self.wq = nn.Parameter(torch.Tensor(d_x, d_k))
        self.wk = nn.Parameter(torch.Tensor(d_x, d_k))
        self.wv = nn.Parameter(torch.Tensor(d_x, d_v))
        self.mha = MultiHeadAttention(n_head=n_head, d_k_=d_k, d_v_=d_v, d_k=d_k, d_v=d_v, d_o=d_o)
        self.init_parameters()
    def init_parameters(self):
        for param in self.parameters():
            stdv = 1. / np.power(param.size(-1), 0.5)
            param.data.uniform_(-stdv, stdv)
    def forward(self, x, mask=None):
        q = torch.matmul(x, self.wq)
        k = torch.matmul(x, self.wk)
        v = torch.matmul(x, self.wv)
        attn, output = self.mha(q, k, v, mask=mask)
        return output

class Sample(nn.Module):
    def __init__(self):
        super(Sample, self).__init__()
    def forward(self, z_mean, z_log_var):
        eps = torch.randn(z_mean.shape)
        eps = eps.to(device)
        std = torch.exp(z_log_var / 2)
        out = z_mean + std*eps
        return out

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(400, 256),
            nn.Dropout(p=0.15, inplace=False),
            nn.LeakyReLU(),
            nn.Linear(256, 64),
            nn.Dropout(p=0.15, inplace=False),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.Dropout(p=0.15, inplace=False),
            nn.LeakyReLU()
        )
        self.z_mean = nn.Linear(32, 6)
        self.z_log_var = nn.Linear(32, 6)
        self.sample = Sample()

    def forward(self, x):
        out = self.model(x)
        z_mean = self.z_mean(out)
        z_log_var = self.z_log_var(out)
        out = self.sample(z_mean, z_log_var)
        return out, z_mean, z_log_var

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(6, 32),
            nn.Dropout(p=0.15, inplace=False),
            nn.LeakyReLU(),
            nn.Linear(32, 64),
            nn.Dropout(p=0.15, inplace=False),
            nn.LeakyReLU(),
            nn.Linear(64, 256),
            nn.Dropout(p=0.15, inplace=False),
            nn.LeakyReLU(),
            nn.Linear(256, 50),
            nn.Dropout(p=0.15, inplace=False),
            nn.LeakyReLU()
        )

    def forward(self, z):
        out = self.model(z)
        return out

class Classfier(nn.Module):
    def __init__(self):
        super(Classfier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(6, 100, bias = True), #input layer
            nn.BatchNorm1d(100, eps = 1e-3),
            nn.LeakyReLU(),
            nn.Dropout(p=0.7, inplace=False),
            nn.Linear(100, 2, bias = True), # 2 hidden layer
            nn.Softmax(dim = 1)    #outpust layer
        )
    def forward(self, x):
        x = self.model(x)
        return x

class MAVC(nn.Module):
    def __init__(self):
        super(MAVC, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.classfier = Classfier()
        self.attn1 = MultiSelfAttention(30, 64, 32, 22, 100)
        self.attn2 = MultiSelfAttention(30, 64, 32, 300, 100)

    def forward(self, clinical_input, cna_input, micro_input, rna_input):
        cli_in = self.attn1(clinical_input)
        cna_in = self.attn2(cna_input)
        mic_in = self.attn2(micro_input)
        rna_in = self.attn2(rna_input)
        x = torch.cat((cli_in, cna_in, mic_in, rna_in), dim = 1)
        x, mean, var = self.encoder(x)
        x = self.classfier(x)
        return x


In [6]:
model = MAVC()
model.to(device)
learning_rate = 1e-3
optimizer = torch.optim.RAdam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
criterion.to(device)
num_epoch = 100
max_accuracy = 0
vae_acc = 0
vae_pre = 0
vae_sen = 0
vae_f1s = 0

In [7]:
for epoch in tqdm(range(num_epoch), leave=False):
    model.train()
    for i, data in enumerate(zip(cli_train_dataloader,cna_train_dataloader,mic_train_dataloader,rna_train_dataloader)):
        clinical_data, cna_data, mic_data, rna_data, y_t = data[0][0], data[1][0], data[2][0], data[3][0], data[0][1]
        clinical_data = clinical_data.to(device)
        cna_data = cna_data.to(device)
        mic_data = mic_data.to(device)
        rna_data = rna_data.to(device)
        y_t = y_t.to(device)
        y_pred = model(clinical_data, cna_data, mic_data, rna_data)
        y_t = y_t.squeeze(dim = 1)
        loss = criterion(y_pred, y_t.long())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for i, data in enumerate(zip(cli_valid_dataloader,cna_valid_dataloader,mic_valid_dataloader,rna_valid_dataloader)):
            clinical_data, cna_data, mic_data, rna_data, y_t = data[0][0], data[1][0], data[2][0], data[3][0], data[0][1]
            clinical_data = clinical_data.to(device)
            cna_data = cna_data.to(device)
            mic_data = mic_data.to(device)
            rna_data = rna_data.to(device)
            y_t = y_t.to(device)
            y_pred = model(clinical_data, cna_data, mic_data, rna_data)
            y_t = y_t.squeeze(dim = 1)
            loss = criterion(y_pred, y_t.long())
            total_test_loss = total_test_loss + loss.item()
            accuracy = (y_pred.argmax(1) == y_t).sum()
            total_accuracy = total_accuracy + accuracy
    if((total_accuracy/len(cna_valid_dataset)) > max_accuracy):
        max_accuracy = total_accuracy/len(cna_valid_dataset)
        torch.save(model, "/data/minwenwen/lixiaoyu/TCGA/Model/MAVC_Best.pth")

model = torch.load('/data/minwenwen/lixiaoyu/TCGA/Model/MAVC_Best.pth')
model.eval()
with torch.no_grad():
    for i, data in enumerate(zip(cli_test_dataloader,cna_test_dataloader,mic_test_dataloader,rna_test_dataloader)):
        clinical_data, cna_data, mic_data, rna_data, mavc_y_t = data[0][0],data[1][0],data[2][0],data[3][0],data[0][1]
        clinical_data = clinical_data.to(device)
        cna_data = cna_data.to(device)
        mic_data = mic_data.to(device)
        rna_data = rna_data.to(device)
        mavc_y_t = y_t.to(device)
        mavc_y_pred = model(clinical_data, cna_data, mic_data, rna_data)
        # mavc_y_t = mavc_y_t.squeeze(dim = 1)

        mavc_y_t = mavc_y_t.detach().cpu().numpy()
        mavc_y_pred = mavc_y_pred.detach().cpu().numpy()
        mavc_acc = accuracy_score(mavc_y_t, mavc_y_pred.argmax(1))
        mavc_pre = precision_score(mavc_y_t, mavc_y_pred.argmin(1))
        mavc_sen = recall_score(mavc_y_t, mavc_y_pred.argmin(1))
        mavc_f1s = f1_score(mavc_y_t, mavc_y_pred.argmin(1))
        mavc_y_pred = np.amax(mavc_y_pred, axis=1)



                                                                                                                                                                                                                                                             

In [8]:
model = torch.load('/data/minwenwen/lixiaoyu/TCGA/Model/MAVC_Best.pth')
model.eval()
with torch.no_grad():
    for i, data in enumerate(zip(cli_test_dataloader,cna_test_dataloader,mic_test_dataloader,rna_test_dataloader)):
        clinical_data, cna_data, mic_data, rna_data, mavc_y_t = data[0][0],data[1][0],data[2][0],data[3][0],data[0][1]
        clinical_data = clinical_data.to(device)
        cna_data = cna_data.to(device)
        mic_data = mic_data.to(device)
        rna_data = rna_data.to(device)
        mavc_y_t = y_t.to(device)
        mavc_y_pred = model(clinical_data, cna_data, mic_data, rna_data)
        # mavc_y_t = mavc_y_t.squeeze(dim = 1)

        mavc_y_t = mavc_y_t.detach().cpu().numpy()
        mavc_y_pred = mavc_y_pred.detach().cpu().numpy()
        mavc_acc = accuracy_score(mavc_y_t, mavc_y_pred.argmax(1))
        mavc_pre = precision_score(mavc_y_t, mavc_y_pred.argmin(1))
        mavc_sen = recall_score(mavc_y_t, mavc_y_pred.argmin(1))
        mavc_f1s = f1_score(mavc_y_t, mavc_y_pred.argmin(1))
        mavc_y_pred = np.amax(mavc_y_pred, axis=1)


In [9]:
mavc_fpr, mavc_tpr, mavc_threshold = roc_curve(mavc_y_t, mavc_y_pred)
mavc_roc_auc = auc(mavc_fpr, mavc_tpr)

ValueError: Input contains NaN.