In [None]:
import warnings 
warnings.filterwarnings('ignore')
import torch
from torch import nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
import torchio as tio
from torchinfo import summary
from scipy.spatial import Delaunay
from scipy.spatial.distance import cdist
import numpy as np
import math
import nibabel as nib
import nrrd
import SimpleITK as sitk
import time
#from tqdm import tqdm
from tqdm.notebook import tqdm
import os
import shutil
import glob
import json
from natsort import natsorted
import matplotlib.pyplot as plt

In [None]:
class LinearInterpolation3d(nn.Module):
    def __init__(self, size):
        super().__init__()

        self.size = size

        grid = F.affine_grid(torch.eye(3, 4).unsqueeze(0), (1, 1)+size, align_corners=True).view(1, -1, 3)
        self.register_buffer("grid", grid)

        pads = torch.ones((1, 8, 3))
        pads[0, 1, 0] = -1
        pads[0, 2, 1] = -1
        pads[0, 3, 2] = -1
        pads[0, 4, 0] = -1
        pads[0, 4, 1] = -1
        pads[0, 5, 0] = -1
        pads[0, 5, 2] = -1
        pads[0, 6, 1] = -1
        pads[0, 6, 2] = -1
        pads[0, 7, 0] = -1
        pads[0, 7, 1] = -1
        pads[0, 7, 2] = -1
        self.register_buffer("pads", pads)

        pads_values = torch.zeros((1, 8, 3))
        self.register_buffer("pads_values", pads_values)
    
    def _get_barycentric_coordinates(self, points_tri, target):
        s = points_tri.find_simplex(target)
        dim = target.shape[1]
        
        b0 = (points_tri.transform[s, :dim].transpose([1, 0, 2]) *
            (target - points_tri.transform[s, dim])).sum(axis=2).T
        coord = np.c_[b0, 1 - b0.sum(axis=1)]

        return coord, s

    def _linear_interp_material(self, points, target):
        """
        Linearly interpolate signal at target locations
        points: numpy array (N, D)
        target: numpy array (N, D)
        """
        points_triangulated = Delaunay(points)
        c, s = self._get_barycentric_coordinates(points_triangulated, target)
        
        return points_triangulated.simplices, points_triangulated.transform, c, s
    
    def _linear_interp(self, points, values, target):
        """
        points: points where the signal is known; torch tensor (B, N, D)
        values: signal; torch tensor (B, N, C)
        target: where the signal needs to be interpolated; torch tensor (B, M, D)
        """
        device = points.device
        B = points.size(0)

        if B>1:
            raise NotImplementedError("Linear interpolation not implemented for batches larger than 1.")

        points_np = points.detach().cpu().numpy()
        target_np = target.detach().cpu().numpy()
            
        simplices, T, coords, s = self._linear_interp_material(points_np[0], target_np[0])
        simplices = torch.tensor(simplices).long().to(device)  # n_simplices, D+1
        T = torch.tensor(T).float().to(device) # n_simplices, D+1
        coords = torch.tensor(coords).float().to(device) # M, D+1
        s = torch.tensor(s).long().to(device) # M

        res = (values[0, simplices[s]] * coords[:, :, None]).sum(1) # M,C
            
        return res[None, :]
    
    def forward(self, kpts, disp):
        """
        kpts: B, N, 3
        disp: B, N, 3
        """
        kpts_pad = torch.cat([kpts, self.pads], dim=1)
        disp_pad = torch.cat([disp, self.pads_values], dim=1)
        interp = self._linear_interp(kpts_pad, disp_pad, self.grid)
        return torch.reshape(interp, (kpts.size(0),)+self.size+(3,)).permute(0, 4, 1, 2, 3)

In [None]:
class ThinPlateSpline(nn.Module):
    def __init__(self, shape, step=4, lambd=0.1, unroll_step_size=2**12):
        super().__init__()
        self.shape = shape  # Output grid shape: (D, H, W)
        self.step = step    # Downsampling step for coarse grid
        self.lambd = lambd
        self.unroll_step_size = unroll_step_size

        # Precompute the identity affine grid for interpolation
        D1, H1, W1 = [s // step for s in shape]
        grid = F.affine_grid(
            torch.eye(3, 4).unsqueeze(0),  # Identity
            size=(1, 1, D1, H1, W1),
            align_corners=True
        )
        self.register_buffer("base_grid", grid.view(-1, 3))  # Flattened 3D grid

    def forward(self, kpts, disps):
        """
        kpts: (1, N, 3) - keypoints from source
        disps: (1, N, 3) - corresponding displacements
        Returns: dense displacement field of shape (1, 3, D, H, W)
        """
        x1 = kpts[0]  # (N, 3)
        y1 = disps[0]  # (N, 3)
        x2 = self.base_grid    # (M, 3) - dense grid to warp

        # Compute TPS parameters
        theta = self._fit(x1, y1)

        # Compute transformed grid
        M = x2.shape[0]
        y2 = torch.zeros((1, M, 3), device=x2.device)

        n_chunks = math.ceil(M / self.unroll_step_size)
        for j in range(n_chunks):
            j1 = j * self.unroll_step_size
            j2 = min((j + 1) * self.unroll_step_size, M)
            y2[0, j1:j2, :] = self._z(x2[j1:j2], x1, theta)

        # Reshape and interpolate back to full resolution
        D1, H1, W1 = [s // self.step for s in self.shape]
        y2 = y2.view(1, D1, H1, W1, 3).permute(0, 4, 1, 2, 3)  # (1, 3, D1, H1, W1)
        y2 = F.interpolate(y2, size=self.shape, mode='trilinear', align_corners=True)
        return y2

    def _fit(self, c, f):
        """Compute TPS parameters (theta)"""
        device = c.device
        n = c.shape[0]
        f_dim = f.shape[1]

        U = self._u(self._d(c, c))  # (n, n)
        K = U + torch.eye(n, device=device) * self.lambd

        P = torch.ones((n, 4), device=device)
        P[:, 1:] = c

        v = torch.zeros((n + 4, f_dim), device=device)
        v[:n, :] = f

        A = torch.zeros((n + 4, n + 4), device=device)
        A[:n, :n] = K
        A[:n, -4:] = P
        A[-4:, :n] = P.t()

        theta = torch.linalg.solve(A, v)
        return theta

    def _z(self, x, c, theta):
        """Apply TPS transformation"""
        U = self._u(self._d(x, c))
        w, a = theta[:-4], theta[-4:].unsqueeze(2)
        b = torch.matmul(U, w)  # (M, 3)
        return (a[0] + a[1] * x[:, 0] + a[2] * x[:, 1] + a[3] * x[:, 2] + b.t()).t()

    def _d(self, a, b):
        """Pairwise Euclidean distances"""
        ra = (a ** 2).sum(dim=1).view(-1, 1)
        rb = (b ** 2).sum(dim=1).view(1, -1)
        dist = ra + rb - 2.0 * torch.mm(a, b.T)
        return torch.sqrt(torch.clamp(dist, min=0.0))

    def _u(self, r):
        """Radial basis function for TPS"""
        return (r ** 2) * torch.log(r + 1e-6)


In [None]:
class NeuriPhyDataset(Dataset):
    def __init__(self, path, out_size=(240,240,155), mode='train', normalization=None, ddf_interpolator='linear'):
        self.data = natsorted(glob.glob(os.path.join(path, "*")))
        self.mode = mode
        self.normalization = normalization
        self.ddf_interp = ddf_interpolator
        self.out_size = out_size

    def preprocess(self, img, mask):
        if self.normalization is not None:
            mask = mask > 0
            if self.normalization == 'standard':
                mean = img[mask].mean()
                std = img[mask].std()
                img = (img - mean) / (std + 1e-8)
            elif self.normalization == 'min-max':
                max_data = np.percentile(img, 99.95)
                min_data = img[mask].min()
                img = (img - min_data) / (max_data - min_data)
            else:
                raise NotImplementedError(f"Normalization '{self.normalization}' should be in ['min-max', 'standard']")
        return torch.tensor(img.astype(np.float32))

    def initialize_disp_field(self, kpts, gt_ddf, tumor_seg, min_kpts=5, max_kpts=20):
        _, D, H, W = gt_ddf.shape

        if self.ddf_interp == 'linear':
            ddf_interp = LinearInterpolation3d((D,H,W))
        elif self.ddf_interp == 'tps':
            ddf_interp = ThinPlateSpline((D,H,W))

        kpts = np.genfromtxt(kpts, delimiter="\t", skip_header=6, dtype=np.float32)[:,:3]
        _, unique_idxs = np.unique(kpts, axis=0, return_index=True)
        sorted_unique_idxs = np.sort(unique_idxs)
        kpts = kpts[sorted_unique_idxs]
        kpts[:, 0] = (kpts[:, 0] / (D - 1)) * 2 - 1
        kpts[:, 1] = (kpts[:, 1] / (H - 1)) * 2 - 1
        kpts[:, 2] = (kpts[:, 2] / (W - 1)) * 2 - 1

        tumor_coords = np.argwhere(tumor_seg > 0)
        tumor_center = np.mean(tumor_coords, axis=0)
        distances = cdist(tumor_center.reshape(1,-1), kpts)
        weight = np.exp(-distances[0])
        weight /= np.sum(weight)

        k = np.random.randint(min_kpts, max_kpts+1)
        choices = np.random.choice(range(kpts.shape[0]), p=weight, size=k, replace=False)
        kpts = kpts[choices]
        kpts = torch.from_numpy(kpts)
        grid = kpts.view(1, -1, 1, 1, 3)
        sampled_disps = F.grid_sample(gt_ddf.unsqueeze(0), grid, mode='bilinear', align_corners=True).permute(2,1,0,3,4).squeeze()

        #pad = torch.full((max_kpts-k,3), -1.0, dtype=kpts.dtype)
        #pad_mask = torch.zeros(max_kpts, dtype=torch.bool)
        #pad_mask[:k] = 1
        #kpts = torch.cat([kpts, pad], dim=0)
        #sampled_disps = torch.cat([sampled_disps, pad], dim=0)

        init_ddf = ddf_interp(kpts.unsqueeze(0), sampled_disps.unsqueeze(0)).squeeze(0)  # (3, D, H, W)

        return init_ddf

    def __getitem__(self, idx):
        curr_data = self.data[idx]
        brain_seg = glob.glob(os.path.join(curr_data, "segmentations", "*brain_mask*.seg.nrrd"))[0]
        tumor_seg = glob.glob(os.path.join(curr_data, "segmentations", "*tumor*.seg.nrrd"))[0]
        gt_ddf = glob.glob(os.path.join(curr_data, "simulations", "simulation*", "*disp_field*.npz"))[0] # todo: right now only using one
        t1ce = glob.glob(os.path.join(curr_data, "images", "*T1ce*.nii"))
        if len(t1ce) > 0:
            kpts = glob.glob(os.path.join(curr_data, "keypoints", "*T1ce.key"))[0]
            img = nib.load(t1ce[0]).get_fdata()
        else:
            kpts = glob.glob(os.path.join(curr_data, "keypoints", "*T2.key"))[0]
            t2 = glob.glob(os.path.join(curr_data, "images", "*T2*.nii"))
            img = nib.load(t2[0]).get_fdata()
        
        brain_seg, _ = nrrd.read(brain_seg)
        tumor_seg, _ = nrrd.read(tumor_seg)
        gt_ddf = np.load(gt_ddf)['field'].transpose(0,3,2,1)
        gt_ddf = torch.tensor(gt_ddf.astype(np.float32))
        brain_seg = torch.tensor(brain_seg.astype(np.float32))

        img = self.preprocess(img, brain_seg)

        init_ddf = self.initialize_disp_field(kpts, gt_ddf, tumor_seg)

        subject = tio.Subject(
            img=tio.ScalarImage(tensor=img.unsqueeze(0)),
            brain_seg=tio.LabelMap(tensor=brain_seg.unsqueeze(0)),
            gt_ddf = tio.ScalarImage(tensor=gt_ddf),
            init_ddf = tio.ScalarImage(tensor=init_ddf)
        )
        transform = tio.CropOrPad(self.out_size)
        transformed = transform(subject)
            
            #img = F.interpolate(img.unsqueeze(0).unsqueeze(0), self.out_size, mode='trilinear', align_corners=True).squeeze(0)
            #brain_seg = F.interpolate(brain_seg.unsqueeze(0).unsqueeze(0), self.out_size, mode='nearest').squeeze(0).squeeze(0)
            #gt_ddf = F.interpolate(gt_ddf.unsqueeze(0), self.out_size, mode='trilinear', align_corners=True).squeeze(0)
            #init_ddf = F.interpolate(init_ddf.unsqueeze(0), self.out_size, mode='trilinear', align_corners=True).squeeze(0)
        
        if self.mode == 'train':
            return transformed['img'].tensor, transformed['init_ddf'].tensor, transformed['gt_ddf'].tensor, transformed['brain_seg'].tensor
        elif self.mode == 'test':
            return transformed['img'].tensor, transformed['init_ddf'].tensor, transformed['brain_seg'].tensor
        
    def __len__(self):
        return len(self.data)

In [None]:
class InitWeights_He(object):
    def __init__(self, neg_slope=1e-2):
        self.neg_slope = neg_slope

    def __call__(self, module):
        if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d):
            module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope)
            if module.bias is not None:
                module.bias = nn.init.constant_(module.bias, 0)


class ConvBlock(nn.Module):
    """
    Specific convolutional block followed by leakyrelu for unet.
    """

    def __init__(self, ndims, in_channels, out_channels, stride=1):
        super().__init__()

        Conv = getattr(nn, 'Conv%dd' % ndims)
        self.main = Conv(in_channels, out_channels, 3, stride, 1)
        self.activation = nn.LeakyReLU(0.2)

    def forward(self, x):
        out = self.main(x)
        out = self.activation(out)
        return out
    

class Unet(nn.Module):
    """
    A unet architecture. Layer features can be specified directly as a list of encoder and decoder
    features or as a single integer along with a number of unet levels. The default network features
    per layer (when no options are specified) are:

        encoder: [16, 32, 32, 32]
        decoder: [32, 32, 32, 32, 32, 16, 16]
    """

    def __init__(self,
                 inshape=None,
                 infeats=None,
                 nb_features=None,
                 nb_levels=None,
                 max_pool=2,
                 feat_mult=1,
                 nb_conv_per_level=1,
                 half_res=False):
        """
        Parameters:
            inshape: Input shape. e.g. (192, 192, 192)
            infeats: Number of input features.
            nb_features: Unet convolutional features. Can be specified via a list of lists with
                the form [[encoder feats], [decoder feats]], or as a single integer. 
                If None (default), the unet features are defined by the default config described in 
                the class documentation.
            nb_levels: Number of levels in unet. Only used when nb_features is an integer. 
                Default is None.
            feat_mult: Per-level feature multiplier. Only used when nb_features is an integer. 
                Default is 1.
            nb_conv_per_level: Number of convolutions per unet level. Default is 1.
            half_res: Skip the last decoder upsampling. Default is False.
        """

        super().__init__()

        # ensure correct dimensionality
        ndims = len(inshape)
        assert ndims in [1, 2, 3], 'ndims should be one of 1, 2, or 3. found: %d' % ndims

        # cache some parameters
        self.half_res = half_res

        # default encoder and decoder layer features if nothing provided
        if nb_features is None:
            nb_features = [
                [16, 32, 32, 32],             # encoder
                [32, 32, 32, 32, 32, 16, 16]  # decoder
            ]

        # build feature list automatically
        if isinstance(nb_features, int):
            if nb_levels is None:
                raise ValueError('must provide unet nb_levels if nb_features is an integer')
            feats = np.round(nb_features * feat_mult ** np.arange(nb_levels)).astype(int)
            nb_features = [
                np.repeat(feats[:-1], nb_conv_per_level),
                np.repeat(np.flip(feats), nb_conv_per_level)
            ]
        elif nb_levels is not None:
            raise ValueError('cannot use nb_levels if nb_features is not an integer')

        # extract any surplus (full resolution) decoder convolutions
        enc_nf, dec_nf = nb_features
        nb_dec_convs = len(enc_nf)
        final_convs = dec_nf[nb_dec_convs:]
        dec_nf = dec_nf[:nb_dec_convs]
        self.nb_levels = int(nb_dec_convs / nb_conv_per_level) + 1

        if isinstance(max_pool, int):
            max_pool = [max_pool] * self.nb_levels

        # cache downsampling / upsampling operations
        MaxPooling = getattr(nn, 'MaxPool%dd' % ndims)
        self.pooling = [MaxPooling(s) for s in max_pool]
        self.upsampling = [lambda x, ref: F.interpolate(x, size=ref.shape[2:], mode='nearest') for _ in max_pool]

        # configure encoder (down-sampling path)
        prev_nf = infeats
        encoder_nfs = [prev_nf]
        self.encoder = nn.ModuleList()
        for level in range(self.nb_levels - 1):
            convs = nn.ModuleList()
            for conv in range(nb_conv_per_level):
                nf = enc_nf[level * nb_conv_per_level + conv]
                convs.append(ConvBlock(ndims, prev_nf, nf))
                prev_nf = nf
            self.encoder.append(convs)
            encoder_nfs.append(prev_nf)

        # configure decoder (up-sampling path)
        encoder_nfs = np.flip(encoder_nfs)
        self.decoder = nn.ModuleList()
        for level in range(self.nb_levels - 1):
            convs = nn.ModuleList()
            for conv in range(nb_conv_per_level):
                nf = dec_nf[level * nb_conv_per_level + conv]
                convs.append(ConvBlock(ndims, prev_nf, nf))
                prev_nf = nf
            self.decoder.append(convs)
            if not half_res or level < (self.nb_levels - 2):
                prev_nf += encoder_nfs[level]

        # now we take care of any remaining convolutions
        self.remaining = nn.ModuleList()
        for num, nf in enumerate(final_convs):
            self.remaining.append(ConvBlock(ndims, prev_nf, nf))
            prev_nf = nf

        self.apply(InitWeights_He())

        self.final = nn.Conv3d(prev_nf, ndims, kernel_size=3, padding=1)
        self.final.weight = nn.Parameter(Normal(0, 1e-5).sample(self.final.weight.shape))
        self.final.bias = nn.Parameter(torch.zeros(self.final.bias.shape))

        

    def forward(self, x):

        # encoder forward pass
        x_history = [x]
        for level, convs in enumerate(self.encoder):
            for conv in convs:
                x = conv(x)
            x_history.append(x)
            x = self.pooling[level](x)

        # decoder forward pass with upsampling and concatenation
        for level, convs in enumerate(self.decoder):
            for conv in convs:
                x = conv(x)
            if not self.half_res or level < (self.nb_levels - 2):
                skip = x_history.pop()
                x = self.upsampling[level](x, skip)
                x = torch.cat([x, skip], dim=1)

        # remaining convs at full resolution
        for conv in self.remaining:
            x = conv(x)

        return self.final(x)

In [None]:
def jacobian(disp):
    """
    Compute the jacobian of a displacement field B, 3, X, Y, Z
    """
    d_dx = disp[:, :, 1:, :-1, :-1] - disp[:, :, :-1, :-1, :-1]
    d_dy = disp[:, :, :-1, 1:, :-1] - disp[:, :, :-1, :-1, :-1]
    d_dz = disp[:, :, :-1, :-1, 1:] - disp[:, :, :-1, :-1, :-1]
    jac = torch.stack([d_dx, d_dy, d_dz], dim=1) # B, [ddisp_./dx, disp_./dy, ddisp_./dz], [ddisp_x/d., ddisp_y/d., ddisp_z/d.], X, Y, Z
    return F.pad(jac, (0, 1, 0, 1, 0, 1)) # B, 3, 3, X, Y, Z

def Jacobian_det(disp):
    """
    Computes mean jacobian determinant of the deformation field, given displacement field
    """
    jac = jacobian((disp)[:, [2, 1, 0]])
    jac[:, 0, 0] += 1.0
    jac[:, 1, 1] += 1.0
    jac[:, 2, 2] += 1.0
    det = (
        jac[:, 0, 0] * jac[:, 1, 1] * jac[:, 2, 2] +
        jac[:, 0, 1] * jac[:, 1, 2] * jac[:, 2, 0] +
        jac[:, 0, 2] * jac[:, 1, 0] * jac[:, 2, 1] -
        jac[:, 0, 0] * jac[:, 1, 2] * jac[:, 2, 1] - 
        jac[:, 0, 1] * jac[:, 1, 0] * jac[:, 2, 2] -
        jac[:, 0, 2] * jac[:, 1, 1] * jac[:, 2, 0]
    )
    return ((det-1)**2).mean()

def Hessian_penalty(ddf):
    """
    Computes bending energy of the displacement field
    """
    jac = jacobian(ddf) # B, 3, 3, X, Y, Z
    B, _, __, X, Y, Z = jac.size()
    hess = jacobian(torch.reshape(jac, (B, -1, X, Y, Z)))
    return (hess**2).sum((1,2)).mean()

In [None]:
def train(data, model, loss_fn, optimizer, epoch, writer):
    model.train()
    device = next(model.parameters()).device

    epoch_mse = 0.0
    epoch_hess = 0.0
    epoch_total_loss = 0.0
    epoch_initial_mask_mse = 0.0
    epoch_initial_whole_mse = 0.0
    epoch_full_mse = 0.0
    
    running_mse = 0.0
    running_hess = 0.0
    running_total_loss = 0.0
    
    len_cases = len(data)
    
    for i, (img, init_ddf, gt_ddf, mask) in enumerate(tqdm(data, leave=False)):
        gt_ddf = gt_ddf.to(device)
        mask = mask.to(device)
        mask = mask > 0
        mask = mask.squeeze(0)

        epoch_initial_whole_mse += loss_fn(init_ddf.to(device), gt_ddf).item()
        epoch_initial_mask_mse += loss_fn(torch.where(mask, init_ddf.to(device), 0.0), torch.where(mask, gt_ddf, 0.0)).item()
        
        inputs = torch.cat([img, init_ddf], dim=1).to(device)

        pred_ddf = model(inputs)
        hess = Hessian_penalty(pred_ddf)
        mse = loss_fn(pred_ddf, gt_ddf)
        #hess = Hessian_penalty(torch.where(mask, pred_ddf, 0.0))
        #mse = loss_fn(torch.where(mask, pred_ddf, 0.0), torch.where(mask, gt_ddf, 0.0)) #+ 20 * hess
        total_loss = mse + 150 * hess

        running_mse += mse.item()
        running_hess += hess.item()
        running_total_loss += total_loss.item()
        epoch_mse += loss_fn(torch.where(mask, pred_ddf, 0.0), torch.where(mask, gt_ddf, 0.0)).item()
        epoch_full_mse += mse.item()
        epoch_hess += hess.item()
        epoch_total_loss += total_loss.item()

        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if i % 50 == 49:
            writer.add_scalar("training_mse", running_mse / 50, epoch * len_cases + i)
            writer.add_scalar("training_hess", running_hess / 50, epoch * len_cases + i)
            writer.add_scalar("training_loss", running_total_loss / 50, epoch * len_cases + i)
            running_mse = 0.0
            running_hess = 0.0
            running_total_loss = 0.0
    
    tqdm.write(f"Initial MSE (whole): {epoch_initial_whole_mse / len_cases}")
    tqdm.write(f"Initial MSE (mask): {epoch_initial_mask_mse / len_cases}")
    tqdm.write(f"Epoch MSE (whole): {epoch_full_mse / len_cases}")
    tqdm.write(f"Epoch MSE (mask): {epoch_mse / len_cases}")
    tqdm.write(f"Epoch Hessian: {epoch_hess / len_cases}")
    tqdm.write(f"Epoch Loss: {epoch_total_loss / len_cases}\n")

In [None]:
def evaluate(data, model, loss_fn, epoch, writer):
    model.eval()
    device = next(model.parameters()).device

    val_mse = 0.0
    val_hess = 0.0
    val_total_loss = 0.0
    val_initial_mask_mse = 0.0
    val_initial_whole_mse = 0.0
    val_full_mse = 0.0
    
    running_mse = 0.0
    running_hess = 0.0
    running_total_loss = 0.0
    
    len_cases = len(data)
    
    with torch.no_grad():
        for i, (img, init_ddf, y, mask) in enumerate(tqdm(data, leave=False)):
            gt_ddf = gt_ddf.to(device)
            mask = mask.to(device)
            mask = mask > 0
            mask = mask.squeeze(0)
    
            val_initial_whole_mse += loss_fn(init_ddf.to(device), gt_ddf).item()
            val_initial_mask_mse += loss_fn(torch.where(mask, init_ddf.to(device), 0.0), torch.where(mask, gt_ddf, 0.0)).item()
            
            inputs = torch.cat([img, init_ddf], dim=1).to(device)
    
            pred_ddf = model(inputs)
            hess = Hessian_penalty(pred_ddf)
            mse = loss_fn(pred_ddf, gt_ddf)
            #hess = Hessian_penalty(torch.where(mask, pred_ddf, 0.0))
            #mse = loss_fn(torch.where(mask, pred_ddf, 0.0), torch.where(mask, gt_ddf, 0.0))
            total_loss = mse + 150 * hess
    
            running_mse += mse.item()
            running_hess += hess.item()
            running_total_loss += total_loss.item()
            val_mse += loss_fn(torch.where(mask, pred_ddf, 0.0), torch.where(mask, gt_ddf, 0.0)).item()
            val_hess += hess.item()
            val_total_loss += total_loss.item()
            val_full_mse += mse.item()
    
            if i % 10 == 9:
                writer.add_scalar("validation_mse", running_mse / 10, epoch * len_cases + i)
                writer.add_scalar("validation_hess", running_hess / 10, epoch * len_cases + i)
                writer.add_scalar("validation_loss", running_total_loss / 10, epoch * len_cases + i)
                running_mse = 0.0
                running_hess = 0.0
                running_total_loss = 0.0
                
    tqdm.write(f"Initial MSE (whole): {val_initial_whole_mse / len_cases}")            
    tqdm.write(f"Initial MSE (mask): {val_initial_mask_mse / len_cases}")
    tqdm.write(f"Val MSE (whole): {val_full_mse / len_cases}")
    tqdm.write(f"Val MSE (mask): {val_mse / len_cases}")
    tqdm.write(f"Val Hessian: {val_hess / len_cases}")
    tqdm.write(f"Val Loss: {val_total_loss / len_cases}\n\n")

In [None]:
interp = 'linear'
norm = 'min-max'

dataset = NeuriPhyDataset("/kaggle/input/neuriphy/Training", ddf_interpolator=interp, normalization=norm)
train_dataset, val_dataset = random_split(dataset, [0.85, 0.15], generator=torch.Generator().manual_seed(13))
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

run_name = "cosanneal_lesskpts_"
run_name += f"{interp}+{norm}" if norm is not None else f"{interp}+no_norm"
writer = SummaryWriter(f"/kaggle/working/runs/{run_name}")
loss_fn = nn.MSELoss()

In [None]:
size = (240,240,155)
in_channels = 4
epochs = 20

model = Unet(size, in_channels).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
#scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=1e-6)

In [None]:
for e in tqdm(range(epochs)):
    train(train_dataloader, model, loss_fn, optimizer, e, writer)
    if e % 3 == 2:
        evaluate(val_dataloader, model, loss_fn, e, writer)
    #scheduler.step()
evaluate(val_dataloader, model, loss_fn, e, writer)
writer.flush()

In [None]:
#%load_ext tensorboard
#%tensorboard --logdir /kaggle/working/runs