# 原理
从一个完美的蛋白质序列开始
逐步引入随机突变(添加噪声)
学习如何从被破坏的序列中恢复原始序列

![diff](diff.png)

![reason](reason.png)

In [1]:
import torch 
import torch.nn as nn

class SimpleDiffusion(nn.Module):
    def __init__(self,seq_length,num_steps):
        super().__init__()
        self.seq_length=seq_length
        self.num_steps=num_steps
        self.betas=torch.linspace(1e-4,0.02,num_steps)

    def forward(self,x_0):
        t=torch.randint(0,self.num_steps,(x_0.shape[0],))
        noise=torch.rand_like(x_0)
        alpha_bar=torch.cumprod(1-self.betas,dim=0)[t]
        x_t=torch.sqrt(alpha_bar)[:,None]*x_0+torch.sqrt(1-alpha_bar)[:,None]*noise
        return x_t,noise
    
model=SimpleDiffusion(seq_length=100,num_steps=1000)
x_0=torch.randn(32,100)
x_t,noise=model(x_0)

print(f"x_0 shape:{x_0.shape}")
print(f"x_t shape:{x_t.shape}")
print(f"noise shape:{noise.shape}")

x_0 shape:torch.Size([32, 100])
x_t shape:torch.Size([32, 100])
noise shape:torch.Size([32, 100])


# Score-Based生成模型

![score](energe.png)

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ScoreNet(nn.Module):
    def __init__(self,dim):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(dim,128),
            nn.ReLU(),
            nn.Linear(128,128),
            nn.ReLU(),
            nn.Linear(128,dim)
        )

    def forward(self,x,t):
        return self.net(x)
    
def langevin_dynamics(score_net,x,n_steps=100,step_size=0.01):
    for _ in range(n_steps):
        noise=torch.rand_like(x)
        grad=score_net(x,None)
        x=x+step_size*grad+torch.sqrt(torch.tensor(2*step_size))*noise
    return x

dim=20
score_net=ScoreNet(dim)
x=torch.randn(10,dim)
generated_sequences=langevin_dynamics(score_net,x)

print(f"Generated sequences shape:{generated_sequences.shape}")
print(f"Sample sequence:\n{generated_sequences[0]}")

Generated sequences shape:torch.Size([10, 20])
Sample sequence:
tensor([7.1661, 5.9199, 7.7893, 4.7096, 8.7382, 7.1403, 7.8489, 9.1252, 6.0345,
        6.8274, 8.1721, 8.1618, 3.8386, 5.0177, 5.3408, 8.8583, 7.7496, 6.3388,
        6.4851, 8.6094], grad_fn=<SelectBackward0>)
