In [1]:
import shutil
import os

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
from functorch import combine_state_for_ensemble, vmap

import boda

import pandas as pd
import numpy as np
import csv
from scipy.stats import pearsonr
import tqdm.notebook as tqdm
import matplotlib.pyplot as plt

# Set helpers

In [2]:
def load_model(artifact_path):
    
    USE_CUDA = torch.cuda.device_count() >= 1
    if os.path.isdir('./artifacts'):
        shutil.rmtree('./artifacts')

    boda.common.utils.unpack_artifact(artifact_path)

    model_dir = './artifacts'

    my_model = boda.common.utils.model_fn(model_dir)
    my_model.eval()
    if USE_CUDA:
        my_model.cuda()
    
    return my_model

class ConsistentModelPool(nn.Module):
    """
    Ensemble of consistent models.

    This class creates an ensemble of consistent models from a list of model paths.
    
    Args:
        path_list (list): List of paths to model artifacts.

    Attributes:
        fmodel (nn.Module): The ensemble forward model.
        params (dict): Parameters shared across models.
        buffers (dict): Buffers shared across models.
    """
    
    def __init__(self,
                 path_list
                ):
        """
        Initialize the ConsistentModelPool with a list of model paths.

        Args:
            path_list (list): List of paths to model artifacts.
        """
        super().__init__()
        
        models = [ load_model(model_path) for model_path in path_list ]
        self.fmodel, self.params, self.buffers = combine_state_for_ensemble(models)
            
    def forward(self, batch):
        """
        Forward pass through the ensemble.

        Args:
            batch (torch.Tensor): Input data batch.

        Returns:
            torch.Tensor: Predictions from the ensemble.
        """
        preds = vmap(self.fmodel, in_dims=(0, 0, None))(self.params, self.buffers, batch)
        return preds.mean(dim=0)


# Setup data

In [3]:
left_flank = boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_UPSTREAM[-200:] 
).unsqueeze(0)
print(f'left flank shape: {left_flank.shape}')

right_flank= boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_DOWNSTREAM[:200] 
).unsqueeze(0)
right_flank.shape
print(f'right flank shape: {right_flank.shape}')

flank_builder = boda.common.utils.FlankBuilder(
    left_flank=left_flank,
    right_flank=right_flank,
)

flank_builder.cuda()

left flank shape: torch.Size([1, 4, 200])
right flank shape: torch.Size([1, 4, 200])


FlankBuilder()

In [4]:
fn_in = 'satmut_20230921.txt'

with open(fn_in,'r') as f:
    header = f.readline()
    seq_data = {}
    id_data  = {}
    for line in f:
        lsplit = line.rstrip().split('\t')
        chrom = lsplit[9]
        ID    = lsplit[0]
        seq = line.split()[-1]
        seq = torch.cat([torch.zeros(4,200 - len(seq)), boda.common.utils.dna2tensor(seq)], axis=1)
        try:
            seq_data[chrom].append(seq)
            id_data[chrom].append(ID)
        except KeyError:
            seq_data[chrom] = [seq]
            id_data[chrom] = [ID]
    
    for key in seq_data.keys():
        seq_data[key] = torch.stack( seq_data[key] )


In [5]:
id_data.keys()

dict_keys(['1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '3', '4', '5', '6', '7', '8', '9', 'X'])

In [7]:
grab_str = 'gs://tewhey-public-data/crossval_models/models_230111/test_{}_*/model_artifacts*'

for i in range(22):
    chrom = str(i+1)
    
    if int(chrom) > 11:
        model_chrom = 23 - int(chrom)
    else:
        model_chrom = chrom
    
    model_paths = !gsutil ls {grab_str.format(model_chrom)}
    
    my_model = ConsistentModelPool(model_paths)
    
    !rm model_artifacts*tar.gz
    
    seq_dataset = TensorDataset(seq_data[chrom])
    seq_loader  = DataLoader(seq_dataset, batch_size=8, shuffle=False)

    results = []
    with torch.no_grad():
        for j, batch in enumerate(tqdm.tqdm(seq_loader)):
            prepped_seq = flank_builder( batch[0].cuda() )
            predictions = my_model( prepped_seq ) + \
                          my_model( prepped_seq.flip(dims=[1,2]) ) # Also
            predictions = predictions.div(2.)
            results.append(predictions.detach().cpu())
        
    predictions = torch.cat(results, dim=0)
    
    outputs = pd.DataFrame(predictions.numpy(), columns=['K562','HepG2', 'SKNSH'])
    outputs['ID'] = id_data[chrom]
    
    outputs.loc[:,['ID','K562','HepG2', 'SKNSH']] \
    .to_csv(f'satmut_20230921_preds_chr{chrom}.txt', sep='\t', index=False, quoting=csv.QUOTE_NONE)


Copying gs://tewhey-public-data/crossval_models/models_230111/test_1_val_10/model_artifacts__20230112_024037__325298.tar.gz...
\ [1 files][ 49.1 MiB/ 49.1 MiB]                                                
Operation completed over 1 objects/49.1 MiB.                                     
archive unpacked in ./


Loaded model from 20230112_024037 in eval mode


Copying gs://tewhey-public-data/crossval_models/models_230111/test_1_val_11/model_artifacts__20230112_003539__758935.tar.gz...
\ [1 files][ 49.8 MiB/ 49.8 MiB]                                                
Operation completed over 1 objects/49.8 MiB.                                     
archive unpacked in ./


Loaded model from 20230112_003539 in eval mode


Copying gs://tewhey-public-data/crossval_models/models_230111/test_1_val_2/model_artifacts__20230111_192235__715916.tar.gz...
| [1 files][ 49.8 MiB/ 49.8 MiB]                                                
Operation completed over 1 objects/49.8 MiB.                                     
archive unpacked in ./


Loaded model from 20230111_192235 in eval mode


Copying gs://tewhey-public-data/crossval_models/models_230111/test_1_val_3/model_artifacts__20230111_190417__993143.tar.gz...
| [1 files][ 49.2 MiB/ 49.2 MiB]                                                
Operation completed over 1 objects/49.2 MiB.                                     
archive unpacked in ./


Loaded model from 20230111_190417 in eval mode


Copying gs://tewhey-public-data/crossval_models/models_230111/test_1_val_4/model_artifacts__20230112_000934__941678.tar.gz...
| [1 files][ 49.8 MiB/ 49.8 MiB]                                                
Operation completed over 1 objects/49.8 MiB.                                     
archive unpacked in ./


Loaded model from 20230112_000934 in eval mode


Copying gs://tewhey-public-data/crossval_models/models_230111/test_1_val_5/model_artifacts__20230112_003327__287605.tar.gz...
| [1 files][ 49.4 MiB/ 49.4 MiB]                                                
Operation completed over 1 objects/49.4 MiB.                                     
archive unpacked in ./


Loaded model from 20230112_003327 in eval mode


Copying gs://tewhey-public-data/crossval_models/models_230111/test_1_val_6/model_artifacts__20230112_020038__431749.tar.gz...
/ [1 files][ 51.0 MiB/ 51.0 MiB]                                                
Operation completed over 1 objects/51.0 MiB.                                     
archive unpacked in ./


Loaded model from 20230112_020038 in eval mode


Copying gs://tewhey-public-data/crossval_models/models_230111/test_1_val_7/model_artifacts__20230112_182326__436818.tar.gz...
| [1 files][ 49.4 MiB/ 49.4 MiB]                                                
Operation completed over 1 objects/49.4 MiB.                                     
archive unpacked in ./


Loaded model from 20230112_182326 in eval mode


Copying gs://tewhey-public-data/crossval_models/models_230111/test_1_val_8/model_artifacts__20230112_014853__150994.tar.gz...
\ [1 files][ 51.3 MiB/ 51.3 MiB]                                                
Operation completed over 1 objects/51.3 MiB.                                     
archive unpacked in ./


Loaded model from 20230112_014853 in eval mode


Copying gs://tewhey-public-data/crossval_models/models_230111/test_1_val_9/model_artifacts__20230111_232551__863644.tar.gz...
\ [1 files][ 51.1 MiB/ 51.1 MiB]                                                
Operation completed over 1 objects/51.1 MiB.                                     
archive unpacked in ./


Loaded model from 20230111_232551 in eval mode


  0%|          | 0/3303 [00:00<?, ?it/s]

  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)


KeyboardInterrupt: 

In [25]:
grab_str = 'gs://tewhey-public-data/crossval_models/models_230111/test_{}_*/model_artifacts*'

model_list = !gsutil ls {grab_str.format(2)}

In [26]:
model_list

['gs://tewhey-public-data/crossval_models/models_230111/test_2_val_1/model_artifacts__20230112_041842__826676.tar.gz',
 'gs://tewhey-public-data/crossval_models/models_230111/test_2_val_10/model_artifacts__20230112_031741__363749.tar.gz',
 'gs://tewhey-public-data/crossval_models/models_230111/test_2_val_11/model_artifacts__20230112_002409__248129.tar.gz',
 'gs://tewhey-public-data/crossval_models/models_230111/test_2_val_3/model_artifacts__20230112_173424__415672.tar.gz',
 'gs://tewhey-public-data/crossval_models/models_230111/test_2_val_4/model_artifacts__20230112_014533__217034.tar.gz',
 'gs://tewhey-public-data/crossval_models/models_230111/test_2_val_5/model_artifacts__20230112_003824__879872.tar.gz',
 'gs://tewhey-public-data/crossval_models/models_230111/test_2_val_6/model_artifacts__20230111_233731__117900.tar.gz',
 'gs://tewhey-public-data/crossval_models/models_230111/test_2_val_7/model_artifacts__20230112_023910__737605.tar.gz',
 'gs://tewhey-public-data/crossval_models/mode

In [30]:
model_paths = !gsutil ls {grab_str.format(chrom)}
model_paths

['CommandException: One or more URLs matched no objects.']

In [31]:
grab_str.format(chrom)

'gs://tewhey-public-data/crossval_models/models_230111/test_18_*/model_artifacts*'

In [28]:
test = boda.common.utils.unpack_artifact('gs://tewhey-public-data/crossval_models/models_230111/test_2_val_3/model_artifacts__20230112_173424__415672.tar.gz')


Copying gs://tewhey-public-data/crossval_models/models_230111/test_2_val_3/model_artifacts__20230112_173424__415672.tar.gz...
| [1 files][ 49.1 MiB/ 49.1 MiB]                                                
Operation completed over 1 objects/49.1 MiB.                                     
archive unpacked in ./


In [29]:
test = load_model('gs://tewhey-public-data/crossval_models/models_230111/test_2_val_3/model_artifacts__20230112_173424__415672.tar.gz')


Copying gs://tewhey-public-data/crossval_models/models_230111/test_2_val_3/model_artifacts__20230112_173424__415672.tar.gz...
\ [1 files][ 49.1 MiB/ 49.1 MiB]                                                
Operation completed over 1 objects/49.1 MiB.                                     


Loaded model from 20230112_173424 in eval mode


archive unpacked in ./


In [None]:
results = []

with torch.no_grad():
    for i, batch in enumerate(tqdm.tqdm(seq_loader)):
        prepped_seq = flank_builder( batch[0].cuda() )
        predictions = my_model( prepped_seq ) + \
                      my_model( prepped_seq.flip(dims=[1,2]) ) # Also
        predictions = predictions.div(2.)
        results.append(predictions.detach().cpu())
                
predictions = torch.cat(results, dim=0)

In [21]:
seq_tensor.shape

torch.Size([350768, 4, 200])

In [9]:
sum([ x.shape[1] == 200 for x in seq_tensor ])

347294

In [11]:
sum([ x.shape[1] > 200 for x in seq_tensor ])

0