In [None]:
# Run CPU-only, GPU code needs further testing
"""
Restart kernel after running
Only need to run once
"""
!pip install scikit-learn matplotlib seaborn

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import seaborn as sns
import torch
import torch.utils.data as td
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset, Subset
import wandb


froot = './data/k562_samp_epft_norm_test_1.csv'

df = pd.read_csv(froot)

In [None]:
wandb.login()

In [None]:
print(df.head())

In [None]:
column_names = np.array(df.columns)
feature_names = column_names[6:-2]
num_features = len(feature_names)
#nucleotides = column_names[-6:-2]
print(feature_names)
#print(nucleotides)
num_samples = df.shape[0]

# process read counts
X_ji = df['score'].values

# process GLM simulated elongation rates
Z_ji = df['zeta'].values

In [None]:
print("Number of Samples: " + str(num_samples))

In [None]:
print("Number of Features: " + str(num_features))

In [None]:
#Y_ji is a list of samples containing lists of their feature values
    # [   
    #   sample_1: [feat_1, feat_2,...,feat_n],
    #   sample_2: [feat_1, feat_2,...,feat_n],
    # ]

Y_ji = df.iloc[:, 6:-2].values
Y_ji_shape = Y_ji.shape
print(Y_ji.shape)

In [None]:
C_j = df['lambda_alphaj'].values

In [None]:
gene_ids = df['ensembl_gene_id'].values

In [66]:
sweep_config = {
    'method': 'grid'
}
metric = {
    'name': 'loss',
    'goal': 'minimize'   
    }

sweep_config['metric'] = metric

parameters_dict = {
    'optimizer': {
        'values': ['adam', 'sgd']
        },
    'momentum': {
        'values': [0, 0.8]
    },
    'learn_rate': {
          'values': [1e-2, 1e-4, 1e-6, 1e-8]
        },
    'batch_size': {
        'values': [16000000, 2000, 500]
    },
    }

parameters_dict.update({
    'epochs': {
        'value': 3}
    })

sweep_config['parameters'] = parameters_dict

In [67]:
sweep_id = wandb.sweep(sweep_config, project="elongation-net")

Create sweep with ID: u5th2vz9
Sweep URL: https://wandb.ai/elongation-net/elongation-net/sweeps/u5th2vz9


In [None]:
cuda_available = torch.cuda.is_available()
print("CUDA (GPU support) is available:", cuda_available)
num_gpus = torch.cuda.device_count()
print("Number of GPUs available:", num_gpus)

In [None]:
class GeneDataset(Dataset):
    def __init__(self, batches):
        self.batches = batches
        self.cache = {}

    def __len__(self):
        return len(self.batches)

    def __getitem__(self, idx):
        if idx in self.cache:
            return self.cache[idx]
        
        batch = self.batches[idx]
        gene_id = batch['GeneId'].values[0]
                
        y_ji_array = np.array(batch['Y_ji'].tolist()).reshape(-1, 12)
        y_ji_tensor = torch.tensor(y_ji_array, dtype=torch.float64)
            
        data = batch.drop(columns=['GeneId', 'dataset', 'Y_ji'])
        tensor_data = torch.tensor(data.values, dtype=torch.float64)
    
        result = {
            'GeneId': batch['GeneId'].values[0],
            'Y_ji': y_ji_tensor,
        }
        for col in data.columns:
            result[col] = tensor_data[:, data.columns.get_loc(col)]
            
        self.cache[idx] = result

        return result

In [None]:
from sklearn.model_selection import train_test_split

data = pd.DataFrame({
    'GeneId': gene_ids,
    'Y_ji': [row for row in Y_ji],
    'X_ji': X_ji,
    'C_j': C_j,
    'Z_ji': Z_ji
})

grouped = data.groupby('GeneId')

train_idx, temp_idx = train_test_split(list(grouped.groups.keys()), test_size=0.2, random_state=42)

val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)

dataset_mapping = {gene_id: 'train' for gene_id in train_idx}
dataset_mapping.update({gene_id: 'val' for gene_id in val_idx})
dataset_mapping.update({gene_id: 'test' for gene_id in test_idx})

data['dataset'] = data['GeneId'].map(dataset_mapping)

train_data = data[data['dataset'] == 'train']
val_data = data[data['dataset'] == 'val']
test_data = data[data['dataset'] == 'test']

In [None]:
print("train data size: " + str(len(train_data)))
print("val data size: " + str(len(val_data)))
print("test data size: " + str(len(test_data)))

In [None]:
print("train # genes: " + str(len(train_data.groupby('GeneId'))))
print("val # genes: " + str(len(val_data.groupby('GeneId'))))
print("test # genes: " + str(len(test_data.groupby('GeneId'))))

In [None]:
def create_batches(grouped_data, max_batch_size):
    batches = []
    for _, group in grouped_data:
        if len(group) <= 2000:
            continue
        # Check if the group size exceeds the max_batch_size
        if len(group) > max_batch_size:
            # Split the group into smaller batches
            for start_idx in range(0, len(group), max_batch_size):
                end_idx = start_idx + max_batch_size
                batch = group.iloc[start_idx:end_idx]
                batches.append(batch)
        else:
            # If the group size is within the limit, add it as is
            batches.append(group)
    return batches


train_batches = create_batches(train_data.groupby('GeneId'), 64)
val_batches = create_batches(val_data.groupby('GeneId'), 64)
test_batches = create_batches(test_data.groupby('GeneId'), 64)

In [None]:
def build_dataset(train_data, batch_size):
    batches = create_batches(train_data.groupby('GeneId'), batch_size)
    dataset = GeneDataset(batches)
    loader = DataLoader(dataset, batch_size=1, num_workers=7, shuffle=False, pin_memory=True)
    return loader

In [None]:
train_dataset = GeneDataset(train_batches)
val_dataset = GeneDataset(val_batches)
test_dataset = GeneDataset(test_batches)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=1)
test_loader = DataLoader(test_dataset, batch_size=1)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = nn.Linear(num_features, 1, bias=False)

if cuda_available:
    """
    if num_gpus > 1:
        print("Using", num_gpus, "GPUs")
        model = torch.nn.DataParallel(model)
    """
    model = model.to('cuda')

print(model)

arr = torch.randn((64,num_features)).to(device)
print(model(arr).shape)
nparm = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: " + str(nparm))

first_param_device = next(model.parameters()).device
print("Model is on device:", first_param_device)
model.double()

In [None]:
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, X_ji, C_j, rho_ji):
        C_j_value = C_j[0]
        X_ji = X_ji.squeeze(0)
        rho_ji = rho_ji.squeeze(0).squeeze(1)
        loss = X_ji * rho_ji + C_j * torch.exp(-rho_ji) - X_ji * torch.log(C_j_value)
        return (loss).mean()

In [None]:
def build_optimizer(network, optimizer, learning_rate, momentum):
    if optimizer == "sgd":
        optimizer = optim.SGD(network.parameters(),
                              lr=learning_rate, momentum=momentum)
    elif optimizer == "adam":
        optimizer = optim.Adam(network.parameters(),
                               lr=learning_rate)
    return optimizer

In [None]:
def train_epoch(model, loader, optimizer, loss_fn):
    model.train()
    print(len(loader))
    total_loss = 0
    for idx, batch in enumerate(loader):
        optimizer.zero_grad()
        Y_ji_batch = batch['Y_ji'].to(device)
        X_ji_batch = batch['X_ji'].to(device)
        C_j_batch = batch['C_j'].to(device)
        outputs = model(Y_ji_batch)
        loss = loss_fn(X_ji_batch, C_j_batch, outputs)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(loader)
    return avg_loss

In [68]:
def train(config=None):
    with wandb.init(config=config):
        config = wandb.config
        model = nn.Linear(num_features, 1, bias=False)
        if cuda_available:
            model = model.to('cuda')
        model.double()
        loader = build_dataset(train_data, config.batch_size)
        optimizer = build_optimizer(model, config.optimizer, config.learn_rate, 0.8)
        loss_fn = CustomLoss()
        
        for epoch in range(3):
            print(f'Epoch {epoch+1}')
            avg_loss = train_epoch(model, loader, optimizer, loss_fn)
            print(avg_loss)
            wandb.log({"epoch": epoch, "loss": avg_loss})

In [None]:
wandb.agent(sweep_id, train)

[34m[1mwandb[0m: Agent Starting Run: d8xr54b8 with config:
[34m[1mwandb[0m: 	batch_size: 16000000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.01
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
366




0.1308530953554396
Epoch 2
366




0.129544826952785
Epoch 3
366




0.12982867496951264


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▁▃

0,1
epoch,2.0
loss,0.12983


[34m[1mwandb[0m: Agent Starting Run: dtdsjo72 with config:
[34m[1mwandb[0m: 	batch_size: 16000000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.01
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
366




0.1344645417690116
Epoch 2
366




0.13000261640578883
Epoch 3
366




0.129463743761936


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▂▁

0,1
epoch,2.0
loss,0.12946


[34m[1mwandb[0m: Agent Starting Run: x9skm1vw with config:
[34m[1mwandb[0m: 	batch_size: 16000000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.01
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
366




0.13100490441004403
Epoch 2
366




0.1295475339099603
Epoch 3
366




0.12979783945807485


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▁▂

0,1
epoch,2.0
loss,0.1298


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: 483sayy3 with config:
[34m[1mwandb[0m: 	batch_size: 16000000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.01
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
366




0.13162143831482073
Epoch 2
366




0.12963441974966647
Epoch 3
366




0.1293529466489215


VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▂▁

0,1
epoch,2.0
loss,0.12935


[34m[1mwandb[0m: Agent Starting Run: kwsueiq8 with config:
[34m[1mwandb[0m: 	batch_size: 16000000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.0001
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
366




0.13640999112201763
Epoch 2
366




0.13522531464585102
Epoch 3
366




0.13440680843127895


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.13441


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: tt3m0rp3 with config:
[34m[1mwandb[0m: 	batch_size: 16000000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.0001
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
366




0.14665749055183427
Epoch 2
366




0.1457964111864675
Epoch 3
366




0.14505680743209382


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.14506


[34m[1mwandb[0m: Agent Starting Run: 9s0tbewp with config:
[34m[1mwandb[0m: 	batch_size: 16000000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.0001
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
366




0.13879531302697168
Epoch 2
366




0.13751229630565617
Epoch 3
366




0.13656078349012823


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.13656


[34m[1mwandb[0m: Agent Starting Run: aszk3qgj with config:
[34m[1mwandb[0m: 	batch_size: 16000000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.0001
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
366




0.1344926343703227
Epoch 2
366




0.1343293377019568
Epoch 3
366




0.1341785501242913


VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.13418


[34m[1mwandb[0m: Agent Starting Run: k5nzi7gf with config:
[34m[1mwandb[0m: 	batch_size: 16000000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-06
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
366




0.14285164707186435
Epoch 2
366




0.1428335212908466
Epoch 3
366




0.1428166066149378


VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.14282


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: z8de3qko with config:
[34m[1mwandb[0m: 	batch_size: 16000000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-06
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
366




0.14957285681318877
Epoch 2
366




0.14893680006930646
Epoch 3
366




0.14837672589008727


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.14838


[34m[1mwandb[0m: Agent Starting Run: 0v0bhrdl with config:
[34m[1mwandb[0m: 	batch_size: 16000000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-06
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
366




0.2358471613772771
Epoch 2
366




0.23567255507138749
Epoch 3
366




0.23550181361160136


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.2355


[34m[1mwandb[0m: Agent Starting Run: v7j87ppq with config:
[34m[1mwandb[0m: 	batch_size: 16000000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-06
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
366




0.14431180910303354
Epoch 2
366




0.144284573613554
Epoch 3
366




0.14425770421023087


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.14426


[34m[1mwandb[0m: Agent Starting Run: ohte1awr with config:
[34m[1mwandb[0m: 	batch_size: 16000000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-08
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
366




0.16739793239338996
Epoch 2
366




0.16739727806850788
Epoch 3
366




0.1673967057716301


VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.1674


[34m[1mwandb[0m: Agent Starting Run: qpaqu5ae with config:
[34m[1mwandb[0m: 	batch_size: 16000000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-08
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
366




nan
Epoch 2
366




nan
Epoch 3
366




nan


VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█

0,1
epoch,2.0
loss,


[34m[1mwandb[0m: Agent Starting Run: 4mgtf1f1 with config:
[34m[1mwandb[0m: 	batch_size: 16000000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-08
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
366




0.16062337783261704
Epoch 2
366




0.16062282435017378
Epoch 3
366




0.16062226809920355


VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▅▁

0,1
epoch,2.0
loss,0.16062


[34m[1mwandb[0m: Agent Starting Run: p0lclvya with config:
[34m[1mwandb[0m: 	batch_size: 16000000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-08
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
366




0.1395812455497791
Epoch 2
366




0.13958121124167855
Epoch 3
366




0.13958117693073327


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▅▁

0,1
epoch,2.0
loss,0.13958


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: j2k3jrab with config:
[34m[1mwandb[0m: 	batch_size: 2000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.01
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
6324




0.09680786089953732
Epoch 2
6324




0.09679737629619999
Epoch 3
6324




0.09679527988474662


VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▂▁

0,1
epoch,2.0
loss,0.0968


[34m[1mwandb[0m: Agent Starting Run: zqr7c3cb with config:
[34m[1mwandb[0m: 	batch_size: 2000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.01
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
6324




0.09660010808801732
Epoch 2
6324




0.09613972264873549
Epoch 3
6324




0.09610160250098519


VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▂▁

0,1
epoch,2.0
loss,0.0961


[34m[1mwandb[0m: Agent Starting Run: gxh50b3x with config:
[34m[1mwandb[0m: 	batch_size: 2000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.01
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
6324




0.09703474863888983
Epoch 2
6324




0.09682600370697367
Epoch 3
6324




0.09700405979550021


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▁▇

0,1
epoch,2.0
loss,0.097


[34m[1mwandb[0m: Agent Starting Run: 8kmhyfzn with config:
[34m[1mwandb[0m: 	batch_size: 2000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.01
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
6324




0.0964794132533411
Epoch 2
6324




0.09597364461159458
Epoch 3
6324




0.09596392454942228


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▁▁

0,1
epoch,2.0
loss,0.09596


[34m[1mwandb[0m: Agent Starting Run: g2uwifob with config:
[34m[1mwandb[0m: 	batch_size: 2000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.0001
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
6324




0.12693055006500462
Epoch 2
6324




0.1122887398195429
Epoch 3
6324




0.10444113473017641


VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▃▁

0,1
epoch,2.0
loss,0.10444


[34m[1mwandb[0m: Agent Starting Run: x402p8sx with config:
[34m[1mwandb[0m: 	batch_size: 2000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.0001
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
6324




0.10438687672661517
Epoch 2
6324




0.10187289659746089
Epoch 3
6324




0.10041824112763816


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.10042


[34m[1mwandb[0m: Agent Starting Run: s64to8rm with config:
[34m[1mwandb[0m: 	batch_size: 2000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.0001
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
6324




0.10031264259508219
Epoch 2
6324




0.0975524248193476
Epoch 3
6324




0.09679877651775898


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▃▁

0,1
epoch,2.0
loss,0.0968


[34m[1mwandb[0m: Agent Starting Run: d8ssobtt with config:
[34m[1mwandb[0m: 	batch_size: 2000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.0001
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
6324




0.10383013671524677
Epoch 2
6324




0.10257742326309553
Epoch 3
6324




0.10028696061408579


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▆▁

0,1
epoch,2.0
loss,0.10029


[34m[1mwandb[0m: Agent Starting Run: k01dojzh with config:
[34m[1mwandb[0m: 	batch_size: 2000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-06
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
6324




0.10319220694740211
Epoch 2
6324




0.10311624782181747
Epoch 3
6324




0.10304444517255734


VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.10304


[34m[1mwandb[0m: Agent Starting Run: qmxu1n2d with config:
[34m[1mwandb[0m: 	batch_size: 2000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-06
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
6324




0.09836565397706454
Epoch 2
6324




0.09835020046490359
Epoch 3
6324




0.09833479029230013


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.09833


[34m[1mwandb[0m: Agent Starting Run: 98uoub89 with config:
[34m[1mwandb[0m: 	batch_size: 2000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-06
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
6324




0.10160963976410499
Epoch 2
6324




0.10154271977114951
Epoch 3
6324




0.10148043578799758


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.10148


[34m[1mwandb[0m: Agent Starting Run: zihsypoz with config:
[34m[1mwandb[0m: 	batch_size: 2000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-06
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
6324




0.10136349040128184
Epoch 2
6324




0.10116732641125722
Epoch 3
6324




0.10099161070874495


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.10099


[34m[1mwandb[0m: Agent Starting Run: mtqrqtjn with config:
[34m[1mwandb[0m: 	batch_size: 2000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-08
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
6324




0.10160511037177884
Epoch 2
6324




0.1016044248179549
Epoch 3
6324




0.10160376746243671


VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6443106719884839, max=1.0…

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.1016


[34m[1mwandb[0m: Agent Starting Run: 29l87nzj with config:
[34m[1mwandb[0m: 	batch_size: 2000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-08
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
6324




0.10776691939848666
Epoch 2
6324




0.10776561888809073
Epoch 3
6324




0.10776431923035792


VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.64426411884039, max=1.0))…

0,1
epoch,▁▅█
loss,█▅▁

0,1
epoch,2.0
loss,0.10776


[34m[1mwandb[0m: Agent Starting Run: xgjenjn4 with config:
[34m[1mwandb[0m: 	batch_size: 2000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-08
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
6324




571.6491351725881
Epoch 2
6324




571.6174225375363
Epoch 3
6324




571.5856973708644


VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▅▁

0,1
epoch,2.0
loss,571.5857


[34m[1mwandb[0m: Agent Starting Run: g80jmocr with config:
[34m[1mwandb[0m: 	batch_size: 2000
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-08
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
6324




3.34694452412626
Epoch 2
6324




0.10978631551559126
Epoch 3
6324




0.10950277751306449


VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▁▁

0,1
epoch,2.0
loss,0.1095


[34m[1mwandb[0m: Agent Starting Run: t9sja0lh with config:
[34m[1mwandb[0m: 	batch_size: 500
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.01
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
25261




0.09868305315048954
Epoch 2
25261




0.10347437464927325
Epoch 3
25261




0.10027414497938493


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,▁█▃

0,1
epoch,2.0
loss,0.10027


[34m[1mwandb[0m: Agent Starting Run: 0i2r9xfb with config:
[34m[1mwandb[0m: 	batch_size: 500
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.01
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
25261




nan
Epoch 2
25261




nan
Epoch 3
25261




nan


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█

0,1
epoch,2.0
loss,


[34m[1mwandb[0m: Agent Starting Run: x5d23jc3 with config:
[34m[1mwandb[0m: 	batch_size: 500
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.01
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
25261




0.1072528173950774
Epoch 2
25261




0.11125604206387255
Epoch 3
25261




0.10039170147961046


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,▅█▁

0,1
epoch,2.0
loss,0.10039


[34m[1mwandb[0m: Agent Starting Run: bvvwk8qs with config:
[34m[1mwandb[0m: 	batch_size: 500
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.01
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
25261




nan
Epoch 2
25261




nan
Epoch 3
25261




nan


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█

0,1
epoch,2.0
loss,


[34m[1mwandb[0m: Agent Starting Run: emdqm0gt with config:
[34m[1mwandb[0m: 	batch_size: 500
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.0001
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
25261




0.09644668020781894
Epoch 2
25261




0.09517559490360654
Epoch 3
25261




0.0951931364903424


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▁▁

0,1
epoch,2.0
loss,0.09519


[34m[1mwandb[0m: Agent Starting Run: 7n6so4ld with config:
[34m[1mwandb[0m: 	batch_size: 500
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.0001
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
25261




0.0988161216191748
Epoch 2
25261




0.09668285208897442
Epoch 3
25261




0.09595216712181835


VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6443804788695539, max=1.0…

0,1
epoch,▁▅█
loss,█▃▁

0,1
epoch,2.0
loss,0.09595


[34m[1mwandb[0m: Agent Starting Run: tcs104ms with config:
[34m[1mwandb[0m: 	batch_size: 500
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.0001
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
25261




0.22035438136758065
Epoch 2
25261




0.13308522715019006
Epoch 3
25261




0.10688737353465926


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▃▁

0,1
epoch,2.0
loss,0.10689


[34m[1mwandb[0m: Agent Starting Run: x67a6cqe with config:
[34m[1mwandb[0m: 	batch_size: 500
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 0.0001
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
25261




0.1011552423210236
Epoch 2
25261




0.09996219886179793
Epoch 3
25261




0.0958034538923835


VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▆▁

0,1
epoch,2.0
loss,0.0958


[34m[1mwandb[0m: Agent Starting Run: zlglllar with config:
[34m[1mwandb[0m: 	batch_size: 500
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-06
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
25261




0.1045283379747307
Epoch 2
25261




0.10428705398263449
Epoch 3
25261




0.1040559461468143


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.10406


[34m[1mwandb[0m: Agent Starting Run: tgqwcjg8 with config:
[34m[1mwandb[0m: 	batch_size: 500
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-06
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
25261




0.11485647724976378
Epoch 2
25261




0.11257793060434189
Epoch 3
25261




0.1112896872133655


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,0.11129


[34m[1mwandb[0m: Agent Starting Run: 3mkcypa7 with config:
[34m[1mwandb[0m: 	batch_size: 500
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-06
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
25261




40.33837113098377
Epoch 2
25261




39.540178316115735
Epoch 3
25261




38.753337480875814


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,38.75334


[34m[1mwandb[0m: Agent Starting Run: sf9eqn04 with config:
[34m[1mwandb[0m: 	batch_size: 500
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-06
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
25261




0.1934208454429918
Epoch 2
25261




0.10128092702606818
Epoch 3
25261




0.10118660904973129


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▁▁

0,1
epoch,2.0
loss,0.10119


[34m[1mwandb[0m: Agent Starting Run: xi6saskz with config:
[34m[1mwandb[0m: 	batch_size: 500
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-08
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
25261




143.33380108264433
Epoch 2
25261




143.3150549566279
Epoch 3
25261




143.29631264763267


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▄▁

0,1
epoch,2.0
loss,143.29631


[34m[1mwandb[0m: Agent Starting Run: f2a40mpd with config:
[34m[1mwandb[0m: 	batch_size: 500
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-08
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	optimizer: sgd


Epoch 1
25261




0.10143461909362178
Epoch 2
25261




0.10143416498327495
Epoch 3
25261




0.10143371091855509


VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
loss,█▅▁

0,1
epoch,2.0
loss,0.10143


[34m[1mwandb[0m: Agent Starting Run: kx6cnc6i with config:
[34m[1mwandb[0m: 	batch_size: 500
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learn_rate: 1e-08
[34m[1mwandb[0m: 	momentum: 0.8
[34m[1mwandb[0m: 	optimizer: adam


Epoch 1
25261




0.09871511240046611
Epoch 2
25261




In [None]:
from datetime import datetime
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"models/Elongation_Linear_Model.pth"
torch.save(model.state_dict(), filename)
"""

In [None]:
model = nn.Linear(num_features, 1, bias=False)
model.load_state_dict(torch.load("models/Elongation_Linear_Model.pth"))
cuda_available = torch.cuda.is_available()
print("CUDA (GPU support) is available:", cuda_available)
num_gpus = torch.cuda.device_count()
print("Number of GPUs available:", num_gpus)
if cuda_available:
    """
    if num_gpus > 1:
        print("Using", num_gpus, "GPUs")
        model = torch.nn.DataParallel(model)
    """
    model = model.to('cuda')

first_param_device = next(model.parameters()).device
print("Model is on device:", first_param_device)
model.double()

In [None]:
weights = model.weight.data.cpu().numpy()
#bias = model.bias.data.cpu().numpy()

combined = ', '.join([f'"{s}": {f}' for s, f in zip(feature_names, weights[0])])
print(combined)

#print("bias: " + str(model.bias.data))

In [None]:
glm_kappa = [-0.0224536145637661, -0.094592589, -0.023815382, 0.030402922, -0.067234092, -0.032196914, -0.040911478, -0.018557168, -0.033545905, -0.051103287, -0.204434712, 0.015831043]

In [None]:
"""
GLM K

* ctcf: -0.02
* h3k36me3: -0.09
* h3k4me1: -0.02
* h3k79me2: +0.03
* h3k9me1: -0.06
* h3k9me3: -0.03
* h4k20me1: -0.04
* sj5: -0.02
* sj3: -0.03
* dms->stem-loop: -0.05
* rpts->low-complex: +0.01
* wgbs->DNAm: -0.2
"""

In [None]:
epochs = range(1, len(loss_hist_train) + 1)
plt.plot(epochs, loss_hist_train, label='train_loss')
plt.plot(epochs, loss_hist_valid, label='valid_loss')

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show 

In [None]:
def plot_data(glm_zeta, net_zeta):
    indices = range(len(glm_zeta))
    
    fig, ax = plt.subplots(figsize=(10, 5))
    
    ax.scatter(indices, net_zeta, color='blue', label='Neural Net Zeta', s=10, alpha=0.5)
    ax.scatter(indices, glm_zeta, color='orange', label='GLM Zeta', s=10, alpha=0.5)
    
    ax.set_title('Neural Net vs GLM Elongation Rate')
    ax.set_xlabel('Index')
    ax.set_ylabel('Elongation Rate')
    ax.legend()
    
    plt.ylim(0.5, 1.3)

    plt.show()


In [None]:
test_batches = create_batches(test_data.groupby('GeneId'), max_batch_size=64)

test_dataset = GeneDataset(test_batches)
tstdl = DataLoader(test_dataset, batch_size=1)

In [None]:
net_zeta = []
glm_zeta = []
with torch.no_grad():
    for batch in tstdl:
        y_inputs = batch['Y_ji'].to(device)
        rho_ji = model(y_inputs)
        # convert log(Z) outputs to Z
        net_zeta.append(torch.exp(rho_ji.cpu()[0]))
        glm_zeta.append(batch['Z_ji'][0])

net_zeta = torch.cat(net_zeta, dim=0)
glm_zeta = torch.cat(glm_zeta, dim=0)
mae = F.l1_loss(net_zeta.squeeze(), glm_zeta)
mse = F.mse_loss(net_zeta.squeeze(), glm_zeta)

correlation_coefficient = np.corrcoef(glm_zeta, net_zeta.squeeze())[0, 1]
print("Correlation Coefficient:", correlation_coefficient)

print(f"Mean Absolute Error: {mae.item():.4f}")
print(f"Mean Squared Error: {mse.item():.4f}")

In [None]:
def density_plot(glm_zeta, net_zeta, gene_id):
    sns.kdeplot(x=glm_zeta, y=net_zeta, fill=True, cmap="Blues")
            
    plt.xlim([min(glm_zeta), max(glm_zeta)])
    plt.ylim([min(net_zeta), max(net_zeta)])


    plt.xlabel('GLM Elongation Rate')
    plt.ylabel('Neural Net Elongation Rate')
    plt.title(gene_id)
    plt.show()

In [None]:
def scatterplot(net_zeta, glm_zeta, gene_id):
    indices = range(len(glm_zeta))
    
    fig, ax = plt.subplots(figsize=(10, 5))
    
    min_val = min(min(net_zeta), min(glm_zeta))
    max_val = max(max(net_zeta), max(glm_zeta))

    plt.xlim(min_val, max_val)
    plt.ylim(min_val, max_val)
    
    ax.scatter(net_zeta, glm_zeta, s=5)
    
    ax.set_title(gene_id)
    ax.set_xlabel('Neural Net Zeta')
    ax.set_ylabel('GLM Zeta')
    ax.legend()

    plt.show()


In [None]:
test_batches2 = create_batches(test_data.groupby('GeneId'), max_batch_size=2000)

test_dataset2 = GeneDataset(test_batches2)
tstdl2 = DataLoader(test_dataset2, batch_size=1)

In [None]:
total_loss = 0
loss_fn = CustomLoss()
for batch in tstdl2:
    gene_id = batch['GeneId'][0]
    model.eval()
    #print("number of samples: " + str(len(batch)))

    with torch.no_grad():
        y_inputs = batch['Y_ji'].to(device)
        rho_ji = model(y_inputs)
    
    glm_zeta = batch['Z_ji'][0]
    # convert log(Z) outputs to Z
    net_zeta = torch.exp(rho_ji.cpu().squeeze())
        
    density_plot(glm_zeta, net_zeta, gene_id)
                
    plot_data(glm_zeta, net_zeta)

In [None]:
plt.figure(figsize=(10, 10))

sns.scatterplot(x=glm_kappa, y=weights[0])

for i in range(len(glm_kappa)):
    plt.text(glm_kappa[i], weights[0][i], feature_names[i], fontsize=13, ha='right', va='top')
plt.xlabel('GLM Weights')
plt.ylabel('Neural Net Weights')

max_val = max(np.max(glm_kappa), np.max(weights[0])) + 0.04
min_val = min(np.min(glm_kappa), np.min(weights[0])) - 0.04

plt.xlim(max_val, min_val)
plt.ylim(max_val, min_val)

# Show the plot
plt.show()

In [None]:
# profiling code
"""
def print_profiler_results(profiler):
    print(profiler.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))

with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=3),
    on_trace_ready=print_profiler_results,
    record_shapes=True,
    profile_memory=True
) as profiler:

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}')
        model.train()
        trndl = DataLoader(train_set, batch_size=batch_size, num_workers=7, shuffle=False, pin_memory=True)
        for i, batch in enumerate(trndl):
            optimizer.zero_grad()
            Y_ji_batch = batch['Y_ji'].to(device)
            X_ji_batch = batch['X_ji'].to(device)
            C_j_batch = batch['C_j'].to(device)
            outputs = model(Y_ji_batch)
            loss = loss_fn(X_ji_batch, C_j_batch, outputs)
            loss.backward()
            optimizer.step()
            loss_hist_train[epoch] += loss.item()
            profiler.step()
        loss_hist_train[epoch] /= len(trndl)
        del trndl
"""

In [None]:
# sparse loss function
"""
def sparse_dense_mul(s, d):
    i = s._indices()
    v = s._values()
    dv = d[i.squeeze()]
    return torch.sparse.FloatTensor(i, v * dv, s.size())

def sparse_dense_add(s, d):
    s = s.to(d.device)

    indices = s._indices()
    if indices.dim() == 1:
        indices = indices.squeeze()
    values = s._values()

    d[indices] += values

    return d

class SparseCustomLoss(nn.Module):
    def __init__(self):
        super(SparseCustomLoss, self).__init__()

    def forward(self, X_ji, C_j, rho_ji):
        C_j_value = C_j[0]
        X_ji = X_ji.squeeze(0)
        rho_ji = rho_ji.squeeze(0).squeeze(1)
        #X_ji_sparse = X_ji.to_sparse()
        #term1 = sparse_dense_mul(X_ji_sparse, rho_ji)
        #term2 = C_j_value * torch.exp(-rho_ji)
        #loss = sparse_dense_add(term1, term2)
        loss = X_ji * rho_ji + C_j * torch.exp(-rho_ji)
        return (loss).mean()
"""