In [1]:
import os
import numpy as np
import pandas as pd

from joblib import load
from tqdm.notebook import tqdm
from torch.utils import data

from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
from IPython.display import clear_output

import torch
from torch.utils import data

from src.Sparse_vector.sparse_vector import SparseVector
from src.data_preparation import get_train_test_dataset
from src.train_test import set_random_seed, train

In [2]:
def chrom_reader(chrom):
    files = sorted([i for i in os.listdir(f"z_dna/hg38_dna/") if f"{chrom}_" in i])
    return "".join([load(f"z_dna/hg38_dna/{file}") for file in files])


chroms = [f"chr{i}" for i in list(range(1, 23)) + ["X", "Y", "M"]]
all_features = [
    i[:-4] for i in os.listdir("z_dna/hg38_features/sparse/") if i.endswith(".pkl")
]
groups = ["DNase-seq", "Histone", "RNA polymerase", "TFs and others"]
feature_names = [i for i in all_features]


In [3]:
%%time
DNA = {chrom:chrom_reader(chrom) for chrom in tqdm(chroms)}

ZDNA_data = load('z_dna/hg38_zdna/sparse/ZDNA_cousine.pkl')

DNA_features = {feature: load(f'z_dna/hg38_features/sparse/{feature}.pkl')
                for feature in tqdm(feature_names)}

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/1946 [00:00<?, ?it/s]

CPU times: user 1min 17s, sys: 6.03 s, total: 1min 23s
Wall time: 1min 23s


In [4]:
np.random.seed(10)

width = 100

train_dataset, test_dataset = get_train_test_dataset(
    width, chroms, feature_names, DNA, DNA_features, ZDNA_data
)


100%|██████████| 2489564/2489564 [00:28<00:00, 88825.31it/s] 
100%|██████████| 2421935/2421935 [00:26<00:00, 93090.79it/s] 
100%|██████████| 1982955/1982955 [00:21<00:00, 94006.96it/s] 
100%|██████████| 1902145/1902145 [00:20<00:00, 92947.13it/s] 
100%|██████████| 1815382/1815382 [00:19<00:00, 92066.47it/s] 
100%|██████████| 1708059/1708059 [00:17<00:00, 98563.90it/s] 
100%|██████████| 1593459/1593459 [00:17<00:00, 89336.54it/s] 
100%|██████████| 1451386/1451386 [00:14<00:00, 97134.63it/s] 
100%|██████████| 1383947/1383947 [00:16<00:00, 86148.48it/s] 
100%|██████████| 1337974/1337974 [00:13<00:00, 98076.71it/s] 
100%|██████████| 1350866/1350866 [00:13<00:00, 98410.23it/s] 
100%|██████████| 1332753/1332753 [00:16<00:00, 82801.21it/s] 
100%|██████████| 1143643/1143643 [00:11<00:00, 97476.80it/s] 
100%|██████████| 1070437/1070437 [00:11<00:00, 96024.21it/s]
100%|██████████| 1019911/1019911 [00:10<00:00, 100163.54it/s]
100%|██████████| 903383/903383 [00:09<00:00, 96513.35it/s] 
100%|██████

In [5]:
params = {"batch_size": 1, "num_workers": 5, "shuffle": True, "pin_memory": True}

loader_train = data.DataLoader(train_dataset, **params)
loader_test = data.DataLoader(test_dataset, **params)


In [6]:
from torch import nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, f1_score, average_precision_score
from IPython.display import clear_output


class ImageZ(nn.Module):
    def __init__(self, width, features_count):
        super().__init__()
        self.width = width
        self.features_count = features_count

        self.seq = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.GroupNorm(2, 4),
            nn.Conv2d(4, 8, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.GroupNorm(4, 8),
            nn.Conv2d(8, 16, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.GroupNorm(8, 16),
            nn.Conv2d(16, 32, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.GroupNorm(16, 32),
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.GroupNorm(16, 64),
            nn.Conv2d(64, 128, kernel_size=(5, 5), padding=2),
            nn.ReLU(),
            nn.GroupNorm(32, 128),
            nn.Conv2d(128, 64, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.GroupNorm(32, 64),
            nn.Conv2d(64, 32, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.GroupNorm(16, 32),
            nn.Conv2d(32, 16, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.GroupNorm(8, 16),
            nn.Conv2d(16, 8, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.GroupNorm(4, 8),
            nn.Conv2d(8, 4, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.GroupNorm(4, 4),
            nn.Conv2d(4, 1, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.GroupNorm(1, 1),
            nn.AlphaDropout(p=0.2),
            nn.Linear(features_count + 4, 500),
            nn.AlphaDropout(p=0.2),
            nn.SELU(),
            nn.Linear(500, 2),
        )

    def forward(self, x):
        batch = x.shape[0]
        x = x.reshape(batch, 1, self.width, self.features_count + 4)
        x = self.seq(x)
        x = torch.squeeze(x)
        x = F.log_softmax(x, dim=-1)
        return x


In [7]:
import gc

model = ImageZ(width, len(feature_names))
model.load_state_dict(torch.load("couzine_0.897250.pt", weights_only=True))
model = model.to("cuda")
model.eval()

gc.collect()

torch.cuda.empty_cache()


In [11]:
from captum.metrics import infidelity

In [20]:
from captum.attr import (
    IntegratedGradients,
    GradientShap,
    DeepLift,
    DeepLiftShap,
    Saliency,
    InputXGradient,
    GuidedBackprop,
    Deconvolution,
    GuidedGradCam,
    FeatureAblation,
    FeaturePermutation,
    Occlusion,
    ShapleyValueSampling,
    Lime,
    KernelShap,
)

ATTR_METHODS = {
    IntegratedGradients:    "IntegratedGradients",
    GradientShap:           "GradientShap",
    DeepLift:               "DeepLift",
    Saliency:               "Saliency",
    InputXGradient:         "InputXGradient",
    GuidedBackprop:         "GuidedBackpropagation",
    Deconvolution:          "Deconvolution",
}


In [21]:
from captum.metrics import infidelity
import random
from torch.utils.data import DataLoader, Subset
import statistics


def perturb_fn(inputs):
    noise = torch.tensor(np.random.normal(0, 0.001, inputs.shape)).float().to(device)
    noise_inputs = inputs - noise
    return noise, noise_inputs.to(device)


def get_infidelity(method, n_perturb_samples=10, normalize=False):

    subset_size = 10000
    indices = list(range(subset_size))

    subset = Subset(test_dataset, indices)
    params = {"batch_size": 1, "num_workers": 5, "shuffle": True, "pin_memory": True}

    loader_test_mini = data.DataLoader(subset, **params)

    infid_list = []
    for x, y_true in tqdm(loader_test_mini):
        x, y_true = x.to(device), y_true.to(device).long()
        explain = method(model)
        if ATTR_METHODS[method] == "IntegratedGradients":
            attribution = explain.attribute(x, target=1, n_steps=1)
            now_list = []
            for _ in range(n_perturb_samples):
                infid = infidelity(
                    model,
                    perturb_fn,
                    x,
                    attribution,
                    n_perturb_samples=1,
                    normalize=normalize,
                )
                now_list.append(infid.item())
                gc.collect()
                torch.cuda.empty_cache()
            infid_list.append(statistics.mean(now_list))
        elif ATTR_METHODS[method] == "DeepLift":
            now_list = []
            for _ in range(n_perturb_samples):
                index_list = []
                for index in random.sample(range(x.shape[1]), k=5):
                    attribution = explain.attribute(x.to(device), target=(index, 1))
                    infid = infidelity(
                        model,
                        perturb_fn,
                        x,
                        attribution,
                        n_perturb_samples=1,
                        normalize=normalize,
                    )
                    index_list.append(infid.item())
                    gc.collect()
                    torch.cuda.empty_cache()
                now_list.append(statistics.mean(index_list))
            infid_list.append(statistics.mean(now_list))
        elif ATTR_METHODS[method] == "GradientShap":
            now_list = []
            for _ in range(n_perturb_samples):
                index_list = []
                for index in random.sample(range(x.shape[1]), k=5):
                    attribution = explain.attribute(
                        x.to(device), target=(index, 1), baselines=torch.zeros_like(x)
                    )
                    infid = infidelity(
                        model,
                        perturb_fn,
                        x,
                        attribution,
                        n_perturb_samples=1,
                        normalize=normalize,
                    )
                    index_list.append(infid.item())
                    gc.collect()
                    torch.cuda.empty_cache()
                now_list.append(statistics.mean(index_list))
            infid_list.append(statistics.mean(now_list))

        else:
            attribution = explain.attribute(x, target=1)
            now_list = []
            for _ in range(n_perturb_samples):
                infid = infidelity(
                    model,
                    perturb_fn,
                    x,
                    attribution,
                    n_perturb_samples=1,
                    normalize=normalize,
                )
                now_list.append(infid.item())
                gc.collect()
                torch.cuda.empty_cache()
            infid_list.append(statistics.mean(now_list))

    return statistics.mean(infid_list)


In [2]:
answer = []
for method in tqdm(ATTR_METHODS.keys()):
    inf_value = get_infidelity(method, 4, True)
    answer.append({ATTR_METHODS[method]: inf_value})