In [29]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from src.dataloaders_and_sets.simple_dataset import SimpleDataset
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.optim as optim
from torcheval.metrics.functional import r2_score
from torch.optim.lr_scheduler import ExponentialLR
import tensorboard
from torch.utils.tensorboard import SummaryWriter


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Loading Data and Preprocessing

In [2]:
data_with_targets = pd.read_csv('data/data.csv', index_col=0)

In [3]:
data = data_with_targets.fillna(0.0)

In [4]:
# remove the possible y labels:
y_labels = ['primary_disease', 'gender', 'age', 'dataset']
data_columns = [col for col in data.columns if col not in y_labels]
y = "gender"
# TODO: makes no sense when using a autoencoder...
X_train, X_test, y_train, y_test = train_test_split(data[data_columns], data[y_labels], train_size=0.8, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, train_size=0.8, random_state=42)
print(X_train.shape, y_train.shape, X_val.shape, y_val.shape)

(7539, 17137) (7539, 4) (1885, 17137) (1885, 4)


In [5]:
# transform dataset for the simple autoencoder

transform_fc_ae = {
    "z_score": "per_gene",
    "most_variant": 5000,
}

In [6]:
fc_ae_dataset = SimpleDataset(X_train, transform=transform_fc_ae)
fc_ae_dataset[0]

  return torch.tensor(self.data.iloc[index], dtype=torch.float32)


tensor([-0.0508, -0.0571, -0.0392,  ...,  0.3304, -0.1530,  0.4498])

In [7]:
# config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("{} is used".format(device))

cuda is used


# Setup Tensorboard

In [31]:
writer = SummaryWriter("logs/autoencoder")


In [24]:
input_dim = len(fc_ae_dataset[0])
hidden_one_dim = 512
hidden_two_dim = 128
z_dim = 64

num_epochs = 50
batch_size = 512
learning_rate = 3e-4
beta = 0.0
print(input_dim)

5000


In [25]:
# dataset loading
from src.models.fc import VAE

train_dataset = SimpleDataset(X_train, transform=transform_fc_ae)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataset = SimpleDataset(X_val, transform=transform_fc_ae)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
model = VAE(input_size=input_dim, hidden_one_size=hidden_one_dim, hidden_two_size=hidden_two_dim, z_size=z_dim).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = ExponentialLR(optimizer, gamma=0.9)

loss_mse = nn.MSELoss()
def loss_kl(mu, sigma):
    return - torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) -sigma.pow(2)) / len(mu)

In [36]:
writer.add_graph(model.cpu(), train_dataset[0])
writer.close()

	%eps : Float(64, strides=[1], requires_grad=0, device=cpu) = aten::randn_like(%98, %45, %46, %47, %48, %49) # /home/fes/Nextcloud/Uni/B.Sc. Bioinfo/Bachelorarbeit/thesis/src/models/fc.py:37:0
This may cause errors in trace checking. To disable trace checking, pass check_trace=False to torch.jit.trace()
  _check_trace(
Tensor-likes are not close!

Mismatched elements: 4994 / 5000 (99.9%)
Greatest absolute difference: 0.04024442099034786 at index (140,) (up to 1e-05 allowed)
Greatest relative difference: 3022.1378215654076 at index (3565,) (up to 1e-05 allowed)
  _check_trace(


In [26]:
for epoch in range(num_epochs):
    train_loss = 0.0
    val_loss = 0.0
    number_train_batches = 0
    number_val_batches = 0
    
    # Training loop
    model.train()
    for batch_idx, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1} Training")):
        batch = batch.to(device)
        optimizer.zero_grad()
        recon_batch, mu, sigma = model(batch)

        # Compute training loss
        single_loss_mse = loss_mse(recon_batch, batch)
        single_loss_kl = loss_kl(mu, sigma)
        # print("MSE: {}, KL: {}".format(single_loss_mse, single_loss_kl))
        loss = single_loss_mse + beta * single_loss_kl
        loss.backward()
        optimizer.step()
    scheduler.step()

        train_loss += loss.item()
        
        number_train_batches += 1
        

    # Validation loop
    model.eval()
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(val_dataloader, desc=f"Epoch {epoch+1} Validation")):
            batch = batch.to(device)
            reconstructed, _, _ = model(batch)
            loss = r2_score(reconstructed, batch)
            #print(loss.item())
            #print(reconstructed - batch)
            val_loss += loss.item()
            

    # Calculate average losses
    avg_train_loss = train_loss / len(train_dataloader.dataset)
    avg_val_loss = val_loss / len(val_dataloader.dataset)

    # Output average losses for the epoch
    print(f"Epoch {epoch+1}:")
    print(f"\tTraining Loss: {avg_train_loss:.6f}")
    print(f"\tValidation Loss: {avg_val_loss:.6f}")


Epoch 1 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 1 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 1:
	Training Loss: 0.000056
	Validation Loss: 0.000000


Epoch 2 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 2 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 2:
	Training Loss: 0.000053
	Validation Loss: 0.000000


Epoch 3 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 3 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 3:
	Training Loss: 0.000079
	Validation Loss: 0.000000


Epoch 4 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 4 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 4:
	Training Loss: 0.000087
	Validation Loss: 0.000000


Epoch 5 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 5 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 5:
	Training Loss: 0.000072
	Validation Loss: 0.000000


Epoch 6 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 6 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 6:
	Training Loss: 0.000073
	Validation Loss: 0.000000


Epoch 7 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 7 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 7:
	Training Loss: 0.000074
	Validation Loss: 0.000000


Epoch 8 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 8 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 8:
	Training Loss: 0.000059
	Validation Loss: 0.000000


Epoch 9 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 9 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 9:
	Training Loss: 0.000406
	Validation Loss: 0.000000


Epoch 10 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 10 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 10:
	Training Loss: 0.000133
	Validation Loss: 0.000000


Epoch 11 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 11 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 11:
	Training Loss: 0.000421
	Validation Loss: 0.000000


Epoch 12 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 12 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 12:
	Training Loss: 0.000460
	Validation Loss: 0.000000


Epoch 13 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 13 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 13:
	Training Loss: 0.000094
	Validation Loss: 0.000000


Epoch 14 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 14 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 14:
	Training Loss: 0.000129
	Validation Loss: 0.000000


Epoch 15 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 15 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 15:
	Training Loss: 0.000105
	Validation Loss: 0.000000


Epoch 16 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 16 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 16:
	Training Loss: 0.000115
	Validation Loss: 0.000000


Epoch 17 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 17 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 17:
	Training Loss: 0.000147
	Validation Loss: 0.000000


Epoch 18 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 18 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 18:
	Training Loss: 0.000138
	Validation Loss: 0.000000


Epoch 19 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 19 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 19:
	Training Loss: 0.000064
	Validation Loss: 0.000000


Epoch 20 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 20 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 20:
	Training Loss: 0.000112
	Validation Loss: 0.000000


Epoch 21 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 21 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 21:
	Training Loss: 0.000118
	Validation Loss: 0.000000


Epoch 22 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 22 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 22:
	Training Loss: 0.000075
	Validation Loss: 0.000000


Epoch 23 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 23 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 23:
	Training Loss: 0.000093
	Validation Loss: 0.000000


Epoch 24 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 24 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 24:
	Training Loss: 0.000066
	Validation Loss: 0.000000


Epoch 25 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 25 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 25:
	Training Loss: 0.000168
	Validation Loss: 0.000000


Epoch 26 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 26 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 26:
	Training Loss: 0.000066
	Validation Loss: 0.000000


Epoch 27 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 27 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 27:
	Training Loss: 0.000074
	Validation Loss: 0.000000


Epoch 28 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 28 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 28:
	Training Loss: 0.000069
	Validation Loss: 0.000000


Epoch 29 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 29 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 29:
	Training Loss: 0.000103
	Validation Loss: 0.000000


Epoch 30 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 30 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 30:
	Training Loss: 0.000066
	Validation Loss: 0.000000


Epoch 31 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 31 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 31:
	Training Loss: 0.000081
	Validation Loss: 0.000000


Epoch 32 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 32 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 32:
	Training Loss: 0.000088
	Validation Loss: 0.000000


Epoch 33 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 33 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 33:
	Training Loss: 0.000169
	Validation Loss: 0.000000


Epoch 34 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 34 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 34:
	Training Loss: 0.000081
	Validation Loss: 0.000000


Epoch 35 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 35 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 35:
	Training Loss: 0.000062
	Validation Loss: 0.000000


Epoch 36 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 36 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 36:
	Training Loss: 0.000081
	Validation Loss: 0.000000


Epoch 37 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 37 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 37:
	Training Loss: 0.000207
	Validation Loss: 0.000000


Epoch 38 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 38 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 38:
	Training Loss: 0.000067
	Validation Loss: 0.000000


Epoch 39 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 39 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 39:
	Training Loss: 0.000126
	Validation Loss: 0.000000


Epoch 40 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 40 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 40:
	Training Loss: 0.000058
	Validation Loss: 0.000000


Epoch 41 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 41 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 41:
	Training Loss: 0.000067
	Validation Loss: 0.000000


Epoch 42 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 42 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 42:
	Training Loss: 0.000143
	Validation Loss: 0.000000


Epoch 43 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 43 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 43:
	Training Loss: 0.000069
	Validation Loss: 0.000000


Epoch 44 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 44 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 44:
	Training Loss: 0.000051
	Validation Loss: 0.000000


Epoch 45 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 45 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 45:
	Training Loss: 0.000065
	Validation Loss: 0.000000


Epoch 46 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 46 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 46:
	Training Loss: 0.000056
	Validation Loss: 0.000000


Epoch 47 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 47 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 47:
	Training Loss: 0.000089
	Validation Loss: 0.000000


Epoch 48 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 48 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 48:
	Training Loss: 0.000119
	Validation Loss: 0.000000


Epoch 49 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 49 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]


Epoch 49:
	Training Loss: 0.000068
	Validation Loss: 0.000000


Epoch 50 Training:   0%|          | 0/7539 [00:00<?, ?it/s]
Epoch 50 Validation:   0%|          | 0/1885 [00:00<?, ?it/s]

Epoch 50:
	Training Loss: 0.000112
	Validation Loss: 0.000000



