In [None]:
import os
import nbimporter

root = os.getcwd().split("survival_analysis")[0]
os.chdir(root + "survival_analysis")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from nets.cox_nn import CoxNN, CoxTimeDependentNN
from nets.monotone_module import MonotonicIncreasingNet, MonotonicIncreasingVectorNet

# CoxNN

In [None]:
def test_run():
    cox_model = CoxNN(
        n_input_features=37,
        monotonic_increasing_net=MonotonicIncreasingNet(latent_sizes=[8, 4, 2]),
        t_scaling=28
    )

    cox_model(torch.randn(32, 37), torch.rand(32, 1))


test_run()

In [None]:
def test_decreasing(n=10, plot=True, latent_sizes=[8, 4, 2]):
    for _ in range(n):
        cox_model = CoxNN(
            n_input_features=37,
            monotonic_increasing_net=MonotonicIncreasingNet(latent_sizes=latent_sizes),
            t_scaling=28
        )

        ts = torch.linspace(0, 1, 100).view(-1, 1) * 28
        xs = torch.randn(1, 37)
        xs = torch.repeat_interleave(xs, 100, dim=0)
        assert xs.shape == (100, 37)

        zs = cox_model(xs=xs, ts=ts)

        if plot:
            plt.plot(ts.flatten().detach().numpy(), zs.flatten().detach().numpy())
            plt.show()

        assert torch.all(0 >= zs[1:] - zs[:-1])


test_decreasing(n=10, plot=False, latent_sizes=[8, 4, 2])

In [None]:
def test_decreasing():
    cox_model = CoxNN(
        n_input_features=37,
        monotonic_increasing_net=MonotonicIncreasingNet(latent_sizes=[8, 4, 2]),
        t_scaling=1
    )

    optim = torch.optim.Adam(cox_model.parameters())

    x = torch.randn(1, 37)
    xs = torch.cat([x]*3, dim=0)
    ts = torch.linspace(0, 1, 3).view(-1, 1)

    for _ in range(100):
        results = cox_model(xs, ts)

        loss = results[1] - results[-1]
        optim.zero_grad()
        loss.backward()
        optim.step()


    results = cox_model( torch.cat([x]*100, dim=0), torch.linspace(0, 1, 100).view(-1, 1))

    plt.plot(results[:,0].detach().numpy())

    diffs = results[1:,0] - results[:-1,0]
    assert torch.all(diffs <= 0), "CoxDeepNN produced non-decreasing curve"


test_decreasing()

# CoxTimeDependentNN

In [None]:
def test_run():
    cox_model = CoxTimeDependentNN(
        n_input_features=37,
        monotonic_increasing_net_baseline=MonotonicIncreasingNet(latent_sizes=[8, 4, 2]),
        monotonic_increasing_net_coefficients=MonotonicIncreasingVectorNet(latent_sizes=[8, 4, 37]),
        t_scaling=1
    )

    results = cox_model(torch.randn(32, 37), torch.linspace(0, 1, 32).view(-1, 1))
    results = cox_model(torch.cat([torch.randn(1, 37)]*100, dim=0), torch.linspace(0, 1, 100).view(-1, 1))
    assert results.shape == (100, 1)
    diffs = results[1:,0] - results[:-1,0]
    assert torch.all(diffs <= 0), "CoxTimeDependentNN produced non-decreasing curve"


for _ in range(100):
    test_run()

In [None]:
def test_decreasing():
    cox_model = CoxTimeDependentNN(
        n_input_features=4,
        monotonic_increasing_net_baseline=MonotonicIncreasingNet(latent_sizes=[8, 4, 2]),
        monotonic_increasing_net_coefficients=MonotonicIncreasingVectorNet(latent_sizes=[8, 4, 4]),
        t_scaling=1
    )

    optim = torch.optim.Adam(cox_model.parameters())

    x = torch.randn(1, 4)
    xs = torch.cat([x]*3, dim=0)
    ts = torch.linspace(0, 1, 3).view(-1, 1)

    for _ in range(100):
        results = cox_model(xs, ts)

        loss = results[1] - results[-1]
        optim.zero_grad()
        loss.backward()
        optim.step()


    results = cox_model( torch.cat([x]*100, dim=0), torch.linspace(0, 1, 100).view(-1, 1))

    plt.plot(results[:,0].detach().numpy())
    plt.show()


    diffs = results[1:,0] - results[:-1,0]
    assert torch.all(diffs <= 0), "CoxTimeDependentNN produced non-decreasing curve"


test_decreasing()

In [None]:
def test_decreasing():
    cox_model = CoxTimeDependentNN(
        n_input_features=4,
        monotonic_increasing_net_baseline=MonotonicIncreasingNet(latent_sizes=[8, 4, 2]),
        monotonic_increasing_net_coefficients=MonotonicIncreasingVectorNet(latent_sizes=[8, 4, 4]),
        t_scaling=1
    )

    optim = torch.optim.Adam(cox_model.parameters())

    x = torch.randn(1, 4)
    xs = torch.cat([x]*3, dim=0)
    ts = torch.linspace(0, 1, 3).view(-1, 1)

    for _ in range(100):
        results = cox_model(xs, ts)

        loss = (.5 - results[1])**2
        optim.zero_grad()
        loss.backward()
        optim.step()


    results = cox_model( torch.cat([x]*100, dim=0), torch.linspace(0, 1, 100).view(-1, 1))

    plt.plot(results[:,0].detach().numpy())
    plt.show()


    diffs = results[1:,0] - results[:-1,0]
    assert torch.all(diffs <= 0), "CoxTimeDependentNN produced non-decreasing curve"


test_decreasing()