# 原理
从一个完美的蛋白质序列开始
逐步引入随机突变(添加噪声)
学习如何从被破坏的序列中恢复原始序列

![diff](diff.png)

In [4]:
import torch 
import torch.nn as nn

class SimpleDiffusion(nn.Module):
    def __init__(self,seq_length,num_steps):
        super().__init__()
        self.seq_length=seq_length
        self.num_steps=num_steps
        self.betas=torch.linspace(1e-4,0.02,num_steps)

    def forward(self,x_0):
        t=torch.randint(0,self.num_steps,(x_0.shape[0],))
        noise=torch.rand_like(x_0)
        alpha_bar=torch.cumprod(1-self.betas,dim=0)[t]
        x_t=torch.sqrt(alpha_bar)[:,None]*x_0+torch.sqrt(1-alpha_bar)[:,None]*noise
        return x_t,noise
    
model=SimpleDiffusion(seq_length=100,num_steps=1000)
x_0=torch.randn(32,100)
x_t,noise=model(x_0)

print(f"x_0 shape:{x_0.shape}")
print(f"x_t shape:{x_t.shape}")
print(f"noise shape:{noise.shape}")

x_0 shape:torch.Size([32, 100])
x_t shape:torch.Size([32, 100])
noise shape:torch.Size([32, 100])
