In [21]:
import sys
import torch
from os.path import join
import numpy as np
from sklearn.ensemble import RandomForestClassifier
import sklearn.metrics
import numpy as np

sys.path.append('/home/agajan/DeepMRI')
from deepmri import Datasets, utils  # noqa: E402
from DiffusionMRI.Conv2dAE import ConvEncoder as Encoder
from DiffusionMRI.Conv2dAE import ConvDecoder as Decoder  # noqa: E402

In [40]:
experiment_dir = '/home/agajan/experiment_DiffusionMRI/'
data_path = experiment_dir + 'tractseg_data/784565/'
model_name = 'Conv2dAECoronal'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # device

In [41]:
trainset = Datasets.OrientationDatasetChannelNorm(data_path + 'training_slices/coronal/', normalize=True, bg_zero=True)
total_examples = len(trainset)

batch_size = 1
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=6)
print("Total training examples: {}, Batch size: {}, Iters per epoch: {}".format(total_examples,
                                                                                batch_size,
                                                                                total_examples / batch_size))


Total training examples: 142, Batch size: 1, Iters per epoch: 142.0


In [42]:
criterion = torch.nn.MSELoss()

encoder = Encoder()
decoder = Decoder()
encoder.to(device)
decoder.to(device)

start_epoch = 50
encoder_path = "{}/models/{}_encoder_epoch_{}".format(experiment_dir, model_name, start_epoch)
decoder_path = "{}/models/{}_decoder_epoch_{}".format(experiment_dir, model_name, start_epoch)
encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))
print("Loaded pretrained weights starting from epoch {}".format(start_epoch))

Loaded pretrained weights starting from epoch 50


In [43]:
feature_name = 'coronal_features_epoch_50.npz'
features_path = join(data_path, 'learned_features', feature_name)
learned_features = np.load(features_path)['data']

In [47]:
idx = np.random.randint(len(trainset))
data = trainset[idx]
crd_0 = 72
crd_1 = 87
crd_2 = 72

In [48]:
with torch.no_grad():
    encoder.eval()
    decoder.eval()
    x = data['data'].unsqueeze(0).to(device)
    h = encoder(x)
    y = decoder(h)
    print('Loss: ', criterion(x, y).item())
    h = h.detach().cpu().squeeze().numpy()
    x = x.detach().cpu().squeeze().numpy()
    y = y.detach().cpu().squeeze().numpy()

Loss:  0.03805779293179512


In [49]:
print('x: ', x[3:6])
print('y: ', y[3:6])
print(learned_features[crd_0, crd_1, crd_2][:3])
print(h[:3])


x:  [[[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]]
y:  [[[ 0.05632295  0.05632295  0.05632295 ...  0.05632295  0.05632295
    0.05632295]
  [ 0.05632295  0.05632295  0.05632295 ...  0.05632295  0.05632295
    0.05632295]
  [ 0.05632295  0.05632295  0.05632295 ...  0.05632295  0.05632295
    0.05632295]
  ...
  [ 0.05632295  0.05632295  0.05632295 ...  0.05632295  0.05632295
    0.05632295]
  [ 0.05632295  0.05632295  0.05632295 ...  0.05632295  0.05632295
    0.05632295]
  [ 0.05632295  0.05632295  0.05632295 ...  0.05632295  0.05632295
    0.056322

In [55]:
h[0, :, :].tolist()

[[-0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.092465840280056,
  -0.09246