In [1]:
import pandas as pd
import numpy as np
import koopomics as ko
import torch
import torch.nn as nn

In [2]:
pregnancy_df = pd.read_csv('/Users/daviddornig/Documents/Master_Thesis/Bioinf/Code/philipp-trinh/KOOPOMICS/input_data/pregnancy/pregnancy_interpolated_featselected_normalized.csv')

sample_id = 'Subject ID'
time_id = 'Gestational age (GA)/weeks'

feature_list = pregnancy_df.columns[7:]

train_set_df = pregnancy_df[pregnancy_df['Cohort'] == 'Discovery'].copy()
test_set_df = pregnancy_df[pregnancy_df['Cohort'] == 'Validation (Test Set 1)'].copy()

train_dataloader = ko.dataloader_AE(train_set_df, feature_list, sample_id=sample_id, time_id=time_id, batch_size=5)
test_dataloader = ko.dataloader_AE(test_set_df, feature_list, sample_id=sample_id, time_id=time_id, batch_size=5)


In [3]:
import pandas as pd
import numpy as np

# Naive model class that predicts the average of the target
class NaiveMeanPredictor(nn.Module):
    def __init__(self):
        # Call the parent class (nn.Module) constructor
        super(NaiveMeanPredictor, self).__init__()
        self.means = nn.Parameter(torch.zeros(1), requires_grad=False)  # Register a dummy parameter

    def fit(self, df, feature_list):
        """
        Calculate the mean of each feature in feature_list from the dataframe
        and store them as a tensor.
        """
        # Calculate means for the features and convert to a torch tensor
        means_values = df[feature_list].mean().values
        self.means = nn.Parameter(torch.tensor(means_values, dtype=torch.float32), requires_grad=False)
    def kmatrix(self):
        return torch.zeros(4,4), torch.zeros(4,4)
    def forward(self, input_vector, fwd=0, bwd=0):
        """
        For the forward pass, ignore the input and return the mean values
        calculated during the fit step.
        """
        # Return the means, repeated for each input sample
        batch_size = input_vector.size(0)
        return [self.means.unsqueeze(0).expand(batch_size, -1)], [self.means.unsqueeze(0).expand(batch_size, -1)]

model = NaiveMeanPredictor()
model.fit(pregnancy_df, feature_list)


In [4]:
import torch
input_tensor = torch.tensor(pregnancy_df[feature_list].values, dtype=torch.float32)
criterion = nn.MSELoss()


with torch.no_grad(): 
    latent_representations, identity_outputs = model(input_tensor)
criterion(identity_outputs[0], input_tensor)

tensor(0.2560)

In [5]:
ko.test(model, train_dataloader, max_Kstep=10, disable_tempcons=True)

{'test_fwd_loss': tensor(0.1427),
 'test_bwd_loss': tensor(1.1729),
 'test_temp_cons_loss': 0,
 'test_inv_cons_loss': 4.0,
 'dict_fwd_step_loss': {1: 0.20049383165314794,
  2: 0.18151442290462078,
  3: 0.1600100714713335,
  4: 0.1459430495792247,
  5: 0.13573912099162314,
  6: 0.13325542316554076,
  7: 0.12733971741684574,
  8: 0.12036833270116055,
  9: 0.11358103643547982,
  10: 0.10883821735277455},
 'dict_fwd_step_tempcons_loss': {},
 'dict_bwd_step_loss': {1: 0.4454635258446983,
  2: 0.6903189714165762,
  3: 0.8678905580866527,
  4: 1.0272069751146309,
  5: 1.1593575851417126,
  6: 1.2968178406318134,
  7: 1.3993743391350864,
  8: 1.5136198594452852,
  9: 1.5976044727846028,
  10: 1.7310986356373796},
 'dict_bwd_step_tempcons_loss': {}}

In [9]:
for batch in train_dataloader:
    # Extract batch data
    inputs = batch['input_data']
    current_time_ids = batch['current_time_ids']
    current_sample_ids = batch['current_sample_ids']
    current_row_ids = batch['current_row_ids']
    timeseries_ids = batch['timeseries_ids']
    timeseries_tensor = batch['timeseries_tensor']

    # Get validation targets (forward or backward)
    target_tensor, comparable = ko.get_validation_targets(inputs, timeseries_tensor, 
                                                         current_time_ids, timeseries_ids, 
                                                         fwd=0, bwd=1)
    
    # Print everything for inspection
    print("Inputs:")
    print(inputs)
    
    print("\nCurrent Time IDs:")
    print(current_time_id)
    
    print("\nCurrent Sample IDs:")
    print(current_sample_ids)
    
    print("\nCurrent Row IDs:")
    print(current_row_ids)
    
    print("\nTimeseries IDs:")
    print(timeseries_ids)
    
    print("\nTimeseries Tensor:")
    print(timeseries_tensor)
    
    print("\nTarget Tensor:")
    print(target_tensor)
    
    print("\nComparable Boolean Mask:")
    print(comparable)
    
    # Break after the first batch if you don't want to print all batches
    break

    

RuntimeError: a Tensor with 30 elements cannot be converted to Scalar

In [6]:
ko.test(model, test_dataloader, max_Kstep=1)

{'test_fwd_loss': tensor(0.2455),
 'test_bwd_loss': tensor(0.7552),
 'test_temp_cons_loss': 0,
 'test_inv_cons_loss': 4.0,
 'dict_fwd_step_loss': {1: 0.24553624524715098},
 'dict_fwd_step_tempcons_loss': {},
 'dict_bwd_step_loss': {1: 0.7551885096633688},
 'dict_bwd_step_tempcons_loss': {}}

In [7]:
for batch in test_dataloader:
    print(batch['input_data'])
    print(batch['current_row_ids'])
    target, boolean = ko.get_validation_targets(batch['input_data'], batch['timeseries_tensor'], 
                        batch['current_time_ids'], batch['timeseries_ids'], 
                        fwd=1, bwd=0)
    print('TARGET')
    print(list(target))

tensor([[-2.1231e+00,  7.6356e-02, -4.4696e+00, -2.4899e+00, -6.4788e-01,
         -7.1412e-01, -4.3584e+00, -1.6773e+00,  1.7712e-01, -7.0211e-01,
          6.4927e-01, -5.6470e+00, -1.5826e-01, -1.4763e+00, -2.9320e+00,
         -2.2691e+00, -2.3313e+00, -2.6855e+00, -2.2469e+00,  2.2987e+00,
         -3.4548e+00,  4.1453e+00, -3.1028e+00,  2.1641e+00, -2.2151e-01,
         -3.5503e+00,  1.3813e+00, -1.3262e+00,  9.0943e-01, -1.1455e+00,
          5.3736e+00, -2.0274e+00, -5.4137e-01, -9.0868e-01, -3.0228e+00,
         -2.3110e-01,  1.3465e+00,  4.4835e+00,  2.7333e+00, -1.2855e+00,
         -4.6267e+00,  1.0634e+00,  5.8567e-01,  1.7588e+00,  3.5660e+00,
         -1.2610e+00, -1.4219e+00,  4.0306e-01, -5.2373e+00,  3.2028e-01],
        [ 2.4413e-01,  2.6456e-01, -5.9159e-01,  2.4878e-01, -1.7833e+00,
         -1.5389e-01, -2.8949e+00, -1.1749e+00,  1.0277e+00, -8.5531e-01,
          1.1297e+00, -1.3790e+00, -5.4755e-01, -1.1825e-01, -1.3361e+00,
         -2.2093e+00, -2.5100e+00, -3