In [None]:
import os
import nbimporter

root = os.getcwd().split("survival_analysis")[0]
os.chdir(root + "survival_analysis")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from nets.cox_nn import CoxNN
from nets.monotone_module import MonotonicIncreasingNet

# MonotonicIncreasingNet

In [None]:
def test_increasing(n=10, plot=True):
    for _ in range(n):
        net = MonotonicIncreasingNet(latent_sizes=[8, 4, 2])
        net.eval()

        t = torch.linspace(0, 1, 100).view(-1, 1)
        z = net(t)
        assert z[0,0].abs() < 1e-2

        if plot:
            plt.plot(t.flatten().detach().numpy(), z.flatten().detach().numpy())
            plt.show()

        assert torch.all(0 <= z[1:] - z[:-1])
        assert torch.all(0 <= z)


test_increasing(n=10, plot=False)

In [None]:
def plot(n=100, plot=True):
    zs = []

    for i in range(n):
        net = MonotonicIncreasingNet(
            latent_sizes=[32]*8,
        )
        net.eval()

        t = torch.linspace(0, 1, 100).view(-1, 1)
        z = net(t)
        zs.append(z.flatten())

        if plot:
            plt.subplot(round(n / 5), 5 , i + 1)
            plt.plot(t.flatten().detach().numpy(), z.flatten().detach().numpy())

    if plot:
        plt.gcf().set_size_inches(25, 10)
        plt.show()

    zs = torch.stack(zs)
    Δzs = zs[:,1:] - zs[:,:-1]
    ΔΔzs = Δzs[:,1:] - Δzs[:,:-1]

    n_close_to_0 = 0
    for i in range(n):
        n_close_to_0 += torch.allclose(zs[i], torch.zeros_like(zs[i]))

    n_growth = 0
    n_flat_later = 0
    for i in range(n):
        z = zs[i]
        z_half = z[int(len(z) * 0.8):]

        growth = torch.any(0.5 < z_half).item()
        flat_later = torch.any((z_half[1:] - z_half[:-1]) / len(z) < 1e-6).item()

        n_growth += growth
        n_flat_later += flat_later

    Δzs_std = torch.std(Δzs, dim=0) / (1e-3 + torch.mean(Δzs, dim=0))
    score = Δzs_std.mean().item()
    mean = zs.mean().item()

    if np.isnan(score):
        plt.gcf().set_size_inches(25, 10)
        plt.show()

    return score, mean, n_close_to_0, n_growth, n_flat_later, torch.sum(ΔΔzs > 1e-2).item(), torch.sum(ΔΔzs[:,50:] > 1e-2).item()

plot(n=20)

In [None]:
def test_run():
    cox_model = CoxNN(
        n_input_features=37,
        monotonic_increasing_net=MonotonicIncreasingNet(latent_sizes=[8, 4, 2]),
        t_scaling=28
    )

    cox_model(torch.randn(32, 37), torch.rand(32, 1))


test_run()

In [None]:
def test_decreasing(n=10, plot=True, latent_sizes=[8, 4, 2]):
    for _ in range(n):
        cox_model = CoxNN(
            n_input_features=37,
            monotonic_increasing_net=MonotonicIncreasingNet(latent_sizes=latent_sizes),
            t_scaling=28
        )

        ts = torch.linspace(0, 1, 100).view(-1, 1) * 28
        xs = torch.randn(1, 37)
        xs = torch.repeat_interleave(xs, 100, dim=0)
        assert xs.shape == (100, 37)

        zs = cox_model(xs=xs, ts=ts)

        if plot:
            plt.plot(ts.flatten().detach().numpy(), zs.flatten().detach().numpy())
            plt.show()

        assert torch.all(0 >= zs[1:] - zs[:-1])


test_decreasing(n=10, plot=False, latent_sizes=[8, 4, 2])

In [None]:
def test_decreasing():
    cox_model = CoxNN(
        n_input_features=37,
        monotonic_increasing_net=MonotonicIncreasingNet(latent_sizes=[8, 4, 2]),
        t_scaling=1
    )

    optim = torch.optim.Adam(cox_model.parameters())

    x = torch.randn(1, 37)
    xs = torch.cat([x]*3, dim=0)
    ts = torch.linspace(0, 1, 3).view(-1, 1)

    for _ in range(100):
        results = cox_model(xs, ts)

        loss = results[1] - results[-1]
        optim.zero_grad()
        loss.backward()
        optim.step()


    results = cox_model( torch.cat([x]*100, dim=0), torch.linspace(0, 1, 100).view(-1, 1))

    plt.plot(results[:,0].detach().numpy())

    diffs = results[1:,0] - results[:-1,0]
    assert torch.all(diffs <= 0), "CoxNN produced non-decreasing curve"


test_decreasing()