In [None]:
model_config = {'im_channels' : 3,
                'im_size' : 128,
                'down_channels' : [32, 64, 128, 256, 256],
                'mid_channels' : [256, 256, 256],
                'down_sample' : [True, True, True, False],
                'time_emb_dim' : 256,
                'num_down_layers' : 1,
                'num_mid_layers' : 1,
                'num_up_layers' : 1,
                'num_heads' : 16}

model architecture adapted from https://github.com/explainingai-code/DDPM-Pytorch

In [None]:
class DownBlock(nn.Module):

    def __init__(self, in_channels, out_channels,
                 down_sample=True,  num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.down_sample = down_sample
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1),
                )
                for i in range(num_layers)
            ]
        )

        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers)
            ]
        )

        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers)
            ]
        )
        self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
                                          4, 2, 1) if self.down_sample else nn.Identity()

    def forward(self, x):
        out = x
        for i in range(self.num_layers):

            resnet_input = out
            out = self.resnet_conv_first[i](out)

            out = self.resnet_conv_second[i](out)
            out = out + self.residual_input_conv[i](resnet_input)



        out = self.down_sample_conv(out)
        return out

In [None]:
class MidBlock(nn.Module):
    """
    Mid conv block with attention.
    Sequence of following blocks
    1. Resnet block
    2. Attention block
    3. Resnet block
    """
    def __init__(self, in_channels, out_channels,  num_heads=4, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
                              padding=1),
                )
                for i in range(num_layers+1)
            ]
        )

        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers+1)
            ]
        )

        self.attention_norms = nn.ModuleList(
            [nn.GroupNorm(8, out_channels)
                for _ in range(num_layers)]
        )

        self.attentions = nn.ModuleList(
            [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                for _ in range(num_layers)]
        )
        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers+1)
            ]
        )

    def forward(self, x):
        out = x

        # First resnet block
        resnet_input = out
        out = self.resnet_conv_first[0](out)
        out = self.resnet_conv_second[0](out)
        out = out + self.residual_input_conv[0](resnet_input)

        for i in range(self.num_layers):

            # Attention Block
            batch_size, channels, h, w = out.shape
            in_attn = out.reshape(batch_size, channels, h * w)
            in_attn = self.attention_norms[i](in_attn)
            in_attn = in_attn.transpose(1, 2)
            out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn

            # Resnet Block
            resnet_input = out
            out = self.resnet_conv_first[i+1](out)
            out = self.resnet_conv_second[i+1](out)
            out = out + self.residual_input_conv[i+1](resnet_input)

        return out

In [None]:
class UpBlock(nn.Module):
    r"""
    Up conv block with attention.
    Sequence of following blocks
    1. Upsample
    2. Concatenate Down block output

    """
    def __init__(self, in_channels, out_channels,  up_sample=True, num_layers=1):
        super().__init__()

        self.up_sample = up_sample
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
                              padding=1),
                )
                for i in range(num_layers)
            ]
        )

        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers)
            ]
        )

        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers)
            ]
        )
        self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
                                                 4, 2, 1) \
            if self.up_sample else nn.Identity()

    def forward(self, x, out_down):
        x = self.up_sample_conv(x)
        x = torch.cat([x, out_down], dim=1)

        out = x
        for i in range(self.num_layers):
            resnet_input = out
            out = self.resnet_conv_first[i](out)
            out = self.resnet_conv_second[i](out)
            out = out + self.residual_input_conv[i](resnet_input)



        return out

In [None]:
class Unet(nn.Module):
    r"""
    Unet model comprising
    Down blocks, Midblocks and Uplocks
    """
    def __init__(self, model_config):
        super().__init__()
        im_channels = model_config['im_channels']
        self.down_channels = model_config['down_channels']
        self.mid_channels = model_config['mid_channels']
        self.down_sample = model_config['down_sample']
        self.num_down_layers = model_config['num_down_layers']
        self.num_mid_layers = model_config['num_mid_layers']
        self.num_up_layers = model_config['num_up_layers']


        self.up_sample = list(reversed(self.down_sample))
        self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))

        #Downsampling part
        self.downs = nn.ModuleList([])
        for i in range(len(self.down_channels)-1):
            self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i+1],
                                        down_sample=self.down_sample[i], num_layers=self.num_down_layers))
        #Bottleneck part
        self.mids = nn.ModuleList([])
        for i in range(len(self.mid_channels)-1):
            self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i+1],
                                      num_layers=self.num_mid_layers))
        #upsampling part
        self.ups = nn.ModuleList([])
        for i in reversed(range(len(self.down_channels)-1)):
            self.ups.append(UpBlock(self.down_channels[i] * 2, self.down_channels[i-1] if i != 0 else 16,
                                     up_sample=self.down_sample[i], num_layers=self.num_up_layers))

        self.norm_out = nn.GroupNorm(8, 16)
        self.conv_out = nn.Conv2d(16, im_channels, kernel_size=3, padding=1) # final convolution to match the number of channels =3


    def forward(self, x):
        # self.downs = [32, 64, 128, 256, 256]
        # self.mids = [256, 256, 256]
        # Image = B x 3 x 128 x 128
        out = self.conv_in(x)   #[B x 3 x 128 x 128] --> [B x 32 x 128 x 128]

        down_outs = []

        for idx, down in enumerate(self.downs):
            down_outs.append(out)
            out = down(out)
        #[B x 32 x 128 x 128] --> [B x 64 x 64 x 64] --> [B x 128 x 32 x 32] --> [B x 256 x 16 x 16] --> [B x 256 x 16 x 16]


        for mid in self.mids:
            out = mid(out)

        # out =  [B x 256 x 16 x 16]

        for up in self.ups:
            down_out = down_outs.pop()
            out = up(out, down_out)

        #[B x 256 x 16 x 16] --> [B x 128 x 16 x 16] --> [B x 64 x 32 x 32] --> [B x 32 x 64 x 64] --> [B x 16 x 128 x 128]
        out = self.norm_out(out)
        out = nn.SiLU()(out)
        out = self.conv_out(out)   #[B x 16 x 128 x 128] --> [B x 3 x 128 x 128]
        return out

In [None]:
NCSN = Unet(model_config).to(device)

In [None]:
if dataset == "butterfly_dataset":
  sigmas = [150*(0.9)**i for i in range(90)]
  sigmas = torch.tensor(sigmas).to(device)
else:
  sigmas = [450*(0.95)**i for i in range(200)]
  sigmas = torch.tensor(sigmas).to(device)


In [None]:
def loss_fn(model, x, sigmas=sigmas):
  """The loss function for training Noise conditioned score networks.
  Args:
    model: An instance of Score model.
    x: A mini-batch of training data.
    sigmas: a tensor of shape [Number of noise levels,1] containing all the noise levels.

  """
  # sampling uniformly a batch of noise levels.
  random_indices = torch.randint(0,len(sigmas),(x.shape[0],)).to(device)
  random_sigmas = sigmas[random_indices].unsqueeze(dim=-1).unsqueeze(dim=-1).unsqueeze(dim=-1)

  z = torch.randn_like(x)

  perturbed_x = x + z * random_sigmas
  score = model(perturbed_x)/random_sigmas #normalizing the score by dividing by noise level.
  loss = torch.mean(torch.sum((score * random_sigmas + z)**2, dim=(1,2,3)))
  return loss

In [None]:
n_epochs = 100
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-5)
dataloader = torch.utils.data.DataLoader(img_dataset, batch_size = 64, shuffle=True, num_workers  =4)
for epoch in range(n_epochs):
    avg_loss = 0
    for X,_ in tqdm(dataloader):
        X = X.to(device)
        optimizer.zero_grad()
        loss  = loss_fn(model, X)
        loss.backward()
        optimizer.step()
        avg_loss += loss.item()
    avg_loss = loss/len(dataloader)
    print(f'epoch {epoch}: loss:{avg_loss}')
save_dir = 'q6_ncsn'
os.makedirs(save_dir, exist_ok=True)
torch.save(model.state_dict(),os.path.join(save_dir, 'model-weights.pth'))

In [None]:
## annealed lageavin dynamics for sampling
def ALD(model, sigmas=sigmas, num_samples=100, num_timesteps=100,eps=1e-6, denoise=False):
  """The sampling algorithm for  generating new images.
  Args:
    model: A trained  Score model.
    sigmas: a tensor of shape [Number of noise levels,1] containing all the noise levels.
    num_timesteps: number of time steps to take at each noise level.
    denoise: whether to apply the last setp denoising
  """
    save_dir = 'generated_images'
    os.makedirs(save_dir, exist_ok=True)
    model.eval()
    with torch.no_grad():
        xt = torch.randn(num_samples,3,128,128).to(device)
        alphas = [eps*((sigmas[i])**2)/((sigmas[-1])**2) for i in range(len(sigmas))]
        for i in tqdm(range(len(sigmas))):
            for t in range(num_timesteps):
                zt = torch.randn_like(xt).to(device)
                xt = xt + alphas[i]*model(xt)/sigmas[i] + torch.sqrt(2*alphas[i])*zt
            image = torch.clamp(xt,-1,1)    # bring the image to [-1,1] range.
            image = (image+1)/2
            image = torchvision.utils.make_grid(image, nrow=10)
            image = image.permute(1,2,0).detach().cpu().numpy()
            plt.figure(figsize=(10,10))
            plt.savefig(os.path.join(save_dir, f'sigma_step{sigmas[i]}.png'))
            plt.close()
    if denoise  == True
      return xt + model(xt)*sigmas[-1]
    else:
      return xt

In [None]:
#display final generated image
image = ALD(model=NCSN)
image = torch.clamp(image,-1,1)
image = (image+1)/2
image = torchvision.utils.make_grid(image, nrow=10)
image = image.permute(1,2,0).detach().cpu().numpy()
plt.figure(figsize=(10,10))
plt.imshow(image)

Fid calculation

In [None]:
from torchvision.models import inception_v3
preprocess = transforms.Compose([       # Define preprocessing steps for Inception v3 input
    transforms.Resize((299, 299)),  # Resize images to Inception v3 input size
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize with ImageNet stats ie we will also preprocess the data in same way as the original imagenet dataset that was used to train Inception v3
])

def prepare_inception_input(images):

    """Apply preprocessing to the input images"""
    return preprocess(images)

In [None]:
def load_inception_model():
    """Load inception pretrained v3 model for extracting features"""
    model = inception_v3(pretrained=True, transform_input=False)  # Load pre-trained Inception v3 model without input transformation
    model.fc = torch.nn.Identity()  # Removes the last fully connected layer
    model.eval()  # Set the model to evaluation mode
    return model.to(device)  # Move the model to the same device as our GAN

inception_model = load_inception_model()  # Load and prepare the modified Inception v3 model

In [None]:
def extract_features(images):
    """returns features of size [2048] extracted from the inception model"""
    with torch.no_grad():
        features = inception_model(images)
    return features  # This will be a tensor of shape [batch_size, 2048]

In [None]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=100, num_workers=4)

In [None]:
def extract_real_features(dataloader, num_images=1000):
    """return features of real images extracted from the inception model"""
    real_features = []  # List to store features of real images
    image_count = 0  # Counter for processed images

    with torch.no_grad():  # Disable gradient computation for efficiency
        for batch in dataloader:
            images = batch[0].to(device)  # Move batch of images to the device (assuming batch[0] contains images)
            batch_size = images.size(0)  # Get the current batch size
            if image_count + batch_size > num_images:
                # If adding this batch would exceed num_images, only take what's needed
                images = images[:num_images - image_count]
            preprocessed_images = prepare_inception_input(images)  # Preprocess images for Inception v3
            batch_features = extract_features(preprocessed_images)  # Extract features using Inception v3
            real_features.append(batch_features)  # Add batch features to the list
            image_count += batch_features.size(0)  # Update the count of processed images
            if image_count >= num_images:
                break  # Stop if we've processed enough images

    real_features_tensor = torch.cat(real_features, dim=0)  # Concatenate all features into a single tensor
    return real_features_tensor[:num_images]  # Return exactly num_images features

# Extract features from 1000 real images
real_features = extract_real_features(dataloader, num_images=1000)  # Extract features from 1000 real images

In [None]:
def generate_and_extract_features(sampler , num_samples=1000, batch_size=100):
    """return features of fake images extracted from the inception model"""
    generated_features = []  # List to store features of generated images
    num_batches = (num_samples + batch_size - 1) // batch_size  # Calculate number of batches needed
    # Set generator to evaluation mode
    print(f"Starting feature extraction for {num_samples} samples with batch size {batch_size}")  # Print start of extraction process
    with torch.no_grad():  # Disable gradient computation for efficiency
        for batch_idx in range(num_batches):
            current_batch_size = min(batch_size, num_samples - len(generated_features))  # Adjust batch size for last batch if needed
            print(f"Processing batch {batch_idx + 1}/{num_batches} with size {current_batch_size}")  # Print current batch information
            fake_images = sampler # Generate fake images
            preprocessed_images = prepare_inception_input(fake_images)  # Preprocess images for Inception v3
            batch_features = extract_features(preprocessed_images)  # Extract features using Inception v3
            generated_features.append(batch_features)  # Add batch features to the list
            print(f"Extracted features shape: {batch_features.shape}")  # Print shape of extracted features
            if len(generated_features) * batch_size >= num_samples:
                print(f"Reached target number of samples. Stopping extraction.")  # Print when target samples reached
                break  # Stop if we've generated enough samples
    print(f"Feature extraction completed. Total features extracted: {len(generated_features) * batch_size}")  # Print completion of extraction process

    generated_features_tensor = torch.cat(generated_features, dim=0)  # Concatenate all features into a single tensor
    return generated_features_tensor[:num_samples]  # Return exactly num_samples features

# Generate 1000 samples and extract their features
generated_features = generate_and_extract_features(sampler = ALD(model=NCSN))  # Generate and extract features from 1000 fake images

In [None]:
# Calculate mean and covariance of real features
real_mean = torch.mean(real_features, dim=0)  # Calculate mean across all samples for each feature
real_cov = torch.cov(real_features.T)  # Calculate covariance matrix of features
# Calculate mean and covariance of generated features
generated_mean = torch.mean(generated_features, dim=0)  # Calculate mean across all samples for each feature
generated_cov = torch.cov(generated_features.T)  # Calculate covariance matrix of features

In [None]:
def calculate_frechet_inception_distance(real_mean, real_cov, generated_mean, generated_cov):
    """
    Calculate the Fréchet Inception Distance (FID) between real and generated image features.
    Args:
    real_mean (torch.Tensor): Mean of real image features.
    real_cov (torch.Tensor): Covariance matrix of real image features.
    generated_mean (torch.Tensor): Mean of generated image features.
    generated_cov (torch.Tensor): Covariance matrix of generated image features.
    Returns:
    float: The calculated FID score.

    """
    # Convert to numpy for scipy operations
    real_mean_np = real_mean.cpu().numpy()  # Convert real mean to numpy array
    real_cov_np = real_cov.cpu().numpy()  # Convert real covariance to numpy array
    generated_mean_np = generated_mean.cpu().numpy()  # Convert generated mean to numpy array
    generated_cov_np = generated_cov.cpu().numpy()  # Convert generated covariance to numpy array

    # Calculate squared L2 norm between means
    mean_diff = np.sum((real_mean_np - generated_mean_np) ** 2)  # Compute squared difference between means
    # Calculate sqrt of product of covariances
    covmean = scipy.linalg.sqrtm(real_cov_np.dot(generated_cov_np))  # Compute matrix square root
    # Check and correct imaginary parts if necessary
    if np.iscomplexobj(covmean):
        covmean = covmean.real  # Take only the real part if result is complex

    # Calculate trace term
    trace_term = np.trace(real_cov_np + generated_cov_np - 2 * covmean)  # Compute trace of the difference



    # Compute FID
    fid = mean_diff + trace_term  # Sum up mean difference and trace term
    return fid  # Return FID as a Python float


# Calculate FID for animal dataset using the above function
fid_score = calculate_frechet_inception_distance(real_mean, real_cov, generated_mean, generated_cov)  # Compute FID score
print(f"Fréchet Inception Distance (FID): {fid_score:.4f}")  # Print the calculated FID score

 **NCSN trained on latent space of a trained vqvae**

In [None]:
# Defining hyperparameters
dimension_of_codebook_vectors = 128 # Dimension of the codebook vectors ie D
number_of_codebook_vectors = 1024 # Number of codebook vectors ie K
Commitment_cost = 1
num_epochs = 50  # Number of training epochs, determines how long the model will train
batch_size = 64  # Number of samples per gradient update, affects training speed and stability
learning_rate = 0.001  # Learning rate for the optimizer, controls the step size during optimization

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1),  # Input: N x 3 x 128 x 128, Output: N x 64 x 64 x 64, 64 filters of size 4x4x3
            nn.LeakyReLU(0.2),  # Applies Leaky ReLU activation function with negative slope of 0.2
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),  # Input: N x 64 x 64 x 64, Output: N x 128 x 32 x 32, 128 filters of size 4x4x64
            nn.BatchNorm2d(128),  # Applies Batch Normalization to the output of the previous layer
            nn.LeakyReLU(0.2),  # Applies Leaky ReLU activation function with negative slope of 0.2
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),  # Input: N x 128 x 32 x 32, Output: N x 256 x 16 x 16, 256 filters of size 4x4x128
            nn.BatchNorm2d(256),  # Applies Batch Normalization to the output of the previous layer
            nn.LeakyReLU(0.2),  # Applies Leaky ReLU activation function with negative slope of 0.2
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),  # Input: N x 256 x 16 x 16, Output: N x 512 x 8 x 8, 512 filters of size 4x4x256
            nn.BatchNorm2d(512),  # Applies Batch Normalization to the output of the previous layer
            nn.LeakyReLU(0.2),  # Applies Leaky ReLU activation function with negative slope of 0.2
        )

        self.final_conv = nn.Conv2d(in_channels=512, out_channels=latent_dim, kernel_size=3, stride=1, padding=1)  # Input: N x 512 x 8 x 8, Output: N x latent_dim x 8 x 8, latent_dim filters of size 3x3x512

    def forward(self, x):
        encoded = self.encoder(x)  # Apply main encoder layers
        latents = self.final_conv(encoded)  # Generate latent vectors
        return latents

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()

        self.embedding_dim = embedding_dim  # Dimension of each embedding vector
        self.num_embeddings = num_embeddings  # Number of embedding vectors in the codebook
        self.commitment_cost = commitment_cost  # Coefficient for the commitment loss

        # Initialize the embedding vectors (codebook)
        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)  # Creates an embedding layer to store the codebook
        self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)  # Initialize embedding weights uniformly

    def forward(self, inputs):
        # Convert inputs from BCHW to BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()  # Rearrange dimensions from BCHW to BHWC

        input_shape = inputs.shape  # Store original input shape

        # Reshape inputs to (batch_size * height * width, channels)
        flat_input = inputs.view(-1, self.embedding_dim)  # Flatten input to 2D tensor

        # Compute L2 distances between flattened input and embedding vectors
        distances = torch.sum(flat_input**2, dim=1, keepdim=True) + \
                    torch.sum(self.embedding.weight**2, dim=1) - \
                    2 * torch.matmul(flat_input, self.embedding.weight.t())  # Calculate distances using the formula: ||x-y||^2 = ||x||^2 + ||y||^2 - 2x^T y

        # Find nearest embedding for each input vector
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)  # Find index of nearest embedding for each input vector
        encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device)  # Create one-hot encodings
        encodings.scatter_(1, encoding_indices, 1)  # Set the corresponding index to 1 for each encoding

        # Quantize the input vectors
        quantized = torch.matmul(encodings, self.embedding.weight).view(input_shape)  # Multiply encodings with embedding weights and reshape to original input shape

        # Compute the VQ Losses
        commitment_loss = F.mse_loss(quantized.detach(), inputs)  # Commitment loss: how far are the inputs from their quantized values
        embedding_loss = F.mse_loss(quantized, inputs.detach())  # Embedding loss: how far are the quantized values from the inputs
        vq_loss = embedding_loss + self.commitment_cost * commitment_loss  # Total VQ loss

        # Straight-through estimator
        quantized = inputs + (quantized - inputs).detach()  # Add quantization error to input (detached to avoid backpropagation through this path)

        # Convert quantized from BHWC back to BCHW
        quantized = quantized.permute(0, 3, 1, 2).contiguous()  # Rearrange dimensions from BHWC back to BCHW

        # Compute perplexity
        avg_probs = torch.mean(encodings, dim=0)  # Average probability of each encoding across the batch
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))  # Compute perplexity (add small epsilon to avoid log(0))

        return vq_loss, quantized, perplexity, encodings  # Return VQ loss, quantized vectors, perplexity, and encodings

In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim  # Store the latent dimension for use in the forward pass

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 64, kernel_size=4, stride=2, padding=1),  # Input: N x latent_dim x 8 x 8, Output: N x 64 x 16 x 16, 64 filters of size 4x4xlatent_dim
            nn.ReLU(),  # Apply ReLU activation to introduce non-linearity
            nn.BatchNorm2d(64),  # Normalize the output to stabilize training

            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # Input: N x 64 x 16 x 16, Output: N x 32 x 32 x 32, 32 filters of size 4x4x64
            nn.ReLU(),  # Apply ReLU activation to introduce non-linearity
            nn.BatchNorm2d(32),  # Normalize the output to stabilize training

            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),  # Input: N x 32 x 32 x 32, Output: N x 16 x 64 x 64, 16 filters of size 4x4x32
            nn.ReLU(),  # Apply ReLU activation to introduce non-linearity
            nn.BatchNorm2d(16),  # Normalize the output to stabilize training

            nn.ConvTranspose2d(16, 3, kernel_size=4, stride=2, padding=1),  # Input: N x 16 x 64 x 64, Output: N x 3 x 128 x 128, 3 filters of size 4x4x16
            nn.Sigmoid()  # Apply Sigmoid activation to ensure output is in range [0, 1]
        )

    def forward(self, x):
        decoded_image = self.decoder(x)  # Pass the input through the decoder layers
        return decoded_image  # Return the decoded RGB image

In [None]:
class VQVAE(nn.Module):
    def __init__(self, embedding_dim, num_embeddings=512, commitment_cost=0.25):
        super(VQVAE, self).__init__()

        self.encoder = Encoder(embedding_dim)  # Initialize the encoder with the embedding dimension as latent dimension
        self.vq_layer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)  # Initialize the Vector Quantizer with default or provided values
        self.decoder = Decoder(embedding_dim)  # Initialize the decoder with the embedding dimension

    def forward(self, x):
        z = self.encoder(x)  # Encode the input
        vq_loss, quantized, perplexity, _ = self.vq_layer(z)  # Apply Vector Quantization
        x_recon = self.decoder(quantized)  # Decode the quantized representation

        return vq_loss, x_recon, perplexity  # Return VQ loss, reconstructed image, and perplexity

In [None]:
vqvae =  VQVAE(embedding_dim=embedding_dim, num_embeddings=num_embeddings)   # Initializing the VQVAE model.
vqvae.load_state_dict(torch.load('/kaggle/input/vqvaemodel/pytorch/default/1/vqvae_model_epoch_32.pth', weights_only=True))  # loading the trained weights of the vqvae model.

In [None]:
model_config = {'im_channels' : 128,
                'im_size' : 8,
                'down_channels' : [128, 128, 128, 256, 256],
                'mid_channels' : [256, 256, 256],
                'down_sample' : [True, True, False, False],
                'time_emb_dim' : 256,
                'num_down_layers' : 1,
                'num_mid_layers' : 1,
                'num_up_layers' : 1,
                'num_heads' : 16}

latent_unet = Unet(model_config).to(device)  #Initializing a Unet model for training on the latent space of vqvae.

In [None]:
# Noise levels for training ncsn on latent space of vqvae.
sigmas = [100*(0.9)**i for i in range(80)]
sigmas = torch.tensor(sigmas).to(device)

In [None]:
def loss_fn(model, x, sigmas=sigmas, eps=1e-5):
  """The loss function for training Noise conditioned score networks.
  Args:
    model: An instance of Score model.
    x: A mini-batch of training data.
    sigmas: a tensor of shape [Number of noise levels,1] containing all the noise levels.
  """
  random_indices = torch.randint(0,len(sigmas),(x.shape[0],)).to(device)
  random_sigmas = sigmas[random_indices].unsqueeze(dim=-1).unsqueeze(dim=-1).unsqueeze(dim=-1)
  z = torch.randn_like(x)

  perturbed_x = x + z * random_sigmas
  score = model(perturbed_x)/random_sigmas
  loss = torch.mean(torch.sum((score * random_sigmas + z)**2, dim=(1,2,3)))
  return loss

In [None]:
n_epochs = 100
model = latent_unet.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
dataloader = torch.utils.data.DataLoader(butterfly_dataset, batch_size = 64, shuffle=True, num_workers  =4)
for epoch in range(n_epochs):
    avg_loss = 0
    for X,_ in tqdm(dataloader):
        X = X.to(device)
        with torch.no_grad():
            X = vqvae.encoder(X)
            _,X, _, _ = self.vq_layer(X)
        optimizer.zero_grad()
        loss  = loss_fn(model, X)
        loss.backward()
        optimizer.step()
        avg_loss += loss.item()
    avg_loss = loss/len(dataloader)
    print(f'epoch {epoch}: loss:{avg_loss}')
save_dir = 'vqvae_ncsn'
os.makedirs(save_dir, exist_ok=True)
torch.save(model.state_dict(),os.path.join(save_dir, 'model-weights.pth'))

In [None]:
def ALD_vqvae(model, sigmas=sigmas, num_samples=100, num_timesteps=100,eps=1e-4):
  """The sampling algorithm for  generating new images from latent space of vqvae.
  Args:
    model: A trained  Score model.
    sigmas: a tensor of shape [Number of noise levels,1] containing all the noise levels.
    num_timesteps: number of time steps to take at each noise level.
    denoise: whether to apply the last setp denoising
  """
    model.eval()
    with torch.no_grad():
        xt = torch.randn(num_samples,128,8,8).to(device)

        alphas = [eps*((sigmas[i])**2)/((sigmas[-1])**2) for i in range(len(sigmas))]
        for i in tqdm(range(len(sigmas))):
            for t in range(num_timesteps):
                zt = torch.randn_like(xt).to(device)
                xt = xt + alphas[i]*model(xt)/sigmas[i] + torch.sqrt(2*alphas[i])*zt

        image = xt + model(xt)*sigmas[-1]
        image = vqvae.decoder(image)
        return image



In [None]:
#plot generated images
image = ALD_vqvae(model)
image = (image+1)/2
image = torchvision.utils.make_grid(image, nrow=10)
image = image.permute(1,2,0).detach().cpu().numpy()
plt.figure(figsize=(10,10))
plt.imshow(image)
