In [1]:
%cd ..

/home/js228/ARUNA


In [2]:
import os
import yaml
from pathlib import Path
from datetime import datetime

In [3]:
from aruna.data_utils import get_mslice_dataset
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from aruna.models import DCAE_MSLICE
from aruna.model_utils import get_peObj
from aruna.model_engine import train_step

In [4]:
CWD = os.getcwd()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# path where trained model will be saved
model_fpath = os.path.join(CWD, "checkpoints", 
                           f"trained_model_{timestamp}.pth")
# config path
config_fpath = os.path.join(CWD, "configs", "example_config.yaml")

# get sample names that need to be imputed, in our case all samples in original dir
base_dir = Path(CWD) / "data" / "gtex_subset"
sample_names = sorted(d.name for d in base_dir.iterdir() if d.is_dir())

In [5]:
# load config
with open(config_fpath) as f:
    config = yaml.safe_load(f)

chr = config["data"]["chrom"]
batch_dim = config["model"]["batch_dim"]

In [6]:
trainData_obj = get_mslice_dataset(config, 
                                   samples = sample_names,
                                   mode = "train")
# initialize dataloaders
trainloader = DataLoader(trainData_obj, 
                         batch_size = config["model"]["batch_dim"], 
                         shuffle = True, num_workers = 4)
print("#Batches in Training (batch_dim={}): {}".format(config["model"]["batch_dim"], 
                                                       len(trainloader)))

Loading Data...
Getting Patch Data for:
Dataset: gtex
Chr(s): chr21
Patch Type: mpatch
NR: mcar_90
#Samples: 16
Current chromosome:  chr21
Looking for Ground-Truth Patchified FM files...
/home/js228/ARUNA/data/gtex/patch_centric/numCpg128/true/FractionalMethylation/chr21_patches.fm.pkl
Patchified data found at: /home/js228/ARUNA/data/gtex/patch_centric/numCpg128/true/FractionalMethylation/chr21_patches.fm.pkl
Curent chromosome:  chr21
Looking for Patchified Noise-Simulated FractionalMethylation and Coverage files...
All files exist!
FM at: /home/js228/ARUNA/data/gtex/patch_centric/numCpg128/mcar_90/FractionalMethylation/chr21_patches.mask.fm.pkl
MASK at: /home/js228/ARUNA/data/gtex/patch_centric/numCpg128/mcar_90/SimulatedMask/chr21_patches.mask.pkl
Loaded Data with 16 samples and 577920 patches.
#Batches in Training (batch_dim=256): 2258


In [7]:
print("Initializing model based on config...")
model = DCAE_MSLICE(config = config["model"])

assert config["model"]["posn_embed"] == 'type1_concat', "Alternate PE is currently not supported"
pe_type = config["model"]["posn_embed"].split("_")[0]

embed_dim = None
pe_obj = get_peObj(pe_type = pe_type, 
                    num_cpg = config["data"]["num_cpgs"], 
                    embed_dim = None,
                    chrom = chr)
trainData_obj.pe_obj = pe_obj

device = config["model"]["device"]
criterion = config["model"]["criterion"]
num_epochs = config["model"]["num_epochs"]
model = model.to(device)
print("Model initialized and loaded onto GPU!")

assert criterion == "mse", "Alternate loss functins are currently not supported"
loss_fn = nn.MSELoss()

Initializing model based on config...
Model initialized and loaded onto GPU!


In [8]:
optimizer = optim.Adam(model.parameters(), 
                            lr = config["model"]["learning_rate"], 
                            weight_decay = config["model"]["l2_penalty"])

In [None]:
for epoch in range(3): # replace with num_epochs
    print("Epoch {}/{}".format(epoch+1, num_epochs))

    # Train 1 Epoch        
    epoch_trainLoss = train_step(model, trainloader, 
                                 loss_fn, optimizer, device)
    print("Epoch Train {}/{} Loss: {}".format(epoch+1, num_epochs, epoch_trainLoss))
                  
torch.save(model.state_dict(), model_fpath)
print("Training Complete")

Epoch 1/100
0/2258 batches complete
800/2258 batches complete
1600/2258 batches complete
2258/2258 batches complete
Epoch Train 1/100 Loss: 0.05905976460409751
Epoch 2/100
0/2258 batches complete
800/2258 batches complete
1600/2258 batches complete
2258/2258 batches complete
Epoch Train 2/100 Loss: 0.05305686353171747
Epoch 3/100
0/2258 batches complete
800/2258 batches complete
1600/2258 batches complete
2258/2258 batches complete
Epoch Train 3/100 Loss: 0.04941611934161651
Training Complete
