In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("..")
                
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', 100)

# viz
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(rc={'figure.figsize':(12.7,10.27)})

# notebook settings
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('retina')

import os
os.environ["CUDA_VISIBLE_DEVICES"]="3"

In [3]:
def getTCGA(disease):
    path = "/srv/nas/mk2/projects/pan-cancer/TCGA_CCLE_GCP/TCGA/TCGA_{}_counts.tsv.gz"
    files = [path.format(d) for d in disease]
    return files

In [4]:
def readGCP(files):
    """
    Paths to count matrices.
    """
    data_dict = {}
    for f in files:
        key = os.path.basename(f).split("_")[1]
        data = pd.read_csv(f, sep='\t', index_col=0)
        meta = pd.DataFrame([row[:-1] for row in data.index.str.split("|")],
                            columns=['ENST', 'ENSG', 'OTTHUMG', 'OTTHUMT', 'GENE-NUM', 'GENE', 'NUM', 'TYPE'])
        data.index = meta['GENE']
        data_dict[key] = data.T
    return data_dict

In [5]:
def renameTCGA(data_dict, mapper):
    for key in data_dict.keys():
        data_dict[key] = data_dict[key].rename(mapper)
    return data_dict

In [6]:
def uq_norm(df, q=0.75):
    """
    Upper quartile normalization of GEX for samples.
    """
    quantiles = df.quantile(q=q, axis=1)
    norm = df.divide(quantiles, axis=0)
    return norm

In [7]:
base = "/srv/nas/mk2/projects/pan-cancer/TCGA_CCLE_GCP"
disease = ['BRCA', 'LUAD', 'KIRC', 'THCA', 'PRAD', 'SKCM']

tcga_files = getTCGA(disease)
tcga_meta = pd.read_csv(os.path.join(base, "TCGA/TCGA_GDC_ID_MAP.tsv"), sep="\t")
tcga_raw = readGCP(tcga_files)

KeyboardInterrupt: 

In [None]:
# rename samples to reflect canonical IDs
tcga_raw = renameTCGA(tcga_raw, mapper=dict(zip(tcga_meta['CGHubAnalysisID'], tcga_meta['Sample ID'])))

In [None]:
# combine samples
tcga_raw = pd.concat(tcga_raw.values())

## Normalization

In [None]:
# Upper quartile normalization
tcga_raw = uq_norm(tcga_raw)

In [None]:
# log norm
tcga = tcga_raw.transform(np.log1p)

In [None]:
# downsample
tcga = tcga.sample(n=18000, axis=1)

In [None]:
tcga_meta[tcga_meta['Sample ID'] == 'TCGA-A7-A26F-01B']

# Model

### Experimental Setup

In [None]:
from collections import OrderedDict 
hierarchy = OrderedDict({'Disease':['BRCA', 'LUAD', 'KIRC', 'THCA', 'PRAD', 'SKCM'],
                         'Sample Type':['Primary Tumor', 'Solid Tissue Normal', 'Metastatic']})

In [None]:
class Experiment():
    """
    Defines an experimental class hierarchy object.
    """
    def __init__(self, meta_data, hierarchy, cases, min_samples):
        self.hierarchy = hierarchy
        self.meta_data = self.categorize(meta_data, self.hierarchy, min_samples)
        self.cases = self.meta_data[cases].unique()
        self.labels = self.meta_data['meta'].cat.codes.values.astype('int')
        self.labels_dict = {key:val for key,val in enumerate(self.meta_data['meta'].cat.categories.values)}
        
    def categorize(self, meta_data, hierarchy, min_samples):
        assert isinstance(hierarchy, OrderedDict), "Argument of wrong type."
        # downsample data
        for key,val in hierarchy.items():
            meta_data = meta_data[meta_data[key].isin(val)]
        # unique meta classes
        meta_data['meta'] = meta_data[list(hierarchy.keys())].apply(lambda row: ':'.join(row.values.astype(str)), axis=1)
        # filter meta classes
        counts = meta_data['meta'].value_counts()
        keep = counts[counts > min_samples].index
        meta_data = meta_data[meta_data['meta'].isin(keep)]
        # generate class categories
        meta_data['meta'] = meta_data['meta'].astype('category')
        return meta_data
    
    def holdout(self, holdout):
        self.holdout = holdout
        self.holdout_samples = self.meta_data[self.meta_data['meta'].isin(holdout)]
        self.meta_data = self.meta_data[~self.meta_data['meta'].isin(holdout)]

In [None]:
from dutils import train_test_split_case
exp = Experiment(meta_data=tcga_meta,
                 hierarchy=hierarchy,
                 cases='Case ID',
                 min_samples=20)
exp.holdout(holdout=['SKCM:Metastatic'])

In [None]:
exp.meta_data['meta'].value_counts()
exp.holdout_samples['meta'].value_counts()

In [None]:
# Define Train / Test sample split
target = 'meta'

train, test = train_test_split_case(exp.meta_data, cases='Case ID')
# stratification is not quite perfect but close
# in order to preserve matched samples for each case together
# in train or test set
case_counts = exp.meta_data[target].value_counts()
train[target].value_counts()[case_counts.index.to_numpy()] / case_counts
test[target].value_counts()[case_counts.index.to_numpy()] / case_counts

In [None]:
# split data
train_data = tcga[tcga.index.isin(train['Sample ID'])].astype(np.float16)
test_data = tcga[tcga.index.isin(test['Sample ID'])].astype(np.float16)

In [None]:
import torch
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.autograd import Variable
#torch.manual_seed(123)

from trainer import fit
import visualization as vis
import numpy as np
cuda = torch.cuda.is_available()
print("Cuda is available: {}".format(cuda))

In [None]:
import torch
from torch.utils.data import Dataset

class SiameseDataset(Dataset):
    """
    Train: For each sample creates randomly a positive or a negative pair
    Test: Creates fixed pairs for testing
    """

    def __init__(self, experiment, data, train=False):
        self.train = train
        self.labels = experiment.meta_data[experiment
                                           .meta_data['Sample ID']
                                           .isin(data.index)]['meta'].cat.codes.values.astype('int')
        assert len(data) == len(self.labels)

        if self.train:
            self.train_labels = self.labels
            self.train_data = torch.from_numpy(data.values).float()
            self.labels_set = set(self.train_labels)
            self.label_to_indices = {label: np.where(self.train_labels == label)[0]
                                     for label in self.labels_set}
        else:
            # generate fixed pairs for testing
            self.test_labels = self.labels
            self.test_data = torch.from_numpy(data.values).float()
            self.labels_set = set(self.test_labels)
            self.label_to_indices = {label: np.where(self.test_labels == label)[0]
                                     for label in self.labels_set}

            random_state = np.random.RandomState(29)

            positive_pairs = [[i,
                               random_state.choice(self.label_to_indices[self.test_labels[i].item()]),
                               1]
                              for i in range(0, len(self.test_data), 2)]

            negative_pairs = [[i,
                               random_state.choice(self.label_to_indices[
                                                       np.random.choice(
                                                           list(self.labels_set - set([self.test_labels[i].item()]))
                                                       )
                                                   ]),
                               0]
                              for i in range(1, len(self.test_data), 2)]
            self.test_pairs = positive_pairs + negative_pairs

    def __getitem__(self, index):
        if self.train:
            target = np.random.randint(0, 2)
            img1, label1 = self.train_data[index], self.train_labels[index].item()
            if target == 1:
                siamese_index = index
                while siamese_index == index:
                    siamese_index = np.random.choice(self.label_to_indices[label1])
            else:
                siamese_label = np.random.choice(list(self.labels_set - set([label1])))
                siamese_index = np.random.choice(self.label_to_indices[siamese_label])
            img2 = self.train_data[siamese_index]
        else:
            img1 = self.test_data[self.test_pairs[index][0]]
            img2 = self.test_data[self.test_pairs[index][1]]
            target = self.test_pairs[index][2]
        
        return (img1, img2), target

    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

# Siamese Network

In [None]:
siamese_train_dataset = SiameseDataset(experiment=exp,
                                       data=train_data,
                                       train=True)
siamese_test_dataset = SiameseDataset(experiment=exp,
                                       data=test_data,
                                       train=False)

In [None]:
siamese_test_dataset.test_pairs

In [None]:
batch_size = 16
kwargs = {'num_workers': 10, 'pin_memory': True} if cuda else {'num_workers': 10}
siamese_train_loader = torch.utils.data.DataLoader(siamese_train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
siamese_test_loader = torch.utils.data.DataLoader(siamese_test_dataset, batch_size=batch_size, shuffle=False, **kwargs)

# Set up the network and training parameters
from tcga_networks import EmbeddingNet, SiameseNet
from losses import ContrastiveLoss, TripletLoss
from metrics import AccumulatedAccuracyMetric

# Step 2
n_samples, n_features = siamese_train_dataset.train_data.shape
embedding_net = EmbeddingNet(n_features, 2)
# Step 3
model = SiameseNet(embedding_net)
if cuda:
    model.cuda()
    
# Step 4
margin = 1.
loss_fn = ContrastiveLoss(margin)
lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
n_epochs = 10
# print training metrics every log_interval * batch_size
log_interval = round(len(siamese_train_dataset)/4/batch_size)

In [None]:
print('Active CUDA Device: GPU', torch.cuda.current_device())

print ('Available devices ', torch.cuda.device_count())
print ('Current cuda device ', torch.cuda.current_device())

In [None]:
train_loss, val_loss = fit(siamese_train_loader, siamese_test_loader, model, loss_fn, optimizer, scheduler, 
    n_epochs, cuda, log_interval)

In [None]:
plt.plot(range(0, n_epochs), train_loss, 'rx-', label='train')
plt.plot(range(0, n_epochs), val_loss, 'bx-', label='validation')
plt.legend()

In [None]:
    
def extract_embeddings(samples, target, model):
    cuda = torch.cuda.is_available()
    with torch.no_grad():
        model.eval()
        assert len(samples) == len(target)
        embeddings = np.zeros((len(samples), 2))
        labels = np.zeros(len(target))
        k = 0
        if cuda:
            samples = samples.cuda()
        if isinstance(model, torch.nn.DataParallel):
            embeddings[k:k+len(samples)] = model.module.get_embedding(samples).data.cpu().numpy()
        else:
            embeddings[k:k+len(samples)] = model.get_embedding(samples).data.cpu().numpy()
        labels[k:k+len(samples)] = target
        k += len(samples)
    return embeddings, labels

In [None]:
train_embeddings_cl, train_labels_cl = extract_embeddings(siamese_train_dataset.train_data, siamese_train_dataset.labels, model)
vis.sns_plot_embeddings(train_embeddings_cl, train_labels_cl, exp.labels_dict, 
                        hue='meta', style='Sample Type', alpha=0.5)
plt.title('PanCancer Train: Siamese')
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

In [None]:
val_embeddings_baseline, val_labels_baseline = extract_embeddings(siamese_test_dataset.test_data, siamese_test_dataset.labels, model)
vis.sns_plot_embeddings(val_embeddings_baseline, val_labels_baseline, exp.labels_dict, 
                        hue='meta', style='Sample Type', alpha=0.5)
plt.title('PanCancer Test: Siamese')
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)