In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as du
import torch.nn.functional as F
from torch.utils import data
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from collections import defaultdict
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import pandas as pd
import sidechainnet as scn
import random
import sklearn
import einops
from einops import rearrange

In [2]:
data = scn.load(casp_version=7, with_pytorch="dataloaders", 
                seq_as_onehot=True, aggregate_model_input=False,
               batch_size=16)

SidechainNet was loaded from ./sidechainnet_data/sidechainnet_casp7_30.pkl.


In [3]:
def get_seq_features(batch):
    '''
    Take a batch of sequence info and return the sequence (one-hot),
    evolutionary info and (phi, psi, omega) angles per position, 
    as well as position mask.
    Also return the distance matrix, and distance mask.
    '''
    str_seqs = batch.str_seqs # seq in str format
    seqs = batch.seqs # seq in one-hot format
    int_seqs = batch.int_seqs # seq in int format
    masks = batch.msks # which positions are valid
    lengths = batch.lengths # seq length
    evos = batch.evos # PSSM / evolutionary info
    angs = batch.angs[:,:,0:2] # torsion angles: phi, psi
    
    # use coords to create distance matrix from c-beta
    # except use c-alpha for G
    # coords[:, 4, :] is c-beta, and coords[:, 1, :] is c-alpha
    coords = batch.crds # seq coord info (all-atom)
    batch_xyz = []
    for i in range(coords.shape[0]):
        xyz = []
        xyz = [coords[i][cpos+4,:] 
                if masks[i][cpos//14] and str_seqs[i][cpos//14] != 'G'
                else coords[i][cpos+1,:]
                for cpos in range(0, coords[i].shape[0]-1, 14)]
        batch_xyz.append(torch.stack(xyz))
    batch_xyz = torch.stack(batch_xyz)
    # now create pairwise distance matrix
    dmats = torch.cdist(batch_xyz, batch_xyz)
    # create matrix mask (0 means i,j invalid)
    dmat_masks = torch.einsum('bi,bj->bij', masks, masks)
    
    return seqs, evos, angs, masks, dmats, dmat_masks

In [4]:
class AttentionHead(nn.Module):
    def __init__(self, in_dim = 256, d_k = 16):
        '''
        Represents an attention head for multihead attention,
        d_k is 16 by default.
        in_dim is 256 by default.
        '''
        super(AttentionHead, self).__init__()
        
        self.d_k = d_k
        #create query, key, and values
        self.q = nn.Linear(in_dim, d_k)
        self.k = nn.Linear(in_dim, d_k)
        self.v = nn.Linear(in_dim, d_k)
        
    def forward(self, sequence, bias, row_or_col):
        '''
        Given a sequence in MSA_rep of size n_res x 256, calculate attention.
        Depending on row_or_col, bias is either added or excluded.
        '''
        query = self.q(sequence)
        key = self.k(sequence)
        value = self.v(sequence)
        
        A_sh = torch.matmul(query, torch.transpose(key, 1, 2))/np.sqrt(self.d_k)
        if row_or_col == "row":
            A_sh += bias.squeeze(dim = -1)
        
        #take softmax with respect to the rows
        A_sh = F.softmax(A_sh, dim = 1)
        A_sh = torch.matmul(A_sh, value)
        
        return A_sh

In [5]:
class Row_Col_Attention(nn.Module):
    '''
    compute either row-wise or column-wise attention depending on the given argument.
    '''
    def __init__(self, row_or_col, num_heads):
        super(Row_Col_Attention, self).__init__()
        
        #define multi head attention
        self.mha = nn.ModuleList([AttentionHead() for i in range(num_heads)])
        
        #create a gate for each head, corresponding to each index.
        #a gate maps msa_rep to 1 and sigmoids it to determine how much information is kept from a head.
        self.gates =  nn.ModuleList([nn.Sequential(nn.Linear(256, 1), nn.Sigmoid()) for i in range(num_heads)])

        self.num_heads = num_heads
        self.row_or_col = row_or_col
        
        #linear layer to project the new msa_rep into 256 dim
        self.fc1 = nn.Linear(128, 256)
        
    def forward(self, msa_rep, bias):
        if self.row_or_col == "col":
            #transpose the msa_rep if we are doing column wise attention
            msa_rep = torch.transpose(msa_rep, 1, 2)
        
        #calculate all the respective gates dot attention head outputs.
        gated_outs = []
        for s in range(msa_rep.shape[1]):
            for i in range(self.num_heads):
                outputs = self.mha[i](msa_rep[:,s,:,:], bias, self.row_or_col)
                gate = self.gates[i](msa_rep).squeeze()
                gate_out = torch.transpose(gate,1,2)*outputs
                gated_outs.append(gate_out)          
        
        #concatenate them to form O_sh
        O_sh = torch.concat(gated_outs, dim = 2)
        new_msa_rep = rearrange(O_sh, 'b i (c j) -> b c i j', c = 16)
        print(msa_rep.shape, new_msa_rep.shape)
        new_msa_rep = self.fc1(new_msa_rep)
        
        return new_msa_rep

In [6]:
class Outer_Prod_Mean(nn.Module):
    '''
    Finds the outer product mean between the pair-wise representation
    and the msa representation.
    The output is a n_res x n_res x 128 pair_rep
    '''
    def __init__(self):
        super(Outer_Prod_Mean, self).__init__()
        #linear layer to project i[s] and j[s] to 32 dim
        self.fc1 = nn.Linear(256, 32)
        
        #flatten the mean outer product to C*C
        self.flatten = nn.Flatten()
        
        #linear layer to project the outer product mean to 128 dim
        self.fc2 = nn.Linear(32, 128)
        
    def forward(self, msa_rep):
        #iterate through clusters, pick slices i and j and project them into 32 dim and gather their outer products
        for i in range(msa_rep.shape[2]):
            outer_prods = [torch.outer(self.fc1(msa_rep[i][s]), self.fc1(msa_rep[j][s]))
                           for j in range(msa_rep.shape[2]) for s in range(msa_rep.shape[1])]
        
        #concatenate all o_ij to make the output and take the mean
        new_pair_rep = torch.mean(torch.concat(outer_prods, dim = 2))
        new_pair_rep = self.flatten(new_pair_rep)
        
        #project to n_res x n_res x 128 dim
        new_pair_rep = self.fc2(new_pair_rep)
        
        #make sure to do residual connection after calling the function
        return new_pair_rep

In [7]:
class Mult_Attention(nn.Module):
    def __init__(self, out = True):
        '''
        Does incoming(default) multiplicative attention on a given pair_rep.
        out: set to False to do incoming attention
        '''
        super(Mult_Attention, self).__init__()
        self.out = out
        self.ln = nn.LayerNorm(128)
        self.fc1 = nn.Linear(128, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 128)
        
        self.gate1 = nn.Sequential(nn.Linear(128, 128), nn.Sigmoid())
        self.gate2 = nn.Sequential(nn.Linear(128, 128), nn.Sigmoid())
        self.gate3 = nn.Sequential(nn.Linear(128, 128), nn.Sigmoid())
    
    def forward(self, pair_rep):
        #Do a layer norm on pair_rep
        pair_rep = self.ln(pair_rep)
        
        #make A and B
        A = self.fc1(pair_rep)
        B = self.fc2(pair_rep)
        
        #Make gates for A and B
        gate_A = self.gate1(pair_rep)
        gate_B = self.gate2(pair_rep)
        gate_Z = self.gate3(pair_rep)
        
        #take dot product of A, B and their gates
        new_A = torch.dot(A, gate_A)
        new_B = torch.dot(B, gate_B)
        
        #transpose a and b if we are doing incoming attention
        if not self.out:
            new_A = torch.transpose(new_A, 1, 2)
            new_B = torch.transpose(new_B, 1, 2)
            
        #prepare to do tensorwise dot product of all slices
        new_A = torch.tile(new_A, (1,1,1,128))
        new_B = torch.tile(new_B, (1,1,1,128))
        new_B = torch.transpose(new_B, 1, 2)
        
        #Find the dot product of all slices and sum
        out = torch.dot(new_A, new_B)
        out = torch.sum(out, 1)
        
        #project to n_res x n_res x 128
        out = self.fc3(out)
        
        #gate the output
        gated_out = torch.dot(out, gate_Z)
        
        return gated_out

In [8]:
class Tri_Attention(nn.Module):
    '''
    Does starting triangular attention by default.
    ending: set to true to do ending triangular attention
    '''
    def __init__(self, ending = False, c = 128, num_heads = 4):
        super(Tri_Attention, self).__init__()
        self.ending = ending
        self.num_heads = num_heads
        self.c = c
        
        self.q = nn.ModuleList([nn.Linear(128, 32) for i in range(num_heads)])
        self.k = nn.ModuleList([nn.Linear(128, 32) for i in range(num_heads)])
        self.v = nn.ModuleList([nn.Linear(128, 32) for i in range(num_heads)])
        self.b = nn.ModuleList([nn.Linear(128, 1) for i in range(num_heads)])
        self.g = nn.ModuleList([nn.Sequential(nn.Linear(128,128), nn.Sigmoid()) for i in range(num_heads)])
        
        self.fc1 = nn.Linear(32, 128)
        
    def forward(self, pair_rep):
        output = []
        for h in range(self.num_heads):
            query = self.q[h](pair_rep)
            key = self.k[h](pair_rep)
            value = self.v[h](pair_rep)
            bias = self.b[h](pair_rep)
            gate = self.g[h](pair_rep)
            
            '''
            for i in range(query.shape[1]):
                for j in range(query.shape[2]):
                    for k in range(query.shape[3]):
                        a = torch.matmul(torch.transpose(query[i,j], 1, 2), key[i,k])/torch.sqrt(self.c) + bias[j,k]
                        a = F.softmax(a, dim=-1)
                        a = torch.sum(a*value[j,k], dim=-1)*g[i,j]
                        output.append(a)
            '''
            
            a = torch.matmul(q, torch.transpose(key, 1, 2))/np.sqrt(self.c) + torch.transpose(bias, 1, 2)
            a = F.softmax(a, dim = 1)
            a *= value
            out = a * g
            output.append(out)
        
        #concat all outputs
        output = torch.concat(output, 1)
        output = self.fc1(output)
        
        return output

In [9]:
class Evoformer(nn.Module):
    def __init__(self, n_clust=16, num_heads=8, device='cpu'):
        '''
        Creates the MSA_representation and the Z(pairwise) matrix given a PSSM and a sequence.
        n_clust: number of PSSMs.
        num_heads: number of attention heads(8 by default)
        '''
        super(Evoformer, self).__init__()
        
        self.n_clust = n_clust
        self.num_heads = num_heads
        self.device = device
        
        #linear layers to project evos into n_clust x n_res x 256
        self.fc0 = nn.ModuleList([nn.Linear(21, 256) for i in range(n_clust)])
        #linear layer to project seqs to n_res x 256
        self.fc1 = nn.Linear(20, 256)
        #linear layer to project seqs to n_res x 128
        self.fc2 = nn.Linear(20, 128)
        self.fc3 = nn.Linear(20, 128)
        #Linear layer to project distances into 128 space
        self.fc4 = nn.Linear(64, 128)
        #linear layer to project pair_rep to bias
        self.fc5 = nn.Linear(128, 1)
        #linear layer to project the single representation to 256 dim
        self.fc6 = nn.Linear(128, 256)
        #linear layer to project the single representation to 384 dim
        self.fc7 = nn.Linear(256, 384)
        
        #define the transitional layers to pass the new msa_rep through
        self.transition = nn.Sequential(nn.Linear(256, 1024), nn.ReLU(), nn.Linear(1024, 256))
        
        #define all attentions
        self.row_att = Row_Col_Attention("row", self.num_heads)
        self.col_att = Row_Col_Attention("col", self.num_heads)
        self.mul_att_in = Mult_Attention(out = False)
        self.mul_att_out = Mult_Attention(out = True)
        self.tri_att_start = Tri_Attention(ending = False)
        self.tri_att_end = Tri_Attention(ending = True)
        
        #define outer_product_mean
        self.out_prod_mean = Outer_Prod_Mean()
        
    
    def create_msa_rep(self, evos, seqs):
        '''
        Create the msa_representation given evolutionary data evos
        and the seqs, both are n_res x 21.
        '''
        #obtain n_clust layers of PSSM(evos); stack them into a (n_clust x n_res x 256) matrix
        clusters = [self.fc0[i](evos) for i in range(self.n_clust)]
        msa_rep = torch.stack(clusters, dim=1)
        
        #project the seqs from n_res x 21 to n_res x 256 and tile it.
        new_seqs = self.fc1(seqs)
        new_seqs = new_seqs.unsqueeze(dim=1)
        new_seqs = torch.tile(new_seqs, (1, self.n_clust, 1, 1))
        
        #add the seqs to the msa_rep
        msa_rep += new_seqs
        
        return msa_rep
    
    def create_pair_rep(self, seqs):
        '''
        Create pair_wise representations given seqs.
        '''
        #create the pairwise rep matrix
        a_i = self.fc2(seqs).unsqueeze(dim=2)
        b_j = self.fc3(seqs).unsqueeze(dim=2)
        a_i = torch.tile(a_i, (1, 1, a_i.shape[1], 1))
        b_j = torch.tile(b_j, (1, 1, b_j.shape[1], 1))
        pair_rep = a_i + torch.transpose(b_j, 1, 2)
        
        #add the relative position rel_pos
        idx_j = torch.arange(0, seqs.shape[1]).unsqueeze(dim=1)
        idx_j = torch.tile(idx_j, (1, idx_j.shape[1]))
        idx_i = torch.transpose(idx_j, 0, 1)
        # idx_i , idx_j = idx_i.to(device), idx_j.to(device)
        dist_ij = idx_i - idx_j   
        bins = torch.linspace(-32, 32, 64)
        dist_ij = torch.bucketize(dist_ij, bins)
        dist_ij[dist_ij>=64] = 63
        dist_ij = dist_ij.unsqueeze(dim=0)
        dist_ij = torch.tile(dist_ij, (pair_rep.shape[0], 1, 1))
        dist_ij = F.one_hot(dist_ij).type(torch.float)
        dist_ij = dist_ij.to(self.device)
        # print(dist_ij.shape)
        rel_pos = self.fc4(dist_ij)
        pair_rep += rel_pos
        return pair_rep
    
    def create_bias(self, pair_rep):
        '''
        given the pairwise representation create the bias
        '''
        bias = self.fc5(pair_rep)
        return bias
        
    def single_rep(self, msa_rep):
        '''
        Find the singular representation of M
        Should only be done on the last block.
        '''
        single_rep = self.fc6(msa_rep[:,1,:,:])
        single_rep = self.fc7(single_rep)
        return single_rep  
    
    def forward(self, seqs, evos):
        #create msa_rep, pair_rep, bias
        msa_rep = self.create_msa_rep(evos, seqs)
        pair_rep = self.create_pair_rep(seqs)
        bias = self.create_bias(pair_rep)
        
        #feed msa_rep into row -> col -> transition
        msa_rep = msa_rep + self.row_att(msa_rep, bias) 
        msa_rep = msa_rep + self.col_att(msa_rep, bias)
        msa_rep = msa_rep + self.transition(msa_rep) #output of evoformer for msa_rep
        
        #do the outer product mean
        pair_rep = pair_rep + self.out_prod_mean(msa_rep)
        
        #do triangular attention
        pair_rep = pair_rep + self.mult_att_out(pair_rep)
        pair_rep = pair_rep + self.mult_att_in(pair_rep)
        pair_rep = pair_rep + self.tri_att_start(pair_rep)
        pair_rep = pair_rep + self.tri_att_end(pair_rep)
        
        #do the transition
        pair_rep = pair_rep + self.transition(pair_rep) #output of evoformer for pair_rep
        
        return msa_rep, pair_rep

In [10]:
device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"using device: {device}")

epochs = 50
learning_rate = 0.0001
n_clust = 16

model = Evoformer(n_clust, device=device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
model = model.to(device)
model.train()

using device: cuda:0


Evoformer(
  (fc0): ModuleList(
    (0): Linear(in_features=21, out_features=256, bias=True)
    (1): Linear(in_features=21, out_features=256, bias=True)
    (2): Linear(in_features=21, out_features=256, bias=True)
    (3): Linear(in_features=21, out_features=256, bias=True)
    (4): Linear(in_features=21, out_features=256, bias=True)
    (5): Linear(in_features=21, out_features=256, bias=True)
    (6): Linear(in_features=21, out_features=256, bias=True)
    (7): Linear(in_features=21, out_features=256, bias=True)
    (8): Linear(in_features=21, out_features=256, bias=True)
    (9): Linear(in_features=21, out_features=256, bias=True)
    (10): Linear(in_features=21, out_features=256, bias=True)
    (11): Linear(in_features=21, out_features=256, bias=True)
    (12): Linear(in_features=21, out_features=256, bias=True)
    (13): Linear(in_features=21, out_features=256, bias=True)
    (14): Linear(in_features=21, out_features=256, bias=True)
    (15): Linear(in_features=21, out_features=25

In [11]:
epoch = 1
for epoch in range(1,epochs+1):
    for batch in data['train']:
        seqs, evos, angs, masks, dmats, dmat_masks = get_seq_features(batch)
        seqs, evos, angs, masks, dmats, dmat_masks = seqs.to(device), evos.to(device), angs.to(device), masks.to(device), dmats.to(device), dmat_masks.to(device)
        
        #generate a random starting index
        start_idx = random.randint(1,16)
        
        seqs = F.pad(seqs, (0, 0, 0, 256 - (seqs.shape[1] - start_idx)%256), 'constant', 0)
        evos = F.pad(evos, (0, 0, 0, 256 - (evos.shape[1] - start_idx)%256), 'constant', 0)
        
        #discretize the matrix
        bins = torch.linspace(2,22, 64)
        bins = bins.to(device)
        discretized = torch.clamp(dmats, min = 2, max = 22)
        discretized = torch.bucketize(discretized, bins, right = True)
        discretized = F.pad(discretized, (0, 256-(discretized.shape[2] - start_idx)%256, 0, 256-(discretized.shape[1] - start_idx)%256, 0, 0), 'constant', 0)
        
        for i in range(start_idx, seqs.shape[1], 128):
            seq_crop = seqs[:,i:i+256, :]
            evo_crop = evos[:,i:i+256, :]
            ddmat = discretized[:,i:i+256, i:i+256]
            pred = model(seq_crop.type(torch.float), evo_crop)
        
        break

torch.Size([9, 16, 256, 256]) torch.Size([9, 16, 256, 128])


RuntimeError: The size of tensor a (256) must match the size of tensor b (16) at non-singleton dimension 2