In [80]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [81]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

In [82]:
from conv import ProteinConv
from pool import ProteinMaxPool
from concat import ProteinConcat
class ProteinModel(nn.Module):
    def __init__(self, input_dim, dropout = 0.2, activation = "sigmoid", stride = 2):
        super(ProteinModel, self).__init__()
        
        activations = {
        'relu': F.relu,
        'sigmoid': F.sigmoid,
        'tanh': torch.tanh
        }
        
        self.activation = activations[activation]
        
        self.l1 = nn.Linear(input_dim, 100)
        self.pconv1 = ProteinConv(no_filters=50, no_dims=100, no_channels=1, window_size=10, dropout_p = dropout, stride = stride, activation = activation)
        self.pool1  = ProteinMaxPool(3)
        
        self.l2 = nn.Linear(100, 50)
        self.pconcat = ProteinConcat(no_dims=50, no_channels=50, window_size=5, op_size=20, stride = stride, dropout_p = 0.2, activation = activation)
        
        self.outlayer = nn.Linear(20, 2)
        self.softlayer = nn.Softmax(dim=1)
        self.input_dim = input_dim
        
        
    def forward(self, p1, p2):
        
        N1, H1, C1, D1 = p1.shape
        N2, H2, C2, D2 = p2.shape
        
        assert (N1, C1, D1) == (N2, C2, D2)
        assert D1 == self.input_dim

        p1 = self.l1(p1)
        p2 = self.l1(p2)
        p1, p2 = self.pconv1(p1, p2)
        p1     = self.pool1(self.activation(p1))
        p2     = self.pool1(self.activation(p2))
        
        p1 = self.l2(p1)
        p2 = self.l2(p2)
        pout = self.pconcat(p1, p2)
        pout = self.softlayer(self.outlayer(pout))
        return pout

In [83]:
import h5py
from torch.utils.data import Dataset, DataLoader
import pandas as pd

embfile = "/afs/csail.mit.edu/u/k/kdevko01/coip-vs-y2h-folder/data/networks/dscript-tt/y2h-coip.h5"
f = h5py.File(embfile)

trainfile = "/afs/csail.mit.edu/u/k/kdevko01/coip-vs-y2h-folder/data/networks/dscript-tt/coip_train.tsv"
dtr = pd.read_csv(trainfile, sep = "\t", header = None)

dtrain = pd.concat([dtr[dtr[2] == 1].sample(n=200), dtr[dtr[2] == 0].sample(n=1000)])
dval   = pd.concat([dtr[dtr[2] == 1].sample(n=50), dtr[dtr[2] == 0].sample(n=500)])
dval   = dval.drop(set(dval.index).intersection(set(dtrain.index)))

dtrain = pd.concat([dtrain, dtrain.loc[:, [1, 0, 2]]]).reset_index(drop = True)
dval = pd.concat([dval, dval.loc[:, [1, 0, 2]]]).reset_index(drop = True)
 
len(dtrain), len(dval)

(2400, 1096)

In [84]:
testfile = "/afs/csail.mit.edu/u/k/kdevko01/coip-vs-y2h-folder/data/networks/dscript-tt/coip_test.tsv"
dtest = pd.read_csv(testfile, sep = "\t", header = None)
dtest = pd.concat([dtest[dtest[2] == 1].sample(n=100), dtest[dtest[2] == 0].sample(n=1000)])
dtest = pd.concat([dtest, dtest.loc[:, [1, 0, 2]]]).reset_index(drop = True)
len(dtest)

2200

In [85]:
import numpy as np
class ProtDataset(Dataset):
    def __init__(self, df, h5data, min_seqlen=75):
        self.df = df
        self.h5 = h5data
        self.min_seqlen = min_seqlen
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, id):
        p, q, w = self.df.iloc[id, :].values
        p1, p2, w = torch.tensor(np.array(self.h5[p]), dtype = torch.float32).squeeze(0).unsqueeze(1), torch.tensor(np.array(self.h5[q]), dtype = torch.float32).squeeze(0).unsqueeze(1), torch.tensor(w, dtype = torch.long)
        
        dim = p1.shape[2]
        
        p1seqlen = p1.shape[0]
        p2seqlen = p2.shape[0]
        if p1seqlen < self.min_seqlen:
            p1 = torch.cat([p1, torch.zeros(self.min_seqlen - p1seqlen, 1, dim, dtype = torch.float32)], dim = 0)
        if p2seqlen < self.min_seqlen:
            p2 = torch.cat([p2, torch.zeros(self.min_seqlen - p2seqlen, 1, dim, dtype = torch.float32)], dim = 0)
        return p1, p2, w
        
trdata = ProtDataset(dtrain, f)
tedata = ProtDataset(dtest, f)
valdata = ProtDataset(dval, f)

In [86]:
p, q, w = trdata[0]
p.shape

torch.Size([754, 1, 6165])

In [87]:
dev = torch.device("cuda:7")
lr = 0.1
no_ep = 5

In [88]:
trloader = DataLoader(trdata, batch_size = 1, shuffle = True)
valloader = DataLoader(valdata, batch_size = 1, shuffle = True)
model = ProteinModel(6165)
model = model.to(dev)
opt = torch.optim.Adam(model.parameters(), lr = lr)

In [89]:
from sklearn.metrics import average_precision_score

lossfn = torch.nn.CrossEntropyLoss()

def compute_aupr(op, target):
    if isinstance(op, torch.Tensor):
        op = op.numpy()
    if isinstance(target, torch.Tensor):
        target = target.numpy()
    
    op = np.argmax(op, axis = 1)
    return average_precision_score(op, target)
    

for ep in range(no_ep):
    running_loss = 0
    for i, data in enumerate(tqdm(trloader)):
        ps, qs, wt = data
        ps = ps.to(dev)
        qs = qs.to(dev)
        wt = wt.to(dev)
        opt.zero_grad()
        out = model(ps, qs)
        loss = lossfn(out, wt)
        loss.backward()
        opt.step()
        
        if dev.type == "cuda":
            ps = ps.to("cpu")
            qs = qs.to("cpu")
            wt = wt.to("cpu")
            loss = loss.to("cpu")
        running_loss += loss.item()
    with torch.no_grad():
        val_loss = 0
        results  = []
        targets  = []
        for j, data in enumerate(valloader):
            ps, qs, wt = data
            ps = ps.to(dev)
            qs = qs.to(dev)
            wt = wt.to(dev)
            out = model(ps, qs)
            loss = lossfn(out, wt)
        
            if dev.type == "cuda":
                ps = ps.to("cpu")
                qs = qs.to("cpu")
                wt = wt.to("cpu")
                loss = loss.to("cpu")
                out = out.to("cpu")
            val_loss += loss.item()
            results += out.numpy()[:, 1].tolist()
            targets += wt.numpy().tolist()
        auprval = average_precision_score(targets, results)
    print(f"Epoch {ep+1}: Training Loss : {running_loss / (i+1)}: AUPR: {auprval}")

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2400/2400 [01:18<00:00, 30.75it/s]


Epoch 1: Training Loss : 0.48005625932166973: AUPR: 0.09124087591240876


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2400/2400 [01:18<00:00, 30.68it/s]


Epoch 2: Training Loss : 0.47992831965287525: AUPR: 0.09124087591240876


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2400/2400 [01:18<00:00, 30.42it/s]


Epoch 3: Training Loss : 0.47992831965287525: AUPR: 0.09124087591240876


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2400/2400 [01:17<00:00, 30.89it/s]


Epoch 4: Training Loss : 0.47992831965287525: AUPR: 0.09124087591240876


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2400/2400 [01:20<00:00, 29.71it/s]


Epoch 5: Training Loss : 0.4803449863071243: AUPR: 0.09124087591240876


In [None]:
p1, p2, w = next(iter(trloader))
p1.shape,p2.shape

In [73]:
import torch
x = torch.randn(2, 15, 3, dtype = torch.float32)
x.unfold(1, 4, 2).shape

torch.Size([2, 6, 3, 4])

In [83]:


no_batch, no_height = 2, 20
no_dims, no_channels, window_size, op_size = 10, 4, 5, 20

p1 = torch.randn(no_batch, no_height, no_channels, no_dims, dtype = torch.float32)
p2 = torch.randn(no_batch, no_height, no_channels, no_dims, dtype = torch.float32)

pconcat = ProteinConcat(no_dims, no_channels, window_size, op_size, stride = 1, dropout_p = 0.2, activation = "tanh")

In [91]:
from conv import ProteinConv
no_filters = 3
pconv = ProteinConv(no_filters, no_dims, no_channels, window_size, stride = 1, dropout_p = 0.2, activation = "tanh")

In [92]:
o1, o2 = pconv(p1, p2)

torch.Size([2, 16, 1, 4, 10, 5])


In [88]:
o1.shape, o2.shape

(torch.Size([2, 16, 3, 10]), torch.Size([2, 16, 3, 10]))

In [94]:
from pool import ProteinMaxPool
pool = ProteinMaxPool(4)
o = pool(p1)
o.shape

torch.Size([2, 5, 4, 10])

In [91]:
from Bio import SeqIO

In [92]:
#!cd ..; ln -s /afs/csail.mit.edu/u/r/rsingh/work/corals/data-scratch1/STRING_foldseek_embeddings foldseek_emb

In [110]:
fasta = SeqIO.parse("../foldseek_emb/r1_foldseekrep_seq.fa", "fasta")

In [108]:
vocab = {}
i = 0
for rec in fasta:
    print(rec)
    break

ID: 9606.ENSP00000386340
Name: 9606.ENSP00000386340
Description: 9606.ENSP00000386340 AF2:AF-P63255-F1-model_v2.pdb.gz 9606.ENSP00000386340
Number of features: 0
Seq('DFAQLQPRDDDDPVQWQQAPNGTHGQCQQAAPPPRHRDDRHQWDDDNSHTHGPP...DDD')


In [105]:
# import json
# json.dump(vocab, open("../foldseek_vocab.json", "w"))

In [109]:
"9606.ENSP00000386340" in fasta

NotImplementedError: SeqRecord comparison is deliberately not implemented. Explicitly compare the attributes of interest.