In [2]:
ROOT_DIR = "/data/rohitrango/BRATS2021/val/"
ENCODER_DIR = "/data/rohitrango/Implicit3DCNNTasks/brats2021_unimodal"
OUT_DIR = "/data/rohitrango/Implicit3DCNNTasks/brats2021_unimodal_val/"
NUM_PTS = 200000
EPOCHS = 2500

In [30]:
import torch
from torch import nn
from queue import Queue
from gridencoder import GridEncoder
from dataloaders import BRATS2021Dataset
from glob import glob
import nibabel as nib
from multiprocessing import Process
from utils import uniform_normalize
from time import sleep
from os import path as osp
from tqdm import tqdm
import argparse
import gc
from utils import init_network
from configs.config import get_cfg_defaults
import numpy as np
from torch.nn import functional as F

In [13]:
dirqueue = list(sorted(glob(ROOT_DIR + "/*") ))
device = torch.device("cuda")
decoders = [nn.Sequential(
        nn.Linear(32, 256),
        nn.LeakyReLU(),
        nn.Linear(256, 1)
    ).to(device) for _ in range(4)]
for i in range(4):
    decoders[i].load_state_dict(torch.load(osp.join(ENCODER_DIR, f"decoder{i}.pth")), strict=True)
    decoders[i].eval()
    decoders[i].requires_grad_(False)

In [14]:
for q in dirqueue[:1]:
    files = sorted(glob(q + "/*"))
    print("Processing {:s}".format(q))
    images = [torch.from_numpy(nib.load(f).get_fdata()).float().to(device) for f in files]

Processing /data/rohitrango/BRATS2021/val/BraTS2021_00001


In [17]:
images = [uniform_normalize(img) for img in images]
encoders = [GridEncoder(level_dim=2, desired_resolution=196).to(device) for _ in range(4)]
optims = [torch.optim.Adam(enc.parameters(), lr=1e-2) for enc in encoders]

In [20]:
ds = BRATS2021Dataset(root_dir=ROOT_DIR, augment=False, num_points=1, multimodal=False, mlabel=0, sample='full')

In [31]:
for idx in [0]:
    datum = ds[idx]
    xyzfloat = datum['xyz'] / (datum['dims'][None] - 1) * 2 - 1  # ranges from -1 to 1  (allpoints, 3)
    xyzfloat = xyzfloat.float().cuda()
    image = datum['imgpoints'].cuda()
    total_points = image.shape[0]
    subj = datum['subj']
    
    pbar = tqdm(range(1000))
    for i in pbar:
        optims[0].zero_grad()
        minibatch = np.random.randint(total_points, size=(100000,))
        xyzminibatch = xyzfloat[minibatch]
        imageminibatch = image[minibatch]
        # load the data
        pred_minibatch = decoders[0](encoders[0](xyzminibatch)) 
        # loss and backward
        loss = F.mse_loss(pred_minibatch, imageminibatch)
        loss.backward()
        pbar.set_description("subj: {} iter: {}/{}, Loss: {:06f}".format(0, i, 1000, loss.item()))
        optims[0].step()

subj: 0 iter: 999/1000, Loss: 0.000340: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 188.49it/s]


In [35]:
encoders[0].embeddings.min(0).values, encoders[0].embeddings.max(0).values

(tensor([-1.4450, -1.1141], device='cuda:0', grad_fn=<MinBackward0>),
 tensor([1.1263, 1.7374], device='cuda:0', grad_fn=<MaxBackward0>))

In [37]:
encoders[0].embeddings.abs().mean(0)

tensor([0.0768, 0.0833], device='cuda:0', grad_fn=<MeanBackward1>)

In [39]:
ls /data/rohitrango/Implicit3DCNNTasks/brats2021_unimodal/

decoder0.pth                 encoder_BraTS2021_01038.pth
decoder1.pth                 encoder_BraTS2021_01039.pth
decoder2.pth                 encoder_BraTS2021_01040.pth
decoder3.pth                 encoder_BraTS2021_01041.pth
decoder_optim0.pth           encoder_BraTS2021_01042.pth
decoder_optim1.pth           encoder_BraTS2021_01043.pth
decoder_optim2.pth           encoder_BraTS2021_01044.pth
decoder_optim3.pth           encoder_BraTS2021_01045.pth
encoder_BraTS2021_00000.pth  encoder_BraTS2021_01046.pth
encoder_BraTS2021_00002.pth  encoder_BraTS2021_01047.pth
encoder_BraTS2021_00003.pth  encoder_BraTS2021_01048.pth
encoder_BraTS2021_00005.pth  encoder_BraTS2021_01049.pth
encoder_BraTS2021_00006.pth  encoder_BraTS2021_01050.pth
encoder_BraTS2021_00008.pth  encoder_BraTS2021_01051.pth
encoder_BraTS2021_00009.pth  encoder_BraTS2021_01052.pth
encoder_BraTS2021_00011.pth  encoder_BraTS2021_01053.pth
encoder_BraTS2021_00012.pth  encoder_BraTS2021_01054.pth
encoder_BraTS2

In [51]:
data = torch.load("/data/rohitrango/Implicit3DCNNTasks/brats2021_unimodal/encoder_BraTS2021_00014.pth")['embeddings']

In [52]:
data.abs().mean(0)

tensor([0.0842, 0.0911, 0.0676, 0.0670, 0.0653, 0.0601, 0.0842, 0.0809],
       device='cuda:0')