In [1]:
import os
import glob
import numpy as np

from tqdm import tqdm


# length_change

In [2]:
def get_feature(input_dir, output_dir,  length, dim, pca_model="None"):
    # mkdir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    for i in tqdm(input_dir, desc="Processing", unit="file"):
        name = os.path.basename(i).split(".")[0]
        save_path = os.path.join(output_dir, name)

        if dim == 768:
            data = np.load(i).reshape(1, -1, dim)
        else:
            data = np.loadtxt(i).reshape(1, -1, dim)

            
        if data.shape[1] < length:
            length_pad = length - data.shape[1]
            data_pad = np.pad(data, [(0,0), (0,length_pad), (0,0)], mode='constant', constant_values=0)
            np.save(save_path, data_pad)
            
        elif data.shape[1] > length:
            data_resize = data[:, :length, :]
            np.save(save_path, data_resize)
    
        else:
            np.save(save_path, data)
            

In [3]:
# 1024 or 768 or 1280
dim = 1024

length = 1000

# "pt" or "tape" or "esm2"
FEATURE = "pt"


"---------------------------- input dir ----------------------------"
# ProtTrans
if FEATURE == "pt":
    ch_train = glob.glob("PT_out/example/ionchannels/train/*")
    ch_test = glob.glob("PT_out/example/ionchannels/test/*")
    tr_train = glob.glob("PT_out/example/iontransporters/train/*")
    tr_test = glob.glob("PT_out/example/iontransporters/test/*")
    me_train = glob.glob("PT_out/example/membraneproteins/train/*")
    me_test = glob.glob("PT_out/example/membraneproteins/test/*")

# TAPE
elif FEATURE == "tape":
    ch_train = glob.glob("TAPE_out/example/ionchannels/train/*")
    ch_test = glob.glob("TAPE_out/example/ionchannels/test/*")
    tr_train = glob.glob("TAPE_out/example/iontransporters/train/*")
    tr_test = glob.glob("TAPE_out/example/iontransporters/test/*")
    me_train = glob.glob("TAPE_out/example/membraneproteins/train/*")
    me_test = glob.glob("TAPE_out/example/membraneproteins/test/*")

# esm
elif FEATURE == "esm2" or FEATURE == "esm1b":
    ch_train = glob.glob(f"ESM_out/example/{FEATURE}_d{dim}_L{length}/ionchannels/train/*")
    ch_test = glob.glob(f"ESM_out/example/{FEATURE}_d{dim}_L{length}/ionchannels/test/*")
    tr_train = glob.glob(f"ESM_out/example/{FEATURE}_d{dim}_L{length}/iontransporters/train/*")
    tr_test = glob.glob(f"ESM_out/example/{FEATURE}_d{dim}_L{length}/iontransporters/test/*")
    me_train = glob.glob(f"ESM_out/example/{FEATURE}_d{dim}_L{length}/membraneproteins/train/*")
    me_test = glob.glob(f"ESM_out/example/{FEATURE}_d{dim}_L{length}/membraneproteins/test/*")

print(len(ch_train))
print(len(ch_test))
print(len(tr_train))
print(len(tr_test))
print(len(me_train))
print(len(me_test))


"---------------------------- output dir ----------------------------"
ch_train_out = f"get_feature/example/{FEATURE}_d{dim}_L{length}/ionchannels/train/"
ch_test_out = f"get_feature/example/{FEATURE}_d{dim}_L{length}/ionchannels/test/"
tr_train_out = f"get_feature/example/{FEATURE}_d{dim}_L{length}/iontransporters/train/"
tr_test_out = f"get_feature/example/{FEATURE}_d{dim}_L{length}/iontransporters/test/"
me_train_out = f"get_feature/example/{FEATURE}_d{dim}_L{length}/membraneproteins/train/"
me_test_out = f"get_feature/example/{FEATURE}_d{dim}_L{length}/membraneproteins/test/"

"---------------------------- funtion ----------------------------"
get_feature(ch_train, ch_train_out, length, dim)
get_feature(ch_test, ch_test_out, length, dim)

get_feature(tr_train, tr_train_out, length, dim)
get_feature(tr_test, tr_test_out, length, dim)

get_feature(me_train, me_train_out, length, dim)
get_feature(me_test, me_test_out, length, dim)



10
10
10
10
10
10


Processing: 100%|████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00,  1.99file/s]
Processing: 100%|████████████████████████████████████████████████████████████████████| 10/10 [00:06<00:00,  1.48file/s]
Processing: 100%|████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.44file/s]
Processing: 100%|████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.81file/s]
Processing: 100%|████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.86file/s]
Processing: 100%|████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  3.03file/s]
