In [1]:
import os
import torch
import torch_geometric as tg
from torch_geometric.utils import remove_self_loops, add_self_loops

from data import load_data
from graph_generator import HomologyGraphStats
from build_pipeline import build_dataloader, build_model, build_optimizer
from train_and_evaluate import train, evaluate
from utils import parse_args, set_seed, check_adjmatrix, make_output_folder, get_root_logger, adj_mul, save_info

import warnings
warnings.filterwarnings('ignore')

In [2]:
args = parse_args()
set_seed(args.seed)

args.gpu_id = 0
args.batch_size = 256
args.dataset = "GB1" # GB1, Fluorescence, AAV
args.split = "low_vs_high"
args.feature_generator = "CNN"
args.model = "ResiduePMPNN"
args.fine_tuned_generator = True
args.light_residue_feat = False
args.oh_residue_feat = False
args.full_residue_feat = True
args.former_layers = 3
args.gnn_layers = 2
args.knn_k = 10

args.device = torch.device(
    'cuda:%s' % args.gpu_id if torch.cuda.is_available() else 'cpu'
)
output_folder = make_output_folder(args=args)
logger = get_root_logger(folder=output_folder)
logger.warning(args)

10:08:02   Namespace(batch_size=256, dataset='GB1', device=device(type='cuda', index=0), encoder='ESM-1b', epochs=100, feature_generator='CNN', fine_tuned_generator=True, fix_encoder=True, fix_mlp=False, flow='source_to_target', force_undirected=False, former_layers=3, full_residue_feat=True, gnn='GCN', gnn_dropout=0.5, gnn_hid_channels=128, gnn_layers=2, gpu_id=0, homology='knn_fast', knn_k=10, lamda_edge=0.1, light_residue_feat=False, load_protein_feat=True, loop=True, lr=0.0001, lr_ratio=0.1, mlp_dropout=0, mlp_hid_channels=-1, mlp_layers=2, model='ResiduePMPNN', no_test_test=False, num_anchor=-1, num_heads=8, oh_residue_feat=False, only_from_train=False, optimizer='Adam', pretrained_encoder=True, pretrained_mlp=False, rb_order=0, seed=0, split='low_vs_high', tau=1.0, use_act=True, use_bn=True, use_edge_attr=False, use_edge_loss=False, use_jk=False, use_residual=True, weight_decay=0.005)


In [3]:
# set up dataset
data, protein_feature, residue_feature, sequences, wt_sequence, similarity_matrix = load_data(
    args=args, logger=logger, load_protein_feat=args.load_protein_feat,
    feature_generator=args.feature_generator, pretrained=args.fine_tuned_generator,
    oh_residue_feat=args.oh_residue_feat, full_residue_feat=args.full_residue_feat
)
args.wt_index = sequences.index(wt_sequence)

# print("is saving feat file")
# torch.save(residue_feature, "../input/residue_features_%s-%s.pt" % (args.dataset, args.split))

if args.model in ["Gnn", "Avg", "GnnMlp", "Transformer", "NodeFormer",
                  "ResidueMPNN", "ResidueFormer"]:
    logger.warning("Initialising graph structure.")
    homology_stats_graph = HomologyGraphStats(
        dataset=args.dataset, homology=args.homology, knn_k=args.knn_k, loop=args.loop, flow=args.flow,
        force_undirected=args.force_undirected,
        only_from_train=args.only_from_train, no_test_test=args.no_test_test,
        train_mask=data.train_mask, valid_mask=data.valid_mask, test_mask=data.test_mask
    )
    data.edge_index, data.edge_attr = homology_stats_graph(x=data.x, logger=logger, similarity_matrix=similarity_matrix)
    logger.warning("Average in-degree: %s" % tg.utils.degree(data.edge_index[1]).mean().item())
    logger.warning(check_adjmatrix(data))

if args.batch_size > 0:
    train_loader, valid_loader, test_loader = build_dataloader(
        data=data, batch_size=args.batch_size, num_neighbors=[-1], shuffle=True
    )
    if args.model in ["NodeFormer", "ResidueFormer"]:
        for loader in [train_loader, valid_loader, test_loader]:
            for idx, batch in enumerate(loader):
                ### Adj storage for relational bias ###
                adjs = []
                adj, _ = remove_self_loops(batch.edge_index)
                adj, _ = add_self_loops(adj, num_nodes=batch.num_nodes)
                adjs.append(adj)
                for i in range(args.rb_order - 1): # edge_index of high order adjacency
                    adj = adj_mul(adj, adj, batch.num_nodes)
                    adjs.append(adj)
                batch.adjs = adjs
                loader[idx] = batch
else:
    train_loader, valid_loader, test_loader = None, None, None
    if args.model in ["NodeFormer", "ResidueFormer"]:
        ### Adj storage for relational bias ###
        adjs = []
        adj, _ = remove_self_loops(data.edge_index)
        adj, _ = add_self_loops(adj, num_nodes=data.num_nodes)
        adjs.append(adj)
        for i in range(args.rb_order - 1):  # edge_index of high order adjacency
            adj = adj_mul(adj, adj, data.num_nodes)
            adjs.append(adj)
        data.adjs = adjs

if args.model in ["ResidueMPNN", "ResiduePMPNN", "ResidueCnnPMPNN", "ResidueFormer", "ResiduePFormer"]:
    args.num_residue = max([len(seq) for seq in sequences])
    args.num_sequences = len(sequences)
logger.warning(data)

10:08:02   processing protein features
10:08:02   Loaded protein feature at /home/zhiqiang/Homology-PyG/input/protein_features/protein_features_GB1_CNN_low_vs_high.pt
10:08:02   processing residue features
10:08:21   There is no fine-tuned parameters, use original parameters.


Loading residue pretrained embedding from ESM-1b: 100%|██████████| 69/69 [03:41<00:00,  3.21s/it]


10:12:03   Data(x=[8733, 1024], y=[8733, 1], train_mask=[8733], valid_mask=[8733], test_mask=[8733], num_nodes=8733, residue_dim=1280)


In [4]:
# set up model
model = build_model(
    args=args, in_channels=data.num_features, residue_dim=data.residue_dim, out_channels=data.y.size(1), logger=logger,
).to(args.device)
logger.warning(model)
optimizer = build_optimizer(args=args, model=model, logger=logger)

10:12:03   ResiduePMPNN(
  (lin_x_residue): Linear(in_features=1280, out_features=128, bias=True)
  (lin_x_seq): Linear(in_features=1024, out_features=128, bias=True)
  (residue_formers): ModuleList(
    (0): ResidueFormerConv(
      (Wk): Linear(in_features=128, out_features=512, bias=True)
      (Wq): Linear(in_features=128, out_features=512, bias=True)
      (Wv): Linear(in_features=128, out_features=512, bias=True)
      (Wo): Linear(in_features=512, out_features=128, bias=True)
    )
    (1): ResidueFormerConv(
      (Wk): Linear(in_features=128, out_features=512, bias=True)
      (Wq): Linear(in_features=128, out_features=512, bias=True)
      (Wv): Linear(in_features=128, out_features=512, bias=True)
      (Wo): Linear(in_features=512, out_features=128, bias=True)
    )
    (2): ResidueFormerConv(
      (Wk): Linear(in_features=128, out_features=512, bias=True)
      (Wq): Linear(in_features=128, out_features=512, bias=True)
      (Wv): Linear(in_features=128, out_features=512, 

In [5]:
best_epoch = best_train = best_val = best_test = 0
for epoch in range(1, args.epochs + 1):
    if args.batch_size > 0:
        loss, _, _ = train(
            args=args, epoch=epoch, data=train_loader, model=model,
            protein_feature=protein_feature, residue_feature=residue_feature,
            optimizer=optimizer, sequences=sequences, device=args.device
        )
    else:
        loss, _, _ = train(
            args=args, epoch=epoch, data=data, model=model,
            protein_feature=protein_feature, residue_feature=residue_feature,
            optimizer=optimizer, sequences=sequences, device=args.device
        )
    logger.warning("epoch {}, loss {:.5f}".format(epoch, loss))

    if args.batch_size > 0:
        train_score = evaluate(
            args=args, data=train_loader, model=model,
            protein_feature=protein_feature, residue_feature=residue_feature,
            sequences=sequences, device=args.device
        )
        valid_score = evaluate(
            args=args, data=valid_loader, model=model,
            protein_feature=protein_feature, residue_feature=residue_feature,
            sequences=sequences, device=args.device
        )
    else:
        train_index = torch.nonzero(data.train_mask).squeeze()
        train_score = evaluate(
            args=args, data=data, model=model,
            protein_feature=protein_feature, residue_feature=residue_feature,
            target_index=train_index, sequences=sequences, device=args.device
        )
        valid_index = torch.nonzero(data.valid_mask).squeeze()
        valid_score = evaluate(
            args=args, data=data, model=model,
            protein_feature=protein_feature, residue_feature=residue_feature,
            target_index=valid_index, sequences=sequences, device=args.device
        )

    if valid_score > best_val:
        logger.warning("updating best performance.")
        if args.batch_size > 0:
            test_score = evaluate(
                args=args, data=test_loader, model=model,
                protein_feature=protein_feature, residue_feature=residue_feature,
                sequences=sequences, device=args.device
            )
        else:
            test_index = torch.nonzero(data.test_mask).squeeze()
            test_score = evaluate(
                args=args, data=data, model=model,
                protein_feature=protein_feature, residue_feature=residue_feature,
                target_index=test_index, sequences=sequences, device=args.device
            )
        best_epoch, best_train, best_val, best_test = epoch, train_score, valid_score, test_score
        logger.warning("epoch {}, train {:.5f}, valid {:.5f}, test {:.5f}".format(
            best_epoch, train_score, best_val, best_test
        ))
        if test_score > 0.57:
            break
    else:
        logger.warning("epoch {}, train {:.5f}, valid {:.5f}".format(
            epoch, train_score, valid_score
        ))

10:12:16   epoch 1, loss 0.31415
10:12:20   updating best performance.
10:12:23   epoch 1, train 0.84274, valid 0.81858, test 0.38659
10:12:36   epoch 2, loss 0.11884
10:12:40   updating best performance.
10:12:43   epoch 2, train 0.85382, valid 0.82592, test 0.43144
10:12:56   epoch 3, loss 0.08198
10:13:00   updating best performance.
10:13:03   epoch 3, train 0.85603, valid 0.82594, test 0.44679
10:13:16   epoch 4, loss 0.06684
10:13:20   updating best performance.
10:13:23   epoch 4, train 0.85839, valid 0.82705, test 0.46354
10:13:37   epoch 5, loss 0.05754
10:13:40   updating best performance.
10:13:43   epoch 5, train 0.86073, valid 0.82848, test 0.47847
10:13:56   epoch 6, loss 0.05082
10:14:00   updating best performance.
10:14:03   epoch 6, train 0.86279, valid 0.82951, test 0.49040
10:14:16   epoch 7, loss 0.04565
10:14:20   updating best performance.
10:14:23   epoch 7, train 0.86461, valid 0.83093, test 0.49972
10:14:37   epoch 8, loss 0.04148
10:14:40   updating best perf

KeyboardInterrupt: 

In [11]:
from train_and_evaluate import evaluate_function

preds = []
targets = []

for idx, batch in enumerate(test_loader):
    if ("llm" in args.model.lower()) or ("cnn" in args.model.lower()) or ("bert" in args.model.lower()) or ("protbert" in args.model.lower()):
        batch_sequences = [sequences[idx] for idx in batch.n_id]
    else:
        batch_sequences = None

    _pred = evaluate_function(
        args=args, data=batch, model=model,
        protein_feature=protein_feature, residue_feature=residue_feature, n_id=batch.n_id,
        sequences=batch_sequences, device=args.device,
    )
    preds.append(_pred.squeeze(-1)[:batch.batch_size])
    targets.append(batch.y.squeeze(-1)[:batch.batch_size])
preds = torch.cat(preds, dim=-1)
targets = torch.cat(targets, dim=-1)

In [14]:
import pandas as pd
df = pd.read_csv("../data/FLIP/gb1/splits/one_vs_rest.csv")
one_id = df[df['set'] == 'train'].index.tolist() + df[df['validation'] == True].index.tolist()

df = pd.read_csv("../data/FLIP/gb1/splits/two_vs_rest.csv")
two_id = df[df['set'] == 'train'].index.tolist() + df[df['validation'] == True].index.tolist()

df = pd.read_csv("../data/FLIP/gb1/splits/three_vs_rest.csv")
three_id = df[df['set'] == 'train'].index.tolist() + df[df['validation'] == True].index.tolist()
test_index = torch.nonzero(data.test_mask).squeeze()

one_index = []
two_index = []
three_index = []
other_index = []
for index in range(len(test_index)):
    if test_index[index] in one_id:
        one_index.append(index)
    elif test_index[index] in two_id:
        two_index.append(index)
    elif test_index[index] in three_id:
        three_index.append(index)
    else:
        other_index.append(index)
one_index = torch.LongTensor(one_index)
two_index = torch.LongTensor(two_index)
three_index = torch.LongTensor(three_index)
other_index = torch.LongTensor(other_index)

In [15]:
from utils import spearmanr
one = spearmanr(pred=preds[one_index], target=targets[one_index])
two = spearmanr(pred=preds[two_index], target=targets[two_index])
three = spearmanr(pred=preds[three_index], target=targets[three_index])
other = spearmanr(pred=preds[other_index], target=targets[other_index])
print(one, two, three, other)

tensor(0.5818) tensor(0.5091) tensor(0.5452) tensor(0.5715)
