In [2]:

import torch
from nn.rrdbunet import UNetRRDB2
from utils.dataloader import DatasetConfig, CC_Dataset, NumpyDataset
from nn.trainer import Trainer
from nn.hooks import EarlyStopping
from nn.losses import CombinedLoss
from utils.transforms import norm_ccmat
import logging

In [6]:
# There are two options for datasets. Either Numpy or directly cooler. 
# Numpy is more straight forward, but requires much more storage, since cooler data loads on demand
# Some sample data can be downloaded (check Readme). 
# If you want to generate your own data from cooler use: ccut/data_prep/convert_and_normalize_v3.py

df_numpy = NumpyDataset(
    # sample list based on genomic coordinates. Can be generated: ccut/data_prep/create_sliding_window_coor.v2.py
    sample_cordinates_file="../data/chr19-22_40x40x50k_nozeroes.csv",
    # pair of cc data in low and high resolution (input, target)
    highres_path="../data/porec-perc999-hr_cutoff-norm-chr19-22-50k-chromosomes.npz",
    lowres_path="../data/porec-4x-perc999-hr_cutoff-norm-chr19-22-50k-chromosomes.npz",
    # bin resolution 
    resolution=50_000,
)

In [11]:

# setup model
unet = UNetRRDB2(in_channels=1, out_channels=1, features=[64, 128, 256, 512, 1024])

In [10]:
# initiate trainer. Can be done using ccdataset class as well as numpydataset
train_loader = torch.utils.data.DataLoader(
    df_numpy, batch_size=4, num_workers=0, shuffle=True
)

In [12]:

# Initialize your hooks
hook_classes = [EarlyStopping]

# Define hook arguments
hook_args = {}

# Set loss function / can be custom if based on abstract class provided in ccut/nn/losses.py
loss = CombinedLoss(window_size=15)

trainer = Trainer(
    model=unet,
    loss_function=loss,
    # use grad scaler
    grad_scaler=torch.cuda.amp.GradScaler(),
    # hooks and arguments for hooks
    hooks=hook_classes,
    hook_args=hook_args,
    # optimizer as Adam and parameters
    optim_class=torch.optim.Adam,
    optim_params={"lr": 5e-5, "betas": (0.0, 0.9), "weight_decay": 1e-4},
    # use mixed precision ? 
    mixed_precision=True,
)



In [15]:
# Setup logging: Intended for vanilla python scipts not nb's
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
    )

# start training for 10 epochs
trainer.train(train_loader, 10)

2024-06-08 14:01:04,738 - INFO - Starting training
2024-06-08 14:01:07,633 - INFO - 0.644659
2024-06-08 14:01:09,968 - INFO - 0.357500
2024-06-08 14:01:12,188 - INFO - 0.246894
2024-06-08 14:01:14,377 - INFO - 0.157305
2024-06-08 14:01:16,525 - INFO - 0.125291


In [None]:
#TODO Saving and exporting + Evaluation