In [None]:
import sys
sys.path.append('..')

import os
import matplotlib.pyplot as plt
import numpy as np
from skimage.filters import gaussian, threshold_otsu
from skimage.transform import resize
from skimage import io
import pickle
from skimage.measure import regionprops

import matplotlib.pyplot as plt
import torch
import time
import model.model_cnnvae_conditional
import model.optimizer as optimizer
from sklearn.decomposition import PCA
import pandas as pd

In [None]:
sharedSizes=[1024]
dSpecific_filter=[(200,16)]
pID_type='randInit'
pIDemb_size=64

In [None]:
holdOutSamples=['HV1','P22','P14','P27']

In [None]:
sourceDir='/data/xinyi/c2p/data/chromark'
segDir=os.path.join(sourceDir,'nuclear_masks')
imgDir=os.path.join(sourceDir,'raw_data')
conditions=['controls','headneck','meningioma', 'glioma']

outSize=128
savename='pathCentered_'+str(outSize)

imgsC_all=None
imgsP_all=None
imgNames_all=None
proteinNames=None
pID_all=None
for condition_i in conditions:
    print(condition_i)
    segDir_i=os.path.join(segDir,condition_i)
    imgDir_i=os.path.join(imgDir,condition_i)
    for stain in os.listdir(segDir_i):
        print(stain)
        segDir_i_stain=os.path.join(segDir_i,stain)
        imgDir_i_stain=os.path.join(imgDir_i,stain)
        
        segPID2name={}
        for pID_dir in os.listdir(segDir_i_stain):
            pID=pID_dir.split('_')
            segPID2name[pID[0]]=pID_dir
        imgPID2name={}
        for pID_dir in os.listdir(imgDir_i_stain):
            pID=pID_dir.split('_')
            imgPID2name[pID[0]]=pID_dir
        for pID in segPID2name.keys():
            if condition_i=='meningioma' and stain=='dapi_gh2ax_lamin_cd3' and pID=='P33': #skipping incorrect images
                continue
            if pID in holdOutSamples:
                print('hold out: '+pID)
                continue
            print(pID)
            if pID not in imgPID2name:
                print('img not found '+pID)
                continue
            imgDir_i_stain_p=os.path.join(imgDir_i_stain,imgPID2name[pID])
            segDir_i_stain_p=os.path.join(segDir_i_stain,segPID2name[pID])
            
            with open(os.path.join(imgDir_i_stain_p,savename+'_imgNames'), 'rb') as output:
                imgNames=pickle.load(output)
            with open(os.path.join(imgDir_i_stain_p,savename+'_img'), 'rb') as output:
                img=pickle.load(output)

            imgP=np.zeros((img.shape[0],1,img.shape[2],img.shape[3]))
            proteinNames_curr=np.array([])
            stain_list=stain.split('_')
            nImgPerStain=int(img.shape[0]/(len(stain_list)-1))
            np.random.seed(3)
            allIdx_all=np.arange(img.shape[0])
            np.random.shuffle(allIdx_all)
            for s in range(1,len(stain_list)):
                s_start=(s-1)*nImgPerStain
                if s==len(stain_list)-1:
                    s_end=img.shape[0]
                else:
                    s_end=s*nImgPerStain
                imgP[s_start:s_end]=img[allIdx_all[s_start:s_end],s].reshape(s_end-s_start,1,img.shape[2],img.shape[3])
                proteinNames_curr=np.concatenate((proteinNames_curr,np.repeat(stain_list[s],s_end-s_start)))
            
            if pID_all is None:
                pID_all=np.repeat(pID,img.shape[0])
                imgsC_all=img[allIdx_all,[0]]
                imgNames_all=imgNames[allIdx_all]
                proteinNames=proteinNames_curr
                imgsP_all=imgP
            else:
                pID_all=np.concatenate((pID_all,np.repeat(pID,img.shape[0])))
                imgsC_all=np.concatenate((imgsC_all,img[allIdx_all,[0]]),axis=0)
                imgNames_all=np.concatenate((imgNames_all,imgNames[allIdx_all]))
                proteinNames=np.concatenate((proteinNames,proteinNames_curr))
                imgsP_all=np.concatenate((imgsP_all,imgP),axis=0)
imgsC_all=imgsC_all.reshape(imgsC_all.shape[0],1,imgsC_all.shape[1],imgsC_all.shape[2])

In [None]:
nProt=np.unique(proteinNames).size
pnames,revIdx,pCounts=np.unique(proteinNames,return_inverse=True,return_counts=True)
plabels=torch.tensor(np.arange(pnames.size)[revIdx]).long()

In [None]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
#VAE settings
seed=3
epochs=5001
saveFreq=100
lr=0.00001 #initial learning rate
weight_decay=0 #Weight for L2 loss on embedding matrix.

batchsize=256
kernel_size=4
stride=2
padding=1

# fc_dim1=6000
hidden1=64 #Number of channels in hidden layer 1
hidden2=128 
hidden3=256
hidden4=256
hidden5=96
hidden5_xy=4
fc_dim1=96*hidden5_xy*hidden5_xy
fc_dim2=6000

dropout=0.01
kl_weight=0.0000001





In [None]:
printFreq=1
def train(epoch,imgs,trainIdx,valIdx,latentp,matchSize_p,model_train,optimizer_train,model_trainShared,optimizer_trainShared,sharedOnly,pLabels):
    t = time.time()
    model_train.train()
    model_trainShared.train()

    loss_kl_train_all=0
    loss_x_train_all=0
    loss_x_trainShared_all=0
    loss_match_train_all=0
    loss_all=0
    ntrainBatches=int(np.ceil(trainIdx.shape[0]/batchsize))
    for i in range(ntrainBatches):
#         if i%200==0:
#         print(i)
        trainIdx_i=trainIdx[i*batchsize:min((i+1)*batchsize,trainIdx.shape[0])]
        trainInput=torch.tensor(imgs[trainIdx_i])
        trainInput_ID=pLabels[trainIdx_i].cuda()
#         print(trainInput.shape)

        trainInput=trainInput.cuda().float()
        optimizer_train.zero_grad()
        optimizer_trainShared.zero_grad()

        recon, z, mu, logvar = model_train(trainInput,trainInput_ID)
        reconShared=model_trainShared(z[:,:matchSize_p],model_train.pIDemb(trainInput_ID))
        
        loss_kl_train=loss_kl(mu, logvar)
        loss_x_train=loss_x(recon, trainInput)
        if latentp is not None:
            loss_match_train=loss_match(z[:,:matchSize_p],latentp[trainIdx_i,:matchSize_p])
        else:
            loss_match_train=0
        loss_xShared_train=loss_x(reconShared,trainInput)
        if sharedOnly:
            loss=loss_kl_train*kl_weight+loss_match_train*match_weight+loss_xShared_train*sharedWeight
        else:
            loss=loss_kl_train*kl_weight+loss_x_train+loss_match_train*match_weight+loss_xShared_train*sharedWeight

        loss_kl_train_all+=loss_kl_train.item()
        loss_x_train_all+=loss_x_train.item()
        if latentp is not None:
            loss_match_train_all+=loss_match_train.item()
        loss_x_trainShared_all+=loss_xShared_train.item()
        loss_all+=loss.item()
        
        loss.backward()
        optimizer_train.step()
        optimizer_trainShared.step()


    loss_kl_train_all=loss_kl_train_all/ntrainBatches
    loss_x_train_all=loss_x_train_all/ntrainBatches
    loss_match_train_all=loss_match_train_all/ntrainBatches
    loss_x_trainShared_all=loss_x_trainShared_all/ntrainBatches
    loss_all=loss_all/ntrainBatches

    with torch.no_grad():
        model_train.eval()
        model_trainShared.eval()

        loss_val_all=0
        loss_x_val_all=0
        loss_match_val_all=0
        loss_x_valShared_all=0
        nvalBatches=int(np.ceil(valIdx.shape[0]/batchsize))
        for i in range(nvalBatches):
            valIdx_i=valIdx[i*batchsize:min((i+1)*batchsize,valIdx.shape[0])]
            valInput=torch.tensor(imgs[valIdx_i])
            valInput=valInput.cuda().float()
            valInput_ID=pLabels[valIdx_i].cuda()
            recon,z, mu, logvar = model_train(valInput,valInput_ID)
            reconShared=model_trainShared(z[:,:matchSize_p],model_train.pIDemb(valInput_ID))

            loss_x_val=loss_x(recon, valInput).item()
            loss_x_valShared=loss_x(reconShared,valInput).item()
            if latentp is not None:
                loss_match_val=loss_match(z[:,:matchSize_p],latentp[valIdx_i,:matchSize_p]).item()
            else:
                loss_match_val=0
            if sharedOnly:
                loss_val=loss_match_val*match_weight+loss_x_valShared    
            else:
                loss_val=loss_x_val+loss_match_val*match_weight+loss_x_valShared

            loss_x_val_all+=loss_x_val
            loss_x_valShared_all+=loss_x_valShared
            if latentp is not None:
                loss_match_val_all+=loss_match_val
            loss_val_all+=loss_val

        loss_x_val_all=loss_x_val_all/nvalBatches
        loss_x_valShared_all=loss_x_valShared_all/nvalBatches
        loss_match_val_all=loss_match_val_all/nvalBatches
        loss_val_all=loss_val_all/nvalBatches

        latent_curr=None
        nplotBatches=int(np.ceil(imgs.shape[0]/batchsize))
        for i in range(nplotBatches):
            plotInput=torch.tensor(imgs[i*batchsize:min((i+1)*batchsize,imgs.shape[0])])
            plotInput=plotInput.cuda().float()
            plotInput_ID=pLabels[i*batchsize:min((i+1)*batchsize,imgs.shape[0])].cuda()
            recon,z, mu, logvar = model_train(plotInput,plotInput_ID)
            if latent_curr is None:
                latent_curr=z.detach()
            else:
                latent_curr=torch.cat((latent_curr,z.detach()),0)
    if epoch%printFreq==0:
        print('Epoch: {:04d}'.format(epoch),
              'loss_train: {:.4f}'.format(loss_all),
              'loss_kl_train: {:.4f}'.format(loss_kl_train_all),
              'loss_x_train: {:.4f}'.format(loss_x_train_all),
              'loss_xShared_train: {:.4f}'.format(loss_x_trainShared_all),
              'loss_match_train: {:.4f}'.format(loss_match_train_all),
              'loss_x_val: {:.4f}'.format(loss_x_val_all),
              'loss_xShared_val: {:.4f}'.format(loss_x_valShared_all),
              'loss_match_val: {:.4f}'.format(loss_match_val_all),
              'time: {:.4f}s'.format(time.time() - t))
    return latent_curr,loss_all,loss_kl_train_all,loss_x_train_all,loss_x_trainShared_all,loss_match_train_all,loss_val_all,loss_x_val_all,loss_x_valShared_all,loss_match_val_all

In [None]:
match_weight=1
sharedWeight=1
name_train='splitChannels_conditional_bce'
modelname_train='cnn_vae_pbmc_cvae'
logsavepath_train=os.path.join('/data/xinyi/c2p/log/',modelname_train,name_train)
modelsavepath_train=os.path.join('/data/xinyi/c2p/models/',modelname_train,name_train)
plotsavepath_train=os.path.join('/data/xinyi/c2p/plots/',modelname_train,name_train)

if not os.path.exists(os.path.join('/data/xinyi/c2p/log/',modelname_train)):
    os.mkdir(os.path.join('/data/xinyi/c2p/log/',modelname_train))
    os.mkdir(os.path.join('/data/xinyi/c2p/models/',modelname_train))
    os.mkdir(os.path.join('/data/xinyi/c2p/plots/',modelname_train))
if not os.path.exists(logsavepath_train):
    os.mkdir(logsavepath_train)
    os.mkdir(os.path.join(logsavepath_train,'dna'))
    os.mkdir(os.path.join(logsavepath_train,'protein'))
if not os.path.exists(modelsavepath_train):
    os.mkdir(modelsavepath_train)
    os.mkdir(os.path.join(modelsavepath_train,'dna'))
    os.mkdir(os.path.join(modelsavepath_train,'protein'))
if not os.path.exists(plotsavepath_train):
    os.mkdir(plotsavepath_train)
    os.mkdir(os.path.join(plotsavepath_train,'dna'))
    os.mkdir(os.path.join(plotsavepath_train,'protein'))





In [None]:

logsavepath_p_dna=os.path.join(logsavepath_train,'dna')
modelsavepath_p_dna=os.path.join(modelsavepath_train,'dna')
plotsavepath_p_dna=os.path.join(plotsavepath_train,'dna')
if not os.path.exists(logsavepath_p_dna):
    os.mkdir(logsavepath_p_dna)
if not os.path.exists(modelsavepath_p_dna):
    os.mkdir(modelsavepath_p_dna)
if not os.path.exists(plotsavepath_p_dna):
    os.mkdir(plotsavepath_p_dna)

logsavepath_p_protein=os.path.join(logsavepath_train,'protein')
modelsavepath_p_protein=os.path.join(modelsavepath_train,'protein')
plotsavepath_p_protein=os.path.join(plotsavepath_train,'protein')
if not os.path.exists(logsavepath_p_protein):
    os.mkdir(logsavepath_p_protein)
if not os.path.exists(modelsavepath_p_protein):
    os.mkdir(modelsavepath_p_protein)
if not os.path.exists(plotsavepath_p_protein):
    os.mkdir(plotsavepath_p_protein)

#train-test split
np.random.seed(3)
pctVal=0.05
pctTest=0.1

allIdx_all=np.arange(proteinNames.size)
np.random.shuffle(allIdx_all)
valIdx_all=allIdx_all[:int(pctVal*proteinNames.size)]
testIdx_all=allIdx_all[int(pctVal*proteinNames.size):(int(pctVal*proteinNames.size)+int(pctTest*proteinNames.size))]
trainIdx_all=allIdx_all[(int(pctVal*proteinNames.size)+int(pctTest*proteinNames.size)):]



for currLatentSize in sharedSizes:
    for dSpecificSize,dfilterSize in dSpecific_filter:
        latent_curr=None
#         if os.path.exists(os.path.join(plotsavepath_p_dna,'loss_seed3_match'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)+'.jpg')):
#             continue


        print(currLatentSize)
        print(dSpecificSize)
        dna_cShared=hidden5-dfilterSize
        p_cShared=dna_cShared

        loss_match=torch.nn.MSELoss()
        loss_kl=optimizer.optimizer_kl
        loss_x=torch.nn.BCELoss()

        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.enabled = True
        if modelname_train=='cnn_vae_pbmc_cvae':
            modelcnn_dna = model.model_cnnvae_conditional.CNN_VAE_split_pIDemb(kernel_size, stride, padding, 1, hidden1, hidden2, hidden3, hidden4, hidden5, dna_cShared, dna_cShared*hidden5_xy*hidden5_xy,(hidden5-dna_cShared)*hidden5_xy*hidden5_xy,currLatentSize,dSpecificSize,pnames.size,'randInit',pIDemb_size)
            modelcnn_protein = model.model_cnnvae_conditional.CNN_VAE_split_pIDemb(kernel_size, stride, padding, 1, hidden1, hidden2, hidden3, hidden4, hidden5,p_cShared,p_cShared*hidden5_xy*hidden5_xy, (hidden5-p_cShared)*hidden5_xy*hidden5_xy,currLatentSize,dSpecificSize,pnames.size,'randInit',pIDemb_size)
            modelcnn_dnaShared=model.model_cnnvae_conditional.CNN_VAE_decode_pIDemb(kernel_size, stride, padding, 1, hidden1, hidden2, hidden3, hidden4, hidden5, fc_dim1,currLatentSize,pIDemb_size)
            modelcnn_pShared=model.model_cnnvae_conditional.CNN_VAE_decode_pIDemb(kernel_size, stride, padding, 1, hidden1, hidden2, hidden3, hidden4, hidden5, fc_dim1,currLatentSize,pIDemb_size)
        modelcnn_dna.cuda()
        modelcnn_protein.cuda()
        modelcnn_dnaShared.cuda()
        modelcnn_pShared.cuda()

        optimizer_dna = torch.optim.Adam(modelcnn_dna.parameters(), lr=lr, weight_decay=weight_decay)
        optimizer_protein = torch.optim.Adam(modelcnn_protein.parameters(), lr=lr, weight_decay=weight_decay)
        optimizer_dnaShared = torch.optim.Adam(modelcnn_dnaShared.parameters(), lr=lr, weight_decay=weight_decay)
        optimizer_pShared = torch.optim.Adam(modelcnn_pShared.parameters(), lr=lr, weight_decay=weight_decay)

        train_loss_dna=[np.inf]*(epochs)
        train_loss_kl_dna=[np.inf]*(epochs)
        train_loss_x_dna=[np.inf]*(epochs)
        train_loss_xShared_dna=[np.inf]*(epochs)
        train_loss_match_dna=[np.inf]*(epochs)
        val_loss_dna=[np.inf]*(epochs)
        val_loss_x_dna=[np.inf]*(epochs)
        val_loss_xShared_dna=[np.inf]*(epochs)
        val_loss_match_dna=[np.inf]*(epochs)

        train_loss_protein=[np.inf]*(epochs)
        train_loss_kl_protein=[np.inf]*(epochs)
        train_loss_x_protein=[np.inf]*(epochs)
        train_loss_xShared_protein=[np.inf]*(epochs)
        train_loss_match_protein=[np.inf]*(epochs)
        val_loss_protein=[np.inf]*(epochs)
        val_loss_x_protein=[np.inf]*(epochs)
        val_loss_xShared_protein=[np.inf]*(epochs)
        val_loss_match_protein=[np.inf]*(epochs)

        t_ep=time.time()

        stateDict_train_dna={}
        stateDict_train_protein={}
        stateDict_train_dnaShared={}
        stateDict_train_proteinShared={}
        

            
        latent_curr=None
        epCounts=0
        for ep in range(epochs):
            latent_curr,train_loss_dna[ep],train_loss_kl_dna[ep],train_loss_x_dna[ep],train_loss_xShared_dna[ep],train_loss_match_dna[ep],val_loss_dna[ep],val_loss_x_dna[ep],val_loss_xShared_dna[ep],val_loss_match_dna[ep]=train(ep,imgsC_all,trainIdx_all,valIdx_all,latent_curr,currLatentSize,modelcnn_dna,optimizer_dna,modelcnn_dnaShared,optimizer_dnaShared,False,plabels)
            latent_curr,train_loss_protein[ep],train_loss_kl_protein[ep],train_loss_x_protein[ep],train_loss_xShared_protein[ep],train_loss_match_protein[ep],val_loss_protein[ep],val_loss_x_protein[ep],val_loss_xShared_protein[ep],val_loss_match_protein[ep]=train(ep,imgsP_all,trainIdx_all,valIdx_all,latent_curr,currLatentSize,modelcnn_protein,optimizer_protein,modelcnn_pShared,optimizer_pShared,False,plabels)

            if ep>200 and (val_loss_x_dna[ep]>=val_loss_x_dna[ep-200] or val_loss_x_protein[ep]>=val_loss_x_protein[ep-200] or val_loss_match_dna[ep]>=val_loss_match_dna[ep-200]):
                epCounts+=1
            else:
                epCounts=0

            if epCounts>100:
                break


            if ep%saveFreq == (saveFreq-1):
                stateDict_train_dna[ep]=modelcnn_dna.cpu().state_dict()
                stateDict_train_protein[ep]=modelcnn_protein.cpu().state_dict()
                stateDict_train_dnaShared[ep]=modelcnn_dnaShared.cpu().state_dict()
                stateDict_train_proteinShared[ep]=modelcnn_pShared.cpu().state_dict()


            modelcnn_dna.cuda()
            modelcnn_protein.cuda()
            modelcnn_dnaShared.cuda()
            modelcnn_pShared.cuda()
            torch.cuda.empty_cache()
        print(' total time: {:.4f}s'.format(time.time() - t_ep))

        with open(os.path.join(modelsavepath_p_dna,'stateDict_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(stateDict_train_dna, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(modelsavepath_p_protein,'stateDict_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(stateDict_train_protein, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(modelsavepath_p_dna,'stateDictShared_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(stateDict_train_dnaShared, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(modelsavepath_p_protein,'stateDictShared_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(stateDict_train_proteinShared, output, pickle.HIGHEST_PROTOCOL)

        with open(os.path.join(logsavepath_p_dna,'train_loss_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(train_loss_dna, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(logsavepath_p_dna,'train_loss_kl_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(train_loss_kl_dna, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(logsavepath_p_dna,'train_loss_x_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(train_loss_x_dna, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(logsavepath_p_dna,'train_loss_xShared_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(train_loss_xShared_dna, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(logsavepath_p_dna,'train_loss_match_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(train_loss_match_dna, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(logsavepath_p_dna,'val_loss_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(val_loss_dna, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(logsavepath_p_dna,'val_loss_x_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(val_loss_x_dna, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(logsavepath_p_dna,'val_loss_xShared_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(val_loss_xShared_dna, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(logsavepath_p_dna,'val_loss_match_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(val_loss_match_dna, output, pickle.HIGHEST_PROTOCOL)

        with open(os.path.join(logsavepath_p_protein,'train_loss_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(train_loss_protein, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(logsavepath_p_protein,'train_loss_kl_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(train_loss_kl_protein, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(logsavepath_p_protein,'train_loss_x_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(train_loss_x_protein, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(logsavepath_p_protein,'train_loss_match_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(train_loss_match_protein, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(logsavepath_p_protein,'val_loss_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(val_loss_protein, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(logsavepath_p_protein,'val_loss_x_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(val_loss_x_protein, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(logsavepath_p_protein,'val_loss_match_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(val_loss_match_protein, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(logsavepath_p_protein,'train_loss_xShared_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(train_loss_xShared_protein, output, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(logsavepath_p_protein,'val_loss_xShared_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
            pickle.dump(val_loss_xShared_protein, output, pickle.HIGHEST_PROTOCOL)

        totalepoch=np.argmin(np.array(val_loss_x_dna)+np.array(val_loss_x_protein)+np.array(val_loss_match_protein))

        print('loss_val_p: {:.4f}'.format(val_loss_x_protein[totalepoch]),
              'loss_val_c: {:.4f}'.format(val_loss_x_dna[totalepoch]),
              'loss_val_match: {:.4f}'.format(val_loss_match_dna[totalepoch]))

        plt.plot(np.arange(epochs),train_loss_match_dna)
        plt.plot(np.arange(epochs),val_loss_match_dna)
        # plt.plot(np.arange(epochs),train_loss_kl_ep)
        plt.legend(['training match loss','validation match loss'],loc='upper right')
        plt.savefig(os.path.join(plotsavepath_p_dna,'loss_seed3_match'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)+'.jpg'))
        plt.show()


        plt.plot(np.arange(epochs),train_loss_x_dna)
        plt.plot(np.arange(epochs),val_loss_x_dna)
        # plt.plot(np.arange(epochs),train_loss_kl_ep)
        plt.legend(['training x recon loss','validation x recon loss','training kl loss'],loc='upper right')
        plt.savefig(os.path.join(plotsavepath_p_dna,'loss_seed3_x'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)+'.jpg'))
        plt.show()

        plt.plot(np.arange(epochs),train_loss_match_protein)
        plt.plot(np.arange(epochs),val_loss_match_protein)
        # plt.plot(np.arange(epochs),train_loss_kl_ep)
        plt.legend(['training match loss','validation match loss'],loc='upper right')
        plt.savefig(os.path.join(plotsavepath_p_protein,'loss_seed3_match'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)+'.jpg'))
        plt.show()

        plt.plot(np.arange(epochs),train_loss_x_protein)
        plt.plot(np.arange(epochs),val_loss_x_protein)
        # plt.plot(np.arange(epochs),train_loss_kl_ep)
        plt.legend(['training x recon loss','validation x recon loss','training kl loss'],loc='upper right')
        plt.savefig(os.path.join(plotsavepath_p_protein,'loss_seed3_x'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)+'.jpg'))
        plt.show()

        plt.plot(np.arange(epochs),train_loss_xShared_dna)
        plt.plot(np.arange(epochs),val_loss_xShared_dna)
        # plt.plot(np.arange(epochs),train_loss_kl_ep)
        plt.legend(['training shared recon','validation shared recon'],loc='upper right')
        plt.savefig(os.path.join(plotsavepath_p_dna,'loss_seed3_xShared'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)+'.jpg'))
        plt.show()

        plt.plot(np.arange(epochs),train_loss_xShared_protein)
        plt.plot(np.arange(epochs),val_loss_xShared_protein)
        # plt.plot(np.arange(epochs),train_loss_kl_ep)
        plt.legend(['training shared recon','validation shared recon'],loc='upper right')
        plt.savefig(os.path.join(plotsavepath_p_protein,'loss_seed3_xShared'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)+'.jpg'))
        plt.show()










In [None]:
stateDict_train_dna[ep]=modelcnn_dna.cpu().state_dict()
stateDict_train_protein[ep]=modelcnn_protein.cpu().state_dict()
stateDict_train_dnaShared[ep]=modelcnn_dnaShared.cpu().state_dict()
stateDict_train_proteinShared[ep]=modelcnn_pShared.cpu().state_dict()
with open(os.path.join(modelsavepath_p_dna,'stateDict_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(stateDict_train_dna, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(modelsavepath_p_protein,'stateDict_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(stateDict_train_protein, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(modelsavepath_p_dna,'stateDictShared_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(stateDict_train_dnaShared, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(modelsavepath_p_protein,'stateDictShared_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(stateDict_train_proteinShared, output, pickle.HIGHEST_PROTOCOL)

with open(os.path.join(logsavepath_p_dna,'train_loss_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(train_loss_dna, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath_p_dna,'train_loss_kl_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(train_loss_kl_dna, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath_p_dna,'train_loss_x_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(train_loss_x_dna, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath_p_dna,'train_loss_xShared_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(train_loss_xShared_dna, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath_p_dna,'train_loss_match_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(train_loss_match_dna, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath_p_dna,'val_loss_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(val_loss_dna, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath_p_dna,'val_loss_x_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(val_loss_x_dna, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath_p_dna,'val_loss_xShared_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(val_loss_xShared_dna, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath_p_dna,'val_loss_match_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(val_loss_match_dna, output, pickle.HIGHEST_PROTOCOL)

with open(os.path.join(logsavepath_p_protein,'train_loss_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(train_loss_protein, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath_p_protein,'train_loss_kl_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(train_loss_kl_protein, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath_p_protein,'train_loss_x_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(train_loss_x_protein, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath_p_protein,'train_loss_match_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(train_loss_match_protein, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath_p_protein,'val_loss_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(val_loss_protein, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath_p_protein,'val_loss_x_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(val_loss_x_protein, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath_p_protein,'val_loss_match_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(val_loss_match_protein, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath_p_protein,'train_loss_xShared_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(train_loss_xShared_protein, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath_p_protein,'val_loss_xShared_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'wb') as output:
    pickle.dump(val_loss_xShared_protein, output, pickle.HIGHEST_PROTOCOL)

totalepoch=np.argmin(np.array(val_loss_x_dna)+np.array(val_loss_x_protein)+np.array(val_loss_match_protein))

print('loss_val_p: {:.4f}'.format(val_loss_x_protein[totalepoch]),
      'loss_val_c: {:.4f}'.format(val_loss_x_dna[totalepoch]),
      'loss_val_match: {:.4f}'.format(val_loss_match_dna[totalepoch]))

plt.plot(np.arange(epochs),train_loss_match_dna)
plt.plot(np.arange(epochs),val_loss_match_dna)
# plt.plot(np.arange(epochs),train_loss_kl_ep)
plt.legend(['training match loss','validation match loss'],loc='upper right')
plt.savefig(os.path.join(plotsavepath_p_dna,'loss_seed3_match'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)+'.jpg'))
plt.show()


plt.plot(np.arange(epochs),train_loss_x_dna)
plt.plot(np.arange(epochs),val_loss_x_dna)
# plt.plot(np.arange(epochs),train_loss_kl_ep)
plt.legend(['training x recon loss','validation x recon loss','training kl loss'],loc='upper right')
plt.savefig(os.path.join(plotsavepath_p_dna,'loss_seed3_x'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)+'.jpg'))
plt.show()

plt.plot(np.arange(epochs),train_loss_match_protein)
plt.plot(np.arange(epochs),val_loss_match_protein)
# plt.plot(np.arange(epochs),train_loss_kl_ep)
plt.legend(['training match loss','validation match loss'],loc='upper right')
plt.savefig(os.path.join(plotsavepath_p_protein,'loss_seed3_match'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)+'.jpg'))
plt.show()

plt.plot(np.arange(epochs),train_loss_x_protein)
plt.plot(np.arange(epochs),val_loss_x_protein)
# plt.plot(np.arange(epochs),train_loss_kl_ep)
plt.legend(['training x recon loss','validation x recon loss','training kl loss'],loc='upper right')
plt.savefig(os.path.join(plotsavepath_p_protein,'loss_seed3_x'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)+'.jpg'))
plt.show()

plt.plot(np.arange(epochs),train_loss_xShared_dna)
plt.plot(np.arange(epochs),val_loss_xShared_dna)
# plt.plot(np.arange(epochs),train_loss_kl_ep)
plt.legend(['training shared recon','validation shared recon'],loc='upper right')
plt.savefig(os.path.join(plotsavepath_p_dna,'loss_seed3_xShared'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)+'.jpg'))
plt.show()

plt.plot(np.arange(epochs),train_loss_xShared_protein)
plt.plot(np.arange(epochs),val_loss_xShared_protein)
# plt.plot(np.arange(epochs),train_loss_kl_ep)
plt.legend(['training shared recon','validation shared recon'],loc='upper right')
plt.savefig(os.path.join(plotsavepath_p_protein,'loss_seed3_xShared'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)+'.jpg'))
plt.show()



In [None]:
modelcnn_dna.cuda()
modelcnn_protein.cuda()
modelcnn_dnaShared.cuda()
modelcnn_pShared.cuda()

In [None]:
with torch.no_grad():
    modelcnn_dna.eval()
    modelcnn_dnaShared.eval()
    modelcnn_protein.eval()
    modelcnn_pShared.eval()

    loss_x_val_allC=0
    loss_x_valShared_allC=0
    loss_x_val_allP=0
    loss_x_valShared_allP=0
    nvalBatches=int(np.ceil(valIdx_all.shape[0]/batchsize))
    for i in range(nvalBatches):
        valIdx_i=valIdx_all[i*batchsize:min((i+1)*batchsize,valIdx_all.shape[0])]
        valInputC=torch.tensor(imgsC_all[valIdx_i])
        valInputC=valInputC.cuda().float()
        valInputP=torch.tensor(imgsP_all[valIdx_i])
        valInputP=valInputP.cuda().float()
        valInput_ID=plabels[valIdx_i].cuda()
        
        reconC,z, mu, logvar = modelcnn_dna(valInputC,valInput_ID)
        reconSharedC=modelcnn_dnaShared(z[:,:currLatentSize],modelcnn_dna.pIDemb(valInput_ID))

        loss_x_val=loss_x(reconC, valInputC).item()
        loss_x_valShared=loss_x(reconSharedC,valInputC).item()

        loss_x_val_allC+=loss_x_val
        loss_x_valShared_allC+=loss_x_valShared
        
        reconP,z, mu, logvar = modelcnn_protein(valInputP,valInput_ID)
        reconSharedP=modelcnn_pShared(z[:,:currLatentSize],modelcnn_protein.pIDemb(valInput_ID))

        loss_x_val=loss_x(reconP, valInputP).item()
        loss_x_valShared=loss_x(reconSharedP,valInputP).item()

        loss_x_val_allP+=loss_x_val
        loss_x_valShared_allP+=loss_x_valShared
        
        for i in range(3):
            print(i)
            print(proteinNames[valIdx_i][i])
            plt.imshow(valInputP[i][0].cpu().detach().numpy())
            plt.show()
            plt.imshow(reconSharedP[i][0].cpu().detach().numpy())
            plt.show()
            plt.imshow(reconP[i][0].cpu().detach().numpy())
            plt.show()
            
            plt.imshow(valInputC[i][0].cpu().detach().numpy())
            plt.show()
            plt.imshow(reconSharedC[i][0].cpu().detach().numpy())
            plt.show()
            plt.imshow(reconC[i][0].cpu().detach().numpy())
            plt.show()



    loss_x_val_allC=loss_x_val_allC/nvalBatches
    loss_x_valShared_allC=loss_x_valShared_allC/nvalBatches
    
    loss_x_val_allP=loss_x_val_allP/nvalBatches
    loss_x_valShared_allP=loss_x_valShared_allP/nvalBatches


In [None]:
imgsC_val=None
imgsP_val=None
imgNames_val=None
proteinNames_val=None
pID_val=None
imgsP_val_all=None
imgsP_val_all_names=None
for condition_i in conditions:
    print(condition_i)
    segDir_i=os.path.join(segDir,condition_i)
    imgDir_i=os.path.join(imgDir,condition_i)
    for stain in os.listdir(segDir_i):
        print(stain)
        segDir_i_stain=os.path.join(segDir_i,stain)
        imgDir_i_stain=os.path.join(imgDir_i,stain)
        
        segPID2name={}
        for pID_dir in os.listdir(segDir_i_stain):
            pID=pID_dir.split('_')
            segPID2name[pID[0]]=pID_dir
        imgPID2name={}
        for pID_dir in os.listdir(imgDir_i_stain):
            pID=pID_dir.split('_')
            imgPID2name[pID[0]]=pID_dir
        for pID in segPID2name.keys():
            if condition_i=='meningioma' and stain=='dapi_gh2ax_lamin_cd3' and pID=='P33': #skipping incorrect images
                continue
            if pID not in holdOutSamples:
                continue
            print(pID)
            if pID not in imgPID2name:
                print('img not found '+pID)
                continue
            imgDir_i_stain_p=os.path.join(imgDir_i_stain,imgPID2name[pID])
            segDir_i_stain_p=os.path.join(segDir_i_stain,segPID2name[pID])
            
            with open(os.path.join(imgDir_i_stain_p,savename+'_imgNames'), 'rb') as output:
                imgNames=pickle.load(output)
            with open(os.path.join(imgDir_i_stain_p,savename+'_img'), 'rb') as output:
                img=pickle.load(output)

            imgP=np.zeros((img.shape[0],1,img.shape[2],img.shape[3]))
            imgP_all=np.zeros((img.shape[0],3,img.shape[2],img.shape[3]))
            proteinNames_val_curr=np.array([])
            imgsP_val_all_names_curr=None
            stain_list=stain.split('_')
            nImgPerStain=int(img.shape[0]/(len(stain_list)-1))
            np.random.seed(3)
            allIdx_all=np.arange(img.shape[0])
            np.random.shuffle(allIdx_all)
            for s in range(1,len(stain_list)):
                s_start=(s-1)*nImgPerStain
                if s==len(stain_list)-1:
                    s_end=img.shape[0]
                else:
                    s_end=s*nImgPerStain
                imgP[s_start:s_end]=img[allIdx_all[s_start:s_end],s].reshape(s_end-s_start,1,img.shape[2],img.shape[3])
                proteinNames_val_curr=np.concatenate((proteinNames_val_curr,np.repeat(stain_list[s],s_end-s_start)))
                imgP_all[s_start:s_end,:img.shape[1]-1]=img[allIdx_all[s_start:s_end],1:].reshape(s_end-s_start,img.shape[1]-1,img.shape[2],img.shape[3])
                if imgsP_val_all_names_curr is None:
                    imgsP_val_all_names_curr=np.tile(stain_list[1:],(s_end-s_start,1))
                else:
                    imgsP_val_all_names_curr=np.concatenate((imgsP_val_all_names_curr,np.tile(stain_list[1:],(s_end-s_start,1))),axis=0)
            if imgsP_val_all_names_curr.shape[1]==2:
                imgsP_val_all_names_curr=np.hstack((imgsP_val_all_names_curr,np.repeat('None',imgsP_val_all_names_curr.shape[0]).reshape(-1,1)))
            if pID_val is None:
                pID_val=np.repeat(pID,img.shape[0])
                imgsC_val=img[allIdx_all,[0]]
                imgNames_val=imgNames[allIdx_all]
                proteinNames_val=proteinNames_val_curr
                imgsP_val=imgP
                imgsP_val_all=imgP_all
                imgsP_val_all_names=imgsP_val_all_names_curr
            else:
                pID_val=np.concatenate((pID_val,np.repeat(pID,img.shape[0])))
                imgsC_val=np.concatenate((imgsC_val,img[allIdx_all,[0]]),axis=0)
                imgNames_val=np.concatenate((imgNames_val,imgNames[allIdx_all]))
                proteinNames_val=np.concatenate((proteinNames_val,proteinNames_val_curr))
                imgsP_val=np.concatenate((imgsP_val,imgP),axis=0)
                imgsP_val_all=np.concatenate((imgsP_val_all,imgP_all),axis=0)
                imgsP_val_all_names=np.concatenate((imgsP_val_all_names,imgsP_val_all_names_curr),axis=0)
imgsC_val=imgsC_val.reshape(imgsC_val.shape[0],1,imgsC_val.shape[1],imgsC_val.shape[2])

In [None]:
nProt_val=np.unique(proteinNames_val).size
pnames_val,revIdx_val,pCounts_val=np.unique(proteinNames_val,return_inverse=True,return_counts=True)
plabels_val=torch.tensor(np.arange(pnames_val.size)[revIdx_val]).long()

In [None]:
logsavepath_p_dna=os.path.join(logsavepath_train,'dna')
modelsavepath_p_dna=os.path.join(modelsavepath_train,'dna')
plotsavepath_p_dna=os.path.join(plotsavepath_train,'dna')

logsavepath_p_protein=os.path.join(logsavepath_train,'protein')
modelsavepath_p_protein=os.path.join(modelsavepath_train,'protein')
plotsavepath_p_protein=os.path.join(plotsavepath_train,'protein')

currLatentSize=sharedSizes[0]
dSpecificSize,dfilterSize=dSpecific_filter[0]
print(currLatentSize)
print(dSpecificSize)
dna_cShared=hidden5-dfilterSize
p_cShared=dna_cShared

loss_match=torch.nn.MSELoss()
loss_kl=optimizer.optimizer_kl
loss_x=torch.nn.BCEWithLogitsLoss()
loss_mse=torch.nn.MSELoss()

np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.enabled = True
if modelname_train=='cnn_vae_pbmc_cvae':
    modelcnn_dna = model.model_cnnvae_conditional.CNN_VAE_split_pIDemb(kernel_size, stride, padding, 1, hidden1, hidden2, hidden3, hidden4, hidden5, dna_cShared, dna_cShared*hidden5_xy*hidden5_xy,(hidden5-dna_cShared)*hidden5_xy*hidden5_xy,currLatentSize,dSpecificSize,pnames.size,'randInit',pIDemb_size)
    modelcnn_protein = model.model_cnnvae_conditional.CNN_VAE_split_pIDemb(kernel_size, stride, padding, 1, hidden1, hidden2, hidden3, hidden4, hidden5,p_cShared,p_cShared*hidden5_xy*hidden5_xy, (hidden5-p_cShared)*hidden5_xy*hidden5_xy,currLatentSize,dSpecificSize,pnames.size,'randInit',pIDemb_size)
    modelcnn_dnaShared=model.model_cnnvae_conditional.CNN_VAE_decode_pIDemb(kernel_size, stride, padding, 1, hidden1, hidden2, hidden3, hidden4, hidden5, fc_dim1,currLatentSize,pIDemb_size)
    modelcnn_pShared=model.model_cnnvae_conditional.CNN_VAE_decode_pIDemb(kernel_size, stride, padding, 1, hidden1, hidden2, hidden3, hidden4, hidden5, fc_dim1,currLatentSize,pIDemb_size)


ep=184

with open(os.path.join(modelsavepath_p_dna,'stateDict_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'rb') as output:
    stateDict_train_dna=pickle.load(output)
with open(os.path.join(modelsavepath_p_protein,'stateDict_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'rb') as output:
    stateDict_train_protein=pickle.load(output)
with open(os.path.join(modelsavepath_p_dna,'stateDictShared_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'rb') as output:
    stateDict_train_dnaShared=pickle.load(output)
with open(os.path.join(modelsavepath_p_protein,'stateDictShared_'+str(currLatentSize)+'_'+str(dSpecificSize)+'_'+str(dfilterSize)), 'rb') as output:
    stateDict_train_proteinShared=pickle.load(output)

modelcnn_dna.load_state_dict(stateDict_train_dna[ep])
modelcnn_protein.load_state_dict(stateDict_train_protein[ep])
modelcnn_dnaShared.load_state_dict(stateDict_train_dnaShared[ep])
modelcnn_pShared.load_state_dict(stateDict_train_proteinShared[ep])

modelcnn_dna.cuda()
modelcnn_protein.cuda()
modelcnn_dnaShared.cuda()
modelcnn_pShared.cuda()

In [None]:
valIdx=np.arange(imgsC_val.shape[0])
batchsize=328

In [None]:
valIdx=np.arange(imgsC_val.shape[0])
batchsize=328

with torch.no_grad():
    modelcnn_dna.eval()
    modelcnn_dnaShared.eval()
    modelcnn_protein.eval()
    modelcnn_pShared.eval()

    loss_x_val_allC=0
    loss_x_valShared_allC=0
    loss_x_val_allP=0
    loss_x_valShared_allP=0
    nvalBatches=int(np.ceil(valIdx.shape[0]/batchsize))
    for i in range(nvalBatches):
        valIdx_i=valIdx[i*batchsize:min((i+1)*batchsize,valIdx.shape[0])]
        valInputC=torch.tensor(imgsC_val[valIdx_i])
        valInputC=valInputC.cuda().float()
        valInputP=torch.tensor(imgsP_val[valIdx_i])
        valInputP=valInputP.cuda().float()
        valInput_ID=plabels[valIdx_i].cuda()
        
        reconC,z, mu, logvar = modelcnn_dna(valInputC,valInput_ID)
        reconSharedC=modelcnn_dnaShared(z[:,:currLatentSize],modelcnn_dna.pIDemb(valInput_ID))

        loss_x_val=loss_x(reconC, valInputC).item()
        loss_x_valShared=loss_x(reconSharedC,valInputC).item()

        loss_x_val_allC+=loss_x_val
        loss_x_valShared_allC+=loss_x_valShared
        
        reconP,z, mu, logvar = modelcnn_protein(valInputP,valInput_ID)
        reconSharedP=modelcnn_pShared(z[:,:currLatentSize],modelcnn_protein.pIDemb(valInput_ID))

        loss_x_val=loss_x(reconP, valInputP).item()
        loss_x_valShared=loss_x(reconSharedP,valInputP).item()

        loss_x_val_allP+=loss_x_val
        loss_x_valShared_allP+=loss_x_valShared
        
        for i in range(10):
            print(i)
            print(proteinNames_val[valIdx_i][i])
            plt.imshow(percentileNorm(valInputP[i][0].cpu().detach().numpy()))
            plt.show()
            plt.imshow(percentileNorm(reconSharedP[i][0].cpu().detach().numpy()))
            plt.show()
            plt.imshow(percentileNorm(reconP[i][0].cpu().detach().numpy()))
            plt.show()
            
            plt.imshow(modeSub(valInputC[i][0].cpu().detach().numpy()))
            plt.show()
            plt.imshow(modeSub(reconSharedC[i][0].cpu().detach().numpy()))
            plt.show()
            plt.imshow(modeSub(reconC[i][0].cpu().detach().numpy()))
            plt.show()



    loss_x_val_allC=loss_x_val_allC/nvalBatches
    loss_x_valShared_allC=loss_x_valShared_allC/nvalBatches
    
    loss_x_val_allP=loss_x_val_allP/nvalBatches
    loss_x_valShared_allP=loss_x_valShared_allP/nvalBatches


In [None]:
loss_x_noReduction=torch.nn.L1Loss(reduction='sum')
with torch.no_grad():
    modelcnn_dna.eval()
    modelcnn_dnaShared.eval()
    modelcnn_protein.eval()
    modelcnn_pShared.eval()

    loss_x_val_allC=0
    loss_x_valShared_allC=0
    loss_x_val_allP=0
    loss_x_valShared_allP=0
    loss_shared={}
    loss_shared['gh2ax']=0
    loss_shared['cd16']=0
    loss_shared['cd3']=0
    loss_shared['cd4']=0
    loss_shared['cd8']=0
    loss_shared['lamin']=0
    loss_full={}
    loss_full['gh2ax']=0
    loss_full['cd16']=0
    loss_full['cd3']=0
    loss_full['cd4']=0
    loss_full['cd8']=0
    loss_full['lamin']=0
    nvalBatches=int(np.ceil(valIdx.shape[0]/batchsize))
    for i in range(nvalBatches):
        valIdx_i=valIdx[i*batchsize:min((i+1)*batchsize,valIdx.shape[0])]
        valInputC=torch.tensor(imgsC_val[valIdx_i])
        valInputC=valInputC.cuda().float()
        valInputP=torch.tensor(imgsP_val[valIdx_i])
        valInputP=valInputP.cuda().float()
        valInput_ID=plabels[valIdx_i].cuda()
        
        reconC,z, mu, logvar = modelcnn_dna(valInputC,valInput_ID)
        reconSharedC=modelcnn_dnaShared(z[:,:currLatentSize],modelcnn_dna.pIDemb(valInput_ID))

        loss_x_val=loss_mse(reconC, valInputC).item()
        loss_x_valShared=loss_mse(reconSharedC,valInputC).item()

        loss_x_val_allC+=loss_x_val
        loss_x_valShared_allC+=loss_x_valShared
        
        reconP,z, mu, logvar = modelcnn_protein(valInputP,valInput_ID)
        reconSharedP=modelcnn_pShared(z[:,:currLatentSize],modelcnn_protein.pIDemb(valInput_ID))

        loss_x_val=loss_mse(reconP, valInputP).item()
        loss_x_valShared=loss_mse(reconSharedP,valInputP).item()

        loss_x_val_allP+=loss_x_val
        loss_x_valShared_allP+=loss_x_valShared
        
        for p in np.unique(proteinNames_val[valIdx_i]):
            loss_full[p]+=loss_x_noReduction(modeSub_torch(reconP[proteinNames_val[valIdx_i]==p]), valInputP[proteinNames_val[valIdx_i]==p]).item()/(valInputP.shape[2]*valInputP.shape[3])
            loss_shared[p]+=loss_x_noReduction(modeSub_torch(reconSharedP[proteinNames_val[valIdx_i]==p]),valInputP[proteinNames_val[valIdx_i]==p]).item()/(valInputP.shape[2]*valInputP.shape[3])

    loss_x_val_allC=loss_x_val_allC/nvalBatches
    loss_x_valShared_allC=loss_x_valShared_allC/nvalBatches
    
    loss_x_val_allP=loss_x_val_allP/nvalBatches
    loss_x_valShared_allP=loss_x_valShared_allP/nvalBatches
    
    for p in np.unique(proteinNames_val):
        print(p)
        loss_shared[p]=loss_shared[p]/np.sum(proteinNames_val==p)
        loss_full[p]=loss_full[p]/np.sum(proteinNames_val==p)
        print('shared',loss_shared[p])
        print('full',loss_full[p])

In [None]:
loss_x_val_allC

In [None]:
loss_x_valShared_allC

In [None]:
def percentileNorm(img_c):
#     intensity,intCounts=np.unique(img_c,return_counts=True)
#     modeint=intensity[np.argmax(intCounts)]
    modeint=np.percentile(img_c,80)
    img_c=img_c-modeint
    img_c[img_c<0]=0
    img_c=img_c/np.max(img_c)
    print(modeint)
    return img_c

def modeSub(img_c):
    intensity,intCounts=np.unique(img_c,return_counts=True)
    modeint=intensity[np.argmax(intCounts)]
#     modeint=np.percentile(img_c,75)
    img_c=img_c-modeint
    img_c[img_c<0]=0
    img_c=img_c/np.max(img_c)
    print(modeint)
    return img_c
def modeSub_torch(img_c):
    intensity,intCounts=torch.unique(img_c,return_counts=True)
    modeint=intensity[torch.argmax(intCounts)]
#     modeint=np.percentile(img_c,75)
    img_c=img_c-modeint
    img_c[img_c<0]=0
    img_c=img_c/torch.max(img_c)
#     print(modeint)
    return img_c

In [None]:
#plotting prediction of  all proteins
plottingIdx=np.array([   0,    1,    2,    3,    4,    5,    6,    7,    8,    9,  328,
                      329,  330,  331,  332,  333,  334,  335,  336,  337,  656,  657,
                      658,  659,  660,  661,  662,  663,  664,  665,  984,  985,  986,
                      987,  988,  989,  990,  991,  992,  993, 1312, 1313, 1314, 1315,
                      1316, 1317, 1318, 1319, 1320, 1321, 1640, 1641, 1642, 1643, 1644,
                      1645, 1646, 1647, 1648, 1649, 1968, 1969, 1970, 1971, 1972, 1973,
                      1974, 1975, 1976, 1977, 2296, 2297, 2298, 2299, 2300, 2301, 2302,
                      2303, 2304, 2305, 2624, 2625, 2626, 2627, 2628, 2629, 2630, 2631,
                      2632, 2633, 2952, 2953, 2954, 2955, 2956, 2957, 2958, 2959, 2960,
                    2961, 3280, 3281, 3282, 3283, 3284, 3285, 3286, 3287, 3288, 3289])
with torch.no_grad():
    modelcnn_dna.eval()
    modelcnn_protein.eval()
    modelcnn_dna.cuda()
    modelcnn_protein.cuda()


    for i in range(plottingIdx.size):
        print(i)
        print('input img',proteinNames_val[plottingIdx][i])
        
        
        valtarget_protein=torch.tensor(imgsP_val[[plottingIdx[i]]]).cuda().float()
        valtarget_dna=torch.tensor(imgsC_val[[plottingIdx[i]]]).cuda().float()
        valInput_ID_orig=plabels_val[[plottingIdx[i]]].cuda()
        valIdx_i=torch.tensor([plottingIdx[i]])
        valAllProteins=imgsP_val_all[plottingIdx[i]]
        valAllProteins_names=imgsP_val_all_names[plottingIdx[i]]
        
        plt.imshow(modeSub(valtarget_dna[0][0].cpu().detach().numpy()))
        plt.show()
#         plt.imshow(percentileNorm(valtarget_protein[0][0].cpu().detach().numpy()))
#         plt.show()
        for pidx in range(3):
            if valAllProteins_names[pidx]=='None':
                continue
            print('True ',valAllProteins_names[pidx])
            plt.imshow(percentileNorm(valAllProteins[pidx]))
            plt.show()
        
        for pidx in range(pnames.size):
            print(pnames[pidx])
            valInput_ID=torch.tensor([pidx]).cuda()
            
            reconC,z_c, mu, logvar= modelcnn_dna(valtarget_dna,valInput_ID_orig)
            reconP,z_p, mu, logvar = modelcnn_protein(valtarget_protein,valInput_ID_orig)

            reconShared_protein=modelcnn_pShared(z_c[:,:currLatentSize],modelcnn_protein.pIDemb(valInput_ID))



            plt.imshow(percentileNorm(reconShared_protein[0][0].cpu().detach().numpy()))
            plt.show()
            plt.imshow(modeSub(reconShared_protein[0][0].cpu().detach().numpy()))
            plt.show()

            



In [None]:
imgsC_val_allProt={}
imgsP_val_allProt={}
imgsP_val_allProt_input={}
imgNames_val_allProt={}
pID_val_allProt={}
conditions_val_allProt={}
proteinNames_val_allProt={}
for condition_i in conditions:
    print(condition_i)
    segDir_i=os.path.join(segDir,condition_i)
    imgDir_i=os.path.join(imgDir,condition_i)
    for stain in os.listdir(segDir_i):
        print(stain)
        segDir_i_stain=os.path.join(segDir_i,stain)
        imgDir_i_stain=os.path.join(imgDir_i,stain)
        
        segPID2name={}
        for pID_dir in os.listdir(segDir_i_stain):
            pID=pID_dir.split('_')
            segPID2name[pID[0]]=pID_dir
        imgPID2name={}
        for pID_dir in os.listdir(imgDir_i_stain):
            pID=pID_dir.split('_')
            imgPID2name[pID[0]]=pID_dir
        for pID in segPID2name.keys():
            if condition_i=='meningioma' and stain=='dapi_gh2ax_lamin_cd3' and pID=='P33': #skipping incorrect images
                continue
            if pID not in holdOutSamples:
                continue
            print(pID)
            if pID not in imgPID2name:
                print('img not found '+pID)
                continue
            imgDir_i_stain_p=os.path.join(imgDir_i_stain,imgPID2name[pID])
            segDir_i_stain_p=os.path.join(segDir_i_stain,segPID2name[pID])
            
            with open(os.path.join(imgDir_i_stain_p,savename+'_imgNames'), 'rb') as output:
                imgNames=pickle.load(output)
            with open(os.path.join(imgDir_i_stain_p,savename+'_img'), 'rb') as output:
                img=pickle.load(output)
                
#             imgP=np.zeros((img.shape[0],1,img.shape[2],img.shape[3]))
#             proteinNames_val_curr=np.array([])
            stain_list=stain.split('_')
            nImgPerStain=int(img.shape[0]/(len(stain_list)-1))
            np.random.seed(3)
            allIdx_all=np.arange(img.shape[0])
            np.random.shuffle(allIdx_all)
            for s in range(1,len(stain_list)):
                s_start=(s-1)*nImgPerStain
                if s==len(stain_list)-1:
                    s_end=img.shape[0]
                else:
                    s_end=s*nImgPerStain
                proteinNames_val_curr=np.repeat(stain_list[s],s_end-s_start)
                imgP=img[allIdx_all[s_start:s_end],s].reshape(s_end-s_start,1,img.shape[2],img.shape[3])

                for sother in range(1,len(stain_list)):
                    if sother==s:
                        continue
                    if stain_list[sother] not in imgsP_val_allProt.keys():
                        pID_val_allProt[stain_list[sother]]=np.repeat(pID,s_end-s_start)
                        imgsC_val_allProt[stain_list[sother]]=img[allIdx_all[s_start:s_end],[0]].reshape(s_end-s_start,1,imgsC_val.shape[2],imgsC_val.shape[3])
                        imgNames_val_allProt[stain_list[sother]]=imgNames[allIdx_all[s_start:s_end]]
                        imgsP_val_allProt[stain_list[sother]]=img[allIdx_all[s_start:s_end],[sother]].reshape(s_end-s_start,1,imgsC_val.shape[2],imgsC_val.shape[3])
                        conditions_val_allProt[stain_list[sother]]=np.repeat(condition_i,s_end-s_start)
                        proteinNames_val_allProt[stain_list[sother]]=proteinNames_val_curr
                        imgsP_val_allProt_input[stain_list[sother]]=imgP
                    else:
                        pID_val_allProt[stain_list[sother]]=np.concatenate((pID_val_allProt[stain_list[sother]],np.repeat(pID,s_end-s_start)))
                        imgsC_val_allProt[stain_list[sother]]=np.concatenate((imgsC_val_allProt[stain_list[sother]],img[allIdx_all[s_start:s_end],[0]].reshape(s_end-s_start,1,imgsC_val.shape[2],imgsC_val.shape[3])),axis=0)
                        imgNames_val_allProt[stain_list[sother]]=np.concatenate((imgNames_val_allProt[stain_list[sother]],imgNames[allIdx_all[s_start:s_end]]))
                        imgsP_val_allProt[stain_list[sother]]=np.concatenate((imgsP_val_allProt[stain_list[sother]],img[allIdx_all[s_start:s_end],[sother]].reshape(s_end-s_start,1,imgsC_val.shape[2],imgsC_val.shape[3])),axis=0)
                        conditions_val_allProt[stain_list[sother]]=np.concatenate((conditions_val_allProt[stain_list[sother]],np.repeat(condition_i,s_end-s_start)))
                        proteinNames_val_allProt[stain_list[sother]]=np.concatenate((proteinNames_val_allProt[stain_list[sother]],proteinNames_val_curr))
                        imgsP_val_allProt_input[stain_list[sother]]=np.concatenate((imgsP_val_allProt_input[stain_list[sother]],imgP),axis=0)


In [None]:
#prediction loss of all proteins - l1 + thresh
loss_l1=torch.nn.L1Loss()
with torch.no_grad():
    modelcnn_dna.eval()
    modelcnn_protein.eval()

    

    for pidx in range(pnames.size):
        print(pnames[pidx])
        
        plabels_orig=torch.zeros(proteinNames_val_allProt[pnames[pidx]].size,dtype=int)
        for pidx_label in range(pnames.size):
            plabels_orig[proteinNames_val_allProt[pnames[pidx]]==pnames[pidx_label]]=pidx_label
        
        valInput_ID_single=torch.tensor([pidx]).cuda()
        
        valIdx_p=np.arange(imgsP_val_allProt[pnames[pidx]].shape[0])
        loss_x_valShared_all_protein=0
        loss_x_val_all_protein=0
        nvalBatches=int(np.ceil(valIdx_p.shape[0]/batchsize))
        for i in range(nvalBatches):
            valIdx_i=valIdx_p[i*batchsize:min((i+1)*batchsize,valIdx_p.shape[0])]
            valtarget_protein=torch.tensor(imgsP_val_allProt_input[pnames[pidx]][valIdx_i]).cuda().float()
            valtarget_protein_pred=torch.tensor(imgsP_val_allProt[pnames[pidx]][valIdx_i]).cuda().float()
            valtarget_dna=torch.tensor(imgsC_val_allProt[pnames[pidx]][valIdx_i]).cuda().float()
            valInput_ID=torch.repeat_interleave(valInput_ID_single,valIdx_i.shape[0]).cuda()
            valInput_ID_orig=plabels_orig[valIdx_i].cuda()
            valIdx_i=torch.tensor(valIdx_i)

            reconC,z_c, mu, logvar= modelcnn_dna(valtarget_dna,valInput_ID_orig)
            reconP,z_p, mu, logvar= modelcnn_protein(valtarget_protein,valInput_ID_orig)

            reconShared_protein=modelcnn_pShared(z_c[:,:currLatentSize],modelcnn_protein.pIDemb(valInput_ID))



            loss_xShared_val_protein=loss_l1(modeSub_torch(reconShared_protein),valtarget_protein_pred)

            loss_x_valShared_all_protein+=loss_xShared_val_protein.item()


        loss_x_valShared_all_protein=loss_x_valShared_all_protein/nvalBatches
        print(loss_x_valShared_all_protein)
