# 单步扩散模型

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# Example usage
input_channels = 1
latent_dim = 128
output_channels = 1
IMG_SIZE = 64
SEQ_LEN = 10
from convlstm import ConvLSTMCell
import numpy as np
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(UNet, self).__init__()
        self.in_channels = input_channels
        self.out_channels = output_channels

        self.enc1 = DoubleConv(input_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)
        self.enc5 = DoubleConv(512, 1024)

        self.pool = nn.MaxPool2d(2)

        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(128, 64)

        self.final_conv = nn.Conv2d(64, output_channels, kernel_size=1)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        enc5 = self.enc5(self.pool(enc4))

        dec4 = self.up4(enc5)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.dec4(dec4)
        dec3 = self.up3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.dec3(dec3)
        dec2 = self.up2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.dec2(dec2)
        dec1 = self.up1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.dec1(dec1)

        return self.final_conv(dec1)

class DiffusionModel(nn.Module):
    def __init__(self, input_channels, latent_dim, output_channels):
        super(DiffusionModel, self).__init__()
        # parameters:
        self.input_channels = input_channels
        self.latent_dim = latent_dim
        self.output_channels = output_channels
        self.z_w = int(np.sqrt(latent_dim // 2))

        # Constants
        num_hiddens = 128

        # prediction encoder:
        self._convlstm = ConvLSTMCell(input_dim=self.input_channels,
                                      hidden_dim=num_hiddens // 4,
                                      kernel_size=(3, 3),
                                      bias=True)

        # UNet-based noise predictor and decoder
        self._noise_predictor = UNet(input_channels=(num_hiddens // 4) + 1, output_channels=num_hiddens // 4)
        self._decoder = UNet(input_channels=num_hiddens // 4, output_channels=self.output_channels)

    def forward(self, x, t, noise=None):
        """
        Forward pass input_img through the network
        """
        # reconstruction:
        # encode:
        # input reshape:
        x = x.reshape(-1, SEQ_LEN, 1, IMG_SIZE, IMG_SIZE)
        # find size of different input dimensions
        b, seq_len, c, h, w = x.size()
        # llc: b = batch size, seq_len = sequence length, c = channel, h = height, w = width

        # encode:
        # initialize hidden states
        h_enc, enc_state = self._convlstm.init_hidden(batch_size=b, image_size=(h, w))
        for t_step in range(seq_len):
            x_in = x[:, t_step]
            h_enc, enc_state = self._convlstm(input_tensor=x_in,
                                              cur_state=[h_enc, enc_state])
        # llc: this is output of the lstm, which is the input to the encoder
        enc_in = h_enc

        # add noise
        if noise is None:
            noise = torch.randn_like(enc_in)
        z_noisy = enc_in + noise

        # prepare time step encoding
        t = t.view(b, 1, 1, 1).repeat(1, 1, h, w)  # Repeat the time step for concatenation
        z_noisy = torch.cat([z_noisy, t], dim=1)

        # predict noise
        z_predicted = self._noise_predictor(z_noisy)

        # denoise
        z_denoised = enc_in - z_predicted

        # decode:
        prediction = self._decoder(z_denoised)
        prediction = torch.sigmoid(prediction)

        return prediction

# Example usage
input_channels = 1
latent_dim = 128
output_channels = 1
IMG_SIZE = 64
SEQ_LEN = 10

model = DiffusionModel(input_channels, latent_dim, output_channels)
input_binary_maps = torch.randn(128, SEQ_LEN, 1, IMG_SIZE, IMG_SIZE)
t = torch.randint(0, 1000, (128,))  # Random time steps
output = model(input_binary_maps, t)
print(output.size())  # Should print torch.Size([128, 1, 64, 64])


torch.Size([128, 1, 64, 64])


# 多步扩散模型

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# Example usage
input_channels = 1
latent_dim = 128
output_channels = 1
IMG_SIZE = 64
SEQ_LEN = 10
from convlstm import ConvLSTMCell
import numpy as np
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(UNet, self).__init__()
        self.in_channels = input_channels
        self.out_channels = output_channels

        self.enc1 = DoubleConv(input_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)
        self.enc5 = DoubleConv(512, 1024)

        self.pool = nn.MaxPool2d(2)

        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(128, 64)

        self.final_conv = nn.Conv2d(64, output_channels, kernel_size=1)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        enc5 = self.enc5(self.pool(enc4))

        dec4 = self.up4(enc5)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.dec4(dec4)
        dec3 = self.up3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.dec3(dec3)
        dec2 = self.up2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.dec2(dec2)
        dec1 = self.up1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.dec1(dec1)

        return self.final_conv(dec1)

class DiffusionModel(nn.Module):
    def __init__(self, input_channels, latent_dim, output_channels, num_steps):
        super(DiffusionModel, self).__init__()
        # parameters:
        self.input_channels = input_channels
        self.latent_dim = latent_dim
        self.output_channels = output_channels
        self.z_w = int(np.sqrt(latent_dim // 2))
        self.num_steps = num_steps  
        # Constants
        num_hiddens = 128

        # prediction encoder:
        self._convlstm = ConvLSTMCell(input_dim=self.input_channels,
                                      hidden_dim=num_hiddens // 4,
                                      kernel_size=(3, 3),
                                      bias=True)

        # UNet-based noise predictor and decoder
        self._noise_predictor = UNet(input_channels=(num_hiddens // 4) + 1, output_channels=num_hiddens // 4)
        self._decoder = UNet(input_channels=num_hiddens // 4, output_channels=self.output_channels)

    def forward(self, x, timesteps, noise=None):
        x = x.reshape(-1, SEQ_LEN, 1, IMG_SIZE, IMG_SIZE)
        b, seq_len, c, h, w = x.size()

        h_enc, enc_state = self._convlstm.init_hidden(batch_size=b, image_size=(h, w))
        for t_step in range(seq_len):
            x_in = x[:, t_step]
            h_enc, enc_state = self._convlstm(input_tensor=x_in, cur_state=[h_enc, enc_state])
        enc_in = h_enc

        if noise is None:
            noise = torch.randn_like(enc_in)
        z_noisy = enc_in + noise

        for t in range(self.num_steps):
            t_tensor = torch.full((b, 1, h, w), t, dtype=torch.float32).to(x.device)
            z_noisy = torch.cat([z_noisy, t_tensor], dim=1)
            z_predicted = self._noise_predictor(z_noisy)
            z_noisy = z_noisy[:, :-1, :, :]  # Remove time channel before updating z_noisy
            z_noisy = z_noisy - z_predicted

        prediction = self._decoder(z_noisy)
        prediction = torch.sigmoid(prediction)
        return prediction

# Example usage
input_channels = 1
latent_dim = 128
output_channels = 1
IMG_SIZE = 64
SEQ_LEN = 10

model = DiffusionModel(input_channels, latent_dim, output_channels,3)
input_binary_maps = torch.randn(128, SEQ_LEN, 1, IMG_SIZE, IMG_SIZE)
t = torch.randint(0, 1000, (128,))  # Random time steps
output = model(input_binary_maps, t)
print(output.size())  # Should print torch.Size([128, 1, 64, 64])


torch.Size([128, 1, 64, 64])
