In [71]:
import tempfile

import numpy as np
import tensorflow as tf
from neurocombat_sklearn import CombatModel
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
import torchvision

from afqinsight import AFQDataset
from afqinsight.nn.pt_models import mlp4_pt
from afqinsight.nn.tf_models import mlp4, cnn_vgg
import torch
import torch.nn.functional as F

In [72]:
dataset = AFQDataset.from_study("hbn")
dataset.drop_target_na()
torch_dataset = dataset.as_torch_dataset(bundles_as_channels=True, channels_last=False)

File /Users/samchou/.cache/afq-insight/hbn/subjects.tsv exists.
File /Users/samchou/.cache/afq-insight/hbn/nodes.csv exists.


In [73]:
for input, label in torch_dataset:
    print(f"Input shape: {input}")
    print(f"Label: {label}")

Input shape: tensor([[0.2482, 0.3212, 0.3459,  ..., 0.3363, 0.3030, 0.2284],
        [0.2790, 0.3873, 0.4087,  ..., 0.3557, 0.3262, 0.2327],
        [0.2286, 0.2927, 0.3187,  ..., 0.2877, 0.2613, 0.2234],
        ...,
        [0.0012, 0.0010, 0.0010,  ..., 0.0010, 0.0009, 0.0011],
        [0.0009, 0.0008, 0.0008,  ..., 0.0010, 0.0010, 0.0011],
        [0.0008, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0010]])
Label: tensor([21.2167,  0.0000,  3.0000])
Input shape: tensor([[0.2520, 0.2961, 0.3261,  ..., 0.3200, 0.2794, 0.2206],
        [0.2231, 0.2717, 0.3117,  ..., 0.3097, 0.2716, 0.2165],
        [0.2451, 0.2869, 0.3174,  ..., 0.2841, 0.2590, 0.2211],
        ...,
        [0.0012, 0.0010, 0.0010,  ..., 0.0009, 0.0010, 0.0011],
        [0.0010, 0.0009, 0.0009,  ..., 0.0010, 0.0010, 0.0010],
        [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]])
Label: tensor([11.9984,  1.0000,  4.0000])
Input shape: tensor([[0.2406, 0.2843, 0.3122,  ..., 0.3145, 0.2793, 0.2249],
        [0.2277

In [74]:
train_dataset, test_dataset =torch.utils.data.random_split(torch_dataset, [int(0.8*len(torch_dataset)), len(torch_dataset)-int(0.8*len(torch_dataset))])
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [int(0.8*len(train_dataset)), len(train_dataset)-int(0.8*len(train_dataset))])

In [75]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

In [76]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [77]:
x = next((iter(torch_dataset)))
input_shape = tuple(x[0].size())
gt_shape = tuple(x[1].size())

print("Input shape: ", input_shape)
print("Ground truth shape: ", gt_shape)

Input shape:  (48, 100)
Ground truth shape:  (3,)


In [86]:
model = mlp4_pt(48*100, 3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss()

In [87]:
for epoch in range(100):
    model.train()
    train_loss = 0
    num_samples = 0
    # for input_batch, gt_batch in train_loader:
    for input_batch, gt_batch in train_loader:
        input_batch, gt_batch = input_batch.to(device).float(), gt_batch.to(device).float()

        optimizer.zero_grad()
        output = model(input_batch)
        loss = criterion(output, gt_batch)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * input_batch.size(0)
        num_samples += input_batch.size(0)
    train_loss /= num_samples

    model.eval()
    val_loss = 0
    num_samples = 0
    with torch.no_grad():
        # for input_batch, gt_batch in val_loader:
        for input_batch, gt_batch in val_loader:
            input_batch, gt_batch = input_batch.to(device).float(), gt_batch.to(device).float()

            output = model(input_batch)
            loss = criterion(output, gt_batch)
            val_loss += loss.item() * input_batch.size(0)
            num_samples += input_batch.size(0)
    val_loss /= num_samples

    print(f"Epoch {epoch}: train loss {train_loss}, val loss {val_loss}")

Epoch 0: train loss 14.15719858965083, val loss 14.131691661566794
Epoch 1: train loss 13.502758859989033, val loss 13.206790276594385
Epoch 2: train loss 12.286972641745205, val loss 11.648741352119574
Epoch 3: train loss 10.813784918793083, val loss 10.451676684478453
Epoch 4: train loss 10.042918438488115, val loss 10.025756768957029
Epoch 5: train loss 9.789829163096059, val loss 9.890160554228817
Epoch 6: train loss 9.702065603617248, val loss 9.836636399744346
Epoch 7: train loss 9.664082503198978, val loss 9.810375873859112
Epoch 8: train loss 9.64491990982388, val loss 9.794820488894663
Epoch 9: train loss 9.631992768202995, val loss 9.785411129827084
Epoch 10: train loss 9.624303875257022, val loss 9.779340696175362
Epoch 11: train loss 9.61914270806752, val loss 9.775189804791607
Epoch 12: train loss 9.615553697748998, val loss 9.772172994836916
Epoch 13: train loss 9.612857590168964, val loss 9.77005718703254
Epoch 14: train loss 9.611118085819673, val loss 9.76853875252714
