In [1]:
try:
    import utils.environment
    env_opts = utils.environment.get_env()
except ImportError:
    import os, subprocess, sys
    assert os.getcwd() == "/content"
    subprocess.run("git clone https://github.com/yamatteo/liver.git /content/liver".split())
    os.chdir("/content/liver")
    import utils.environment
    utils.environment.reset_environment()
    env_opts = utils.environment.get_env()


import argparse
import importlib
import pickle
import shutil
from pathlib import Path

import nibabel
import numpy as np
import torch
from adabelief_pytorch import AdaBelief
from rich.progress import Progress
from torch import nn
from torch.nn import functional
from torch.utils.data import DataLoader
from rich import print

import newmodel
import report
import utils.ndarray as nd
import utils.path_explorer as px
from distances import liverscore, tumorscore
from newmodel.data import store_441_dataset, Dataset, BufferDataset
from newmodel.models import UNet
from newmodel.funet import FUNet
from utils.debug import dbg
from utils.namespace import Namespace
from utils.slices import slices
from utils.storage import gen_split_dataset






def trainsetup(opts: Namespace):
    self = Namespace()
    if opts.arch == "unet":
        self.model = UNet(**opts)
    elif opts.arch == "funet":
        self.model = FUNet(**opts)
    self.model.to(device=opts.device)

    try:
        if isinstance(opts.resume, str):
            model_name = opts.resume
        elif isinstance(opts.resume, int) and opts.resume > 0:
            model_name = f"checkpoint{opts.resume:03}.pth"
        else:
            model_name = None
            print("Starting with a new model.")
        if model_name:
            self.model.load_state_dict(torch.load(opts.models_path / model_name, map_location=opts.device))
            print(f"Model loaded from {opts.models_path / model_name}")
    except FileNotFoundError:
        print(f"Model {opts.models_path / model_name} does not exist. Starting with a new model.")
    
    if opts.momentum:
        self.model.set_momentum(opts.momentum)

    if not hasattr(opts, "buffer_size"):
        opts.buffer_size = 0
    if opts.buffer_size > 0:
        self.train_dataset = BufferDataset(opts.dataset_path / "train", opts.buffer_size, opts.buffer_size // 10)
        dbg(len(list((opts.dataset_path / "train").iterdir())))
        dbg(len(self.train_dataset.buffer))
    else:
        self.train_dataset = Dataset(opts.dataset_path / "train")
        dbg(len(self.train_dataset.files))
    
    self.valid_dataset = Dataset(opts.dataset_path / "valid")
    dbg(len(self.valid_dataset.files))

    self.tdl = DataLoader(
        self.train_dataset,
        pin_memory=True,
        batch_size=opts.batch_size,
    )
    self.vdl = DataLoader(
        self.valid_dataset,
        pin_memory=True,
        batch_size=opts.batch_size,
    )

    self.optimizer = AdaBelief(
        self.model.parameters(),
        lr=opts.lr,
        eps=1e-8,
        betas=(0.9, 0.999),
        weight_decouple=False,
        rectify=False,
        print_change_log=False,
    )

    self.loss_function = nn.CrossEntropyLoss(torch.tensor([1, 2, 5])).to(device=opts.device)
    if opts.buffer_size > 0:
        m = nn.CrossEntropyLoss(torch.tensor([1, 2, 5]), reduction="none").to(device=opts.device)
        def score_function(pred, segm, keys):
            loss = m(pred, segm)
            loss = torch.mean(loss, dim=[1, 2, 3])
            scores = {k.item(): loss[i].item() for i, k in enumerate(keys)}
            return scores
        self.score_function = score_function
    return Namespace(opts, self)

In [None]:
### Setup 0A91
env_opts = Namespace(
    utils.environment.get_env(),
    device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
)

run_opts = Namespace(
    batch_size=5,
    buffer_size=100,
    lr=1e-3 / 20,
    momentum = 0.9,
    downsampled_resume="BBE9_400.pth",
    setup_id = "0A91",
    store_slice_shape = (512, 512, 16),
    store_pool_kernel = (1, 1, 1),
)

net_opts = Namespace(
    downsampled_channels=[4, 48, 64, 80, 96],
    funnel_channels=[4+48, 128, 256, 512],
    funnel_size = 5,
    down_normalization="instance",
    up_dropout="drop",
)

opts = Namespace(env_opts, run_opts, net_opts)

model = FUNet(**opts)
model.to(device=opts.device)

if opts.downsampled_resume:
    downsampled_state = torch.load(opts.models_path / opts.downsampled_resume, map_location=opts.device)
    downsampled_state = {
        key[8:]: parameter
        for key, parameter in downsampled_state.items()
        if key[:8] == "model.0."
    }
    model.downsampled_model.load_state_dict(downsampled_state)
    dbg("Submodel loaded from", opts.models_path / opts.downsampled_resume)
for param in model.downsampled_model.parameters():
    param.requires_grad = False

def rebuild_dataset():
    try:
        shutil.rmtree(opts.dataset_path)
    except FileNotFoundError:
        pass
    finally:
        opts.dataset_path.mkdir()
    i, k = 0, 0
    pooler = nn.AvgPool3d((4, 4, 1))
    unpooler = nn.Upsample(scale_factor=(4, 4, 1), mode='trilinear')
    print("Storing dataset:")
    print("  source =", opts.sources_path)
    print("  target =", opts.dataset_path)
    (opts.dataset_path / "train").mkdir(exist_ok=True)
    (opts.dataset_path / "valid").mkdir(exist_ok=True)
    for case in utils.path_explorer.iter_trainable(opts.sources_path):
        print("  ...working on:", case)
        case_path = opts.sources_path / case
        scan = nd.load_scan_from_regs(case_path)
        scan = np.clip(scan, -1024, 1024)
        prediction_aid = []
        for slice in slices(scan, (512, 512, 16), overlap=False, pad_seed="scan"):
            slice = pooler(torch.tensor(slice, dtype=torch.float32, device=opts.device).unsqueeze(0))
            aid = model.downsampled_model(slice)
            dbg(aid.shape)
            aid = unpooler(aid).squeeze().cpu()
            prediction_aid.append(aid)
            del aid
            torch.cuda.empty_cache()
        prediction_aid = torch.cat(prediction_aid, dim=3)
        dbg(prediction_aid.shape)
        assert False
        if 5:
            if k % 10 == 0:
                file_path = (target_path / "valid" / f"{i:06}.pt")
            else:
                file_path = (target_path / "train" / f"{i:06}.pt")
            torch.save(gen(scan, segm), file_path)
            i += 1
        k += 1
    print(f"Train dataset: {len(list((opts.dataset_path / 'train').iterdir()))} items.")
    print(f"Valid dataset: {len(list((opts.dataset_path / 'valid').iterdir()))} items.")
def generate_dataset(scan, segm):
    dgscan = nn.AvgPool3d((4, 4, 1))(scan)
    dgpred = model.downsampled_model(dgscan.unsqueeze(0)).squeeze(0)
    return {"scan": scan, "dgpred": dgpred, "segm": segm}

rebuild_dataset()
if not hasattr(opts, "buffer_size"):
    opts.buffer_size = 0
if opts.buffer_size > 0:
    train_dataset = BufferDataset(opts.dataset_path / "train", opts.buffer_size, opts.buffer_size // 10)
    dbg(len(list((opts.dataset_path / "train").iterdir())))
    dbg(len(train_dataset.buffer))
else:
    train_dataset = Dataset(opts.dataset_path / "train")
    dbg(len(train_dataset.files))

valid_dataset = Dataset(opts.dataset_path / "valid")
dbg(len(valid_dataset.files))

tdl = DataLoader(
    train_dataset,
    pin_memory=True,
    batch_size=opts.batch_size,
)
vdl = DataLoader(
    valid_dataset,
    pin_memory=True,
    batch_size=opts.batch_size,
)

optimizer = AdaBelief(
    model.funnel_model.parameters(),
    lr=opts.lr,
    eps=1e-8,
    betas=(0.9, 0.999),
    weight_decouple=False,
    rectify=False,
    print_change_log=False,
)

loss_function = nn.CrossEntropyLoss(torch.tensor([1, 2, 5]), reduction="none").to(device=opts.device)
def score_function(pred, segm, keys):
    loss = loss_function(pred, segm)
    loss = torch.mean(loss, dim=[1, 2, 3])
    scores = {k.item(): loss[i].item() for i, k in enumerate(keys)}
    return scores
score_function = score_function

In [None]:
# ### Setup BBE9
# env_opts = Namespace(
#     dataset_path = Path("/content/dataset"),
#     device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
#     drive_folder = Path(drive_mount) / "MyDrive" / "COLAB",
#     models_path = Path(drive_mount) / "MyDrive" / "saved_models",
# )

# run_opts = Namespace(
#     batch_size=5,
#     buffer_size=100,
#     lr=1e-3 / 20,
#     momentum = 0.9,
#     resume=False,
#     setup_id = "BBE9",
#     store_slice_shape = (512, 512, 16),
#     store_pool_kernel = (4, 4, 1),
# )

# net_opts = Namespace(
#     channels=[4, 48, 64, 80, 96],
#     down_normalization="instance",
#     up_dropout="drop",
# )

# opts = Namespace(env_opts, run_opts, net_opts)

# # rebuild_dataset(opts)
# setup = trainsetup(opts)
# newmodel.train(setup, epochs=401)

In [None]:
# ### Setup 8C87
# env_opts = Namespace(
#     dataset_path = Path("/content/dataset"),
#     device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
#     drive_folder = Path(drive_mount) / "MyDrive" / "COLAB",
#     models_path = Path(drive_mount) / "MyDrive" / "saved_models",
# )

# run_opts = Namespace(
#     batch_size=5,
#     buffer_size=100,
#     lr=1e-3 / 20,
#     momentum = 0.9,
#     resume=False,
#     setup_id = "8C87",
# )

# net_opts = Namespace(
#     channels=[4, 32, 48, 64],
#     down_normalization="instance",
# )

# opts = Namespace(env_opts, run_opts, net_opts)

# rebuild_dataset((512, 512, 16), opts)
# setup = trainsetup(opts)
# newmodel.train(setup, epochs=201)