In [None]:
import torch
import torch.nn as nn
import json
import csv
import numpy as np
from bindenergy import *
#%env TORCH_EXTENSIONS_DIR=.

In [None]:
frame1 = "EVQLVESGGGLVQAGDSLRLSCTASG"
frame2 = "MGWFRQAPGKEREFVASIS"
frame3 = "TYYADSVKGRFTISRDDARNTVYLQMNSLKPEDTAVYYCNM"
frame4 = "EYWGQGTQVTVSS"

def has_no_repeat(seq):
    for aa in ALPHABET[1:]:
        if aa * 3 in seq:
            return False
    return True

#### DSMBind Student Model

In [None]:
class SequenceModel(nn.Module):

    def __init__(self, args):
        super(SequenceModel, self).__init__()
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.encoder = SRUpp(
                len(ALPHABET),
                args.hidden_size // 2,
                args.hidden_size // 2,
                num_layers=args.depth,
                dropout=args.dropout,
                bidirectional=True,
        )
        self.W_o = nn.Sequential(
                nn.Linear(args.hidden_size, args.hidden_size),
                nn.ReLU(),
                nn.Linear(args.hidden_size, 1)
        )
        for param in self.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

    def forward(self, cdr_S, label):
        cdr_mask = (cdr_S > 0).float()
        cdr_S = F.one_hot(cdr_S, num_classes=len(ALPHABET)).float()
        cdr_h, _, _ = self.encoder(
                cdr_S.transpose(0, 1),
                mask_pad=(~cdr_mask.transpose(0, 1).bool())
        )
        cdr_h = cdr_h.transpose(0, 1)  # [B, N, H]
        cdr_h = cdr_h.mean(dim=1)
        logit = self.W_o(cdr_h).squeeze(-1)
        loss = self.bce_loss(logit, label)
        return loss, torch.sigmoid(logit)
    
    def predict(self, cdr_S):
        cdr_mask = (cdr_S > 0).float()
        cdr_S = F.one_hot(cdr_S, num_classes=len(ALPHABET)).float()
        cdr_h, _, _ = self.encoder(
                cdr_S.transpose(0, 1),
                mask_pad=(~cdr_mask.transpose(0, 1).bool())
        )
        cdr_h = cdr_h.transpose(0, 1)  # [B, N, H]
        cdr_h = cdr_h.mean(dim=1)
        logit = self.W_o(cdr_h).squeeze(-1)
        return torch.sigmoid(logit)

In [None]:
model_ckpt, model_args = torch.load('ckpts/PDL1.ckpt')
model = SequenceModel(model_args).cuda()
model.load_state_dict(model_ckpt)
model.eval()
aaratio = torch.load('ckpts/aaratio.ckpt')

In [None]:
for _ in range(100000):
    batch = []
    cdr_list = []
    while len(batch) < 10000:
        S = torch.multinomial(aaratio[1:], num_samples=7+5+9, replacement=True) + 1
        cdr = [ALPHABET[aa] for aa in S.tolist()]
        cdr = ''.join(cdr)
        cdr1, cdr2, cdr3 = cdr[:7], cdr[7:12], cdr[12:]
        if has_no_repeat(cdr) and cdr1.count('W') + cdr1.count('Y') <= 2 and cdr2.count('W') + cdr2.count('Y') <= 1 and cdr3.count('W') + cdr3.count('Y') <= 2:
            batch.append(S)
            cdr_list.append((cdr1, cdr2, cdr3))

    S = torch.stack(batch, dim=0).cuda()
    Y = torch.zeros(len(batch)).cuda()
    with torch.no_grad():
        prob = model.predict(S)
    
    for (cdr1, cdr2, cdr3), score in zip(cdr_list, prob.tolist()):
        if score > 0.8:
            seq = frame1 + cdr1 + frame2 + cdr2 + frame3 + cdr3 + frame4
            print(f'{score:.4f} {cdr1} {cdr2} {cdr3} {seq}')