In [None]:
ROOT_DIR = "/data/rohitrango/BRATS2021/val/"
ENCODER_DIR = "/data/rohitrango/Implicit3DCNNTasks/brats2021_unimodal"
OUT_DIR = "/data/rohitrango/Implicit3DCNNTasks/brats2021_unimodal_val/"
NUM_PTS = 200000
EPOCHS = 2500

In [None]:
import torch
from torch import nn
from queue import Queue
from gridencoder import GridEncoder
from dataloaders import BRATS2021Dataset
from glob import glob
import nibabel as nib
from multiprocessing import Process
from utils import uniform_normalize
from time import sleep
from os import path as osp
from tqdm import tqdm
import argparse
import gc
from utils import init_network
from configs.config import get_cfg_defaults
import numpy as np
from torch.nn import functional as F
from matplotlib import pyplot as plt

# Load some random encoders and check distribution

In [None]:
train_encoded = sorted(glob("/data/rohitrango/Implicit3DCNNTasks/brats2021_unimodal/encoder_*.pth"))
rng = np.random.RandomState(55)
rng.shuffle(train_encoded)

val_encoded = sorted(glob("/data/rohitrango/Implicit3DCNNTasks/brats2021_unimodal_val/*.pth"))
rng = np.random.RandomState(12321)
rng.shuffle(val_encoded)

# val_encoded = sorted(glob("/data/rohitrango/Implicit3DCNNTasks/brats2021_unimodal/encoder_*.pth"))
# rng = np.random.RandomState(8165)
# rng.shuffle(val_encoded)

In [None]:
# val_encoded[0], train_encoded[0]

In [None]:
N = 200
train_data, val_data = [], []
for f in train_encoded[:N]:
    d = torch.load(f, map_location='cpu')['embeddings']
    train_data.append(d)

for f in val_encoded[:N]:
    d = torch.load(f, map_location='cpu') #['embeddings']
    val_data.append(d)

In [None]:
trainall, valall = torch.stack(train_data), torch.stack(val_data)

In [None]:
# print(trainall.abs().mean(1))
# print(valall.abs().mean(1))

In [None]:
# plot histograms
train_p = trainall.abs().mean(1).reshape(-1).data.numpy()
val_p = valall.abs().mean(1).reshape(-1).data.numpy()
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.hist(train_p, bins=20)
ax.hist(val_p, bins=20, alpha=0.4)

# Run an initial run of encoder

In [None]:
dirqueue = list(sorted(glob(ROOT_DIR + "/*") ))
device = torch.device("cuda")
decoders = [nn.Sequential(
        nn.Linear(32, 256),
        nn.LeakyReLU(),
        nn.Linear(256, 1)
    ).to(device) for _ in range(4)]
for i in range(4):
    decoders[i].load_state_dict(torch.load(osp.join(ENCODER_DIR, f"decoder{i}.pth")), strict=True)
    decoders[i].eval()
    decoders[i].requires_grad_(False)

In [None]:
for q in dirqueue[:1]:
    files = sorted(glob(q + "/*"))
    print("Processing {:s}".format(q))
    images = [torch.from_numpy(nib.load(f).get_fdata()).float().to(device) for f in files]

In [None]:
images = [uniform_normalize(img) for img in images]
encoders = [GridEncoder(level_dim=2, desired_resolution=196).to(device) for _ in range(4)]
optims = [torch.optim.Adam(enc.parameters(), lr=1e-2) for enc in encoders]

In [None]:
ds = BRATS2021Dataset(root_dir=ROOT_DIR, augment=False, num_points=1, multimodal=False, mlabel=0, sample='full')

In [None]:
for idx in [0]:
    datum = ds[idx]
    xyzfloat = datum['xyz'] / (datum['dims'][None] - 1) * 2 - 1  # ranges from -1 to 1  (allpoints, 3)
    xyzfloat = xyzfloat.float().cuda()
    image = datum['imgpoints'].cuda()
    total_points = image.shape[0]
    subj = datum['subj']
    
    print(datum['dims'])
    input()
    
    pbar = tqdm(range(1000))
    for i in pbar:
        optims[0].zero_grad()
        minibatch = np.random.randint(total_points, size=(100000,))
        xyzminibatch = xyzfloat[minibatch]
        imageminibatch = image[minibatch]
        # load the data
        pred_minibatch = decoders[0](encoders[0](xyzminibatch)) 
        # loss and backward
        loss = F.mse_loss(pred_minibatch, imageminibatch)
        loss.backward()
        pbar.set_description("subj: {} iter: {}/{}, Loss: {:06f}".format(0, i, 1000, loss.item()))
        optims[0].step()

In [None]:
encoders[0].embeddings.min(0).values, encoders[0].embeddings.max(0).values

In [None]:
encoders[0].embeddings.abs().mean(0)

In [None]:
ls /data/rohitrango/Implicit3DCNNTasks/brats2021_unimodal/

In [None]:
data = torch.load("/data/rohitrango/Implicit3DCNNTasks/brats2021_unimodal/encoder_BraTS2021_00014.pth")['embeddings']

In [None]:
data.abs().mean(0)