In [1]:
import os
import torch
import pandas as pd
import subprocess

from utils.anti_numbering import get_regions

In [24]:
kernel_size = 3
dilation = 1
stride = 1
input_size = 256

padding = ((stride - 1) * (input_size - 1) - input_size + dilation * (kernel_size - 1) + 1) // 2

In [25]:
padding

-127

In [22]:
import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout=0.1):
        super(CrossAttention, self).__init__()
        self.attention = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.nhead = nhead

    def forward(self, x1, x2, mask1, mask2, need_weights=True):
        x1 = x1.transpose(0, 1)
        x2 = x2.transpose(0, 1)

        # 将mask1和mask2组合成一个合适的注意力mask
        attn_mask = mask1 * mask2.transpose(1, 2)
        attn_mask = attn_mask.repeat(self.nhead, 1, 1)  # 重复mask以适应多头注意力的维度
        attn_mask = attn_mask.eq(0)  # 将mask中的0转换为True，其他值转换为False

        print(attn_mask)

        # 更新key_padding_mask
        key_padding_mask = mask2.squeeze(-1).eq(0)  # 将mask2中的0转换为True，其他值转换为False

        attn_output, attn_weights = self.attention(x1, x2, x2, key_padding_mask=key_padding_mask, attn_mask=attn_mask, need_weights=need_weights)
        attn_output = attn_output.transpose(0, 1)
        return attn_output, attn_weights

# 示例
batch_size = 16
seq_len1 = 10
seq_len2 = 12
d_model = 256
nhead = 8

x1 = torch.randn(batch_size, seq_len1, d_model)
x2 = torch.randn(batch_size, seq_len2, d_model)

mask1 = torch.ones(batch_size, seq_len1, 1)
mask2 = torch.ones(batch_size, seq_len2, 1)

# 将mask1和mask2的某些元素设置为0，表示这些位置应该被mask掉
mask1[:, 5:] = 0
mask2[:, 8:] = 0

cross_attention = CrossAttention(d_model, nhead)
output, attn_weights = cross_attention(x1, x2, mask1, mask2)

print("Output shape:", output.shape)
print("Attention weights shape:", attn_weights.shape)
print(output)

tensor([[[False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True],
         ...,
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True]],

        [[False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True],
         ...,
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True]],

        [[False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True],
         ...,
         [ True,  True,  True,  ...,  True,  True,  True],
         [

In [12]:
import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout=0.1):
        super(CrossAttention, self).__init__()
        self.attention = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

    def forward(self, x1, x2, need_weights=True):
        # x1 shape: (batch_size, seq_len1, d_model)
        # x2 shape: (batch_size, seq_len2, d_model)
        x1 = x1.transpose(0, 1)  # 将x1转置为（seq_len1, batch_size, d_model）
        x2 = x2.transpose(0, 1)  # 将x2转置为（seq_len2, batch_size, d_model）
        attn_output, attn_weights = self.attention(x1, x2, x2, need_weights=need_weights)
        attn_output = attn_output.transpose(0, 1)  # 将输出转置回（batch_size, seq_len1, d_model）
        return attn_output, attn_weights

# 示例
batch_size = 16
seq_len1 = 10
seq_len2 = 12
d_model = 256
nhead = 8

x1 = torch.randn(batch_size, seq_len1, d_model)  # 假设有16个样本，每个样本具有10个位置，每个位置的向量维度为256
x2 = torch.randn(batch_size, seq_len2, d_model)  # 假设有16个样本，每个样本具有12个位置，每个位置的向量维度为256

cross_attention = CrossAttention(d_model, nhead)
output, attn_weights = cross_attention(x1, x2)

print("Output shape:", output.shape)  # 输出的形状应该与x1张量的形状相同
print("Attention weights shape:", attn_weights.shape)  # 注意力权重的形状应为（batch_size, seq_len1, seq_len2）

Output shape: torch.Size([16, 10, 256])
Attention weights shape: torch.Size([16, 10, 12])


In [9]:
import numpy as np
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import label_binarize

# 示例数据
y_true = np.array([0, 1, 2, 0, 1, 2])  # 注意这里没有类别3
y_score = np.array([[0.9, 0.05, 0.03, 0.02],
                    [0.1, 0.8, 0.05, 0.05],
                    [0.0, 0.2, 0.8, 0.0],
                    [0.9, 0.1, 0.0, 0.0],
                    [0.1, 0.9, 0.0, 0.0],
                    [0.1, 0.1, 0.8, 0.0]])

# 将y_true转换为二进制矩阵
n_classes = 4  # 明确指定类别数量
y_true_bin = label_binarize(y_true, classes=np.arange(n_classes))
print(y_true_bin)
# 计算多分类ROC AUC
roc_auc = roc_auc_score(y_true_bin, y_score, labels=np.arange(n_classes))
print("Multiclass ROC AUC:", roc_auc)

[[1 0 0 0]
 [0 1 0 0]
 [0 0 1 0]
 [1 0 0 0]
 [0 1 0 0]
 [0 0 1 0]]


ValueError: Only one class present in y_true. ROC AUC score is not defined in that case.

In [63]:
# Get imgt sequence
def trans_seq_to_imgt_type(aa_seq):
    cmd_str = f'ANARCI --sequence {aa_seq}'
    cmd_out = subprocess.check_output(cmd_str, shell=True)
    # cmd_out = os.system(cmd_str)
    line_strs = cmd_out.decode('utf-8').split('\n')
    sub_strs = line_strs[5].split('|')

    chn_type = sub_strs[2]
    if chn_type == 'K':
        chn_type = 'L'
    idx_resd_beg = int(sub_strs[5])  # inclusive
    idx_resd_end = int(sub_strs[6])  # inclusive

    idx_resd = idx_resd_beg
    labl_vec = torch.zeros(len(aa_seq), dtype=torch.int8)  # 0: framework
    reorder_seq = []
    for line_str in line_strs:
        if not line_str.startswith(chn_type):
            continue
        if line_str.endswith('-'):
            reorder_seq.append('-')
            continue
        idx_resd_imgt = int(line_str.split()[1])
        reorder_seq.append(line_str.split()[2])
        if 27 <= idx_resd_imgt <= 38:
            labl_vec[idx_resd] = 1  # CDR-1
        elif 56 <= idx_resd_imgt <= 65:
            labl_vec[idx_resd] = 2  # CDR-2
        elif 105 <= idx_resd_imgt <= 117:
            labl_vec[idx_resd] = 3  # CDR-3
        idx_resd += 1
    assert idx_resd == idx_resd_end + 1, f'{idx_resd} {idx_resd_beg} {idx_resd_end} {chn_type} {cmd_out}'

    print(''.join(reorder_seq))
    print(len(''.join(reorder_seq)))
    print(len(labl_vec))
    return labl_vec

In [66]:
def trans_batch(csv_path):
    """
    :param csv_path: specific csv path
    :return: the seq of the data frame
    """
    pair_data = pd.read_csv(csv_path, header=0)
    for idx, row_data in pair_data.iterrows():
        H_seq = row_data['HSEQ']
        L_seq = row_data['LSEQ']
        if idx == 0:
            print(H_seq)
            imgt_H_seq = trans_seq_to_imgt_type(H_seq)
            print(L_seq)
            imgt_L_seq = trans_seq_to_imgt_type(L_seq)
        elif idx == 0:
            continue
        else:
            break

In [67]:
paired_csv = '/data/home/waitma/antibody_proj/antidiff/data/oas_pair_human_data/oas_paired.csv'
trans_batch(paired_csv)

QVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLEWMGWINPNSGGTNYAQKFQGRVTMTRDTSISTAYMELSRLRSDDTAVYYCARAGAVAASKGYYYYYYGMDVWGQGTTVTVSS
QVQLVQSGA-EVKKPGASVKVSCKASGYTF----TGYYMHWVRQAPGQGLEWMGWINPN--SGGTNYAQKFQ-GRVTMTRDTSISTAYMELSRLRSDDTAVYYCARAGAVAABCDDCBAYYGMDVWGQGTTVTVSS
136
128
 QSVLTQPPSVSEAPRQRVTISCSGSSSNIGNNAVNWYQQLPGKAPKLLIYYDDLLPSGVSDRFSGSKSGTSASLAISGLQSEDEADYYCAAWDDSLNVVVFGGGTKLTVL
QSVLTQPPS-VSEAPRQRVTISCSGSSSNI----GNNAVNWYQQLPGKAPKLLIYYD-------DLLPSGVS-DRFSGSK--SGTSASLAISGLQSEDEADYYCAAWDDS--LNVVVFGGGTKLTVL
127
111


In [22]:
paired_csv = '/data/home/waitma/antibody_proj/antidiff/data/oas_pair_human_data/oas_paired.csv'
trans_batch(paired_csv)

QVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLEWMGWINPNSGGTNYAQKFQGRVTMTRDTSISTAYMELSRLRSDDTAVYYCARAGAVAASKGYYYYYYGMDVWGQGTTVTVSS
QVQLVQSGA-EVKKPGASVKVSCKASGYTF----TGYYMHWVRQAPGQGLEWMGWINPN--SGGTNYAQKFQ-GRVTMTRDTSISTAYMELSRLRSDDTAVYYCARAGAVAABCDDCBAYYGMDVWGQGTTVTVSS
 QSVLTQPPSVSEAPRQRVTISCSGSSSNIGNNAVNWYQQLPGKAPKLLIYYDDLLPSGVSDRFSGSKSGTSASLAISGLQSEDEADYYCAAWDDSLNVVVFGGGTKLTVL
QSVLTQPPS-VSEAPRQRVTISCSGSSSNI----GNNAVNWYQQLPGKAPKLLIYYD-------DLLPSGVS-DRFSGSK--SGTSASLAISGLQSEDEADYYCAAWDDS--LNVVVFGGGTKLTVL


# change compare

In [24]:
import abnumber
from abnumber import Chain

In [32]:
paired_csv = '/data/home/waitma/antibody_proj/antidiff/data/oas_pair_human_data/oas_paired.csv'

pair_data = pd.read_csv(paired_csv, header=0)

In [39]:
seqs = pair_data.head(100)

In [40]:
seqs

Unnamed: 0,ENTRY,HSEQ,LSEQ
0,1279049_1_Paired_All_0_H|1279049_1_Paired_All_0_L,QVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLE...,QSVLTQPPSVSEAPRQRVTISCSGSSSNIGNNAVNWYQQLPGKAP...
1,1279049_1_Paired_All_1_H|1279049_1_Paired_All_1_L,QVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLE...,QSVLTQPPSASGTPGQRVTISCSGSSSNIGSNTVNWYQQLPGTAP...
2,1279049_1_Paired_All_2_H|1279049_1_Paired_All_2_L,QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYAMHWVRQAPGQRLE...,SYELTQPPSVSVSPGQTARITCSGDALPKQYAYWYQQKPGQAPVL...
3,1279049_1_Paired_All_3_H|1279049_1_Paired_All_3_L,EVQLVESGGGLVKPGGSLRLSCAASGFTFSNAWMSWVRQAPGKGLE...,SYELTQPPSVSVSPGQTARITCSGDALPKKYAYWYQQKSGQAPVL...
4,1279049_1_Paired_All_4_H|1279049_1_Paired_All_4_L,EVQLLESGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLE...,DIQMTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPK...
...,...,...,...
95,1279049_1_Paired_All_95_H|1279049_1_Paired_All...,EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYEMNWVRQAPGKGLE...,DIQLTQSPSFLSASVGDRVTITCRASQGISSYLAWYQQKPGKAPK...
96,1279049_1_Paired_All_96_H|1279049_1_Paired_All...,QVQLQESGPGLVKPSETLSLTCTVSGGSISSYYWSWIRQPPGKGLE...,SYELTQPPSVSVSPGQTARITCSGDALPKQYAYWYQQKPGQAPVL...
97,1279049_1_Paired_All_97_H|1279049_1_Paired_All...,QLQLQESGPGLVKPSETLSLTCTVSGGSISSSSYYWGWIRQPPGKG...,QSVLTQPPSVSGAPGQRVTISCTGSSSNIGAGYDVHWYQQLPGTA...
98,1279049_1_Paired_All_98_H|1279049_1_Paired_All...,QLQLQESGPGLVKPSETLSLTCTVSGGSISSSSYYWGWIRQPPGKG...,EIVLTQSPATLSLSPGERATLSCRASQSVSSYLAWYQQKPGQAPR...


In [41]:
len(seqs)

100

In [42]:
heavy_chains = seqs.apply(lambda row: Chain(row['HSEQ'], name=row.name, scheme='imgt'), axis=1)

In [87]:
heavy_chains[0]

QVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLEWMGWINPNSGGTNYAQKFQGRVTMTRDTSISTAYMELSRLRSDDTAVYYCARAGAVAASKGYYYYYYGMDVWGQGTTVTVSS
                         ^^^^^^^^                 ^^^^^^^^                                      ^^^^^^^^^^^^^^^^^^^^^           

In [89]:
light_chains = seqs.apply(lambda row: Chain(row['LSEQ'], name=row.name, scheme='imgt'), axis=1)

In [90]:
light_chains[0]

QSVLTQPPSVSEAPRQRVTISCSGSSSNIGNNAVNWYQQLPGKAPKLLIYYDDLLPSGVSDRFSGSKSGTSASLAISGLQSEDEADYYCAAWDDSLNVVVFGGGTKLTVL
                         ^^^^^^^^                 ^^^                                    ^^^^^^^^^^^          

In [None]:
QVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLEWMGWINPNSGGTNYAQKFQGRVTMTRDTSISTAYMELSRLRSDDTAVYYCARAGAVAASKGYYYYYYGMDVWGQGTTVTVSS
QVQLVQSGA-EVKKPGASVKVSCKASGYTF----TGYYMHWVRQAPGQGLEWMGWINPN--SGGTNYAQKFQ-GRVTMTRDTSISTAYMELSRLRSDDTAVYYCARAGAVAABCDDCBAYYGMDVWGQGTTVTVSS

In [46]:
len('QVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLEWMGWINPNSGGTNYAQKFQGRVTMTRDTSISTAYMELSRLRSDDTAVYYCARAGAVAASKGYYYYYYGMDVWGQGTTVTVSS')

128

In [47]:
seq1 = 'QVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLEWMGWINPNSGGTNYAQKFQGRVTMTRDTSISTAYMELSRLRSDDTAVYYCARDSLSAWGQGTLVTVSS'
len(seq1)

114

In [48]:
chain1 = Chain(seq1, scheme='imgt')

In [49]:
chain1.print_tall()


       H1 Q          H25 A          H52 W          H78 M         H101 V
       H2 V          H26 S          H53 M          H79 T         H102 Y
       H3 Q     CDR1 H27 G          H54 G          H80 R         H103 Y
       H4 L     CDR1 H28 Y          H55 W          H81 D         H104 C
       H5 V     CDR1 H29 T     CDR2 H56 I          H82 T    CDR3 H105 A
       H6 Q     CDR1 H30 F     CDR2 H57 N          H83 S    CDR3 H106 R
       H7 S     CDR1 H35 T     CDR2 H58 P          H84 I    CDR3 H107 D
       H8 G     CDR1 H36 G     CDR2 H59 N          H85 S    CDR3 H108 S
       H9 A     CDR1 H37 Y     CDR2 H62 S          H86 T    CDR3 H115 L
      H11 E     CDR1 H38 Y     CDR2 H63 G          H87 A    CDR3 H116 S
      H12 V          H39 M     CDR2 H64 G          H88 Y    CDR3 H117 A
      H13 K          H40 H     CDR2 H65 T          H89 M         H118 W
      H14 K          H41 W          H66 N          H90 E         H119 G
      H15 P          H42 V          H67 Y          H91 L        

In [10]:
117-105


12

AttributeError: 'str' object has no attribute 'print_tall'

In [54]:
chain1

QVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLEWMGWINPNSGGTNYAQKFQGRVTMTRDTSISTAYMELSRLRSDDTAVYYCARDSLSAWGQGTLVTVSS
                         ^^^^^^^^                 ^^^^^^^^                                      ^^^^^^^           

In [57]:
len('QVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLEWMGWINPNSGGTNYAQKFQGRVTMTRDTSISTAYMELSRLRSDDTAVYYCARDSLSAWGQGTLVTVSS')

114

In [68]:
seq1


'QVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLEWMGWINPNSGGTNYAQKFQGRVTMTRDTSISTAYMELSRLRSDDTAVYYCARDSLSAWGQGTLVTVSS'

In [69]:
for i in seq1:
    print(i)


Q
V
Q
L
V
Q
S
G
A
E
V
K
K
P
G
A
S
V
K
V
S
C
K
A
S
G
Y
T
F
T
G
Y
Y
M
H
W
V
R
Q
A
P
G
Q
G
L
E
W
M
G
W
I
N
P
N
S
G
G
T
N
Y
A
Q
K
F
Q
G
R
V
T
M
T
R
D
T
S
I
S
T
A
Y
M
E
L
S
R
L
R
S
D
D
T
A
V
Y
Y
C
A
R
D
S
L
S
A
W
G
Q
G
T
L
V
T
V
S
S


In [70]:
a = {'c': 2, 'd': 3}

In [71]:
a.keys()

dict_keys(['c', 'd'])

In [73]:
'd' in a.keys()

True

In [86]:
red_list = ['Q', 'V', 'L', 'S', 'G', 'A', 'E', 'K', 'P', 'C', 'Y', 'T', 'F', 'M', 'H', 'W', 'R', 'I', 'N', 'D']
len(red_list)

20

In [75]:
High_H = {'Q': 9943359, 
         'V': 16586989, 
         'L': 14587607, 
         'S': 27400373, 
         'G': 22067603, 
         'A': 12735476, 
         'E': 6586589, 
         'K': 7895777, 
         'P': 6166750, 
         'C': 3647638, 
         'Y': 13508200, 
         'T': 15424020, 
         'F': 5453165, 
         'M': 3637888, 
         'H': 1289489, 
         'W': 6699793, 
         'R': 9260831, 
         'I': 6144554, 
         'N': 5661022, 
         'D': 8427505}
High_L = {'Q': 12722690, 
          'S': 27412177, 
          'V': 9593981, 
          'L': 13354920, 
          'T': 15532925, 
          'P': 10751101, 
          'E': 5382286, 
          'A': 10910077, 
          'R': 6685286, 
          'I': 8250769, 
          'C': 3473020, 
          'G': 17643109, 
          'N': 4115497, 
          'W': 2670343, 
          'Y': 9838025, 
          'K': 6728829, 
          'D': 7253657, 
          'F': 5683407, 
          'M': 1164497, 
          'H': 897210}


In [79]:
sum = 0
for v in High_H.values():
    sum += v


In [81]:
sumH = sum

In [82]:
sumV = 0
for v in High_L.values():
    sumV += v

In [83]:
sumV

180063806

In [85]:
for k in High_H.keys():
    v = High_H[k]
    p = v / sumH
    print(k, p)

Q 0.04895201088072885
V 0.08165917231858266
L 0.0718160429074115
S 0.13489439104351245
G 0.10864070603984072
A 0.06269784282386477
E 0.03242634369279928
K 0.03887158872729111
P 0.03035944021519636
C 0.017957635348875567
Y 0.06650202948310138
T 0.07593377598702605
F 0.02684639993531459
M 0.017909635260968945
H 0.006348265164576696
W 0.03298365671345377
R 0.04559186688085898
I 0.030250167399691188
N 0.027869697809366573
D 0.04148933136753855
