In [None]:

import os
dir='/home/houshiyuan/frp/CCTAD'
Dataset='HIC002'   
resolution=50
output_dir=dir+'/output/'+Dataset+f'/{resolution}kb'
models_dir=dir+'/config/models/'+Dataset+f'/{resolution}kb'
best_threshold_dir=dir+'/config/best_threshold'
final_output=dir+'/final_output/'+Dataset+f'/{resolution}kb'
path='/mnt/sdi/frp/data/marks'  
hic_matrix_dir='/mnt/sdi/frp/data/hic_matrix'  

chromosomes = [f"chr{i}" for i in range(1,23)] 


In [None]:

os.makedirs(output_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)
os.makedirs(final_output, exist_ok=True)
os.makedirs(os.path.join(final_output, "TAD"), exist_ok=True)
os.makedirs(os.path.join(final_output, "TAD_res"), exist_ok=True)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.cluster import AgglomerativeClustering
from sklearn.preprocessing import StandardScaler
import numpy as np
from scipy.sparse import lil_matrix
import sys
import os
import random

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")  
print(f"Using device: {device}")

class ConvAutoencoder1D(nn.Module):
    def __init__(self, input_dim=100, embedding_dim=16):
        super(ConvAutoencoder1D, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(1, 8, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(8, 4, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(input_dim * 4, embedding_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(embedding_dim, input_dim * 4),
            nn.Unflatten(1, (4, input_dim)),
            nn.ConvTranspose1d(4, 8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(8, 1, kernel_size=5, padding=2),
        )

    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon, z

def create_connectivity(n):
    mat = lil_matrix((n, n))
    for i in range(n - 1):
        mat[i, i + 1] = mat[i + 1, i] = 1
    return mat

def run_clustering(features, distance_threshold=1.5):
    clustering = AgglomerativeClustering(
        n_clusters=None,
        distance_threshold=distance_threshold,
        linkage='average',
        connectivity=create_connectivity(len(features))
    )
    return clustering.fit_predict(features)

def train_model(hic_matrix, chrom,input_dim=100, embedding_dim=16, num_epochs=500, lr=1e-4,load_model=False):   
    features = StandardScaler().fit_transform(hic_matrix)
    features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(1).to(device)
    model = ConvAutoencoder1D(input_dim=input_dim, embedding_dim=embedding_dim).to(device)
    
    if load_model:
        best_model_path = models_dir+f'/best_model_{chrom}.pt'
        if os.path.exists(best_model_path):
            print("加载已有模型参数：", best_model_path)
            model.load_state_dict(torch.load(best_model_path, map_location=device))
        else:
            print("未找到模型文件，无法加载。")
        model.eval()
        with torch.no_grad():
                _, z = model(features_tensor)
        return model, z
    
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        model.train()
        for epoch in range(num_epochs):
            optimizer.zero_grad()
            recon, z = model(features_tensor)
            recon_loss = F.mse_loss(recon.squeeze(1), features_tensor.squeeze(1))
            recon_loss.backward()
            optimizer.step()
            
        return model, z


Using device: cuda:3


In [None]:
def identify_tad_boundaries(cluster_labels):  
    boundaries = []
    for i in range(1, len(cluster_labels)):
        if cluster_labels[i] != cluster_labels[i - 1]:
            boundaries.append(i)
    return boundaries

def create_boundary_labels_from_clusters(cluster_labels, hic_len, save_path):  
    labels = np.zeros(hic_len, dtype=int)
    boundary_bins = []

    for i in range(1, len(cluster_labels)):
        if cluster_labels[i] != cluster_labels[i - 1]:
            labels[i] = 1
            boundary_bins.append(i)

    np.savetxt(save_path, labels, fmt='%d')
   
    return boundary_bins

def merge_consecutive_small_tads(input_file, output_file, min_bin_size=3):
    with open(input_file, 'r') as f:
        lines = [list(map(int, line.strip().split())) for line in f if line.strip()]

    merged = []
    i = 0
    while i < len(lines):
        start, end = lines[i]
        if (end - start) < min_bin_size:           
            merge_start = start
            merge_end = end
            i += 1
            while i < len(lines) and lines[i][0] == merge_end and (lines[i][1] - lines[i][0]) < min_bin_size:
                merge_end = lines[i][1]
                i += 1
            merged.append((merge_start, merge_end))
        else: 
            merged.append((start, end))
            i += 1
    with open(output_file, 'w') as f:
        for s, e in merged:
            f.write(f"{s} {e}\n")

def process_tad_file(input_path, output_path, min_size=2):
    def read_intervals_from_file(file_path):
        intervals = []
        with open(file_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 2:
                    intervals.append([int(parts[0]), int(parts[1])])
        return intervals

    def write_intervals_to_file(intervals, output_path):
        with open(output_path, 'w') as f:
            for interval in intervals:
                f.write(f"{interval[0]} {interval[1]}\n")

    def merge_tads(tad_intervals, min_size):
        if not tad_intervals:
            return []
        merged = []
        i = 0
        while i < len(tad_intervals):
            curr = tad_intervals[i]
            start, end = curr
            length = end - start
            if length < min_size:              
                if merged and merged[-1][1] == start:
                    merged[-1][1] = end               
                elif i + 1 < len(tad_intervals) and tad_intervals[i + 1][0] == end:
                    tad_intervals[i + 1][0] = start               
            else:
                merged.append([start, end])
            i += 1
        return merged
 
    intervals = read_intervals_from_file(input_path)
    merged_intervals = merge_tads(intervals, min_size)
    write_intervals_to_file(merged_intervals, output_path)    

def convert_boundary_file_to_tads(input_file, output_file, start=1):
    with open(input_file, 'r') as f:
        boundaries = [int(line.strip()) for line in f if line.strip().isdigit()]
    current = start
    with open(output_file, 'w') as f:
        for b in boundaries:
            f.write(f"{current} {b}\n")
            current = b
            
def convert_tad_resolution(input_file, output_file, resolution=25000):
    tads = np.loadtxt(input_file, dtype=int)
    if tads.ndim == 1:
        tads = np.expand_dims(tads, axis=0)

    tads_bp = tads * resolution
    np.savetxt(output_file, tads_bp, fmt="%d", delimiter="\t")
  
def mark_label_transitions(labels):
    labels = np.array(labels)
    transitions = np.zeros_like(labels)
    for i in range(1, len(labels)):
        if labels[i] != labels[i - 1]:
            transitions[i] = 1
            transitions[i - 1] = 1  
    return transitions
def mark_transitions_from_file(input_file, output_file):
    with open(input_file, 'r') as f:
        labels = [int(line.strip()) for line in f if line.strip().isdigit()]

    transitions = mark_label_transitions(labels)
    np.savetxt(output_file, transitions, fmt='%d')
  

In [None]:


def load_tad_intervals(filepath):
    tad_bins = []
    with open(filepath, 'r') as f:
        lines = f.readlines()
        for line in lines:
            start, end = map(int, line.strip().split())
     
            bins = list(range(start, end))
            tad_bins.append(bins)
    return tad_bins

def compute_intra_contact(tad_bins, hic_matrix):
    if len(tad_bins) <= 1:
        return 0.0
    sub_matrix = hic_matrix[np.ix_(tad_bins, tad_bins)]
    tril_indices = np.tril_indices_from(sub_matrix, k=-1)
    values = sub_matrix[tril_indices]
    return np.nanmean(values) if len(values) > 0 else 0.0

def compute_inter_contact(tad_bins, neighbor_bins, hic_matrix, boundary_size=2):
    if len(tad_bins) == 0 or len(neighbor_bins) == 0:
        return 0.0
    left_edge = tad_bins[:boundary_size]
    right_edge = neighbor_bins[-boundary_size:]
    sub_matrix = hic_matrix[np.ix_(left_edge, right_edge)]
    return np.nanmean(sub_matrix) if sub_matrix.size > 0 else 0.0

def evaluate_tad_quality(tads, hic_matrix, boundary_size=2):
    quality_scores = []
    for i, tad in enumerate(tads):
        intra = compute_intra_contact(tad, hic_matrix)
        left = tads[i - 1] if i > 0 else []
        right = tads[i + 1] if i < len(tads) - 1 else []
        inter_left = compute_inter_contact(tad, left, hic_matrix, boundary_size)
        inter_right = compute_inter_contact(tad, right, hic_matrix, boundary_size)

        inters = [v for v in [inter_left, inter_right] if v > 0]
        inter = np.mean(inters) if inters else 0
        quality = intra - inter
        quality_scores.append(quality)

    average_score = np.mean(quality_scores) if quality_scores else 0
    return quality_scores, average_score


In [None]:

import time
def tad(hic_matrix,chrom,seed,load_model,threshold):
    if seed is not None:
        seed_everything(seed) 
    trained_model, z = train_model(hic_matrix, chrom,input_dim=hic_matrix.shape[1],num_epochs=500,load_model=load_model)    
    z_np = z.detach().cpu().numpy()
    cluster_labels = run_clustering(z_np, distance_threshold=threshold)   
   
    if all(x == 0 for x in cluster_labels):
        return 0,None
    cluster_labels_path=os.path.join(output_dir, "cluster_labels", f"{chrom}_cluster_labels.txt") 
    os.makedirs(os.path.dirname(cluster_labels_path), exist_ok=True)
    np.savetxt(cluster_labels_path, cluster_labels, fmt="%d")
  
    boundary=identify_tad_boundaries(cluster_labels)
    boundary_file=os.path.join(output_dir, "boundary", f"boundary_{chrom}.txt")
    os.makedirs(os.path.dirname(boundary_file), exist_ok=True)
    np.savetxt(boundary_file, boundary, fmt="%d")

    tad_path=os.path.join(output_dir, "TAD", f"tads_{chrom}.txt")
    os.makedirs(os.path.dirname(tad_path), exist_ok=True)
    convert_boundary_file_to_tads(boundary_file, tad_path, start=1)

    merge_consecutive_small_tads(tad_path,tad_path,min_bin_size=3)  
    process_tad_file(tad_path,tad_path,min_size=3)
    
    res_tad_path=os.path.join(output_dir, "TAD_res", f"res_tads_{chrom}.txt")
    os.makedirs(os.path.dirname(res_tad_path), exist_ok=True)
    convert_tad_resolution(tad_path, res_tad_path, resolution=resolution*1000)
    return None,trained_model


In [None]:
def find_best_threshold( hic_matrix,chrom, thresholds):
    
    cluster_labels_path=os.path.join(output_dir, "cluster_labels", f"{chrom}_cluster_labels.txt")  
    os.makedirs(os.path.dirname(cluster_labels_path), exist_ok=True)
    boundary_file=os.path.join(output_dir, "boundary", f"boundary_{chrom}.txt")
    os.makedirs(os.path.dirname(boundary_file), exist_ok=True)
    tad_path=os.path.join(output_dir, "TAD", f"tads_{chrom}.txt")
    os.makedirs(os.path.dirname(tad_path), exist_ok=True)
    res_tad_path=os.path.join(output_dir, "TAD_res", f"res_tads_{chrom}.txt")
    os.makedirs(os.path.dirname(res_tad_path), exist_ok=True)
    os.makedirs(os.path.join(output_dir, "labels"), exist_ok=True)

    #模型训练
    trained_model, z = train_model(hic_matrix, chrom,input_dim=hic_matrix.shape[1],num_epochs=500,load_model=False)    
    z_np = z.detach().cpu().numpy()
    best_score=0
    for threshold in thresholds:
        cluster_labels = run_clustering(z_np, distance_threshold=threshold)   
        if all(x == 0 for x in cluster_labels):
            continue 
        boundary=identify_tad_boundaries(cluster_labels)
        np.savetxt(boundary_file, boundary, fmt="%d")
        convert_boundary_file_to_tads(boundary_file, tad_path, start=1)
        merge_consecutive_small_tads(tad_path,tad_path,min_bin_size=3)  
        process_tad_file(tad_path,tad_path,min_size=3)
        tads = load_tad_intervals(tad_path)  
        # 计算质量评分
        scores, avg = evaluate_tad_quality(tads, hic_matrix, boundary_size=2)
        if avg > best_score:
            best_score = avg
            best_t = threshold
        #print(f"{chrom}聚类阈值为：{threshold:.1f}，TAD质量评分：{avg:.4f}")
    #print(f"{chrom}最优聚类阈值为：{best_t:.1f}，TAD质量评分：{best_score:.4f}")
    return best_t

In [None]:
import shutil
all_thresholds = []
for chrom in chromosomes:
    hic_matrix = np.loadtxt( f"{hic_matrix_dir}/{Dataset}/{resolution}kb/{Dataset}_{resolution}k_KR.{chrom}")                    
    best_threshold=find_best_threshold(hic_matrix,chrom, thresholds=np.arange(0.1, 2.0, 0.1))
    all_thresholds.append(best_threshold)
    print(f'{chrom}:{best_threshold:.2f}')
    
average_threshold = np.mean(all_thresholds)
print(f'Average threshold across all chromosomes: {average_threshold:.3f}')

with open(os.path.join(best_threshold_dir, f"{Dataset}_{resolution}kb_best_threshold.txt"), "a") as f:
    f.write(f"best_threshold:{average_threshold:.1f}\n")