In [None]:
import nibabel as nib
import os
import numpy as np
import matplotlib.pyplot as plt
import nibabel.processing as nibp
from scipy import signal
from itertools import combinations_with_replacement
from numpy import savetxt
import nibabel as nib
import math
from numpy import random
import sklearn.preprocessing  
import matplotlib.pyplot as plt 
from sklearn.metrics import mean_squared_error
import torch
from torch import nn
from torch.nn import functional as F
from torch import Tensor
from typing import List, Callable, Union, Any, TypeVar, Tuple
import torch.optim as optim
from sklearn.decomposition import PCA
from sklearn import linear_model

In [None]:
import pickle

In [None]:
# Setting a random seed for reproducibility
import os
import random
import numpy as np
import torch

def set_seed(seed: int):
    # 1) Python/hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    # 2) PyTorch CPU & GPU
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # 3) CuDNN determinism
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Choose any integer seed
SEED = 42
set_seed(SEED)

In [None]:
def correlation(x,y):
  x_mean = np.repeat(x.mean(),x.shape,axis=0)
  y_mean = np.repeat(y.mean(),y.shape,axis=0)
  cov = (x-x_mean)*(y-y_mean)
  r = cov.sum()/(x.std()*y.std()*x.shape[0])
  return r

def remove_std0(arr):
    std0 = np.argwhere(np.std(arr, axis=1) == 0.0)
    arr_o = np.delete(arr,std0 ,axis=0) 
    return arr_o

def compute_in(x):
  return (x-3)/2+1
def compute_in_size(x):
  for i in range(4):
    x = compute_in(x)
  return x
def compute_out_size(x):
  return ((((x*2+1)*2+1)*2+1)*2+1)
def compute_padding(x):
  rounding = np.ceil(compute_in_size(x))-compute_in_size(x)
  y = ((((rounding*2)*2)*2)*2)
  pad = bin(int(y)).replace('0b', '')
  if len(pad) < 4:
      for i in range(4-len(pad)):
          pad = '0' + pad
  final_size = compute_in_size(x+y)
  pad_out = bin(int(compute_out_size(final_size)-x)).replace('0b','')
  if len(pad_out) < 4:
      for i in range(4-len(pad_out)):
          pad_out = '0' + pad_out
  return pad,final_size, pad_out

class Scaler():
    def __init__(self,inputs):
        self.data = inputs
        self.mean = np.mean(inputs,axis=1)
        self.std = np.std(inputs, axis=1)
        self.vox, self.time = inputs.shape
    def transform(self,inputs):
        self.mean = np.reshape(self.mean,(self.vox,1))
        self.m_large = np.repeat(self.mean,self.time,axis=1)
        self.std = np.reshape(self.std,(self.vox,1))
        self.s_large = np.repeat(self.std,self.time,axis=1)
        return np.divide(inputs-self.m_large,self.s_large)
    def inverse_transform(self,outputs):
        return np.multiply(outputs,self.s_large)+self.m_large

class TrainDataset(torch.utils.data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, X, Y):
    self.obs = X
    self.noi = Y

  def __len__(self):
    return min(self.obs.shape[0],self.noi.shape[0])

  def __getitem__(self, index):
    observation = self.obs[index]
    noise = self.noi[index]
    s = 2*np.random.beta(4,4,1)
    noise_aug = s*noise
    return observation, noise_aug

class DenoiseDataset(torch.utils.data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, X):
    self.obs = X
    
  def __len__(self):
    return self.obs.shape[0]

  def __getitem__(self, index):
    observation = self.obs[index]
    return observation


class cVAE(nn.Module):

    def __init__(self,in_channels: int,in_dim: int, latent_dim: int,hidden_dims: List = None) -> None:
        super(cVAE, self).__init__()

        self.latent_dim = latent_dim
        self.in_channels = in_channels
        self.in_dim = in_dim

        modules_z = []
        if hidden_dims is None:
            hidden_dims = [64, 128, 256, 256]
        
        self.pad, self.final_size, self.pad_out = compute_padding(self.in_dim)

        # Build Encoder
        for i in range(len(hidden_dims)):
            h_dim = hidden_dims[i]
            modules_z.append(
                nn.Sequential(
                    nn.Conv1d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = int(self.pad[-i-1])),
                    nn.BatchNorm1d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder_z = nn.Sequential(*modules_z)
        self.fc_mu_z = nn.Linear(hidden_dims[-1]*int(self.final_size), latent_dim)
        self.fc_var_z = nn.Linear(hidden_dims[-1]*int(self.final_size), latent_dim)

        modules_s = []
        in_channels = self.in_channels
        for i in range(len(hidden_dims)):
            h_dim = hidden_dims[i]
            modules_s.append(
                nn.Sequential(
                    nn.Conv1d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = int(self.pad[-i-1])),
                    nn.BatchNorm1d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder_s = nn.Sequential(*modules_s)
        self.fc_mu_s = nn.Linear(hidden_dims[-1]*int(self.final_size), latent_dim)
        self.fc_var_s = nn.Linear(hidden_dims[-1]*int(self.final_size), latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(2*latent_dim, hidden_dims[-1] * int(self.final_size))

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose1d(hidden_dims[i],
                                    hidden_dims[i + 1],
                                    kernel_size=3,
                                    stride = 2,
                                    padding=int(self.pad_out[-4+i]),
                                    output_padding=int(self.pad_out[-4+i])),
                    nn.BatchNorm1d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )


        self.decoder = nn.Sequential(*modules)

        
        
        
        
        
        
        self.final_layer = nn.Sequential(
                            nn.ConvTranspose1d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=int(self.pad_out[-1]),
                                               output_padding=int(self.pad_out[-1])),
                            nn.BatchNorm1d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv1d(hidden_dims[-1], out_channels= 1,
                                      kernel_size= 3, padding= 1))
           #out_channels

    def encode_z(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder_z(input)
  
        result = torch.flatten(result, start_dim=1)


        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu_z(result)
        log_var = self.fc_var_z(result)

        return [mu, log_var]

    def encode_s(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder_s(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu_s(result)
        log_var = self.fc_var_s(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1,256,int(self.final_size))
        result = self.decoder(result)
        result = self.final_layer(result)

        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward_tg(self, input: Tensor) -> List[Tensor]:
        tg_mu_z, tg_log_var_z = self.encode_z(input)
        tg_mu_s, tg_log_var_s = self.encode_s(input)
        tg_z = self.reparameterize(tg_mu_z, tg_log_var_z)
        tg_s = self.reparameterize(tg_mu_s, tg_log_var_s)
        output = self.decode(torch.cat((tg_z, tg_s),1))
        return  [output, input, tg_mu_z, tg_log_var_z, tg_mu_s, tg_log_var_s,tg_z,tg_s]

    def forward_bg(self, input: Tensor) -> List[Tensor]:
        bg_mu_s, bg_log_var_s = self.encode_s(input)
        bg_s = self.reparameterize(bg_mu_s, bg_log_var_s)
        zeros = torch.zeros_like(bg_s)
        output = self.decode(torch.cat((zeros, bg_s),1))
        return  [output, input, bg_mu_s, bg_log_var_s]

    def forward_fg(self, input: Tensor) -> List[Tensor]:
        fg_mu_z, fg_log_var_z = self.encode_z(input)
        tg_z = self.reparameterize(fg_mu_z, fg_log_var_z)
        zeros = torch.zeros_like(tg_z)
        output = self.decode(torch.cat((tg_z, zeros),1))
        return  [output, input, fg_mu_z, fg_log_var_z]

    def loss_function(self,
                      *args,
                      ) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        beta = 0.00001
        gamma = 1

        recons_tg = args[0]
        input_tg = args[1]
        tg_mu_z = args[2]
        tg_log_var_z = args[3]
        tg_mu_s = args[4]
        tg_log_var_s = args[5]
        tg_z = args[6]
        tg_s = args[7]
        recons_bg = args[8]
        input_bg = args[9]
        bg_mu_s = args[10]
        bg_log_var_s = args[11]

        #kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss = F.mse_loss(recons_tg, input_tg)
        recons_loss += F.mse_loss(recons_bg, input_bg)
        # recons_loss *= input_shape[0]*input_shape[1]

        # z1 = tg_z[:int(batch_size/2),:]
        # z2 = tg_z[int(batch_size/2):,:]
        # s1 = tg_s[:int(batch_size/2),:]
        # s2 = tg_s[int(batch_size/2):,:]
        # q_bar = torch.cat(torch.cat((s1,z2),1),torch.cat((s2,z1),1),0)
        # q = torch.cat(torch.cat((s1,z1),1),torch.cat((s2,z1),1),0)
        # q_bar_score = nn.Sigmoid(q_bar)
        # q_score = nn.Sigmoid(q)
        # tc_loss = torch.log(q_score/(1-q_score))
        # discriminator_loss = - torch.log(q_score) - torch.log(1-q_bar_score)

        kld_loss = 1 + tg_log_var_z - tg_mu_z ** 2 - tg_log_var_z.exp()
        kld_loss += 1 + tg_log_var_s - tg_mu_s ** 2 - tg_log_var_s.exp()
        kld_loss += 1 + bg_log_var_s - bg_mu_s ** 2 - bg_log_var_s.exp()
        kld_loss = torch.mean(-0.5 * torch.sum(kld_loss, dim = 1), dim = 0)

        loss = torch.mean(recons_loss + beta*kld_loss)
        return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD': kld_loss.detach()}

    def sample(self,
               num_samples:int,
               current_device: int) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward_fg(x)[0]

In [None]:
from nilearn.glm.first_level import make_first_level_design_matrix
import pandas as pd

from tqdm import tqdm

def get_regs(events_fn):
    

    events = pd.read_csv(events_fn,delimiter='\t')

    t_r = 2.0 
    n_scans = 156
    frame_times = (np.arange(n_scans) * t_r)

    X1 = make_first_level_design_matrix(frame_times,events,drift_model="polynomial",drift_order=3,hrf_model="SPM") #

    face_reg = X1[['face']].values.sum(axis=1)
    place_reg = X1[['house','scene']].values.sum(axis=1)
    
    return face_reg,place_reg

def show_bashboard(single_fig=True):
    import sys
    from IPython import display
    nrows=5
    ncols=9
    sp=0
    
    if single_fig==True:
        plt.close()
        sys.stdout.flush()
        display.clear_output(wait=True);
        display.display(plt.gcf());
        plt.figure(figsize=(5*ncols,5*nrows))
    
    sp+=1;plt.subplot(nrows,ncols,sp);plt.plot(running_loss_L);   plt.title('total loss: {:.2f}'.format(running_loss_L[-1]))
    sp+=1;plt.subplot(nrows,ncols,sp);plt.plot(running_recons_L);   plt.title('Recon Loss: {:.2f}'.format(running_recons_L[-1]))
    sp+=1;plt.subplot(nrows,ncols,sp);plt.plot(track['ffa_varexp']);   plt.title('FFA varexp: {:.2f}'.format(track['ffa_varexp'][-1]))
    sp+=1;plt.subplot(nrows,ncols,sp);plt.plot(track['c_io']);   plt.title('ffa_io: {:.2f}'.format(track['c_io'][-1]))

    
    
    sp+=1;plt.subplot(nrows,ncols,sp);
    plt.plot(batch_in[0,0,:])
    plt.plot(batch_out[0,0,:])
    plt.plot(model.forward_bg(inputs_gm)[0].detach().cpu().numpy()[0,0,:],'r-')
    plt.plot(model.forward_fg(inputs_gm)[0].detach().cpu().numpy()[0,0,:],'g-')
    plt.title('batch timecourse (single voxel)')
    
    sp+=1;plt.subplot(nrows,ncols,sp);
    plt.plot(ffa_list_coords.mean(axis=0)[0,:])
    plt.plot(recon.mean(axis=0))
    plt.title('FFA AVG')
    
    sp+=1;plt.subplot(nrows,ncols,sp);
    plt.plot(ffa_list_coords.mean(axis=0)[0,:])
    plt.plot(signal.mean(axis=0),'g-')
    plt.plot(face_reg)
    plt.title('FFA SIGNAL')

    sp+=1;plt.subplot(nrows,ncols,sp);
    plt.plot(ffa_list_coords.mean(axis=0)[0,:])
    plt.plot(noise.mean(axis=0),'r-')
    plt.plot(face_reg)
    plt.title('FFA NOISE')

    sp+=1;plt.subplot(nrows,ncols,sp);
    plt.plot(inputs_cf.detach().cpu().numpy()[0,0,:])
    plt.plot(model.forward_bg(inputs_cf)[0].detach().cpu().numpy()[0,0,:])
    plt.plot(model.forward_fg(inputs_cf)[0].detach().cpu().numpy()[0,0,:])
    plt.title('CF batch voxel')
    
    
    sp+=1;plt.subplot(nrows,ncols,sp);
    plt.plot(track['r_ffa_reg'],'k-')
    plt.plot(track['r_TG_reg'])
    plt.title('R TG-REG {}'.format(track['r_TG_reg'][-1]))
    
    sp+=1;plt.subplot(nrows,ncols,sp);
    plt.plot(track['r_ffa_reg'],'k-')
    if single_fig==True:
        plt.plot(track['r_FG_reg'],'g-')
    else:
        plt.plot(track['r_FG_reg'])
        
    plt.title('R FG-REG {}'.format(track['r_FG_reg'][-1]))
    
    
    sp+=1;plt.subplot(nrows,ncols,sp);
    plt.plot(track['r_ffa_reg'],'k-')
    plt.plot(track['r_BG_reg'],'r-')
    plt.title('R BG-REG {}'.format(track['r_BG_reg'][-1]))

    if single_fig==True:
        plt.suptitle(f'{sub}-R{r}-rep-{rep} E:{epoch}',y=.91,fontsize=20)
        plt.show()

In [None]:
def safe_mkdir(path):
    if not os.path.exists(path):
        os.mkdir(path)

In [None]:
def get_batches(data, batch_size):
    n = data.shape[0]
    for start in range(0, n, batch_size):
        yield data[start: start + batch_size]

In [None]:
def load_pickle(fn):
    if os.path.exists(fn):
        with open(fn, 'rb') as file:
            loaded_dict = pickle.load(file)
    return loaded_dict

In [None]:
def init_track():
    # Initialize tracking variables
    track = {}
    track['ffa_varexp'] = []
    track['batch_varexp'] = []
    track['r_ffa_reg'] = []
    track['r_TG_reg'] = []
    track['r_FG_reg'] = []
    track['r_BG_reg'] = []
    track['c_io'] = []
    return track

In [None]:
# Parameters for looping with Papermill.

s = 0 # index for subject (out of 14)
r = 1 # index for run (out of 4)
analysis_name = 'test' # This is be appended to the saved output files 

In [None]:
# Directory where to save the outputs
ofdir_root = '../Data/StudyForrest/ensembles_last_CVAE'
ofdir = os.path.join(ofdir_root,analysis_name)
safe_mkdir(ofdir)
print(ofdir)

In [None]:
# Data directory, point to location of the fmriprep folder

indir = '../Data/StudyForrest/fmriprep/'
subs_orig = [s for s in os.listdir(indir) if all((s.startswith('sub'),not s.endswith('.html')))]
subs_orig.sort() # Subject names

n_orig = len(subs_orig)
epi_fn = os.path.join(indir,'{sub}/ses-localizer/func/{sub}_ses-localizer_task-objectcategories_run-{r}_bold_space-MNI152NLin2009cAsym_preproc.nii.gz')
events_fn_temp = '../Data/StudyForrest/events/{sub}_ses-localizer_task-objectcategories_run-{r}_events.tsv'
cf_fn = os.path.join(indir,'mask_roni.nii')
gm_fn = os.path.join(indir,'mask_roi.nii')

sub = subs_orig[s]
sub

In [None]:
face_reg,place_reg = get_regs(events_fn_temp.format(sub=sub,r=r))

In [None]:
func = nib.load(epi_fn.format(sub=sub,r=r))
roni_idx = nib.load(cf_fn)
roi_idx = nib.load(gm_fn)

In [None]:
func.shape

In [None]:
gm_mask_c = roi_idx.get_fdata()==1
cf_mask_c = roni_idx.get_fdata()==1

In [None]:
# Extract functional data from masks
func_values = func.get_fdata()#[:,:,:,5:]
func_reshaped = np.reshape(func_values,[func.shape[0]*func.shape[1]*func.shape[2],func.shape[3]])
gm_reshaped = np.reshape(gm_mask_c,-1)
cf_reshaped = np.reshape(cf_mask_c,-1)
func_gm = func_reshaped[gm_reshaped,:] # these are the functional data in gray matter
func_cf = func_reshaped[cf_reshaped,:] # these are the functional data in the regions of no interest

In [None]:
#Normalization of Data
func_gm = remove_std0(func_gm)
func_cf = remove_std0(func_cf)

obs_scale = Scaler(func_gm)
obs_list = obs_scale.transform(func_gm)
noi_scale = Scaler(func_cf)
noi_list = noi_scale.transform(func_cf)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
ffa_im = nib.load(f'../Data/StudyForrest/ROIs/rFFA_final_mask_{sub}_bin.nii.gz')
ffa_idx = ffa_im.get_fdata()
func_ffa = func_values[ffa_idx==1]
func_ffa = remove_std0(func_ffa)
ffa_scale = Scaler(func_ffa)
ffa_list = ffa_scale.transform(func_ffa)
ffa_list_coords = ffa_list[:,np.newaxis,:]
ffa_list_coords_torch = torch.from_numpy(ffa_list_coords).float().to(device)
ffa_list_coords_torch.shape

In [None]:
conf_pcs = PCA(n_components=5).fit_transform(noi_list.transpose())
lin_reg = linear_model.LinearRegression()
lin_reg.fit(conf_pcs,ffa_list_coords[:,0,:].transpose());
ffa_compcorr = ffa_list_coords[:,0,:].transpose()-lin_reg.predict(conf_pcs)
ffa_compcorr = ffa_compcorr.transpose()

In [None]:
print(obs_list.shape)
print(noi_list.shape)
if obs_list.shape[0]>noi_list.shape[0]:
    n_pad = obs_list.shape[0]-noi_list.shape[0]
    pad_idx = np.random.randint(low=0,high=noi_list.shape[0],size=n_pad)
    noi_list = np.concatenate([noi_list,np.array([noi_list[i,:] for i in pad_idx])])
    print(obs_list.shape)
    print(noi_list.shape)

In [None]:
nreps = 20
batch_size = 512
epoch_num = 100

In [None]:
for rep in np.arange(nreps):
    track = init_track()
    # DataLoader
    train_inputs = TrainDataset(obs_list,noi_list)

    # dataloading 
    train_in = torch.utils.data.DataLoader(train_inputs, batch_size=batch_size,
                                                 shuffle=True, num_workers=1)

    # cVAE model
    Tensor = TypeVar('torch.tensor')
    model = cVAE(1,func_cf.shape[1],8)

    # optimizer
    optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

    # Training the Model
    model.to(device)
    
    running_loss_L = []
    running_recons_L = []
    running_KLD_L = []
    
    for epoch in tqdm(np.arange(epoch_num)):

        running_loss = 0.0
        running_reconstruction_loss = 0.0
        running_KLD = 0.0


        dataloader_iter_in = iter(train_in)
        model.train();
        for i in range(len(train_in)):
            inputs_gm,inputs_cf = next(dataloader_iter_in)

            inputs_gm = inputs_gm.unsqueeze(1).float().to(device)
            inputs_cf = inputs_cf.unsqueeze(1).float().to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # encoder + decoder
            [outputs_gm, inputs_gm, tg_mu_z, tg_log_var_z, tg_mu_s, tg_log_var_s,tg_z,tg_x] = model.forward_tg(inputs_gm)
            [outputs_cf, inputs_cf, bg_mu_s, bg_log_var_s] = model.forward_bg(inputs_cf)
            #outputs = torch.concat((outputs_gm,outputs_cf),1)
            loss = model.loss_function(outputs_gm, inputs_gm, tg_mu_z, tg_log_var_z, tg_mu_s, tg_log_var_s,tg_z,tg_x, outputs_cf, inputs_cf, bg_mu_s, bg_log_var_s)
            # backward + optimize
            loss['loss'].backward()
            optimizer.step()
            running_loss += loss['loss']
            running_reconstruction_loss += loss['Reconstruction_Loss']
            running_KLD += loss['KLD']

        epoch_running_loss = running_loss / (len(train_in)*2)
        epoch_running_reconstruction_loss = running_reconstruction_loss / (len(train_in)*2)
        epoch_running_KLD = running_KLD / (len(train_in)*2)
        running_loss_L.append(epoch_running_loss.cpu().detach().numpy())
        running_recons_L.append(epoch_running_reconstruction_loss.cpu().detach().numpy())
        running_KLD_L.append(epoch_running_KLD.cpu().detach().numpy())

        if np.mod(epoch,1)==0:
            model.eval();

            recon = model.forward_tg(ffa_list_coords_torch)[0]
            recon = recon.detach().cpu().numpy()[:,0,:]

            signal = model.forward_fg(ffa_list_coords_torch)[0]
            signal = signal.detach().cpu().numpy()[:,0,:]

            noise = model.forward_bg(ffa_list_coords_torch)[0]
            noise = noise.detach().cpu().numpy()[:,0,:]

            SST = ((ffa_list_coords[:,0,:]-ffa_list_coords[:,0,:].mean())**2).sum()
            SSM = ((ffa_list_coords[:,0,:]-recon)**2).sum()
            varexp = 1-SSM/SST
            varexp = varexp.round(2)

            batch_signal = model.forward_fg(inputs_gm)[0].detach().cpu().numpy()
            batch_noise = model.forward_bg(inputs_gm)[0].detach().cpu().numpy()

            batch_in = inputs_gm.detach().cpu().numpy()
            batch_out = outputs_gm.detach().cpu().numpy()

            batch_SST = ((batch_in[:,0,:]-batch_in[:,0,:].mean(axis=0))**2).sum()
            batch_SSM = ((batch_in[:,0,:]-batch_out[:,0,:])**2).sum()
            batch_varexp = (1-batch_SSM/batch_SST).round(2)

            c_t = np.array([np.corrcoef(ffa_list_coords[:,0,:].mean(axis=0),face_reg)[0,1],
            np.corrcoef(recon.mean(axis=0),face_reg)[0,1],
            np.corrcoef(signal.mean(axis=0),face_reg)[0,1],
            np.corrcoef(noise.mean(axis=0),face_reg)[0,1],]).round(2)

            c_io = np.corrcoef(ffa_list_coords[:,0,:].mean(axis=0),recon.mean(axis=0))[0,1]

            l = loss['loss'].detach().cpu().numpy()
            kld_loss = loss['KLD'].detach().cpu().numpy()
            recons_loss_roi = loss['Reconstruction_Loss'].detach().cpu().numpy()

            track['ffa_varexp'].append(varexp)
            track['batch_varexp'].append(batch_varexp)
            track['r_ffa_reg'].append(c_t[0])
            track['r_TG_reg'].append(c_t[1])
            track['r_FG_reg'].append(c_t[2])
            track['r_BG_reg'].append(c_t[3])
            track['c_io'].append(c_io)


    outputs = {
            'recon' : recon,
            'signal' : signal,
            'noise' : noise,
            'ffa' : ffa_list_coords[:,0,:],
            'ffa_compcorr' : ffa_compcorr,
            'face_reg' : face_reg,
            'place_reg' : place_reg,}

    outputs_ofn = os.path.join(ofdir,f'outputs_S{s}_R{r}_rep_{rep}.pickle')
    track_ofn = os.path.join(ofdir,f'track_S{s}_R{r}_rep_{rep}.pickle')
    model_ofn = os.path.join(ofdir,f'model_S{s}_R{r}_rep_{rep}.pickle')

    with open(track_ofn, 'wb') as handle:
        pickle.dump(track, handle, protocol=pickle.HIGHEST_PROTOCOL)

    with open(outputs_ofn, 'wb') as handle:
        pickle.dump(outputs, handle, protocol=pickle.HIGHEST_PROTOCOL)

    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss, 
    }, model_ofn)


    brain_signals=[]
    gm_batches = list(get_batches(obs_list, batch_size))
    for i in range(len(gm_batches)):
        gm_batch = torch.from_numpy(gm_batches[i][:,np.newaxis,:]).float().to(device)
        brain_signals.append(model.forward_fg(gm_batch)[0].detach().cpu().numpy()[:,0,:])
    brain_signals = np.vstack(brain_signals)
    brain_signals = obs_scale.inverse_transform(brain_signals)

    std0 = func_values.std(axis=-1)==0.0
    brain_signals_arr = np.zeros(func_values.shape)
    brain_signals_arr[gm_mask_c*~std0]=brain_signals
    new_img = nib.Nifti1Image(brain_signals_arr, affine=func.affine, header=func.header)
    signal_ofn = os.path.join(ofdir,f'signal_S{s}_R{r}_rep_{rep}.nii.gz')
    nib.save(new_img,signal_ofn)
    
    show_bashboard()

In [None]:
# Loads ensembled data
output_files = [os.path.join(ofdir,val) for val in os.listdir(ofdir) if val.startswith(f'outputs_S{s}_R{r}_') ]
track_files = [os.path.join(ofdir,val) for val in os.listdir(ofdir) if val.startswith(f'track_S{s}_R{r}_') ]

output_files.sort()
track_files.sort()
print(f'{len(output_files)}/{len(track_files)}')

In [None]:
outputs = [load_pickle(output_file) for output_file in output_files]
tracks = [load_pickle(track_file) for track_file in track_files]

In [None]:
signals = np.array([output['signal'].mean(axis=0) for output in outputs])

In [None]:
signals = np.array([output['signal'] for output in outputs])
signals = np.median(signals,axis=1)

In [None]:
signal_avg = np.average(signals,axis=0)
signal_med = np.median(signals,axis=0)

In [None]:
# Plot performance for this subject/run
r_ffa = np.corrcoef(ffa_list_coords[:,0,:].mean(axis=0),face_reg)[0,1]
r_compcor = np.corrcoef(ffa_compcorr.mean(axis=0),face_reg)[0,1]
r_fg_med = np.corrcoef(signal_avg,face_reg)[0,1]

In [None]:
print(r_ffa.round(2))
print(r_compcor.round(2))
print(r_fg_med.round(2))

In [None]:
xs = [0,1,2]
ys = np.array([r_ffa,r_compcor,r_fg_med])
ys = ys-ys[0]
plt.bar(xs,ys)
plt.xticks(xs,labels=['preproc','compcor','fg med']);
plt.title('Improvement')

In [None]:
plt.figure(figsize=(25,5))

plt.subplot(1,3,1)
plt.plot(ffa_list_coords[:,0,:].mean(axis=0),'k-')
plt.plot(signal_med,'g-')
plt.plot(face_reg,'y-')

plt.subplot(1,3,2)
plt.plot(signals.transpose(),alpha=.3)
plt.plot(face_reg)

In [None]:
import warnings
warnings.filterwarnings("ignore")
plt.figure(figsize=(5*9,5*5))
for track in tracks:
    show_bashboard(single_fig=False)