In [5]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from src.dataloaders_and_sets.simple_dataset import SimpleDataset
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.optim as optim

In [6]:
data_with_targets = pd.read_csv('data/data.csv', index_col=0)

In [10]:
data_test = data_with_targets.drop(['primary_disease', 'gender', 'age', 'dataset'], axis=1)
data_test.shape

(11780, 17137)

In [11]:
data_test.isna().sum().sum()

1392628

In [15]:
data_test.shape[0] * data_test.shape[1]

201873860

In [3]:
data = data_with_targets.fillna(0.0)

In [4]:
# remove the possible y labels:
data_columns = [col for col in data.columns if col not in ['primary_disease', 'gender', 'age', 'dataset', 'normalized_age']]
y = "gender"
# TODO: makes no sense when using a autoencoder...
X_train, y_train, X_test, y_test = train_test_split(data[data_columns], data["age"], train_size=0.8, random_state=42)


In [5]:
X_train.shape

(9424, 17137)

In [6]:
X_train.head()

Unnamed: 0,A1BG,A1CF,A2M,A2ML1,A4GALT,A4GNT,AAAS,AACS,AADAC,AADACL2,...,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11A,ZYG11B,ZYX,ZZEF1,ZZZ3
TCGA-AS-3778-01,5.53,7.37,15.02,0.0,10.68,5.63,9.19,9.36,1.27,0.0,...,7.12,7.95,6.85,8.83,10.42,4.74,10.18,12.26,10.55,9.61
TCGA-BQ-5877-01,5.9,4.96,13.72,0.89,10.34,1.58,10.09,10.29,0.0,0.0,...,7.33,8.02,6.51,8.66,9.39,3.41,9.9,12.93,9.74,9.97
ACH-001484,0.275007,0.0,0.014355,0.014355,1.201634,0.0,5.521365,3.718088,0.0,0.0,...,3.632268,5.467606,1.201634,2.166715,2.582556,1.257011,1.550901,5.06652,3.087463,3.521051
TCGA-NG-A4VU-01,8.07,0.0,14.12,0.0,9.55,0.6,9.8,9.27,0.0,0.0,...,9.44,9.39,5.52,9.05,10.26,5.42,10.49,11.99,9.81,9.39
TCGA-BH-A18P-01,7.1,0.45,12.7,1.5,5.97,0.45,9.32,9.93,0.0,0.0,...,9.3,9.95,6.42,9.3,10.31,4.52,9.8,11.64,10.44,10.62


In [7]:
# transform dataset for the simple autoencoder

transform_fc_ae = {
    "z_score": "per_gene",
}

In [8]:
fc_ae_dataset = SimpleDataset(X_train, transform=transform_fc_ae)
fc_ae_dataset[0]

  return torch.tensor(self.data.iloc[index], dtype=torch.float32)


tensor([-0.2390,  1.9360,  0.7182,  ...,  0.4927,  0.5141,  0.4326])

In [9]:
# config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("{} is used".format(device))

cuda is used


In [13]:
input_dim = X_train.shape[1]
hidden_dim = 2048
z_dim = 128

num_epochs = 10
batch_size = 512
learning_rate = 2e-5
print(input_dim)

17137


In [14]:
# dataset loading
from src.models.fc import VAE

dataset = SimpleDataset(X_train, transform=transform_fc_ae)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
model = VAE(input_size=input_dim, hidden_size=hidden_dim, z_size=z_dim).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_mse = nn.MSELoss()
def loss_kl(mu, sigma):
    return - torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) -sigma.pow(2))

In [15]:
for epoch in range(num_epochs):
    train_loss = 0.0
    val_loss = 0.0

    # Training loop
    model.train()
    for batch_idx, data in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1} Training")):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, sigma = model(data)

        # Compute training loss
        loss = loss_mse(recon_batch, data) + loss_kl(mu, sigma)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        

    # Validation loop
    model.eval()
    with torch.no_grad():
        for batch_idx, data in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1} Validation")):
            data = data.to(device)
            pred_recon, _, _ = model(data)
            loss = loss_mse(pred_recon, data)
            val_loss += loss.item()
            

    # Calculate average losses
    avg_train_loss = train_loss / len(dataloader.dataset)
    avg_val_loss = val_loss / len(dataloader.dataset)

    # Output average losses for the epoch
    print(f"Epoch {epoch+1}:")
    print(f"\tTraining Loss: {avg_train_loss:.6f}")
    print(f"\tValidation Loss: {avg_val_loss:.6f}")


  return torch.tensor(self.data.iloc[index], dtype=torch.float32)
Epoch 1 Training: 100%|██████████| 19/19 [00:27<00:00,  1.44s/it]
Epoch 1 Validation: 100%|██████████| 19/19 [00:26<00:00,  1.39s/it]


Epoch 1:
	Training Loss: 549.567292
	Validation Loss: 0.001980


Epoch 2 Training: 100%|██████████| 19/19 [00:27<00:00,  1.46s/it]
Epoch 2 Validation: 100%|██████████| 19/19 [00:25<00:00,  1.37s/it]


Epoch 2:
	Training Loss: 424.340777
	Validation Loss: 0.001846


Epoch 3 Training: 100%|██████████| 19/19 [00:27<00:00,  1.43s/it]
Epoch 3 Validation: 100%|██████████| 19/19 [00:26<00:00,  1.37s/it]


Epoch 3:
	Training Loss: 361.023377
	Validation Loss: 0.001656


Epoch 4 Training: 100%|██████████| 19/19 [00:27<00:00,  1.46s/it]
Epoch 4 Validation: 100%|██████████| 19/19 [00:26<00:00,  1.37s/it]


Epoch 4:
	Training Loss: 313.954908
	Validation Loss: 0.001606


Epoch 5 Training: 100%|██████████| 19/19 [00:26<00:00,  1.42s/it]
Epoch 5 Validation: 100%|██████████| 19/19 [00:26<00:00,  1.37s/it]


Epoch 5:
	Training Loss: 274.378563
	Validation Loss: 0.001631


Epoch 6 Training: 100%|██████████| 19/19 [00:27<00:00,  1.43s/it]
Epoch 6 Validation: 100%|██████████| 19/19 [00:25<00:00,  1.34s/it]


Epoch 6:
	Training Loss: 246.727424
	Validation Loss: 0.001654


Epoch 7 Training: 100%|██████████| 19/19 [00:26<00:00,  1.40s/it]
Epoch 7 Validation: 100%|██████████| 19/19 [00:25<00:00,  1.35s/it]


Epoch 7:
	Training Loss: 218.943676
	Validation Loss: 0.001680


Epoch 8 Training: 100%|██████████| 19/19 [00:27<00:00,  1.44s/it]
Epoch 8 Validation: 100%|██████████| 19/19 [00:26<00:00,  1.39s/it]


Epoch 8:
	Training Loss: 195.134870
	Validation Loss: 0.001689


Epoch 9 Training: 100%|██████████| 19/19 [00:27<00:00,  1.43s/it]
Epoch 9 Validation: 100%|██████████| 19/19 [00:25<00:00,  1.36s/it]


Epoch 9:
	Training Loss: 176.808202
	Validation Loss: 0.001692


Epoch 10 Training: 100%|██████████| 19/19 [00:27<00:00,  1.46s/it]
Epoch 10 Validation: 100%|██████████| 19/19 [00:26<00:00,  1.38s/it]

Epoch 10:
	Training Loss: 160.959551
	Validation Loss: 0.001684



