HiCFoundation Resolution Enhancement Pipeline
This document provides a complete pipeline for Hi-C resolution enhancement using HiCFoundation, including data preprocessing, pre-training, fine-tuning, and inference.
Table of Contents

1. Environment Setup
2. Data Preprocessing
3. Submatrix Generation
4. Pre-training
5. Fine-tuning Preparation
6. Fine-tuning
7. Inference
8. Visualization and Evaluation

Environment Setup:
Install Dependencies

Install pip: https://pip.pypa.io/en/stable/installation/


Install pytorch: Install pytorch: Please check pytorch_site (https://pytorch.org/get-started/previous-versions/) to select pytorch=1.8.1 version that is compatible with your cuda version.
You can check the cuda version of your server with 'nvidia-smi' or 'nvcc -V' to check your cuda version. Then you can run the recommended installation command (recommend pip command) from the website in this environment.


In [None]:
# Install required Python packages
pip install easydict opencv-python simplejson lvis Pillow==9.5.0 pytorch_msssim 
pip install pandas hic-straw matplotlib scikit-image scipy einops tensorboard cooler numba pyBigWig timm==0.3.2


# Clone HiCFoundation repositories
git clone https://github.com/Noble-Lab/HiCFoundation.git
git clone https://github.com/Noble-Lab/HiCFoundation_paper.git

Create Directory Structure

In [None]:
import os

# Create necessary directories
dirs_to_create = [
    'utils', 
    'input-dirs', 
    'input-dirs/pre-train-dirs', 
    'ft-inputs', 
    'outputs', 
    'models'
]

for dir_name in dirs_to_create:
    os.makedirs(dir_name, exist_ok=True)

Check GPU Availability

In [None]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Download data from here: 
https://drive.google.com/drive/folders/1D5MqwauHKRFixhRbGljSnouxWFNVfL1l?usp=sharing 

Data Preprocessing:
1. Create hic2array.py
Save this as utils/hic2array.py

In [None]:
import numpy as np
from scipy.sparse import coo_matrix
import hicstraw
import os
import pickle 

def write_pkl(data, path):
    with open(path, 'wb') as f:
        pickle.dump(data, f)

def read_chrom_array(chr1, chr2, normalization, hic_file, resolution):
    chr1_name = chr1.name
    chr2_name = chr2.name
    infos = []
    infos.append('observed')
    infos.append(normalization)
    infos.append(hic_file)
    infos.append(chr1_name)
    infos.append(chr2_name)
    infos.append('BP')
    infos.append(resolution)
    print(infos)
    row, col, val = [], [], []
    rets = hicstraw.straw(*infos)
    print('\tlen(rets): {:3e}'.format(len(rets)))
    for ret in rets:
        row.append((int)(ret.binX // resolution))
        col.append((int)(ret.binY // resolution))
        val.append(ret.counts)
    print('\tsum(val): {:3e}'.format(sum(val)))
    if sum(val) == 0:
        return None
    if chr1_name==chr2_name:
        max_shape =max(max(row),max(col))+1
        mat_coo = coo_matrix((val, (row, col)), shape = (max_shape,max_shape),dtype=np.float32)
    else:
        max_row = max(row)+1
        max_column = max(col)+1
        mat_coo = coo_matrix((val, (row, col)), shape = (max_row,max_column),dtype=np.float32)

    mat_coo = mat_coo #+ triu(mat_coo, 1).T #no below diagonaline records

    return mat_coo


def hic2array(input_hic,output_pkl=None,
              resolution=25000,normalization="NONE",
              tondarray=0):
    """
    input_hic: str, input hic file path
    output_pkl: str, output pickle file path
    resolution: int, resolution of the hic file
    """

    hic = hicstraw.HiCFile(input_hic)
    chrom_list=[]
    chrom_dict={}
    for chrom in hic.getChromosomes():
        print(chrom.name, chrom.length)
        if "all" in chrom.name.lower():
            continue
        chrom_list.append(chrom)
        chrom_dict[chrom.name]=chrom.length
    resolution_list = hic.getResolutions()
    if resolution not in resolution_list:
        print("Resolution not found in the hic file, please choose from the following list:")
        print(resolution_list)
        exit()
    output_dict={}
    for i in range(len(chrom_list)):
        for j in range(i,len(chrom_list)):
            if i!=j and tondarray in [2,3]:
                #skip inter-chromosome region
                continue
            
            chrom1 = chrom_list[i]
            chrom1_name = chrom_list[i].name
            chrom2 = chrom_list[j]
            chrom2_name = chrom_list[j].name
            if 'Un' in chrom1_name or 'Un' in chrom2_name:
                continue
            if "random" in chrom1_name.lower() or "random" in chrom2_name.lower():
                continue
            if "alt" in chrom1_name.lower() or "alt" in chrom2_name.lower():
                continue
            read_array=read_chrom_array(chrom1,chrom2, normalization, input_hic, resolution)
            if read_array is None:
                print("No data found for",chrom1_name,chrom2_name)
                continue
            if tondarray in [1,3]:
                read_array = read_array.toarray()
            if tondarray in [2,3]:
                output_dict[chrom1_name]=read_array
            else:
                output_dict[chrom1_name+"_"+chrom2_name]=read_array
    if output_pkl is not None:
        output_dir = os.path.dirname(os.path.realpath(output_pkl))
        os.makedirs(output_dir, exist_ok=True)
        write_pkl(output_dict,output_pkl)

    return output_dict

if __name__ == '__main__':
    import os 
    import sys
    if len(sys.argv) != 6:
        print('Usage: python3 hic2array.py [input.hic] [output.pkl] [resolution] [normalization_type] [mode]')
        print("This is the full hic2array script. ")
        print("normalization type: 0: None normalization; 1: VC normalization; 2: VC_SQRT normalization; 3: KR normalization; 4: SCALE normalization")
        print("mode: 0 for sparse matrix, 1 for dense matrix, 2 for sparce matrix (only cis-contact); 3 for dense matrix (only cis-contact).")
        sys.exit(1)
    resolution = int(sys.argv[3])
    normalization_type = int(sys.argv[4])
    mode = int(sys.argv[5])
    normalization_dict={0:"NONE",1:"VC",2:"VC_SQRT",3:"KR",4:"SCALE"}
    if normalization_type not in normalization_dict:
        print('normalization type should be 0,1,2,3,4')
        print("normalization type: 0: None normalization; 1: VC normalization; 2: VC_SQRT normalization; 3: KR normalization; 4: SCALE normalization")
        sys.exit(1)
    normalization_type = normalization_dict[normalization_type]
    if mode not in [0,1,2,3]:
        print('mode should be in choice of 0/1/2/3')
        print("mode: 0 for sparse matrix, 1 for dense matrix, 2 for sparce matrix (only cis-contact); 3 for dense matrix (only cis-contact).")
        sys.exit(1)
    input_hic_path = os.path.abspath(sys.argv[1])
    output_pkl_path = os.path.abspath(sys.argv[2])
    output_dir = os.path.dirname(output_pkl_path)
    os.makedirs(output_dir,exist_ok=True)
    hic2array(input_hic_path,output_pkl_path,resolution,normalization_type,mode)

Submatrix Generation:
1. Create scan_array.py
Save this as utils/scan_array.py:

In [None]:
import numpy as np
import pickle
from scipy.sparse import coo_matrix
import os

def write_pickle(output_dict,output_path):
    """
    output_dict: dict, output dictionary
    output_path: str, output path
    """
    with open(output_path, 'wb') as f:
        pickle.dump(output_dict, f)

def scan_matrix(matrix, input_row_size,input_col_size, stride_row,
                stride_col,hic_count,output_dir,current_chrom,
                filter_threshold=0.05):
    """
    matrix: 2D array
    input_row_size: int, row size of scanned output submatrix
    input_col_size: int, column size of scanned output submatrix
    stride_row: int, row stride
    stride_col: int, column stride
    hic_count: int, total read count of the Hi-C experiments
    output_dir: str, output directory
    current_chrom: str, current chromosome
    """
    row_size = matrix.shape[0]
    col_size = matrix.shape[1]
    count_save=0
    region_size = input_row_size * input_col_size
    for i in range(0, row_size - input_row_size//2, stride_row):
        for j in range(0, col_size - input_col_size//2, stride_col):
            submatrix = np.zeros((input_row_size, input_col_size))
            row_start = max(0,i)
            row_end = min(row_size, i + input_row_size)
            col_start = max(0,j)
            col_end = min(col_size, j + input_col_size)
            submatrix[:row_end-row_start,:col_end-col_start] = matrix[row_start: row_end, col_start: col_end]
            #filter out the submatrices with too many zeros
            count_useful = np.count_nonzero(submatrix)
            if count_useful < region_size * filter_threshold:
                continue
            
            output_dict={}
            output_dict['input']=submatrix
            output_dict['input_count']=hic_count
            #judge if the diag is possibly included
            if col_start < row_start and col_end >row_start:
                output_dict['diag']=abs (col_start-row_start)
            elif col_start == row_start:
                output_dict['diag']=0
            elif col_start> row_start and col_start < row_end:
                output_dict['diag']= -abs (col_start-row_start)
            else:
                output_dict['diag']=None
            output_path = os.path.join(output_dir, str(current_chrom) + '_' + str(i) + '_' + str(j) + '.pkl')
            write_pickle(output_dict,output_path)
            count_save+=1
            if count_save%100==0:
                print('Processed %d submatrices' % count_save, " for chromosome ", current_chrom)
        
    return 

def scan_pickle(input_pkl_path, input_row_size,input_col_size, stride_row,
                stride_col,output_dir,filter_threshold):
    """
    input_pkl_path: str, input pickle path  
    input_row_size: int, row size of scanned output submatrix
    input_col_size: int, column size of scanned output submatrix
    stride_row: int, row stride
    stride_col: int, column stride
    output_dir: str, output directory
    """

    os.makedirs(output_dir, exist_ok=True)

    with open(input_pkl_path, 'rb') as f:
        data = pickle.load(f)
    total_count = 0
    for key in data:
        matrix = data[key]
        if isinstance(matrix, np.ndarray):
            cur_count = np.sum(matrix)
        elif isinstance(matrix, coo_matrix):
            cur_count = matrix.sum()
        else:
            print("Type not supported", type(matrix))
            exit()
        total_count += cur_count
    print("Total read count of Hi-C: ", total_count)        

    for key in data:
        matrix = data[key]
        if isinstance(matrix, coo_matrix):
            matrix = matrix.toarray()
            
            if matrix.shape[0]==matrix.shape[1]:
                #intra chromosmoe
                #get the symmetrical one 
                upper_tri = np.triu(matrix,1)
                all_triu = np.triu(matrix)
                matrix = all_triu + upper_tri.T
            else:
                matrix = matrix
        current_chrom = str(key)
        if "chr" not in current_chrom:
            current_chrom = "chr" + current_chrom

        scan_matrix(matrix, input_row_size,input_col_size, stride_row,
                stride_col,total_count,output_dir,current_chrom,filter_threshold)

#run with the simple command line
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_pkl_path', type=str, required=True)
    parser.add_argument('--input_row_size', type=int, required=True)
    parser.add_argument('--input_col_size', type=int, required=True)
    parser.add_argument('--stride_row', type=int, required=True)
    parser.add_argument('--stride_col', type=int, required=True)
    parser.add_argument('--output_dir', type=str, required=True)
    parser.add_argument('--filter_threshold', type=float, default=0.05)
    args = parser.parse_args()
    input_pkl_path = os.path.abspath(args.input_pkl_path)
    output_dir = os.path.abspath(args.output_dir)
    scan_pickle(input_pkl_path, args.input_row_size, args.input_col_size, 
                args.stride_row, args.stride_col, output_dir, args.filter_threshold)

2. Create .pkl files from .hci files

In [None]:
python3 hic2array.py Ft1-GSM6077013_at_hic_ndx1-4_r2.hic Ftr1.pkl 25000 0 0 
&& python3 hic2array.py Pt1-GSM4705443_ddcc.hic Ptr1.pkl 25000 0 0 
&& python3 hic2array.py Pt2-GSM6077012_at_hic_ndx1-4_r1.hic Ptr2.pkl 25000 0 0 
&& python3 hic2array.py Pv1-GSM5091844_S_WT_2h1_DNB-15.allValidPairs.hic Pv1.pkl 25000 0 0

3. Generate Submatrices

In [None]:
# Generate submatrices for pre-training
python3 utils/scan_array.py --input_pkl_path Ptr1.pkl  --input_row_size 448 \
    --input_col_size 448 --stride_row 224 --stride_col 224 \
    --output_dir HiC-PTR1 --filter_threshold 0.01

python3 utils/scan_array.py --input_pkl_path Ptr2.pkl  --input_row_size 448 \
    --input_col_size 448 --stride_row 224 --stride_col 224 \
    --output_dir HiC-PTR2 --filter_threshold 0.01

python3 utils/scan_array.py --input_pkl_path Pv1.pkl  --input_row_size 448 \
    --input_col_size 448 --stride_row 224 --stride_col 224 \
    --output_dir HiC-PV1 --filter_threshold 0.01

4. Create Configuration Files

In [None]:
# Create train.txt
echo "HiC-PTR1" > input-dirs/pre-train-dirs/train.txt
echo "HiC-PTR2" >> input-dirs/pre-train-dirs/train.txt

# Create val.txt
echo "HiC-PV1" > input-dirs/pre-train-dirs/val.txt

Pre-training:
Run the pre-training command

In [None]:
python3 pretrain.py --batch_size 1 --accum_iter 4 \
    --epochs 1 --warmup_epochs 1 --pin_mem \
    --mask_ratio 0.75 --sparsity_ratio 0.05 \
    --blr 1.5e-4 --min_lr 1e-7 --weight_decay 0.05 \
    --model "vit_large_patch16" --loss_alpha 1 --seed 888 \
    --data_path "input-dirs/pre-train-dirs/" --train_config "train.txt" \
    --valid_config "val.txt" --output "hicfoundation_finetune" \
    --tensorboard 1 --world_size 1 --dist_url "tcp://localhost:10001" --rank 0 \
    --input_row_size 448 --input_col_size 448 --patch_size 16 \
    --print_freq 1 --save_freq 1

After training, rename the output directory:

In [None]:
mv hicfoundation_finetune hicfoundation_pretrain

Fine-tuning Preparation
1. Create downsample_pkl.py
Save this as utils/downsample_pkl.py:

In [None]:
import sys
import os
from collections import defaultdict
import pickle
import numpy as np
from scipy.sparse import coo_matrix

def array_to_coo(array):
    """
    Convert a regular 2D NumPy array to a scipy.sparse.coo_matrix.

    Parameters:
    - array (numpy.ndarray): The input 2D array.

    Returns:
    - scipy.sparse.coo_matrix: The converted COO matrix.
    """
    # Find the non-zero elements in the array
    row, col = np.nonzero(array)

    # Get the values of the non-zero elements
    data = array[row, col]

    # Create the COO matrix
    coo_mat = coo_matrix((data, (row, col)), shape=array.shape)

    return coo_mat

def sparse2tag(coo_mat):
    tag_len = coo_mat.sum()
    tag_len = int(tag_len)
    tag_mat = np.zeros((tag_len, 2))
    tag_mat = tag_mat.astype(int)
    row, col, data = coo_mat.row, coo_mat.col, coo_mat.data
    start_idx = 0
    for i in range(len(row)):
        end_idx = start_idx + int(data[i])
        tag_mat[start_idx:end_idx, :] = (row[i], col[i])
        start_idx = end_idx
    return tag_mat, tag_len

def tag2sparse(tag, nsize):
    """
    Coverts a coo-based tag matrix to sparse matrix.
    """
    coo_data, data = np.unique(tag, axis=0, return_counts=True)
    row, col = coo_data[:, 0], coo_data[:, 1]
    sparse_mat = coo_matrix((data, (row, col)), shape=(nsize, nsize))
    return sparse_mat

def downsampling_sparce(matrix, down_ratio, verbose=False):
    """
    Downsampling method for sparse matrix.
    """
    if verbose: print(f"[Downsampling] Matrix shape is {matrix.shape}")
    tag_mat, tag_len = sparse2tag(matrix)
    sample_idx = np.random.choice(tag_len, int(tag_len *down_ratio))
    sample_tag = tag_mat[sample_idx]
    if verbose: print(f'[Downsampling] Sampling {down_ratio} of {tag_len} reads')
    down_mat = tag2sparse(sample_tag, matrix.shape[0])
    return down_mat


def downsample_pkl(input_pkl, output_pkl, downsample_rate):
    data = pickle.load(open(input_pkl, 'rb'))
    return_dict={}
    for chrom in data:
        current_data = data[chrom]
        if current_data.shape[0] <=100:
            continue
        #if it is numpy array convert to sparse matrix
        if isinstance(current_data, np.ndarray):
            current_data = array_to_coo(current_data)
            
        downsampled_data = downsampling_sparce(current_data, downsample_rate,verbose=1)
        return_dict[chrom] = downsampled_data
    pickle.dump(return_dict, open(output_pkl, "wb"))
    print("finish downsampling %s"%output_pkl)

if __name__ == '__main__':
    if len(sys.argv)!=4:
        print("Usage: python3 downsample_pkl.py [input.pkl] [output.pkl] [downsample_rate]")
        print("This script is used to downsample the input pickle file.")
        print("[input.pkl]: the input pickle file")
        print("[output.pkl]: the output pickle file")
        print("[downsample_rate]: the downsample rate [float].")
        sys.exit(1)
    input_pkl = os.path.abspath(sys.argv[1])
    output_pkl = os.path.abspath(sys.argv[2])
    output_dir = os.path.dirname(output_pkl)
    os.makedirs(output_dir, exist_ok=True)    
    downsample_rate = float(sys.argv[3])
    downsample_pkl(input_pkl, output_pkl, downsample_rate)

2. Downsample Data

In [None]:
python3 utils/downsample_pkl.py Ftr1.pkl Ftr1_downsampled.pkl 0.1

3. Create scan_array_diag.py
Save this as utils/scan_array_diag.py:

In [None]:
import numpy as np
import pickle
from scipy.sparse import coo_matrix
import os

def write_pickle(output_dict,output_path):
    """
    output_dict: dict, output dictionary
    output_path: str, output path
    """
    with open(output_path, 'wb') as f:
        pickle.dump(output_dict, f)

def scan_matrix_paired(original_matrix, downsampled_matrix, input_row_size, input_col_size, stride,
                      hic_count, output_dir, current_chrom):
    """
    original_matrix: 2D array, original high-quality Hi-C matrix
    downsampled_matrix: 2D array, downsampled low-quality Hi-C matrix
    input_row_size: int, row size of scanned output submatrix
    input_col_size: int, column size of scanned output submatrix
    stride: int, row stride
    hic_count: int, total read count of the Hi-C experiments
    output_dir: str, output directory
    current_chrom: str, current chromosome
    """
    row_size = original_matrix.shape[0]
    col_size = original_matrix.shape[1]
    count_save = 0
    
    # Ensure both matrices have the same dimensions
    assert original_matrix.shape == downsampled_matrix.shape, \
        f"Matrix shapes don't match: {original_matrix.shape} vs {downsampled_matrix.shape}"
    
    print(f"Scanning matrix {current_chrom} with shape {original_matrix.shape}")
    print(f"Submatrix size: {input_row_size}x{input_col_size}, stride: {stride}")
    
    # For rectangular matrices, scan with different patterns
    if row_size == col_size:
        # Square matrix: use diagonal scanning
        for i in range(0, row_size - input_row_size + 1, stride):
            j = i  # Diagonal scanning
            if j + input_col_size > col_size:
                continue
                
            original_submatrix = original_matrix[i:i+input_row_size, j:j+input_col_size]
            downsampled_submatrix = downsampled_matrix[i:i+input_row_size, j:j+input_col_size]
            
            # Filter out submatrices with too many zeros
            count_useful = np.count_nonzero(original_submatrix)
            if count_useful < 1:
                continue
            
            # Create paired output dictionary
            output_dict = {}
            output_dict['input'] = downsampled_submatrix.copy()
            output_dict['2d_target'] = original_submatrix.copy()
            output_dict['input_count'] = hic_count
            
            output_path = os.path.join(output_dir, str(current_chrom) + '_' + str(i) + '_' + str(j) + '.pkl')
            write_pickle(output_dict, output_path)
            count_save += 1
            
            if count_save % 100 == 0:
                print('Processed %d paired submatrices' % count_save, " for chromosome ", current_chrom)
    else:
        # Rectangular matrix: scan all possible positions
        for i in range(0, row_size - input_row_size + 1, stride):
            for j in range(0, col_size - input_col_size + 1, stride):
                original_submatrix = original_matrix[i:i+input_row_size, j:j+input_col_size]
                downsampled_submatrix = downsampled_matrix[i:i+input_row_size, j:j+input_col_size]
                
                # Filter out submatrices with too many zeros
                count_useful = np.count_nonzero(original_submatrix)
                if count_useful < 1:
                    continue
                
                # Create paired output dictionary
                output_dict = {}
                output_dict['input'] = downsampled_submatrix.copy()
                output_dict['2d_target'] = original_submatrix.copy()
                output_dict['input_count'] = hic_count
                
                output_path = os.path.join(output_dir, str(current_chrom) + '_' + str(i) + '_' + str(j) + '.pkl')
                write_pickle(output_dict, output_path)
                count_save += 1
                
                if count_save % 100 == 0:
                    print('Processed %d paired submatrices' % count_save, " for chromosome ", current_chrom)
    
    print(f"Total submatrices saved for {current_chrom}: {count_save}")
    return 

def scan_pickle_paired(original_pkl_path, downsampled_pkl_path, input_row_size, input_col_size, 
                      stride, output_dir):
    """
    original_pkl_path: str, path to original (high-quality) pickle file
    downsampled_pkl_path: str, path to downsampled (low-quality) pickle file  
    input_row_size: int, row size of scanned output submatrix
    input_col_size: int, column size of scanned output submatrix
    stride: int, row stride
    output_dir: str, output directory
    """

    os.makedirs(output_dir, exist_ok=True)

    # Load both pickle files
    with open(original_pkl_path, 'rb') as f:
        original_data = pickle.load(f)
    
    with open(downsampled_pkl_path, 'rb') as f:
        downsampled_data = pickle.load(f)
    
    # Ensure both datasets have the same chromosomes
    assert set(original_data.keys()) == set(downsampled_data.keys()), \
        "Original and downsampled data must have the same chromosomes"
    
    # Calculate total count from original data
    total_count = 0
    for key in original_data:
        matrix = original_data[key]
        if isinstance(matrix, np.ndarray):
            cur_count = np.sum(matrix)
        elif isinstance(matrix, coo_matrix):
            cur_count = matrix.sum()
        else:
            print("Type not supported", type(matrix))
            exit()
        total_count += cur_count
    print("Total read count of original Hi-C: ", total_count)        

    # Process each chromosome
    for key in original_data:
        original_matrix = original_data[key]
        downsampled_matrix = downsampled_data[key]
        
        # Convert sparse matrices to dense arrays
        if isinstance(original_matrix, coo_matrix):
            original_matrix = original_matrix.toarray()
        
        if isinstance(downsampled_matrix, coo_matrix):
            downsampled_matrix = downsampled_matrix.toarray()
        
        current_chrom = str(key)
        if "chr" not in current_chrom:
            current_chrom = "chr" + current_chrom
        
        # Only apply symmetry operation if matrix is square
        if original_matrix.shape[0] == original_matrix.shape[1]:
            # Get the symmetrical matrix for square matrices
            upper_tri = np.triu(original_matrix, 1)
            all_triu = np.triu(original_matrix)
            original_matrix = all_triu + upper_tri.T
            
            upper_tri = np.triu(downsampled_matrix, 1)
            all_triu = np.triu(downsampled_matrix)
            downsampled_matrix = all_triu + upper_tri.T
        else:
            print(f"Warning: Matrix for {current_chrom} is not square ({original_matrix.shape}). Skipping symmetry operation.")

        print(f"Processing chromosome {current_chrom}")
        print(f"Original matrix shape: {original_matrix.shape}")
        print(f"Downsampled matrix shape: {downsampled_matrix.shape}")

        scan_matrix_paired(original_matrix, downsampled_matrix, input_row_size, input_col_size, 
                          stride, total_count, output_dir, current_chrom)

# Run with the simple command line
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--original_pkl_path', type=str, required=True, 
                      help='Path to original (high-quality) pickle file')
   parser.add_argument('--downsampled_pkl_path', type=str, required=True,
                      help='Path to downsampled (low-quality) pickle file')
   parser.add_argument('--input_row_size', type=int, required=True)
   parser.add_argument('--input_col_size', type=int, required=True)
   parser.add_argument('--stride', type=int, required=True)
   parser.add_argument('--output_dir', type=str, required=True)
   args = parser.parse_args()
   
   original_pkl_path = os.path.abspath(args.original_pkl_path)
   downsampled_pkl_path = os.path.abspath(args.downsampled_pkl_path)
   output_dir = os.path.abspath(args.output_dir)
   
   scan_pickle_paired(original_pkl_path, downsampled_pkl_path, args.input_row_size, 
                     args.input_col_size, args.stride, output_dir)

4. Generate Paired Submatrices

In [None]:
python3 utils/scan_array_diag.py \
    --original_pkl_path Ftr1.pkl \
    --downsampled_pkl_path Ftr1_downsampled.pkl \
    --input_row_size 224 --input_col_size 224 --stride 20 \
    --output_dir Ftr1

5. Prepare Fine-tuning Data

In [None]:
import glob
import random
import shutil

# Get all pkl files from Ftr1 directory
ftr1_files = glob.glob('Ftr1/*.pkl')

# Shuffle and split (80-20 split)
random.shuffle(ftr1_files)
split_idx = int(0.8 * len(ftr1_files))

train_files = ftr1_files[:split_idx]
val_files = ftr1_files[split_idx:]

# Create directories for fine-tuning
os.makedirs('ft-inputs/train', exist_ok=True)
os.makedirs('ft-inputs/val', exist_ok=True)

# Copy files to respective directories
for f in train_files:
    shutil.copy(f, 'ft-inputs/train/')
for f in val_files:
    shutil.copy(f, 'ft-inputs/val/')

# Create configuration files
with open('ft-inputs/train_config.txt', 'w') as f:
    f.write('train\n')

with open('ft-inputs/val_config.txt', 'w') as f:
    f.write('val\n')

print(f"Created fine-tuning dataset: {len(train_files)} train, {len(val_files)} validation samples")

Fine-tuning:
1. Modified train_epoch.py
Save this modified version as finetune/train_epoch.py:

In [None]:
import math
import sys
import numpy as np
from typing import Iterable
import torch
import torch.nn.functional as F
import time

from ops.Logger import MetricLogger,SmoothedValue
import model.lr_sched as lr_sched
from finetune.loss import configure_loss
from ops.train_utils import list_to_device, to_value, create_image, torch_to_nparray, convert_gray_rgbimage


def train_epoch(model, data_loader_train, optimizer, 
                loss_scaler, epoch, device,
                log_writer=None, args=None):
    model.train()
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))

    header = 'Epoch: [{}]'.format(epoch)
    print_freq = args.print_freq

    accum_iter = args.accum_iter

    optimizer.zero_grad()
    if log_writer is not None:
        print('Tensorboard log dir: {}'.format(log_writer.log_dir))
    print("number of iterations: ",len(data_loader_train))
    criterion = configure_loss(args)

    num_iter = len(data_loader_train)
    for data_iter_step, train_data in enumerate(metric_logger.log_every(data_loader_train, print_freq, header)):
        if data_iter_step % accum_iter == 0:
            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader_train) + epoch, args)
        input_matrix, total_count, target_matrix, embed_target, target_vector = list_to_device(train_data,device=device)
        
        # Forward pass
        output_embedding, output_2d, output_1d = model(input_matrix, total_count)
        
        # Calculate losses - ensure all outputs participate in loss calculation
        loss_components = []
        
        if embed_target is not None:
            embedding_loss = criterion(output_embedding, embed_target)
            loss_components.append(embedding_loss)
        else:
            # Use a small multiplier on the output to ensure gradients flow
            # but don't affect the actual loss value
            embedding_loss = 0.0 * output_embedding.mean()
            loss_components.append(embedding_loss)
            
        if target_matrix is not None:
            #flatten 2d matrix
            output_2d_flatten = torch.flatten(output_2d, start_dim=1,end_dim=-1)
            target_matrix_flatten = torch.flatten(target_matrix, start_dim=1,end_dim=-1)
            output_2d_loss = criterion(output_2d_flatten, target_matrix_flatten)
            loss_components.append(output_2d_loss)
        else:
            # Use a small multiplier on the output to ensure gradients flow
            output_2d_loss = 0.0 * output_2d.mean()
            loss_components.append(output_2d_loss)
            
        if target_vector is not None:
            output_1d_loss = criterion(output_1d, target_vector)
            loss_components.append(output_1d_loss)
        else:
            # Use a small multiplier on the output to ensure gradients flow
            output_1d_loss = 0.0 * output_1d.mean()
            loss_components.append(output_1d_loss)
        
        # Sum all loss components
        loss = sum(loss_components)
        
        # Update metrics
        metric_logger.update(loss=to_value(loss))
        metric_logger.update(embedding_loss=to_value(embedding_loss))
        metric_logger.update(output_2d_loss=to_value(output_2d_loss))
        metric_logger.update(output_1d_loss=to_value(output_1d_loss))
        
        if not math.isfinite(to_value(loss)):
            print("Loss is {}, stopping training".format(to_value(loss)))
            #sys.exit(1)
            optimizer.zero_grad()
            continue
            
        loss = loss / accum_iter
        loss_scaler(loss, optimizer, parameters=model.parameters(),
                    update_grad=(data_iter_step + 1) % accum_iter == 0)

        if (data_iter_step + 1) % accum_iter == 0:
            optimizer.zero_grad()

        torch.cuda.synchronize() # Make sure all gradients are finished computing before moving on
        lr = optimizer.param_groups[0]["lr"]
        metric_logger.update(lr=lr)
        

        if log_writer is not None and ((data_iter_step + 1) % accum_iter == 0 or data_iter_step==0):
            """ 
            We use epoch_1000x as the x-axis in tensorboard.
            This calibrates different curves when batch size changes.
            """
            epoch_1000x = int((data_iter_step / len(data_loader_train) + epoch) * 1000)
            log_writer.add_scalars('Loss/loss', {'train_loss': to_value(loss)}, epoch_1000x)
            log_writer.add_scalars('Loss/embedding_loss', {'train_loss': to_value(embedding_loss)}, epoch_1000x)
            log_writer.add_scalars('Loss/output_2d_loss', {'train_loss': to_value(output_2d_loss)}, epoch_1000x)
            log_writer.add_scalars('Loss/output_1d_loss', {'train_loss': to_value(output_1d_loss)}, epoch_1000x)
            log_writer.add_scalars('LR/lr', {'lr': lr}, epoch_1000x)
            if ((data_iter_step+1)//accum_iter)%50==0 or data_iter_step==0:
                #add visualization for your output and input
                new_samples = create_image(input_matrix)
                select_num = min(8,len(new_samples))
                sample_image = torch_to_nparray(new_samples.clone().detach()[:select_num])
                log_writer.add_images('Input_%s'%"train", sample_image, epoch_1000x)
                output_2d_image = convert_gray_rgbimage(output_2d.clone().detach()[:select_num])
                output_2d_image = torch_to_nparray(output_2d_image)
                log_writer.add_images('Output_2d_%s'%"train", output_2d_image, epoch_1000x)
                # for name, param in model.named_parameters():
                #     log_writer.add_histogram(name, param, epoch_1000x)
                #raise errors, see https://github.com/pytorch/pytorch/issues/91516
                #If you want to use this, install tensorboardX 
                #then change the code in main_worker.py to "from tensorboardX import SummaryWriter"
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

2. Run Fine-tuning

In [None]:
python3 finetune.py --batch_size 1 --accum_iter 4 \
    --epochs 1 --warmup_epochs 0 --pin_mem \
    --blr 1e-3 --min_lr 1e-7 --weight_decay 0.05 \
    --layer_decay 0.75 --model vit_large_patch16 \
    --pretrain hicfoundation_pretrain/model/model_best.pth.tar \
    --finetune 1 --seed 888 \
    --loss_type 1 --data_path "ft-inputs" \
    --train_config "train_config.txt" \
    --valid_config "val_config.txt" \
    --output "hicfoundation_finetune" --tensorboard 1 \
    --world_size 1 --dist_url "tcp://localhost:10001" --rank 0 \
    --input_row_size 448 --input_col_size 448 --patch_size 16 \
    --print_freq 1 --save_freq 1

Inference:
First check to see if you have inference.py, main_worker.py, load_model.py and inference_worker.py from https://github.com/Noble-Lab/HiCFoundation/tree/main/inference. If you have all four of the python models installed continue from step 4. 
1. Create a directory named inference

In [None]:
mkdir inference

2. Save this as inference/inference.py:

In [None]:
import os
import timm
assert timm.__version__ == "0.3.2" # version check for timm
from ops.argparser import  argparser_infer
from ops.file_format_convert import convert_to_pkl

def main(args):
    import socket
    hostname = socket.gethostname()
    local_ip = socket.gethostbyname(hostname)
    print("local ip: ",local_ip)
    #format processing, convert different formats to .pkl format for further processing
    output_dir = os.path.abspath(args.output)
    os.makedirs(output_dir,exist_ok=True)
    input_file = os.path.abspath(args.input)
    config_resolution = args.resolution
    input_pkl=convert_to_pkl(input_file, output_dir,config_resolution)
    
    #for reproducibility analysis, we need to smooth the matrix to generate embeddings.
    if args.task==1:
        from ops.smooth_matrix import smooth_pkl
        smooth_pkl_file = os.path.join(output_dir,"input_smoothed.pkl")
        input_pkl = smooth_pkl(input_pkl,smooth_pkl_file)
        print("Reproducibility analysis smoothed input matrix saved to ",input_pkl)
    from inference.main_worker import main_worker
    main_worker(args, input_pkl)


if __name__ == '__main__':
    print("HiCFoundation inference started!")
    parser = argparser_infer()
    args = parser.parse_args()
    if args.gpu is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    #print mode based on --task
    if args.task==1:
        print("Reproducibility analysis")
    elif args.task==2:
        print("Loop calling")
    elif args.task==3:
        print("Resolution enhancement")
    elif args.task==4:
        print("Epigenomic assay prediction")
    elif args.task==5:
        print("scHi-C enhancement")
    elif args.task==6:
        print("Hi-C embedding generation")
        embed_depth = args.embed_depth
        if embed_depth>8:
            print("Error: embed_depth is larger than 8, that is beyond decoder depth. Please set embed_depth<=8")
            print("0 indicates the encoder output, k indicates the k-th decoder layer's output")
            exit(1)
    else:
        print("Unknown task specified ",args.task)
        print("Please specify the task using --task with 1,2,3,4,5,6")
        exit(1)
    #check the specied input size, must be a multiple of args.patch_size
    if args.input_row_size%args.patch_size!=0 or args.input_col_size%args.patch_size!=0:
        print("args configuration error: input_row_size and input_col_size must be a multiple of patch_size")
        exit(1)
    #output the args in a beautiful format
    main(args)


3. Save this as inference/main_worker.py:

In [None]:
import torch
import os
import torch.nn as nn
from scipy.sparse import coo_matrix

from utils.hic_coverage import calculate_coverage
from data_processing.inference_dataset import Inference_Dataset
from ops.io_utils import write_pickle,append_record
from ops.mean_shift_merge import mean_shift_merge
from ops.file_format_convert import pkl2others
from utils.array2bigwig import array2bigwig
from model.pos_embed import interpolate_pos_embed_inputsize


def configure_dataset(args,input_pkl):
    resolution = args.resolution
    import torchvision.transforms as transforms
    transform_input = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    if args.task==3:
        #resolution enhancement
        fill_diagonal_zero=True
        
    else:
        fill_diagonal_zero=False
    if args.task==3:
        #judge if it is a very deep sequencing data, if it is, set max_cutoff to None
        coverage_perresolution = calculate_coverage(input_pkl)/resolution
        if coverage_perresolution>1:
            max_cutoff = None
        else:
            max_cutoff = 100
    elif args.task==2:
        #loop calling
        max_cutoff = 1000
    elif args.task==5:
        #scHi-C enhancement
        max_cutoff = 100
    else:
        max_cutoff = None
    
    if args.task==4:
        #epigenomic assay prediction
        locus_embedding = True
    else:
        locus_embedding = False
        
    bounding = args.bound
    stride = args.stride
    input_row_size = args.input_row_size
    input_col_size = args.input_col_size
    task = args.task
    dataset = Inference_Dataset(data_path=input_pkl,   
                            transform=transform_input,
                            stride=stride,
                            window_height= input_row_size,
                            window_width = input_col_size,
                            max_cutoff=max_cutoff,
                            fill_diagonal_zero=fill_diagonal_zero,
                            bounding=bounding,
                            locus_embedding=locus_embedding,
                            task=task)
    sample_batch_size = args.batch_size
    data_loader_test = torch.utils.data.DataLoader(
        dataset,
        batch_size=sample_batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        drop_last=False)
    return data_loader_test

def generate_loop(return_dict,threshold,output_bedpe,config_resolution):
    with open(output_bedpe,'w') as wfile:
        wfile.write("chr1\tx1\tx2\tchr2\ty1\ty2\n")
    for chrom in return_dict:
        mean_array = return_dict[chrom]
        if mean_array.shape[0]<=10000:
            mean_array = mean_array.toarray()
            try:
                mean_loc_list = mean_shift_merge(mean_array,cutoff=threshold)
            except:
                mean_loc_list = []
        else:
            mean_loc_list = []
            for i in range(0,mean_array.shape[0],10000):
                cur_start = i
                cur_end = min(i+10000,mean_array.shape[0])
                select_index1 = (mean_array.row>=cur_start)&(mean_array.row<cur_end)
                select_index2 = (mean_array.col>=cur_start)&(mean_array.col<cur_end)
                select_index = select_index1&select_index2
                cur_select_row = mean_array.row[select_index]-cur_start
                cur_select_col = mean_array.col[select_index]-cur_start
                cur_select_data = mean_array.data[select_index]
                cur_size = cur_end-cur_start
                cur_array = coo_matrix((cur_select_data,(cur_select_row,cur_select_col)),shape=(cur_size,cur_size))
                cur_array = cur_array.toarray()
                try:
                    cur_loc_list = mean_shift_merge(cur_array,cutoff=threshold)
                except:
                   cur_loc_list = []
                for loc in cur_loc_list:
                    x,y = loc
                    x+=cur_start
                    y+=cur_start
                    mean_loc_list.append([x,y])
                print(i, "detect length: mean",len(cur_loc_list),"total",len(mean_loc_list))
       
        print("%s detect length: mean"%chrom,len(mean_loc_list))
        if "_" in chrom:
            chrom = chrom.split("_")[0]
        append_record(output_bedpe,mean_loc_list,chrom,resolution=config_resolution)
def main_worker(args, input_pkl):
    resolution = args.resolution
    #check model_path exists
    model_path = os.path.abspath(args.model_path)
    assert os.path.exists(model_path), "model_path does not exist"
    output_dir = os.path.abspath(args.output)
    dataloader = configure_dataset(args, input_pkl)
    import model.Vision_Transformer_count as Vision_Transformer
    #should be a dyanmic input model
    patch_wise_size = (args.input_row_size//args.patch_size,args.input_col_size//args.patch_size)
    vit_backbone = Vision_Transformer.__dict__[args.model](img_size=(args.input_row_size,args.input_col_size))
    if args.task==6:
        # embedding genration inference
        # only load encoder weights
        checkpoint = torch.load(model_path, map_location='cpu')
        checkpoint_model = checkpoint['model']
        state_dict = vit_backbone.state_dict()
        for k in ['head.weight', 'head.bias']:
            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        # interpolate position embedding
        #this can apply to most scenarios but not our condition
        
        interpolate_pos_embed_inputsize(vit_backbone, checkpoint_model,input_size=patch_wise_size,
                                            use_decoder=False)
        # load pre-trained model
        msg = vit_backbone.load_state_dict(checkpoint_model, strict=False)
        print("Loading pre-train encoder message:",msg)
    from model.Finetune_Model_Head import Finetune_Model_Head
    model = Finetune_Model_Head(vit_backbone, task=args.task,
                            decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                        mlp_ratio=4., norm_layer=nn.LayerNorm,pos_embed_size=patch_wise_size)
    
    
    #load model weights
    if args.task!=6:
        checkpoint = torch.load(model_path, map_location='cpu')
        if "model" in checkpoint:
            checkpoint_model = checkpoint["model"]
        elif "state_dict" in checkpoint:
            checkpoint_model = checkpoint["state_dict"]
        else:
            checkpoint_model = checkpoint
        msg = model.load_state_dict(checkpoint_model, strict=False)
        print("Loading fine-tuned task-specific model message:",msg)
    else:
        
        checkpoint = torch.load(model_path, map_location='cpu')
        checkpoint_model = checkpoint['model']
        #loading pre-trained decoder
        interpolate_pos_embed_inputsize(model, checkpoint['model'],
                                        input_size=patch_wise_size,use_decoder=True)
        msg = model.load_state_dict(checkpoint_model, strict=False)
        print("Loading pre-train model decoder message:",msg)

    model = model.cuda()
    model = nn.DataParallel(model, device_ids=None)
    from inference.inference_worker import inference_worker
    return_dict= inference_worker(model,dataloader,
                                  log_dir=output_dir,
                                  args=args)
    if args.task==1:
        output_path = os.path.join(output_dir,"HiCFoundation_reproducibility_embedding.pkl")
        write_pickle(return_dict,output_path)
        print("Reproducibility analysis finished!")
        print("The embedding results are saved to ",output_path)
    elif args.task==2:
        #0.9 is used for benchmark, but please choose the threshold based on your own data

        threshold_list= [0.5,0.75,0.9]
        for threshold in threshold_list:
            output_bedpe = os.path.join(output_dir,"HiCFoundation_loop_{}.bedpe".format(threshold))
            generate_loop(return_dict,threshold,output_bedpe,resolution)
        print("Loop calling finished!")
        print("The loop calling results are saved to ",output_dir," with different thresholds in .bedpe format.")
    elif args.task==3:
        #convert to hic format as final output
        output_pkl = os.path.join(output_dir,"HiCFoundation_enhanced.pkl")
        #revise the return dict key if it has "_", make to one chromosome
        for key in list(return_dict.keys()):
            if "_" in key:
                key_list = key.split("_")
                return_dict[key_list[0]] = return_dict[key]
                del return_dict[key]
        write_pickle(return_dict,output_pkl)
        input_file = os.path.abspath(args.input)
        extention_name = input_file.split('.')[-1]
        output_file = os.path.join(output_dir,"HiCFoundation_enhanced."+extention_name)
        pkl2others(output_pkl, output_file,resolution,args.genome_id)
        if not os.path.exists(output_file):
            print("Error: file conversion failed.")
            print("Resolution enhancement finished!")
            print("The final output is saved in .pkl format, please convert it to other formats manually.")
            print("The .pkl file is saved to ",output_pkl)
    elif args.task==4:
        #epigenomic assay prediction
        output_path = os.path.join(output_dir,"HiCFoundation_epigenomic_assay_prediction.pkl")
        write_pickle(return_dict,output_path)
        #write to bigWig file
        key_word_list=['CTCF','H3K4me3','H3K27ac','H3K27me3','ATAC-seq','DNase-seq']
        for key_index,key_word in enumerate(key_word_list):
            current_dict={}
            for chrom in return_dict:
                if "_" in chrom:
                    chrom_key = chrom.split("_")[0]
                else:
                    chrom_key = chrom
                current_dict[chrom_key] = return_dict[chrom][key_index]
            current_pkl = os.path.join(output_dir,"HiCFoundation_epigenomic_assay_prediction_%s.pkl"%key_word)
            write_pickle(current_dict,current_pkl)
            output_bigwig = os.path.join(output_dir,"HiCFoundation_pred_%s.bigWig"%key_word)
            array2bigwig(current_pkl,output_bigwig,resolution=resolution)
        print("Epigenomic assay prediction finished!")
        print("The prediction results are saved to ",output_dir," in .pkl and .bigWig format.")

    elif args.task==5:
        #scHi-C enhancement
        output_path = os.path.join(output_dir,"HiCFoundation_sc_enhanced.pkl")
        write_pickle(return_dict,output_path)
        input_file = os.path.abspath(args.input)
        extention_name = input_file.split('.')[-1]
        output_file = os.path.join(output_dir,"HiCFoundation_sc_enhanced."+extention_name)
        pkl2others(output_path, output_file,resolution,args.genome_id)
        if not os.path.exists(output_file):
            print("Error: file conversion failed.")
            print("scHi-C enhancement finished!")
            print("The final output is saved in .pkl format, please convert it to other formats manually.")
            print("The .pkl file is saved to ",output_path)
    elif args.task==6:  
        #embedding generation
        output_path = os.path.join(output_dir,"HiCFoundation_embedding.pkl")
        write_pickle(return_dict,output_path)
        print("Hi-C embedding generation finished!")
        print("The embedding results are saved to ",output_path," in .pkl format.")


    print("Enjoy your HiCFoundation results!")

5. Save this as inference/load_model.py:

In [None]:
def load_model(model_path,input_row_size,input_col_size, task=6):
    """
    Load a model from a file.

    Args:
        model_path (str): The path to the model file.
        input_row_size (int): The number of rows in the input matrix.
        input_col_size (int): The number of columns in the input matrix.

    Returns:
        model: The loaded model.

    Notes:
        task 0: fine-tuning setting
        task 1: reproducibility analysis
        task 2: loop calling
        task 3: resolution enhancement
        task 4: epigenomic assay prediction
        task 5: scHi-C enhancement
        task 6: embedding analysis
    """
    import torch
    import model.Vision_Transformer_count as Vision_Transformer
    from model.pos_embed import interpolate_pos_embed_inputsize
    import torch.nn as nn

    model_name="vit_large_patch16"
    patch_size=16

    patch_wise_size = (input_row_size//patch_size, input_col_size//patch_size)
    vit_backbone = Vision_Transformer.__dict__[model_name](img_size=(input_row_size,input_col_size))
    checkpoint = torch.load(model_path, map_location='cpu')
    checkpoint_model = checkpoint['model']
    state_dict = vit_backbone.state_dict()
    for k in ['head.weight', 'head.bias']:
        if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
            print(f"Removing key {k} from pretrained checkpoint")
            del checkpoint_model[k]
    interpolate_pos_embed_inputsize(vit_backbone, checkpoint_model,input_size=patch_wise_size,
                                            use_decoder=False)
    # load pre-trained model
    msg = vit_backbone.load_state_dict(checkpoint_model, strict=False)
    print("Loading pre-train encoder message!")

    from model.Finetune_Model_Head import Finetune_Model_Head
    model = Finetune_Model_Head(vit_backbone, task=task,
                            decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                        mlp_ratio=4., norm_layer=nn.LayerNorm,pos_embed_size=patch_wise_size)
    checkpoint = torch.load(model_path, map_location='cpu')
    checkpoint_model = checkpoint['model']
    #loading pre-trained decoder
    interpolate_pos_embed_inputsize(model, checkpoint['model'],
                                    input_size=patch_wise_size,use_decoder=True)
    msg = model.load_state_dict(checkpoint_model, strict=False)
    print("Loading pre-train model decoder!")
    return model # return the loaded model

def to_cuda(x):
    """
    Move a tensor to the GPU.

    Args:
        x (torch.Tensor): The tensor to move to the GPU.

    Returns:
        torch.Tensor: The tensor on the GPU.
    """
    import torch
    if x is not None:
        #if it is float or int, change to tensor
        if type(x) is int or type(x) is float:
            x = torch.tensor(x)
        return x.cuda()
    else:
        return None
    
def to_float(x):
    """
    Convert a tensor to float.

    Args:
        x (torch.Tensor): The tensor to convert to float.

    Returns:
        torch.Tensor: The tensor as float.
    """
    import torch
    if x is not None:
        return x.float()
    else:
        return None

def convert_rgb(data_log,max_value):
    import torch
    if len(data_log.shape)==2:
        data_log = data_log[None,:,:]
    data_red = torch.ones(data_log.shape)
    data_log1 = (max_value-data_log)/max_value
    data_rgb = torch.cat([data_red,data_log1,data_log1],dim=0)
    return data_rgb

def format_input(input):
    """
    Format the input for the model.

    Args:
        input (torch.Tensor): The input tensor.

    Returns:
        torch.Tensor: The formatted input tensor.
    """
    import torch
    import torchvision.transforms as transforms
    transform_input = transforms.Compose([
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    
    input = torch.nan_to_num(input)
    max_value = torch.max(input)
    input = torch.log10(input+1)
    max_value = torch.log10(max_value+1)
    input = convert_rgb(input,max_value)
    
    input = transform_input(input)
    return input

6. Save this as inference/inference_worker.py:

In [None]:
import math
import sys
import numpy as np
from typing import Iterable
import torch
import torch.nn as nn
import time
from ops.Logger import MetricLogger,SmoothedValue
import os
from collections import defaultdict
from ops.sparse_ops import array_to_coo
from scipy.sparse import coo_matrix,triu
def inference_worker(model,data_loader,log_dir=None,args=None):
    """
    model: model for inference
    data_loader: data loader for inference
    log_dir: log directory for inference
    args: arguments for inference
    """
    model.eval()
    config_resolution = args.resolution
    metric_logger = MetricLogger(delimiter="  ")
    header = 'Inference: '
    print_freq = args.print_freq
    print("number of iterations: ",len(data_loader))
    num_iter = len(data_loader)
    dataset_shape_dict = data_loader.dataset.dataset_shape
    infer_task = args.task
    if infer_task==1:
        output_dict=defaultdict(list)
    elif infer_task==2 or infer_task==3 or infer_task==5:
        output_dict={}
        for chrom in dataset_shape_dict:
            output_dict[chrom] = {"row_record":[],"col_record":[],"value_record":[],"count_record":[]}
    elif infer_task==4:
        #epigenomic assay prediction
        num_track = 6
        output_dict={}
        for chrom in dataset_shape_dict:
            current_shape = dataset_shape_dict[chrom]
            current_length = current_shape[0]
            mean_array = np.zeros([num_track,current_length])
            count_array = np.zeros([num_track,current_length])
            output_dict[chrom] = {"mean":mean_array,"count":count_array}
    elif infer_task==6:
        output_dict={"submat_embedding":defaultdict(list),"patch_embedding":defaultdict(list)}

    if infer_task==3:
        #resolution enhancement
        cutoff= 1000
        cutoff = torch.tensor(cutoff).float().cuda()
        log_cutoff = torch.log10(cutoff+1).cuda()
    if infer_task==5:
        #scHi-C enhancement
        cutoff= 1000
        log_cutoff = np.log10(cutoff+1)
        output_dict={}
        for chrom in dataset_shape_dict:
            current_shape = dataset_shape_dict[chrom]
            current_length = current_shape[0]
            mean_array = np.zeros(current_shape)
            count_array = np.zeros(current_shape)
            output_dict[chrom] = {"mean":mean_array,"count":count_array}
    for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        input,total_count,indexes = data
        input = input.cuda()
        input = input.float()
        total_count = total_count.cuda()
        total_count = total_count.float()
        with torch.no_grad():
            output = model(input,total_count) 
            # fixme: loop, and epigenomic assay prediction did not take count in benchmark, I think this will not impact performance, will check later. If yes, will revise it to model(input)
        if infer_task==1:
            #reproducibility analysis
            pass
        elif infer_task==2:
            #loop calling
            output= torch.sigmoid(output)
        elif infer_task==3:
            #resolution enhancement
            output = output*log_cutoff
            output = torch.pow(10,output)-1
            output = torch.clamp(output,min=0)

        elif infer_task==6:
            #get the specified encoder/decoder layer's output
            output = output[args.embed_depth]


        # elif infer_task==5:
        #     #scHi-C enhancement
        #     output = output*log_cutoff
        #     output = torch.pow(10,output)-1
        #     output = torch.round(output)-2
        #     output = torch.clamp(output,min=0)

        output = output.detach().cpu().numpy()
        input = input.detach().cpu().numpy()
        chrs, row_starts, col_starts = indexes
        for i in range(len(output)):
            chr = chrs[i]
            row_start = row_starts[i]
            col_start = col_starts[i]
            row_start = int(row_start)
            col_start = int(col_start)
            row_start = max(0,row_start)
            col_start = max(0,col_start)
            current_shape = dataset_shape_dict[chr]
            row_end = min(row_start+args.input_row_size,current_shape[0])
            col_end = min(col_start+args.input_col_size,current_shape[1])
            current_input = input[i]
            #input_count = np.sum(current_input)
            #ignore empty matrix
            if np.isnan(np.sum(current_input)):
                print("empty matrix:",chr,row_start,col_start)
                continue

            # # may be not necessary, will check if error happens
            # if input_count<=len(current_input):
            #     #skip super low read count matrix
            #     #that's to say, <1 read per 10 kb, samller than 0.3M total read for human
            #     continue
            cur_output = output[i]
            if infer_task==1:
                match_key = f"{chr}:{row_start*config_resolution},{col_start*config_resolution}"
                output_dict[match_key] = cur_output
            elif infer_task==2 or infer_task==3:
                #loop calling, resolution enhancement
                cur_output = cur_output[:row_end-row_start,:col_end-col_start]
                cur_output = array_to_coo(cur_output)
                output_dict[chr]["row_record"].append(cur_output.row+row_start)
                output_dict[chr]["col_record"].append(cur_output.col+col_start)
                output_dict[chr]["value_record"].append(cur_output.data)
                output_dict[chr]["count_record"].append([1]*len(cur_output.data))
            elif infer_task==4:
                #epigenomic assay prediction
                cur_output = cur_output[:, :row_end-row_start]
                output_dict[chr]['mean'][:, row_start:row_end] += cur_output
                output_dict[chr]['count'][:, row_start:row_end] += 1

            elif infer_task==6:
                refer_row = row_start
                refer_col = col_start
                real_row_start = max(0,refer_row-args.input_row_size//2)
                real_col_start = max(0,refer_col-args.input_col_size//2)
                real_row_end = min(current_shape[0],refer_row+args.input_row_size//2)
                real_col_end = min(current_shape[1],refer_col+args.input_col_size//2)
                patch_row_range = (real_row_end-real_row_start)//args.patch_size
                patch_col_range = (real_col_end-real_col_start)//args.patch_size
                # cur_output = cur_output[:patch_row_range,:patch_col_range]
                # we can let the patch embedding choice.
                if args.patch_embedding:
                    for row_index in range(real_row_start,real_row_end, args.patch_size):
                        for col_index in range(real_col_start,real_col_end,args.patch_size):
                            row_index = int(row_index)
                            col_index = int(col_index)
                            patch_row_index = (row_index-real_row_start)//args.patch_size
                            patch_col_index = (col_index-real_col_start)//args.patch_size
                            cur_patch_embedding = cur_output[patch_row_index,patch_col_index]
                            middle_row = row_index+args.patch_size//2
                            middle_col = col_index+args.patch_size//2
                            search_key = f"{chr}:{middle_row*config_resolution},{middle_col*config_resolution}"
                            output_dict["patch_embedding"][search_key].append(cur_patch_embedding)
                            
                search_key = f"{chr}:{refer_row*config_resolution},{refer_col*config_resolution}"
                #average embedding
                all_embedding = cur_output.reshape(-1,cur_output.shape[-1])
                all_embedding = np.mean(all_embedding,axis=0)
                output_dict["submat_embedding"][search_key].append(all_embedding)


            elif infer_task == 5:
                #scHi-C enhancement
                if current_shape[0] < args.input_row_size or current_shape[1] < args.input_col_size:
                    #remove padding regions
                    left_up_pad_size = (args.input_row_size - current_shape[0]) // 2
                    right_down_pad_size = args.input_row_size - current_shape[0] - left_up_pad_size
                    left_up_pad_size_col = (args.input_col_size - current_shape[1]) // 2
                    right_down_pad_size_col = args.input_col_size - current_shape[1] - left_up_pad_size_col
                    cur_output = cur_output[left_up_pad_size:-right_down_pad_size, left_up_pad_size_col:-right_down_pad_size_col]
                    output_dict[chr]['mean'] += cur_output
                    output_dict[chr]['count'] += 1
                else:
                    output_dict[chr]['mean'][row_start:row_start+args.input_row_size, col_start:col_start+args.input_col_size] += cur_output
                    output_dict[chr]['count'][row_start:row_start+args.input_row_size, col_start:col_start+args.input_col_size] += 1
                # cur_output = array_to_coo(cur_output)
                # output_dict[chr]["row_record"].append(cur_output.row+row_start)
                # output_dict[chr]["col_record"].append(cur_output.col+col_start)
                # output_dict[chr]["value_record"].append(cur_output.data)
                # output_dict[chr]["count_record"].append([1]*len(cur_output.data))


    
    if infer_task==1:
        return output_dict
    elif infer_task==2 or infer_task==3:
        final_dict={}
        for chrom in output_dict:
            row_record = np.concatenate(output_dict[chrom]["row_record"])
            col_record = np.concatenate(output_dict[chrom]["col_record"])
            value_record = np.concatenate(output_dict[chrom]["value_record"])
            count_record = np.concatenate(output_dict[chrom]["count_record"])
            combine_row=np.concatenate([row_record,col_record])
            combine_col=np.concatenate([col_record,row_record])
            combine_value=np.concatenate([value_record,value_record])
            combine_count=np.concatenate([count_record,count_record])
            prediction_sym = coo_matrix((combine_value, (combine_row, combine_col)), shape=dataset_shape_dict[chrom])
            count_sym = coo_matrix((combine_count, (combine_row, combine_col)), shape=dataset_shape_dict[chrom])
            
            prediction_sym.sum_duplicates()
            count_sym.sum_duplicates()
            prediction_sym.data = prediction_sym.data/count_sym.data
            #remove very small prediction to save time
            select_index = prediction_sym.data>0.01
            prediction_sym.data = prediction_sym.data[select_index]
            prediction_sym.row = prediction_sym.row[select_index]
            prediction_sym.col = prediction_sym.col[select_index]
            print("finish summarize %s prediction"%chrom,prediction_sym.nnz)
            final_dict[chrom] = triu(prediction_sym,0)
        return final_dict
    elif infer_task==4:
        #epigenomic assay prediction
        return_dict={}
        for chrom in dataset_shape_dict:
            count_array=output_dict[chrom]['count']
            mean_array=output_dict[chrom]['mean']
            count_array =np.maximum(count_array,1)
            mean_array = mean_array/count_array
            mean_array = np.nan_to_num(mean_array)
            return_dict[chrom] = mean_array
        return return_dict
    elif infer_task == 5:
        return_dict={}
        for chrom in dataset_shape_dict:
            count_array=output_dict[chrom]['count']
            mean_array=output_dict[chrom]['mean']
            count_array =np.maximum(count_array,1)
            mean_array = (mean_array + mean_array.T)/2
            mean_array = mean_array/count_array
            mean_array = np.nan_to_num(mean_array)
            mean_array = mean_array*log_cutoff
            mean_array = np.power(10, mean_array) - 1
            mean_array = np.round(mean_array) - 2
            mean_array = np.clip(mean_array, 0, np.max(mean_array))
            return_dict[chrom] = np.triu(mean_array)
        return return_dict

    elif infer_task==6:
        #embedding generation
        return_dict={"submat_embedding":{},"patch_embedding":{},"chromo_embedding":{},"genome_embedding":{}}

        #read patch embedding in output_dict, average the same location embedding
        for key in output_dict["patch_embedding"]:
            cur_embedding = output_dict["patch_embedding"][key]
            cur_embedding = np.stack(cur_embedding,axis=0)
            cur_embedding = np.mean(cur_embedding,axis=0)
            return_dict["patch_embedding"][key] = cur_embedding
        
        #read submat embedding in output_dict, average the same location embedding
        chrom_embedding = defaultdict(list)
        for key in output_dict["submat_embedding"]:
            cur_embedding = output_dict["submat_embedding"][key]
            cur_embedding = np.stack(cur_embedding,axis=0)
            cur_embedding = np.mean(cur_embedding,axis=0)
            return_dict["submat_embedding"][key] = cur_embedding
            chrom = key.split(":")[0]
            chrom_embedding[chrom].append(cur_embedding)
        
        #get average chromo embedding
        for chrom in chrom_embedding:
            cur_embedding = chrom_embedding[chrom]
            cur_embedding = np.stack(cur_embedding,axis=0)
            cur_embedding = np.mean(cur_embedding,axis=0)
            return_dict["chromo_embedding"][chrom] = cur_embedding
        #get average genome embedding
        all_embedding = list(return_dict["chromo_embedding"].values())
        all_embedding = np.stack(all_embedding,axis=0)
        all_embedding = np.mean(all_embedding,axis=0)
        return_dict["genome_embedding"] = all_embedding
        return return_dict

7. Run inference on new Hi-C data (task 3 is for resolution enhancement):

In [None]:
python inference.py --batch_size 1 \
    --input B1-GSM4705442_cmt2cmt3.hic \
    --resolution 10000 \
    --task 3 \
    --input_row_size 224 --input_col_size 224 \
    --stride 32 --bound 0 \
    --num_workers 1 \
    --model hicfoundation_finetune/model/model_best.pth.tar \
    --model_path hicfoundation_finetune/model/model_best.pth.tar \
    --output B1

And Done! The directory B1 should have the enhanced files in both .hic and .pkl format