In [1]:
from peft import PeftModel
from transformers import T5Tokenizer, T5EncoderModel
from clalign import ProteinSeq, AlignmentResult, get_embs, fast_align, f1score

In [2]:
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_uniref50')
model = T5EncoderModel.from_pretrained('Rostlab/prot_t5_xl_uniref50').cuda()

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [3]:
model = PeftModel.from_pretrained(model, 'clalign/CLAlignT5')

In [4]:
def align(seq1, seq2):
    embs = get_embs([seq1, seq2], tokenizer, model, 2048, True)
    return fast_align(seq1, seq2, embs[0] @ embs[1].T)

In [5]:
import numpy as np
from clalign import ProteinSeq, AlignmentResult, f1score

def test(data):
    name_list, f_list, tms = [], [], []
    with open(f'data/{data}.csv') as fp, open(f'results/{data}/clalign.txt', 'w') as fout:
        fp.readline()
        for line in fp:
            name, f1, f2, s1, s2, aln1, aln2 = line.strip().split(',')
            name_list.append(name)
            manual = AlignmentResult(seq1 := ProteinSeq(s1), seq2 := ProteinSeq(s2), aln1, aln2)
            aln_res = align(seq1, seq2)
            f_list.append(f_ := f1score(manual, aln_res))
            with open(out_:=(f'results/{data}/clalign/{name}.txt'), 'w') as faln:
                print('>p1', file=faln)
                print(aln_res.aln1, file=faln)
                print('>p2', file=faln)
                print(aln_res.aln2, file=faln)
            out = !TMalign data/{data}/{name}/{f1} data/{data}/{name}/{f2} -I {out_}
            tms.append(s_:=max([float(line[10:17]) for line in out if line.startswith('TM-score')]))
            print(f'{name}: P: {f_[0]:.3f}, R: {f_[1]:.3f}, F: {f_[2]:.3f}, S: {s_:.5f}')
            print(name, f_[0], f_[1], f_[2], s_, file=fout, sep='\t')
        p_, r_, f_ = np.asarray(f_list).mean(axis=0)
        print(f'total: {len(f_list)}, P: {p_:.3f}, R: {r_:.3f}, F: {f_:.3f}, TM-score: {np.mean(tms):.5f}')
        return tms

In [6]:
malidup_tms = test('malidup')

d19hca_: P: 0.703, R: 0.703, F: 0.703, S: 0.49928
d1a4pa_: P: 0.650, R: 0.650, F: 0.650, S: 0.49066
d1a4sa_: P: 0.561, R: 0.561, F: 0.561, S: 0.58717
d1a6da1: P: 0.240, R: 0.258, F: 0.249, S: 0.31504
d1a8l_1: P: 0.863, R: 0.863, F: 0.863, S: 0.73713
d1af2a1: P: 0.667, R: 0.677, F: 0.672, S: 0.54023
d1afwb1: P: 0.739, R: 0.739, F: 0.739, S: 0.63921
d1ahja_: P: 0.889, R: 0.889, F: 0.889, S: 0.37244
d1ahua1: P: 0.559, R: 0.559, F: 0.559, S: 0.38793
d1ai3__: P: 0.262, R: 0.341, F: 0.297, S: 0.33178
d1aj8a_: P: 0.000, R: 0.000, F: 0.000, S: 0.18884
d1ako__: P: 0.755, R: 0.763, F: 0.759, S: 0.55926
d1ala___1: P: 1.000, R: 1.000, F: 1.000, S: 0.91556
d1ala___2: P: 0.986, R: 0.986, F: 0.986, S: 0.73035
d1ala___3: P: 0.973, R: 0.973, F: 0.973, S: 0.74045
d1ala___4: P: 1.000, R: 1.000, F: 1.000, S: 0.90142
d1anv_2: P: 0.250, R: 0.288, F: 0.268, S: 0.28842
d1aoa_1: P: 0.794, R: 0.833, F: 0.813, S: 0.63398
d1aora2: P: 0.681, R: 0.689, F: 0.685, S: 0.60356
d1at0__: P: 0.778, R: 0.790, F: 0.784, S: 

In [7]:
malisam_tms = test('malisam')

d1a05a_d1dgsa3: P: 0.392, R: 0.403, F: 0.397, S: 0.34854
d1a05a_d1j71a_: P: 0.000, R: 0.000, F: 0.000, S: 0.21305
d1a05a_d1rblm_: P: 0.494, R: 0.506, F: 0.500, S: 0.45011
d1a2za_d1ghha_: P: 0.103, R: 0.104, F: 0.104, S: 0.30960
d1a2za_d1u9da_: P: 0.323, R: 0.375, F: 0.347, S: 0.30102
d1a7j__d1kafa_: P: 0.145, R: 0.196, F: 0.167, S: 0.21667
d1a7j__d2if1__: P: 0.592, R: 0.714, F: 0.647, S: 0.42443
d1aa7a_d1b68a_: P: 0.000, R: 0.000, F: 0.000, S: 0.20846
d1aa7a_d1qkra_: P: 0.022, R: 0.022, F: 0.022, S: 0.19060
d1ac5__d1jroa3: P: 0.139, R: 0.139, F: 0.139, S: 0.26876
d1ac5__d1vk0a_: P: 0.522, R: 0.529, F: 0.526, S: 0.38105
d1adja2d1drw_2: P: 0.528, R: 0.573, F: 0.550, S: 0.38127
d1aora2d1dkza2: P: 0.260, R: 0.257, F: 0.259, S: 0.25733
d1axn__d1nkta3: P: 0.000, R: 0.000, F: 0.000, S: 0.17349
d1axn__d1sf9a_: P: 0.000, R: 0.000, F: 0.000, S: 0.30260
d1b05a_d1lc0a2: P: 0.016, R: 0.016, F: 0.016, S: 0.20032
d1b05a_d1nkia_: P: 0.040, R: 0.038, F: 0.039, S: 0.21122
d1b0pa2d1k0da2: P: 0.000, R: 0.