# 原理
从一个完美的蛋白质序列开始
逐步引入随机突变(添加噪声)
学习如何从被破坏的序列中恢复原始序列

![diff](diff.png)

![reason](reason.png)

In [1]:
import torch 
import torch.nn as nn

class SimpleDiffusion(nn.Module):
    def __init__(self,seq_length,num_steps):
        super().__init__()
        self.seq_length=seq_length
        self.num_steps=num_steps
        self.betas=torch.linspace(1e-4,0.02,num_steps)

    def forward(self,x_0):
        t=torch.randint(0,self.num_steps,(x_0.shape[0],))
        noise=torch.rand_like(x_0)
        alpha_bar=torch.cumprod(1-self.betas,dim=0)[t]
        x_t=torch.sqrt(alpha_bar)[:,None]*x_0+torch.sqrt(1-alpha_bar)[:,None]*noise
        return x_t,noise
    
model=SimpleDiffusion(seq_length=100,num_steps=1000)
x_0=torch.randn(32,100)
x_t,noise=model(x_0)

print(f"x_0 shape:{x_0.shape}")
print(f"x_t shape:{x_t.shape}")
print(f"noise shape:{noise.shape}")

x_0 shape:torch.Size([32, 100])
x_t shape:torch.Size([32, 100])
noise shape:torch.Size([32, 100])


# Score-Based生成模型

![score](energe.png)

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ScoreNet(nn.Module):
    def __init__(self,dim):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(dim,128),
            nn.ReLU(),
            nn.Linear(128,128),
            nn.ReLU(),
            nn.Linear(128,dim)
        )

    def forward(self,x,t):
        return self.net(x)
    
def langevin_dynamics(score_net,x,n_steps=100,step_size=0.01):
    for _ in range(n_steps):
        noise=torch.rand_like(x)
        grad=score_net(x,None)
        x=x+step_size*grad+torch.sqrt(torch.tensor(2*step_size))*noise
    return x

dim=20
score_net=ScoreNet(dim)
x=torch.randn(10,dim)
generated_sequences=langevin_dynamics(score_net,x)

print(f"Generated sequences shape:{generated_sequences.shape}")
print(f"Sample sequence:\n{generated_sequences[0]}")

Generated sequences shape:torch.Size([10, 20])
Sample sequence:
tensor([7.1661, 5.9199, 7.7893, 4.7096, 8.7382, 7.1403, 7.8489, 9.1252, 6.0345,
        6.8274, 8.1721, 8.1618, 3.8386, 5.0177, 5.3408, 8.8583, 7.7496, 6.3388,
        6.4851, 8.6094], grad_fn=<SelectBackward0>)


# 基于扩散模型的蛋白质设计

In [9]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F

class SimpleDiffusion(nn.Module):
    def __init__(self,seq_length,num_steps):
        super().__init__()
        self.seq_length=seq_length
        self.num_steps=num_steps

        self.betas=torch.linspace(1e-4,0.02,num_steps)
        self.alphas=1.-self.betas
        self.alphas_cumprod=torch.cumprod(self.alphas,dim=0)

        self.model=nn.Sequential(
            nn.Linear(seq_length,128),
            nn.ReLU(),
            nn.Linear(128,128),
            nn.ReLU(),
            nn.Linear(128,seq_length)
        )

    def forward(self,x_0,t):
        noise=torch.rand_like(x_0)
        alpha_cumprod=self.alphas_cumprod[t]
        x_t=torch.sqrt(alpha_cumprod)[:,None]*x_0+torch.sqrt(1-alpha_cumprod)[:,None]*noise
        predict_noise=self.model(x_t)

        return predict_noise,noise
    
    @torch.no_grad()
    def sample(self,batch_size):
        x=torch.randn(batch_size,self.seq_length)

        for t in reversed(range(self.num_steps)):
            z=torch.rand_like(x) if t>0 else 0

            alpha=self.alphas[t]
            alpha_cumprod=self.alphas_cumprod[t]
            alpha_cumprod_prev=self.alphas_cumprod[t-1] if t >0 else 1.0

            predicted_noise=self.model(x)
            beta_tlide=self.betas[t]*(1.-alpha_cumprod_prev)/(1.-alpha_cumprod)

            x=(1/torch.sqrt(alpha))*(x-((1-alpha)/torch.sqrt(1-alpha_cumprod))*predicted_noise)+torch.sqrt(beta_tlide)*z

            return x
        
seq_length=100
num_steps=1000
model=SimpleDiffusion(seq_length,num_steps)

x_0=torch.randn(32,seq_length)
t=torch.randint(0,num_steps,(32,))

predicted_noise,noise=model(x_0,t)

print(f"input shape:{x_0.shape}")
print(f"predict noise shape:{predicted_noise.shape}")

generated_sequences=model.sample(5)
print(f"generated_sequences shape:{generated_sequences.shape}")
print(f"sample sequence:\n{generated_sequences[0]}")

input shape:torch.Size([32, 100])
predict noise shape:torch.Size([32, 100])
generated_sequences shape:torch.Size([5, 100])
sample sequence:
tensor([ 0.7650, -0.7414,  0.6749,  0.7513,  0.2382,  0.6098, -1.3936, -0.6235,
        -0.2113, -0.0903,  1.9056,  0.3702,  1.2035, -0.6425,  0.0374, -0.2673,
         1.1183,  1.0885,  0.8043, -0.0401,  0.7595,  0.2840,  0.4250, -0.7042,
        -0.0581, -0.9208, -0.3186,  0.1742, -0.2153, -0.9581,  0.3822, -0.5750,
         0.2951,  0.5060,  1.6777, -0.1411, -0.2222,  0.7055,  0.2090, -2.3242,
         0.1939,  0.4066, -0.3884, -2.9212, -0.6391,  0.4944, -2.1755, -0.9518,
         2.0915, -0.6694, -0.4965,  0.1828, -2.4428,  2.0834,  0.4515,  0.6233,
        -1.7486,  0.9551,  0.7655, -0.5323,  0.1307, -0.9355, -0.4072, -2.3063,
         0.3661, -1.1529, -0.1884,  0.1885,  1.1705,  0.1068,  1.6158, -1.2301,
         0.8702, -0.5995,  0.6600,  0.8932,  0.0264, -2.1870, -0.1651,  1.0527,
        -1.2005,  0.3289,  0.6678,  0.8852,  1.9004,  0.6726