<a href="https://colab.research.google.com/github/zoujiulong/Multimodal/blob/main/DDPM_DDIM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q omegaconf torchvision tqdm matplotlib

In [None]:
import math
import torch
import torch.nn as nn

# class TimeEmedding(nn.Module):
#     def __init__(self,in_channel):
#         super().__init__()
#         self.in_channel=in_channel
#         self.proj1=nn.Linear(in_channel//4,in_channel)
#         self.act1=nn.ELU()
#         self.proj2=nn.Linear(in_channel,in_channel)
#     def forward(self,t):
#         print('t',t.shape)
#         # t_emb=torch.empty((*t.shape[:-1],self.in_channel//4))
#         print('t_emb',t_emb.shape)
#         emb=math.log(10000)/(self.in_channel//8)
#         emb=t[:,None]*torch.exp(torch.arange(self.in_channel//8)*-emb)
#         print('emb',emb.shape)
#         t_emb[:,:,:,::2]=emb.sin()
#         t_emb[:,:,:,1::2]=emb.cos()
#         t_emb=self.proj1(t_emb)
#         t_emb=self.act1(t_emb)
#         t_emb=self.proj2(t_emb)
#         return t_emb
class Swish(nn.Module):
    """
    ### Swish activation function

    $$x \cdot \sigma(x)$$
    """

    def forward(self, x):
        return x * torch.sigmoid(x)
class TimeEmbedding(nn.Module):
    """
    ### Embeddings for $t$
    """

    def __init__(self, n_channels: int):
        """
        * `n_channels` is the number of dimensions in the embedding
        """
        super().__init__()
        self.n_channels = n_channels
        # First linear layer
        self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
        # Activation
        self.act = Swish()
        # Second linear layer
        self.lin2 = nn.Linear(self.n_channels, self.n_channels)

    def forward(self, t: torch.Tensor):
        # Create sinusoidal position embeddings
        # [same as those from the transformer](../../transformers/positional_encoding.html)
        #
        # \begin{align}
        # PE^{(1)}_{t,i} &= sin\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) \\
        # PE^{(2)}_{t,i} &= cos\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg)
        # \end{align}
        #
        # where $d$ is `half_dim`
        half_dim = self.n_channels // 8
        emb = math.log(10_000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=1)

        # Transform with the MLP
        emb = self.act(self.lin1(emb))
        emb = self.lin2(emb)

        #
        return emb

class Attention(nn.Module):
    def __init__(self,head_num,ch,p=0.1):
        super().__init__()
        self.head_num=head_num
        self.head_dim=ch//head_num
        self.qkv=nn.Linear(ch,3*ch)
        self.softmax=nn.Softmax()
        self.drop1=nn.Dropout(p)
        self.linear=nn.Linear(ch,ch)
        self.drop2=nn.Dropout(p)
    def forward(self,x):
        bs,ch,h,w=x.shape
        out=x.view(bs,ch,-1).permute(0,2,1)
        out=self.qkv(out).view(bs,-1,self.head_num,3*self.head_dim).permute(0,2,1,3)
        q,k,v=torch.chunk(out,3,dim=-1)
        k=torch.transpose(k,-2,-1)
        inter=(q@k)/(self.head_dim)**0.5
        inter=self.drop1(self.softmax(inter))
        o=inter@v
        o=o.permute(0,2,1,3).contiguous()
        o=o.view(bs,h*w,-1)
        o=self.drop2(self.linear(o))
        o=o.permute(0,2,1).view(bs,-1,h,w)+x
        return o

class ResidualBlock(nn.Module):
    def __init__(self,in_channel,out_channel,time_channel,n_groups=32,p=0.1):
        super().__init__()
        self.in_ch=in_channel
        # self.ins1=nn.InstanceNorm2d(in_channel)
        self.ins1=nn.GroupNorm(n_groups, in_channel)
        # self.act1=nn.ELU()
        self.act1=Swish()
        self.conv1=nn.Conv2d(in_channel,out_channel,kernel_size=(3,3),padding=1)
        # self.ins2=nn.InstanceNorm2d(in_channel)
        self.ins2=nn.GroupNorm(n_groups, out_channel)
        self.act2=Swish()
        self.conv2=nn.Conv2d(out_channel,out_channel,kernel_size=(3,3),padding=1)
        if in_channel!=out_channel:
            self.residual=nn.Conv2d(in_channel,out_channel,kernel_size=(1,1))
        else:
            self.residual=nn.Identity()
        self.time_emb = nn.Linear(time_channel, out_channel)
        self.time_act = Swish()
        self.dropout = nn.Dropout(p)
    def forward(self,img,t):
        out=self.conv1(self.act1(self.ins1(img)))
        print('out',out.shape)
        # out+=self.time_emb(self.time_act(t)).permute(0,3,1,2)
        if t is not None:
          out+=self.time_emb(self.time_act(t))[:, :, None, None]
        out=self.conv2(self.dropout(self.act2(self.ins2(out))))
        out+=self.residual(img)
        return out

class DownBlock(nn.Module):
    def __init__(self,in_channel,out_channel,time_channel,head_num,has_attention=False):
        super().__init__()
        self.residual=ResidualBlock(in_channel,out_channel,time_channel)
        if has_attention:
            self.attention=Attention(head_num,out_channel)
        else:
            self.attention=nn.Identity()
    def forward(self,img,t):
        out=self.residual(img,t)
        out=self.attention(out)
        return out

class MiddleBlock(nn.Module):
    def __init__(self,in_channel,out_channel,time_channel,head_num):
        super().__init__()
        self.residual1=ResidualBlock(in_channel,out_channel,time_channel)
        self.attention=Attention(head_num,out_channel)
        self.residual2=ResidualBlock(out_channel,out_channel,time_channel)
    def forward(self,img,t):
        out=self.residual1(img,t)
        out=self.attention(out)
        print('attention')
        out=self.residual2(out,t)
        return out

class UpBlock(nn.Module):
    def __init__(self,in_channel,out_channel,time_channel,head_num,has_attention=False):
        super().__init__()
        self.residual=ResidualBlock(in_channel+out_channel,out_channel,time_channel)
        if has_attention:
            self.attention=Attention(head_num,out_channel)
        else:
            self.attention=nn.Identity()
    def forward(self,img,t):
        out=self.residual(img,t)
        out=self.attention(out)
        return out

class DownSampleBlock(nn.Module):
    def __init__(self,in_channel):
        super().__init__()
        self.conv=nn.Conv2d(in_channel,in_channel,kernel_size=(3,3),stride=(2,2),padding=(1,1))
    def forward(self,img,t):
        _=t
        return self.conv(img)

class UpSampleBlock(nn.Module):
    def __init__(self,in_channel):
        super().__init__()
        self.conv=nn.ConvTranspose2d(in_channel,in_channel,kernel_size=(4,4),stride=(2,2),padding=(1,1))
    def forward(self,img,t):
        _=t
        return self.conv(img)

class Unet(nn.Module):
    def __init__(self,layers,img_ch=1,in_channel=64,blocks=2,head_num=8):
        super().__init__()
        down=[]
        coe=(1,2,2,4)
        self.conv1=nn.Conv2d(img_ch,in_channel,kernel_size=(3,3),padding=(1,1))
        self.time=TimeEmbedding(in_channel*4)
        n_channel=in_channel
        for i in range(layers):
            out_ch=coe[i]*in_channel
            # print(out_ch)
            # print(in_channel)
            for _ in range(blocks):
                down.append(DownBlock(in_channel,out_ch,n_channel*4,head_num))
                in_channel=out_ch
            # print(in_channel)
            # print(out_ch)
            if i<layers-1:
                down.append(DownSampleBlock(out_ch))
        self.down=nn.ModuleList(down)
        self.middle=MiddleBlock(out_ch,out_ch,n_channel*4,head_num)
        up=[]
        in_channel=out_ch
        for i in range(layers):
            out_ch=in_channel
            for _ in range(blocks):
                up.append(UpBlock(in_channel,out_ch,n_channel*4,head_num))
            out_ch=in_channel//coe[layers-1-i]
            up.append(UpBlock(in_channel,out_ch,n_channel*4,head_num))
            in_channel=out_ch
            if i<layers-1:
                up.append(UpSampleBlock(in_channel))
        self.up=nn.ModuleList(up)
        # self.ins=nn.InstanceNorm2d(in_channel)
        self.ins=nn.GroupNorm(8, n_channel)

        # self.act=nn.ELU()
        self.act=Swish()
        self.conv2=nn.Conv2d(in_channel,img_ch,kernel_size=(3,3),padding=(1,1))
    def forward(self,x,t):
        out=self.conv1(x)
        t=self.time(t)
        record=[out]
        for down in self.down:
            out=down(out,t)
            print('out')
            record.append(out)
        out=self.middle(out,t)
        for up in self.up:
            if isinstance(up,UpSampleBlock):
                out=up(out,t)
            else:
                out=torch.cat((out,record.pop()),dim=1)
                out=up(out,t)
        out=self.conv2(self.act(self.ins(out)))
        return out

In [None]:
import torch
import numpy as np
from tqdm.notebook import tqdm
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from torchvision import datasets, transforms
import logging
from torch.utils.data import DataLoader,Dataset,Subset

def custom_collate_fn(batch):
    images=[x[0] for x in batch]
    return torch.stack(images)
# def custom_collate_fn(batch):
#     images=[x[0] for x in batch]
#     return torch.stack(images)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    force=True  # 强制重新设置 logging 配置（Jupyter 已设置过一次）
)

class DDPM(nn.Module):
  def __init__(self,beta_s,beta_e,time_steps,conf_path,img_path):
    self.model_config=OmegaConf.load(conf_path)
    self.img_path=img_path
    super().__init__()
    self.predict_model=Unet(3)
    self.time_steps=time_steps
    beta=torch.linspace(beta_s,beta_e,time_steps).numpy()
    # yes
    alpha=1-beta
    #
    self.cum_prod_alpha=np.cumprod(alpha)
    self.cum_prod_alpha_prev=np.append(1,alpha[:-1])
    self.sqrt_cum_prod_alpha=torch.tensor(np.sqrt(self.cum_prod_alpha),dtype=torch.float32)
    self.sqrt_one_cum_prod_alpha=torch.tensor(np.sqrt(1-self.cum_prod_alpha),dtype=torch.float32)
    self.optim=torch.optim.Adam(self.predict_model.parameters(),lr=2e-4)
    self.posterior_mean_coef1=torch.tensor(beta*np.sqrt(self.cum_prod_alpha_prev)/(1-self.cum_prod_alpha),dtype=torch.float32)
    self.posterior_mean_coef2=torch.tensor((1-self.cum_prod_alpha_prev)*np.sqrt(alpha)/(1-self.cum_prod_alpha),dtype=torch.float32)
    self.sqrt_recip_alphas_cumprod=torch.tensor(np.sqrt(1/self.cum_prod_alpha),dtype=torch.float32)
    self.sqrt_recipm1_alphas_cumprod=torch.tensor(np.sqrt(1/self.cum_prod_alpha-1),dtype=torch.float32)
    self.posterior_variance=torch.tensor(beta*(1-self.cum_prod_alpha_prev)/(1-self.cum_prod_alpha))
    self.posterior_log_variance_clipped=torch.tensor(np.log(np.maximum(self.posterior_variance,1e-20)),dtype=torch.float32)
  def p_sample(self,x,t):
    b=x.shape[0]
    noise=torch.randn_like(x)
    model_mean,_,model_log_var=self.p_mean_variance(x,t,noise)
    nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
    return model_mean + nonzero_mask * (0.5 * model_log_var).exp() * noise

  def q_posterior(self, x_start, x_t, t):
      posterior_mean = (
              self.extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
              self.extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
      )
      posterior_variance = self.extract_into_tensor(self.posterior_variance, t, x_t.shape)
      posterior_log_variance_clipped = self.extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
      return posterior_mean, posterior_variance, posterior_log_variance_clipped

  def predict_start_from_noise(self, x_t, t, noise):
      return (
              self.extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
              self.extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
      )

  def p_mean_variance(self, x, t, clip_denoised: bool):
      model_out = self.predict_model(x, t)
      x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
      model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
      return model_mean, posterior_variance, posterior_log_variance

  def extract_into_tensor(self,a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

  def q_sample(self,x,t,noise):
    return self.extract_into_tensor(self.sqrt_cum_prod_alpha,t,x.shape)*x+self.extract_into_tensor(self.sqrt_one_cum_prod_alpha,t,x.shape)*noise

  def p_loss(self,x,t,noise):
    x_noise=self.diffusion(x,t,noise)
    pred=self.predict_model(x_noise,t)
    loss=torch.nn.functional.mse_loss(pred,noise,reduction='none').mean()
    return loss

  def diffusion(self,x,t,noise):
    bs=x.shape[0]
    x_noise=self.q_sample(x,t,noise)
    return x_noise

  def create_loader(self,dataset='mnist'):
    if dataset=='mnist':
      transform=transforms.Compose([
          transforms.ToTensor(),
      ])
      train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
      small_dataset = Subset(train_dataset, range(1024))

      loader = DataLoader(small_dataset, batch_size=128, shuffle=True,collate_fn=custom_collate_fn)

      # test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    else:
      transform=transforms.Compose([
          transforms.ToTensor(),
          transforms.Resize(280),
          transforms.CenterCrop(256),
      ])
      logging.info('start transform')
      dataset=datasets.ImageFolder(self.img_path,transform=transform)
      logging.info('finish transform')
      bs=self.model_config.run.batch_size
      # num_workers=self.model_config.run.num_workers
      loader=iter(DataLoader(dataset,batch_size=bs,collate_fn=custom_collate_fn,num_workers=1,pin_memory=True))
    return loader
#  ddpm的采样
  @torch.no_grad()
  def sample_loop(self, shape):
    img = torch.randn(shape)  # 确保到GPU上
    for t in reversed(range(self.time_steps)):
      t_batch = torch.full((img.shape[0],), t,dtype=torch.long)
      img = self.p_sample(img, t_batch)
    img = torch.clamp(img, 0, 1)
    return img

  def ddim_p_sample(self,x,t,interval,eta):
    noise=torch.randn_like(x)
    model_out=self.predict_model(x,t)
    x_recon=self.predict_start_from_noise(x,t,model_out)
    t_prev=torch.clamp(t-interval,min=0)
    sigma=eta**2*self.posterior_variance
    return self.extract_into_tensor(self.sqrt_cum_prod_alpha,t_prev,x.shape)*x_recon+self.extract_into_tensor(torch.sqrt((1-torch.from_numpy(self.cum_prod_alpha).float()-sigma.float())),t_prev,x.shape)*model_out+self.extract_into_tensor(torch.sqrt(sigma.float()),t_prev,x.shape)*noise

  @torch.no_grad()
  def ddim_sample_loop(self, shape,interval,eta):
    img = torch.randn(shape)  # 确保到GPU上
    for t in reversed(range(0,self.time_steps,interval)):
      t_batch = torch.full((img.shape[0],), t,dtype=torch.long)
      img = self.ddim_p_sample(img, t_batch,interval,eta)
    img = torch.clamp(img, 0, 1)
    return img

  def train(self,epochs,shape,interval,eta=0):
    loss_list=[]
    for epoch in tqdm(range(epochs),desc='epoch'):
      loader=self.create_loader()
      for x in tqdm(loader):
        loss=self(x)
        loss_value = loss.item() if hasattr(loss, "item") else float(loss)
        logging.info("cur_epoch:{0},loss:{1}".format(epoch,loss_value))
        loss_list.append(loss_value)
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
      img = self.ddim_sample_loop(shape,interval,eta)
      grid=make_grid(img,nrow=shape[0])
      grid = grid.permute(1, 2, 0).cpu().numpy()
      # 显示
      plt.figure(figsize=(12, 6))
      plt.imshow(grid)
      plt.axis('off')  # 不显示坐标轴
      plt.show()
    plt.figure(figsize=(10, 4))
    plt.plot(loss_list)
    plt.title("Loss Curve")
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.show()
  def forward(self,x):
    bs=x.shape[0]
    t=torch.randint(0,self.time_steps,(bs,))
    noise=torch.randn_like(x)
    loss=self.p_loss(x,t,noise)
    return loss

In [None]:
conf_path='/content/drive/MyDrive/config.yaml'
img_path='/content/drive/MyDrive/filter/'
model=DDPM(1e-4,2e-2,100,conf_path,img_path)
model.train(50,(16,1,28,28),20)