学习AlphaFold,UniFold的策略结合ESM-IF来实现Uni_IF

In [1]:
import torch
import torch.nn as nn
from matplotlib import pyplot as plt 
import numpy as np

In [84]:
# 函数1: 学习torch.split,在指定维度将tensor分成许多chunk tensor组成的tuple
# 参数二可以写为一个single chunk的大小，要么是许多chunk size的大小加起来等于10
x = torch.arange(10)
x_chunk = torch.split(x,5,dim=0)
# E.g.
    #    # [*, N_res, H * P_q * 3]
    #     q_pts = self.linear_q_points(s)

    #     # This is kind of clunky, but it's how the original does it
    #     # [*, N_res, H * P_q, 3]
    #     q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
    #     q_pts = torch.stack(q_pts, dim=-1)
N_res = 100
H = 10
P_q = 4
P_v = 8
q_pts = torch.randn(5,N_res,H*P_q*3)
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
# q_pts = torch.stack(q_pts, dim=-1)

In [None]:
# 函数 1.5 : torch.chunk(input,chunk_size)
# chunk_size指的是将tensor变成几个chunk

In [87]:
# 函数2 : torch.unbind(input, dim=0)
# Removes a tensor dimension.
# Returns a tuple of all slices along a given dimension, already without it.
# torch.unbind : 把某一个维度按1进行切割
# torch.split  : 把某一个维度按照指定的chunk size进行切割
torch.unbind(torch.arange(10),dim=0) == torch.split(torch.arange(10),1,dim=0)
# E.g. *torch.unbind(o_pt, dim=-1) 其中*表示解压unbind产生的tuple 当作实参数
o_pt = torch.randn((N_res,H*P_v,3))
print(*torch.unbind(o_pt,dim=-1))

tensor([[ 1.1089,  0.1828, -1.1277,  ...,  0.0441,  0.1841, -0.1688],
        [-1.0839, -1.7649,  0.9898,  ...,  0.5986, -0.8992, -0.4005],
        [-0.4315,  0.1409,  0.5621,  ...,  1.9601, -0.0056, -1.6909],
        ...,
        [ 0.2440, -0.4118,  0.0060,  ..., -1.3390, -0.4782,  0.8730],
        [-0.7326,  0.6082, -0.5335,  ..., -0.1101,  0.5573,  0.8758],
        [ 1.5065,  1.4604,  0.7707,  ..., -2.6308, -0.7498, -0.5366]]) tensor([[-1.3005, -0.1431, -1.2304,  ...,  0.6920,  0.0064,  0.0126],
        [-1.4925,  0.1571,  2.0259,  ...,  0.6518,  2.2871,  0.0633],
        [ 0.0748,  1.0305,  0.3305,  ...,  0.4663, -0.2982, -1.2689],
        ...,
        [ 0.6041,  0.5428, -0.5820,  ...,  0.5456, -0.5846,  0.5794],
        [ 0.1015, -2.0565,  0.2441,  ...,  0.8954, -1.0636,  0.9557],
        [ 1.5222,  0.5061,  1.4869,  ...,  1.4584, -0.2904, -0.7817]]) tensor([[ 0.5649, -0.2386,  0.0038,  ...,  1.2166,  0.2035, -0.0906],
        [-0.1684,  2.3154, -1.6595,  ...,  0.1254, -0.4387, -0

In [83]:
# 函数3 : torch.nn.Softplut(beta=1,threshold=20)
# SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive.
a = np.linspace(-10.0,10.0,num=1000)


In [4]:
H=16
head_weights = torch.randn((H,))

In [5]:
head_weights.view(*())

tensor([ 0.6482, -0.9007, -0.7718, -0.1868, -0.4537,  0.7549, -0.1214, -0.6609,
         2.3571, -0.3103, -2.0868,  0.8933, -1.5520, -0.1901,  0.5534, -1.4012])

In [38]:
# class 内部的装饰器 1.
# 一共有三种class method的类型: instance method; statistic method; class method
# instance method绑定的对象是instance，必须先创建一个对象之后才能调用该方法(函数定义时形参是self)
# static method创建的时候不需要添加self 形参,可以在不创建对象的时候就调用该函数，避免了创建对象所带来的内存消耗;另一个角度理解是static method本身就是一个普通的函数，只是挂靠在class的这个命名空间底下而言，不需要实例化对象
# classmethod  类方法主要是为了优雅的创建多个对象
class MyDog:
    def __init__(self) -> None:
        pass
    def instance_method(self): # self 形参接受的是实例化后对象的IP地址
        print("My dog is an instance method")
    
    @staticmethod
    def static_method():
        print("My dog is a static method")
    
    @classmethod
    def class_method(cls): # cls参数表示这个类本身
        print(f"My dog is a class method,given {cls}")


My dog is a class method,given <class '__main__.MyDog'>


In [55]:
# class 内部的装饰器 2.
# @property 将一个method的调用方法变成属性的调用方法
# @property 将某个方法变成只读属性
# @xxx.setter 把该属性进行赋值
# 注意属性的方法名不要和属性名重名,否则self.birth调用时会递归的找self.birth这个函数造成栈溢出
class Student:
    def __init__(self) -> None:
        self._birth = 10
    @property
    def birth(self):
        return self._birth

    @birth.setter
    def birth(self,value):
        self._birth = value
a = Student()
a.birth = 100

### Module 0. Rigid Transformation

如果对坐标做变换，如何存储transformation

1. Rotation 是一个类似张量的 Rotation 对象，每一个point代表一个Rotation
2. 

In [88]:
from typing import Optional,Tuple
# Q1 : 3*3 旋转矩阵到旋转四元数的变换？3*3旋转正交矩阵自由度不是只有3?
# identity rotation matrix eye(3)
def identity_rot_mats(
    batch_dims: Tuple[int], 
    dtype: Optional[torch.dtype] = None, 
    device: Optional[torch.device] = None, 
    requires_grad: bool = True,
) -> torch.Tensor:
    rots = torch.eye(
        3, dtype=dtype, device=device, requires_grad=requires_grad
    )
    rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
    rots = rots.expand(*batch_dims, -1, -1)
    rots = rots.contiguous()

    return rots

# identity translation vector [0,0,0]
def identity_trans(
    batch_dims: Tuple[int], 
    dtype: Optional[torch.dtype] = None,
    device: Optional[torch.device] = None, 
    requires_grad: bool = True,
) -> torch.Tensor:
    trans = torch.zeros(
        (*batch_dims, 3), 
        dtype=dtype, 
        device=device, 
        requires_grad=requires_grad
    )
    return trans

# identity quaternion [1, 0, 0, 0]
def identity_quats(
    batch_dims: Tuple[int], 
    dtype: Optional[torch.dtype] = None,
    device: Optional[torch.device] = None, 
    requires_grad: bool = True,
) -> torch.Tensor:
    quat = torch.zeros(
        (*batch_dims, 4), 
        dtype=dtype, 
        device=device, 
        requires_grad=requires_grad
    )

    with torch.no_grad():
        # quat [* , 4] 
        quat[..., 0] = 1

    return quat
_quat_elements = ["a", "b", "c", "d"]
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}


### Module 1. IPA
PS : 
1. H最开始肯定放在N_res后面,但是在attention score中一定是放在N_res前面，也就是H,N_res,N_res
2. key,value是二维的情况下怎么画图


总结：
- 两套qkv,全部从single chain representation中得到
  - qkv 为标准的多头注意力, 将c_s 投影为H * C_hidden
  - q_pts, k_pts, v_pts 显示的构建了坐标的注意力机制,将c_s 投影为H*P_qk/P_v * 3,之后直接用Rigid transforamtion变换这些点点坐标
- 三个attention score (affinities) 和一个bias
  - pair-bias 由pair-representation直接将c_z投影到H的维度进行加和
  - dot-product 由标准的多头注意力机制qkv计算得来
  - squared distance affinities 由点注意力机制构建 $o_{pt} = -\frac{\gamma ^h W_c}{2}\sum\limits_{p}||q^{hp}_{i}-k^{hp}_{j} ||^2$
    - 具体计算步骤就是输入一个[*, N_res, H, P_q, 3]的query q_pts和key k_pts
    - pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) query key氨基酸每个原子中的xyz坐标之差，利用brodcast来进行对称
    - pt_att = pt_att** 2得到xyz距离平方
    - pt_att = torch.sum(pt_att,dim=-1) 求和xyz坐标，得到每个residue的每个原子的坐标距离
      - 源代码里是这样写的 sum(*torch.unbind(pt_att,dim=-1)),似乎没什么差别
    - 加入weight
    - 加入squared mask,常规操作:
      - mask [*, N_res]
      - -inf * (mask.unsqueeze(-1) * mask.unsqueeze(-2)).unsqueeze(-3) -> [*, 1, N_res, N_res]
- 四个output value, 最后concat h,q
  - pair-representation value
    - [ *, H, N_res, N_res].transpose(-2,-3) @ [ *, N_res, N_res, c_z] -> [ *, N_res, H, c_z] -> [ *, N_res, H * c_z]
  - dot-product value
    - [ *, H, N_res, N_res] matmaul [ \*, H, N_res, C_hidden] -> [ \*, H, N_res ,C_hidden] -> [ *, N_res, H * C_hidden]
  - o_pt (point attention value)
    - input
      - v_pts : [ *, N_res, H, P_v, 3] 
      - a     : [ *, H, N_res, N_res]
    - output : o_pt : [ *, N_res, H, P_v, 3], split into 3 points : [ *, N_res, H * P_v] * 3
  - o_pt_norm
    - v_pts : $\sqrt{v_{pts}^2}$
- 1个输出
  - 更新之后的s

In [None]:
# From OpenFold
class InvariantPointAttention(nn.Module):
    """
    Implements Algorithm 22.
    """
    def __init__(
        self,
        c_s: int,
        c_z: int,
        c_hidden: int,
        no_heads: int,
        no_qk_points: int,
        no_v_points: int,
        inf: float = 1e5,
        eps: float = 1e-8,
    ):
        """
        Args:
            c_s:
                Single representation channel dimension
            c_z:
                Pair representation channel dimension
            c_hidden:
                Hidden channel dimension
            no_heads:
                Number of attention heads
            no_qk_points:
                Number of query/key points to generate
            no_v_points:
                Number of value points to generate
        """
        super(InvariantPointAttention, self).__init__()

        self.c_s = c_s
        self.c_z = c_z
        self.c_hidden = c_hidden
        self.no_heads = no_heads
        self.no_qk_points = no_qk_points
        self.no_v_points = no_v_points
        self.inf = inf
        self.eps = eps

        # These linear layers differ from their specifications in the
        # supplement. There, they lack bias and use Glorot initialization.
        # Here as in the official source, they have bias and use the default
        # Lecun initialization.
        hc = self.c_hidden * self.no_heads
        self.linear_q = Linear(self.c_s, hc)
        self.linear_kv = Linear(self.c_s, 2 * hc) # 2*hc 表示一个是key,一个是value

        hpq = self.no_heads * self.no_qk_points * 3 # 表示point attention 中 query 需要投影到的维度
        self.linear_q_points = Linear(self.c_s, hpq)

        hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3 # 表示point attention 中 key + value 需要投影到的维度
        self.linear_kv_points = Linear(self.c_s, hpkv)

        hpv = self.no_heads * self.no_v_points * 3

        self.linear_b = Linear(self.c_z, self.no_heads)

        self.head_weights = nn.Parameter(torch.zeros((no_heads))) # head weights 表示一个可学习的gamma h参数用来控制point attention affinity中的不同head
        ipa_point_weights_init_(self.head_weights)

        concat_out_dim = self.no_heads * (
            self.c_z + self.c_hidden + self.no_v_points * 4
        ) # 最终拼接时的dimension, c_z表示pair representation的output,  c_hidden 表示从single representaion计算的output, 4个no_v_points表示3个坐标(此处concat p)+1个norm；外侧乘H表示concat (H)
        self.linear_out = Linear(concat_out_dim, self.c_s, init="final")

        self.softmax = nn.Softmax(dim=-1)
        self.softplus = nn.Softplus()

    def forward(
        self,
        s: torch.Tensor,
        z: Optional[torch.Tensor],
        r: Rigid,
        mask: torch.Tensor,
        inplace_safe: bool = False,
        _offload_inference: bool = False,
        _z_reference_list: Optional[Sequence[torch.Tensor]] = None,
    ) -> torch.Tensor:
        """
        Args:
            s:
                [*, N_res, C_s] single representation
            z:
                [*, N_res, N_res, C_z] pair representation
            r:
                [*, N_res] transformation object
            mask:
                [*, N_res] mask
        Returns:
            [*, N_res, C_s] single representation update
        """
        if(_offload_inference and inplace_safe):
            z = _z_reference_list
        else:
            z = [z]
       
        #######################################
        # Generate scalar and point activations
        #######################################
        # [*, N_res, H * C_hidden]
        q = self.linear_q(s)
        kv = self.linear_kv(s)

        # [*, N_res, H, C_hidden]
        q = q.view(q.shape[:-1] + (self.no_heads, -1))

        # [*, N_res, H, 2 * C_hidden]
        kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))

        # [*, N_res, H, C_hidden]
        k, v = torch.split(kv, self.c_hidden, dim=-1)

        # [*, N_res, H * P_q * 3]
        q_pts = self.linear_q_points(s)

        # This is kind of clunky, but it's how the original does it
        # 注意Rigid instance 做slicing的时候是对*做的(batch dimension)，而不用考虑最后的3*3
        q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
        # [*, N_res, H * P_q, 3]
        q_pts = torch.stack(q_pts, dim=-1)
        q_pts = r[..., None].apply(q_pts)

        # [*, N_res, H, P_q, 3]
        q_pts = q_pts.view(
            q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
        )

        # [*, N_res, H * (P_q + P_v) * 3]
        kv_pts = self.linear_kv_points(s)

        # [*, N_res, H * (P_q + P_v), 3]
        kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
        kv_pts = torch.stack(kv_pts, dim=-1)
        # r[...,None] -> 
        # Rigid : _rot       = [ *, N_res, 1, 3, 3]  
        #         _trans     = [ *, N_res, 1, 3]
        kv_pts = r[..., None].apply(kv_pts)


        # [*, N_res, H, (P_q + P_v), 3]
        kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))

        # [*, N_res, H, P_q/P_v, 3]
        k_pts, v_pts = torch.split(
            kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
        )

        ##########################
        # Compute attention scores
        ##########################
        # [*, N_res, N_res, H]
        b = self.linear_b(z[0])
        
        if(_offload_inference):
            assert(sys.getrefcount(z[0]) == 2)
            z[0] = z[0].cpu()

        # [*, H, N_res, N_res]
        if(is_fp16_enabled()):
            with torch.cuda.amp.autocast(enabled=False):
                a = torch.matmul(
                    permute_final_dims(q.float(), (1, 0, 2)),  # [*, H, N_res, C_hidden]
                    permute_final_dims(k.float(), (1, 2, 0)),  # [*, H, C_hidden, N_res]
                )
        else:
            # q,k,v original [*, N_res, H, C_hidden]
            a = torch.matmul(
                permute_final_dims(q, (1, 0, 2)),  # [*, H, N_res, C_hidden]
                permute_final_dims(k, (1, 2, 0)),  # [*, H, C_hidden, N_res]
            ) # attention affinity [*, H, N_res, N_res] 
        
        a *= math.sqrt(1.0 / (3 * self.c_hidden))
        a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
        # bias from  [*, N_res, N_res, H] -> [*, H, N_res, N_res]

        # [*, N_res, N_res, H, P_q, 3]
        # Original q_pts,k_pts,v_pts after global transformation  -> [*, N_res, H, P_qk, 3]
        # broadcast from [*, N_res, 1, H, P_q, 3] - [*, N_res, 1, H, P_q, 3]
        pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
        if(inplace_safe):
            pt_att *= pt_att
        else:
            pt_att = pt_att ** 2

        # [*, N_res, N_res, H, P_q]
        pt_att = sum(torch.unbind(pt_att, dim=-1)) # 等价于 torch.sum(pt_att,dim=-1), 求两两点之间的膜长
        head_weights = self.softplus(self.head_weights).view(
            *((1,) * len(pt_att.shape[:-2]) + (-1, 1))
        ) # [*, 1, 1, H, 1]
        head_weights = head_weights * math.sqrt(
            1.0 / (3 * (self.no_qk_points * 9.0 / 2))
        )
        if(inplace_safe):
            pt_att *= head_weights
        else:
            pt_att = pt_att * head_weights

        # [*, N_res, N_res, H]
        pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
        # [*, N_res, N_res]
        # Broadcast from [*, N_res, 1] * [*, 1, N_res]
        square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
        square_mask = self.inf * (square_mask - 1)

        # [*, H, N_res, N_res]
        pt_att = permute_final_dims(pt_att, (2, 0, 1))
        
        if(inplace_safe):
            a += pt_att
            del pt_att
            a += square_mask.unsqueeze(-3)
            # in-place softmax
            attn_core_inplace_cuda.forward_(
                a,
                reduce(mul, a.shape[:-1]),
                a.shape[-1],
            )
        else:
            # a from dot-product affinities and bias , will be added point attention 
            a = a + pt_att 
            a = a + square_mask.unsqueeze(-3)
            a = self.softmax(a)

        ################
        # Compute output
        ################
        # [*, N_res, H, C_hidden]
        o = torch.matmul(
            a, v.transpose(-2, -3).to(dtype=a.dtype)
        ).transpose(-2, -3)

        # [*, N_res, H * C_hidden]
        o = flatten_final_dims(o, 2)

        # [*, H, 3, N_res, P_v] 
        if(inplace_safe):
            v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
            o_pt = [
                torch.matmul(a, v.to(a.dtype)) 
                for v in torch.unbind(v_pts, dim=-3)
            ]
            o_pt = torch.stack(o_pt, dim=-3)
        else:
            o_pt = torch.sum(
                (
                    a[..., None, :, :, None]
                    * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
                ),
                dim=-2,
            )
            # step by step
            # a[...,None,:,:,None] from [*, H, N_res, N_res]                -> [*, H, 1, N_res, N_res,    1]
            # v_pts from [*, N_res, H, P_v, 3] ->  [*, H, 3, N_res, P_v]    -> [*, H, 3, 1,     N_res,  P_v] 
            # 与 [*, H, N_res, N_res] @ [*, H, N_res, 3*P_v] 然后再把最后一个维度展开有啥区别?

        # [*, N_res, H, P_v, 3]
        o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
        # _rot : [*, N_res, 1, 1, 3, 3], _trans : [*, N_res, 1, 1, 3]
        o_pt = r[..., None, None].invert_apply(o_pt)

        # [*, N_res, H * P_v]
        o_pt_norm = flatten_final_dims(
            torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
        )

        # [*, N_res, H * P_v, 3]
        o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)

        if(_offload_inference):
            z[0] = z[0].to(o_pt.device)

        # [*, N_res, H, C_z]
        # [*, N_res, H, N_res] @ [*, N_res, N_res, c_z] -> [*, N_res, H, c_z]
        o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))

        # [*, N_res, H * C_z]
        o_pair = flatten_final_dims(o_pair, 2)

        # [*, N_res, C_s]
        s = self.linear_out(
            torch.cat(
                (o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
            ).to(dtype=z[0].dtype)
            # concat list : H * (4*P_v + C_z + c_hidden)
            # o : [*, N_res, H*c_hidden]
            # * torch.unbind(o_pt, dim=-1),解压每一个坐标的维度 
            # - o_pt_p1 [*, N_res, H*P_v]
            # - o_pt_p2 [*, N_res, H*P_v]
            # - o_pt_p3 [*, N_res, H*P_v]
            # o_pt_norm : [*, N_res, H*P_v]
            # o_pair    : [*, N_res, H*C_z]
        )
        
        return s