<a href="https://colab.research.google.com/github/zoujiulong/Multimodal/blob/main/DiT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q omegaconf torchvision tqdm matplotlib

In [None]:
import math
import torch
import torch.nn as nn

# class TimeEmedding(nn.Module):
#     def __init__(self,in_channel):
#         super().__init__()
#         self.in_channel=in_channel
#         self.proj1=nn.Linear(in_channel//4,in_channel)
#         self.act1=nn.ELU()
#         self.proj2=nn.Linear(in_channel,in_channel)
#     def forward(self,t):
#         print('t',t.shape)
#         # t_emb=torch.empty((*t.shape[:-1],self.in_channel//4))
#         print('t_emb',t_emb.shape)
#         emb=math.log(10000)/(self.in_channel//8)
#         emb=t[:,None]*torch.exp(torch.arange(self.in_channel//8)*-emb)
#         print('emb',emb.shape)
#         t_emb[:,:,:,::2]=emb.sin()
#         t_emb[:,:,:,1::2]=emb.cos()
#         t_emb=self.proj1(t_emb)
#         t_emb=self.act1(t_emb)
#         t_emb=self.proj2(t_emb)
#         return t_emb
class Swish(nn.Module):
    """
    ### Swish activation function

    $$x \cdot \sigma(x)$$
    """

    def forward(self, x):
        return x * torch.sigmoid(x)
class TimeEmbedding(nn.Module):
    """
    ### Embeddings for $t$
    """

    def __init__(self, n_channels: int):
        """
        * `n_channels` is the number of dimensions in the embedding
        """
        super().__init__()
        self.n_channels = n_channels
        # First linear layer
        self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
        # Activation
        self.act = Swish()
        # Second linear layer
        self.lin2 = nn.Linear(self.n_channels, self.n_channels)

    def forward(self, t: torch.Tensor):
        # Create sinusoidal position embeddings
        # [same as those from the transformer](../../transformers/positional_encoding.html)
        #
        # \begin{align}
        # PE^{(1)}_{t,i} &= sin\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) \\
        # PE^{(2)}_{t,i} &= cos\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg)
        # \end{align}
        #
        # where $d$ is `half_dim`
        half_dim = self.n_channels // 8
        emb = math.log(10_000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=1)

        # Transform with the MLP
        emb = self.act(self.lin1(emb))
        emb = self.lin2(emb)

        #
        return emb

class Attention(nn.Module):
    def __init__(self,head_num,ch,p=0.1):
        super().__init__()
        self.head_num=head_num
        self.head_dim=ch//head_num
        self.qkv=nn.Linear(ch,3*ch)
        self.softmax=nn.Softmax()
        self.drop1=nn.Dropout(p)
        self.linear=nn.Linear(ch,ch)
        self.drop2=nn.Dropout(p)
    def forward(self,x):
        bs,ch,h,w=x.shape
        out=x.view(bs,ch,-1).permute(0,2,1)
        out=self.qkv(out).view(bs,-1,self.head_num,3*self.head_dim).permute(0,2,1,3)
        q,k,v=torch.chunk(out,3,dim=-1)
        k=torch.transpose(k,-2,-1)
        inter=(q@k)/(self.head_dim)**0.5
        inter=self.drop1(self.softmax(inter))
        o=inter@v
        o=o.permute(0,2,1,3).contiguous()
        o=o.view(bs,h*w,-1)
        o=self.drop2(self.linear(o))
        o=o.permute(0,2,1).view(bs,-1,h,w)+x
        return o

class ResidualBlock(nn.Module):
    def __init__(self,in_channel,out_channel,time_channel,n_groups=32,p=0.1):
        super().__init__()
        self.in_ch=in_channel
        # self.ins1=nn.InstanceNorm2d(in_channel)
        self.ins1=nn.GroupNorm(n_groups, in_channel)
        # self.act1=nn.ELU()
        self.act1=Swish()
        self.conv1=nn.Conv2d(in_channel,out_channel,kernel_size=(3,3),padding=1)
        # self.ins2=nn.InstanceNorm2d(in_channel)
        self.ins2=nn.GroupNorm(n_groups, out_channel)
        self.act2=Swish()
        self.conv2=nn.Conv2d(out_channel,out_channel,kernel_size=(3,3),padding=1)
        if in_channel!=out_channel:
            self.residual=nn.Conv2d(in_channel,out_channel,kernel_size=(1,1))
        else:
            self.residual=nn.Identity()
        self.time_emb = nn.Linear(time_channel, out_channel)
        self.time_act = Swish()
        self.dropout = nn.Dropout(p)
    def forward(self,img,t):
        out=self.conv1(self.act1(self.ins1(img)))
        print('out',out.shape)
        # out+=self.time_emb(self.time_act(t)).permute(0,3,1,2)
        if t is not None:
          out+=self.time_emb(self.time_act(t))[:, :, None, None]
        out=self.conv2(self.dropout(self.act2(self.ins2(out))))
        out+=self.residual(img)
        return out

class DownBlock(nn.Module):
    def __init__(self,in_channel,out_channel,time_channel,head_num,has_attention=False):
        super().__init__()
        self.residual=ResidualBlock(in_channel,out_channel,time_channel)
        if has_attention:
            self.attention=Attention(head_num,out_channel)
        else:
            self.attention=nn.Identity()
    def forward(self,img,t):
        out=self.residual(img,t)
        out=self.attention(out)
        return out

class MiddleBlock(nn.Module):
    def __init__(self,in_channel,out_channel,time_channel,head_num):
        super().__init__()
        self.residual1=ResidualBlock(in_channel,out_channel,time_channel)
        self.attention=Attention(head_num,out_channel)
        self.residual2=ResidualBlock(out_channel,out_channel,time_channel)
    def forward(self,img,t):
        out=self.residual1(img,t)
        out=self.attention(out)
        print('attention')
        out=self.residual2(out,t)
        return out

class UpBlock(nn.Module):
    def __init__(self,in_channel,out_channel,time_channel,head_num,has_attention=False):
        super().__init__()
        self.residual=ResidualBlock(in_channel+out_channel,out_channel,time_channel)
        if has_attention:
            self.attention=Attention(head_num,out_channel)
        else:
            self.attention=nn.Identity()
    def forward(self,img,t):
        out=self.residual(img,t)
        out=self.attention(out)
        return out

class DownSampleBlock(nn.Module):
    def __init__(self,in_channel):
        super().__init__()
        self.conv=nn.Conv2d(in_channel,in_channel,kernel_size=(3,3),stride=(2,2),padding=(1,1))
    def forward(self,img,t):
        _=t
        return self.conv(img)

class UpSampleBlock(nn.Module):
    def __init__(self,in_channel):
        super().__init__()
        self.conv=nn.ConvTranspose2d(in_channel,in_channel,kernel_size=(4,4),stride=(2,2),padding=(1,1))
    def forward(self,img,t):
        _=t
        return self.conv(img)

class Unet(nn.Module):
    def __init__(self,layers,img_ch=1,in_channel=64,blocks=2,head_num=8):
        super().__init__()
        down=[]
        coe=(1,2,2,4)
        self.conv1=nn.Conv2d(img_ch,in_channel,kernel_size=(3,3),padding=(1,1))
        self.time=TimeEmbedding(in_channel*4)
        n_channel=in_channel
        for i in range(layers):
            out_ch=coe[i]*in_channel
            # print(out_ch)
            # print(in_channel)
            for _ in range(blocks):
                down.append(DownBlock(in_channel,out_ch,n_channel*4,head_num))
                in_channel=out_ch
            # print(in_channel)
            # print(out_ch)
            if i<layers-1:
                down.append(DownSampleBlock(out_ch))
        self.down=nn.ModuleList(down)
        self.middle=MiddleBlock(out_ch,out_ch,n_channel*4,head_num)
        up=[]
        in_channel=out_ch
        for i in range(layers):
            out_ch=in_channel
            for _ in range(blocks):
                up.append(UpBlock(in_channel,out_ch,n_channel*4,head_num))
            out_ch=in_channel//coe[layers-1-i]
            up.append(UpBlock(in_channel,out_ch,n_channel*4,head_num))
            in_channel=out_ch
            if i<layers-1:
                up.append(UpSampleBlock(in_channel))
        self.up=nn.ModuleList(up)
        # self.ins=nn.InstanceNorm2d(in_channel)
        self.ins=nn.GroupNorm(8, n_channel)

        # self.act=nn.ELU()
        self.act=Swish()
        self.conv2=nn.Conv2d(in_channel,img_ch,kernel_size=(3,3),padding=(1,1))
    def forward(self,x,t):
        out=self.conv1(x)
        t=self.time(t)
        record=[out]
        for down in self.down:
            out=down(out,t)
            print('out')
            record.append(out)
        out=self.middle(out,t)
        for up in self.up:
            if isinstance(up,UpSampleBlock):
                out=up(out,t)
            else:
                out=torch.cat((out,record.pop()),dim=1)
                out=up(out,t)
        out=self.conv2(self.act(self.ins(out)))
        return out

In [None]:
import torch.nn as nn
import torch
import math
import random

class PatchEmbed(nn.Module):
  def __init__(self,in_ch=4,embed_dim=768,patch_dim=2):
    super().__init__()
    self.embed_dim=embed_dim
    self.proj=nn.Conv2d(in_ch,embed_dim,kernel_size=patch_dim,stride=patch_dim)
  def forward(self,x):
    b,c,h,w=x.shape
    x=self.proj(x)
    x=x.view(b,self.embed_dim,-1).permute(0,2,1)
    return x

class CondEmbedding(nn.Module):
  def __init__(self,num_class,embed_dim):
    super().__init__()
    self.d=embed_dim//2
    self.c=nn.Embedding(num_class,embed_dim)
  def forward(self,t,label):
    embed=torch.arange(self.d)
    embed=torch.exp(-embed/(self.d-1)*math.log(10**4))
    embed=t[:,None]*embed[None,:]
    embed=torch.cat((embed.sin(),embed.cos()),dim=-1)
    c=self.c(label)
    cond_emb=torch.cat((embed[:,None,:],c[:,None,:]),dim=1)
    return cond_emb

class SpatialAttention(nn.Module):
  def __init__(self,in_dim,heads,head_dim,types=''):
    super().__init__()
    self.norm1=nn.LayerNorm(in_dim)
    inner_dim=head_dim*heads
    self.types=types
    self.heads=heads
    self.q=nn.Linear(in_dim,inner_dim)
    self.k=nn.Linear(in_dim,inner_dim)
    self.v=nn.Linear(in_dim,inner_dim)
    self.act=nn.Softmax(dim=-1)
    self.drop1=nn.Dropout()
    self.linear=nn.Linear(inner_dim,inner_dim)
    self.drop2=nn.Dropout()
    self.norm2=nn.LayerNorm(in_dim)
    self.ffn=nn.Sequential(
        nn.Linear(inner_dim,4*inner_dim),
        nn.GELU(),
        nn.Linear(4*inner_dim,inner_dim)
    )
    self.mlp1=nn.Sequential(
        nn.Linear(2 * inner_dim, 3 * inner_dim),
        nn.ReLU(),
        nn.Linear(3 * inner_dim, 3 * inner_dim),
    )
    self.mlp2=nn.Sequential(
        nn.Linear(2 * inner_dim, 3 * inner_dim),
        nn.ReLU(),
        nn.Linear(3 * inner_dim, 3 * inner_dim),
    )

  def forward(self,x,c):
    b,n,embed=x.shape
    alpha=None
    if self.types=='ada':
      gamma,beta,alpha=self.mlp1(c.view(b,-1)).view(b,3,-1).chunk(3,dim=1)
      x=gamma*self.norm1(x)+beta
    else:
      x=self.norm1(x)
    q=self.q(x).view(b,n,self.heads,-1).permute(0,2,1,3)
    if self.types=='crossattn':
      k=self.k(c).view(b,-1,self.heads,q.shape[-1]).permute(0,2,1,3)
      v=self.v(c).view(b,-1,self.heads,q.shape[-1]).permute(0,2,1,3)
    else:
      k=self.k(x).view(b,-1,self.heads,q.shape[-1]).permute(0,2,1,3)
      v=self.v(x).view(b,-1,self.heads,q.shape[-1]).permute(0,2,1,3)
    inter=self.drop1(self.act(q@k.transpose(-1,-2)/math.sqrt(q.shape[-1])))
    o=(inter@v).permute(0,2,1,3).flatten(2)
    o=self.drop2(self.linear(o))
    if self.types=='ada':
      o=self.norm2(alpha*o+x)
      gamma,beta,alpha=self.mlp2(c.view(b,-1)).view(b,3,-1).chunk(3,dim=1)
      o=gamma*o+beta
      o=o+alpha*self.ffn(o)
    else:
      o=self.norm2(o+x)
      o=o+self.ffn(o)
    return o

class DiT(nn.Module):
  def __init__(self,beta_s,beta_e,time_steps,num_class,in_ch=4,layers=12,embed_dim=768,head_nums=12,patch_dim=2,types='',img_size=224):
    super().__init__()
    self.cond_emb=CondEmbedding(num_class,embed_dim)
    self.num_class=num_class
    self.types=types
    self.blocks=nn.ModuleList()
    self.mlp=nn.Sequential(
            nn.Linear(2 * embed_dim, 3 * embed_dim),
            nn.ReLU(),
            nn.Linear(3 * embed_dim, 3 * embed_dim),
        )
    self.patch=PatchEmbed(in_ch,embed_dim,patch_dim)
    for i in range(layers):
      self.blocks.append(SpatialAttention(embed_dim,head_nums,embed_dim//head_nums,self.types))
    self.norm=nn.LayerNorm(embed_dim)
    self.patch_dim=patch_dim
    self.linear = nn.Linear(embed_dim, patch_dim * patch_dim * 2 * in_ch)
    self.apply(self._init_weight)
    # param
    self.time_steps=time_steps
    beta=torch.linspace(beta_s,beta_e,time_steps)
    alpha=1-beta
    self.cum_prod_alpha=np.cumprod(alpha)
    self.cum_prod_alpha_prev=np.append(1,alpha[:-1])
    self.sqrt_cum_prod_alpha=torch.tensor(np.sqrt(self.cum_prod_alpha),dtype=torch.float32)
    self.sqrt_one_cum_prod_alpha=torch.tensor(np.sqrt(1-self.cum_prod_alpha),dtype=torch.float32)
    self.optim=torch.optim.Adam(self.parameters(),lr=3e-4)
    self.posterior_mean_coef1=torch.tensor(beta*np.sqrt(self.cum_prod_alpha_prev)/(1-self.cum_prod_alpha),dtype=torch.float32)
    self.posterior_mean_coef2=torch.tensor((1-self.cum_prod_alpha_prev)*np.sqrt(alpha)/(1-self.cum_prod_alpha),dtype=torch.float32)
    self.sqrt_recip_alphas_cumprod=torch.tensor(np.sqrt(1/self.cum_prod_alpha),dtype=torch.float32)
    self.sqrt_recipm1_alphas_cumprod=torch.tensor(np.sqrt(1/self.cum_prod_alpha-1),dtype=torch.float32)
    self.posterior_variance=beta*(1-self.cum_prod_alpha_prev)/(1-self.cum_prod_alpha)
    self.posterior_log_variance_clipped=torch.tensor(np.maximum(self.posterior_variance,1e-20),dtype=torch.float32)

  def p_sample(self,x,t,c):
    b=x.shape[0]
    noise=torch.randn_like(x)
    model_mean,_,model_log_var=self.p_mean_variance(x,t,c,noise)
    nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
    return model_mean + nonzero_mask * (0.5 * model_log_var).exp() * noise

  def q_posterior(self, x_start, x_t, t):
      posterior_mean = (
              self.extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
              self.extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
      )
      posterior_variance = self.extract_into_tensor(self.posterior_variance, t, x_t.shape)
      posterior_log_variance_clipped = self.extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
      return posterior_mean, posterior_variance, posterior_log_variance_clipped

  def predict_start_from_noise(self, x_t, t, noise):
    print('t',t)
    print('x_t',x_t.shape)
    print('sqrt_recipm1_alphas_cumprod',self.sqrt_recipm1_alphas_cumprod.shape)
    print('noise',noise.shape)
    print('x_t shape',x_t.shape)
    out=(
            self.extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            self.extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
    )

    return out

  def p_mean_variance(self, x, t, c,clip_denoised: bool):
      model_out,_ = self(x, t,c)
      x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
      model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
      return model_mean, posterior_variance, posterior_log_variance

  def extract_into_tensor(self,a, t, x_shape):
    b, *_ = t.shape
    print('extract',t,b)
    t=t.long().view(-1)
    out = a.gather(-1, t)
    print('extract',out)
    print(out.reshape(b, *((1,) * (len(x_shape) - 1))))
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

  def q_sample(self,x,t,noise):
    return self.extract_into_tensor(self.sqrt_cum_prod_alpha,t,x.shape)*x+self.extract_into_tensor(self.sqrt_one_cum_prod_alpha,t,x.shape)*noise

  def p_loss(self,x,t,c,noise):
    x_noise=self.diffusion(x,t,noise)
    pred,_=self(x_noise,t,c)
    loss=torch.nn.functional.mse_loss(pred,noise,reduction='none').mean()
    return loss

  def diffusion(self,x,t,noise):
    bs=x.shape[0]
    x_noise=self.q_sample(x,t,noise)
    return x_noise

  def _init_weight(self,m):
    if isinstance(m,nn.LayerNorm):
      nn.init.constant_(m.bias,0)
      nn.init.constant_(m.weight,1.0)

  def train(self,x,shape,c):
    loss_list=[]
      # loader=self.create_loader()
      # for x in tqdm(loader):
    bs=x.shape[0]
    t=torch.randint(0,self.time_steps,(bs,))
    noise=torch.randn_like(x)
    loss=self.p_loss(x,t,c,noise)
    loss_value = loss.item() if hasattr(loss, "item") else float(loss)
    logging.info("loss:{0}".format(loss_value))
    self.optim.zero_grad()
    loss.backward()
    self.optim.step()


      # img = self.sample_loop(shape)
      # grid=make_grid(img,nrow=shape[0])
      # grid = grid.permute(1, 2, 0).cpu().numpy()
      # # 显示
      # plt.figure(figsize=(12, 6))
      # plt.imshow(grid)
      # plt.axis('off')  # 不显示坐标轴
      # plt.show()

    img=self.sample_loop(shape)
    return img,loss_value


  @torch.no_grad()
  def sample_loop(self, shape):
    img = torch.randn(shape)  # 确保到GPU上
    # c=random.randint(0,self.num_class-1)
    c=1
    for t in reversed(range(self.time_steps)):
      t_batch = torch.full((img.shape[0],), t,dtype=torch.long)
      c_batch = torch.full((img.shape[0],), c,dtype=torch.long)
      img = self.p_sample(img, t_batch,c_batch)
    img = torch.clamp(img, 0, 1)
    return img

  '''
  x:b,c,h,w
  '''
  def forward(self,x,t,c):

    b,ch,h,w=x.shape
    # bs,2,embed
    cond_emb=self.cond_emb(t,c)
    # b,n,embed_dim
    x=self.patch(x)
    b,n,embed_dim=x.shape
    print('patch',x.shape)
    c=None
    if self.types=='in_context':
      for block in self.blocks:
        x=block(torch.cat((x,cond_emb),dim=1),c)
    else:
      for block in self.blocks:
        x=block(x,cond_emb)
    x=self.linear(self.norm(x)).view(b,-1,2*ch,self.patch_dim**2).permute(0,2,1,3).contiguous().view(b,2*ch,-1).view(b,2*ch,h,-1)
    pred_n,cov=x.chunk(2,dim=1)
    return pred_n,cov

In [None]:
from torchvision.utils import make_grid
from torchvision import datasets,transforms
class UpBlock(nn.Module):
    def __init__(self,in_channel,out_channel,time_channel,head_num,has_attention=False):
        super().__init__()
        self.residual=ResidualBlock(in_channel,out_channel,time_channel)
        if has_attention:
            self.attention=Attention(head_num,out_channel)
        else:
            self.attention=nn.Identity()
    def forward(self,img,t):
        out=self.residual(img,t)
        out=self.attention(out)
        return out
def custom_collate_fn(batch):
    images=[x[0] for x in batch]
    return torch.stack(images)
class Encoder(nn.Module):
    def __init__(self,z_ch,embed_dim,ch_mult=(1,2,2,4),img_ch=3,in_channel=64,blocks=2,head_num=8):
        super().__init__()
        down=[]
        layers=len(ch_mult)
        self.conv1=nn.Conv2d(img_ch,in_channel,kernel_size=(3,3),padding=(1,1))
        self.time=TimeEmbedding(in_channel*4)
        n_channel=in_channel
        for i in range(layers):
            out_ch=ch_mult[i]*in_channel
            for _ in range(blocks):
                down.append(DownBlock(in_channel,out_ch,n_channel*4,head_num))
                in_channel=out_ch
            if i<layers-1:
                down.append(DownSampleBlock(out_ch))
        self.down=nn.ModuleList(down)

        self.middle=MiddleBlock(out_ch,out_ch,n_channel*4,head_num)
        self.ins=nn.InstanceNorm2d(in_channel)
        self.act=nn.ELU()
        self.conv2=nn.Conv2d(in_channel,2*z_ch,kernel_size=(3,3),padding=(1,1))
        self.dis=nn.Conv2d(2*z_ch,2*embed_dim,1)
    def sample(self,z):
      mean,log_var=self.dis(z).chunk(2,dim=1)
      std=torch.exp(0.5*log_var)
      x=mean+torch.randn(mean.shape)*std
      return x
    def forward(self,x,t):
        out=self.conv1(x)
        if t is not None:
          t=self.time(t)
        record=[out]
        for down in self.down:
            out=down(out,t)
            print('out')
            record.append(out)
        out=self.middle(out,t)
        out=self.conv2(self.act(self.ins(out)))
        return out

class Decoder(nn.Module):
    def __init__(self,z_ch,ch_multi=(1,2,2,4),img_ch=3,in_channel=64,blocks=2,head_num=8):
        super().__init__()
        layers=len(ch_multi)
        self.conv1=nn.Conv2d(z_ch,in_channel,kernel_size=(3,3),padding=(1,1))
        self.time=TimeEmbedding(in_channel*4)
        n_channel=in_channel
        self.middle=MiddleBlock(in_channel,in_channel,n_channel*4,head_num)
        up=[]
        for i in range(layers):
            out_ch=in_channel
            for _ in range(blocks):
                up.append(UpBlock(in_channel,out_ch,n_channel*4,head_num))
            out_ch=in_channel//ch_multi[layers-1-i]
            up.append(UpBlock(in_channel,out_ch,n_channel*4,head_num))
            in_channel=out_ch
            if i<layers-1:
                up.append(UpSampleBlock(in_channel))
        self.up=nn.ModuleList(up)
        self.ins=nn.InstanceNorm2d(in_channel)
        self.act=nn.ELU()
        self.conv2=nn.Conv2d(in_channel,img_ch,kernel_size=(3,3),padding=(1,1))
    def forward(self,x,t):
        out=self.conv1(x)
        if t is not None:
          t=self.time(t)
        out=self.middle(out,t)
        for up in self.up:
          out=up(out,t)
        out=self.conv2(self.act(self.ins(out)))
        return out

class DiTWrapper(nn.Module):
  def __init__(self,encoder_config,decoder_config,dit_config,img_path,bs):
    super().__init__()
    self.encoder=Encoder(**encoder_config)
    self.decoder=Decoder(**decoder_config)
    self.encoder.requires_grad_(False)
    self.decoder.requires_grad_(False)
    self.DiT=DiT(**dit_config)
    self.bs=bs
    self.img_path=img_path
  def frozen(self,model):
    for param in model.parameters():
      param.requires_grad=False

  def create_loader(self):
    if datasets=='mnist':
      transform=transforms.Compose([
          transforms.ToTensor(),
      ])
      train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
      loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

      # test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    else:
      transform=transforms.Compose([
          transforms.ToTensor(),
          transforms.Resize(280),
          transforms.CenterCrop(256),
      ])
      logging.info('start transform')
      dataset=datasets.ImageFolder(self.img_path,transform=transform)
      logging.info('finish transform')
      bs=self.bs
      # num_workers=self.model_config.run.num_workers
      loader=iter(DataLoader(dataset,batch_size=bs,collate_fn=custom_collate_fn,num_workers=1,pin_memory=True))
    return loader

  def train(self,c,epochs,shape):
    loss_list=[]
    for epoch in tqdm(range(epochs)):
      create_loader=self.create_loader()
      for x in tqdm(create_loader):
        loss=self(x,c,shape)
        loss_list.append(loss)
    plt.figure(figsize=(10, 4))
    plt.plot(loss_list)
    plt.title("Loss Curve")
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.show()

  def forward(self,x,c,shape):
    t=None
    # encoder=self.frozen(self.encoder)
    # print(type(encoder))
    # with torch.no_grad():
    #   e=self.encoder(x,t)
    # z=self.encoder.sample(e)
    # z_rec,loss=self.DiT.train(z,shape,c)
    z_rec,loss=self.DiT.train(x,shape,c)

    # with torch.no_grad():
    #   x_rec=self.decoder(z_rec,t)
    # grid=make_grid(x_rec,nrow=shape[0])
    grid=make_grid(z_rec,nrow=shape[0])
    grid = grid.permute(1, 2, 0).detach().cpu().numpy()
    # 显示
    plt.figure(figsize=(12, 6))
    plt.imshow(grid)
    plt.axis('off')  # 不显示坐标轴
    plt.show()
    return loss

In [None]:
from omegaconf import OmegaConf
img_path='/content/drive/MyDrive/filter/'
conf=OmegaConf.load('/content/drive/MyDrive/DiT.yaml')
print(conf.model.DiT)
model=DiTWrapper(conf.model.Encoder,conf.model.Decoder,conf.model.DiT,img_path,16)
model.train(torch.tensor([1]*16,dtype=torch.long),16,(16,4,32,32))