In [1]:
# Import necessary PyTorch libraries
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms


# Additional libraries for visualization and utilities
import matplotlib.pyplot as plt


import numpy as np
from unet import UNet
from unet_decoder import UNetDecoder
from echo import echo_sample, echo_loss

In [2]:
def get_device():
    """Selects the best available device for PyTorch computations.

    Returns:
        torch.device: The selected device.
    """

    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')

device = get_device()
print(f"using device: {device}")

using device: mps


In [3]:
from torchvision import datasets
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import Compose, Normalize, ToTensor,Resize

transform = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,))
  # Normalize with MNIST mean and std
])


# Load the CelebA dataset
dataset = datasets.MNIST(root='./data', download=True, transform=transform)

# Print the total number of images in the dataset
print(f"Total number of images in the dataset: {len(dataset)}")

# Splitting dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Print the number of images in the train and validation sets
print(f"Number of images in the training set: {len(train_dataset)}")
print(f"Number of images in the validation set: {len(val_dataset)}")

# Create DataLoader instances
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

Total number of images in the dataset: 60000
Number of images in the training set: 48000
Number of images in the validation set: 12000


In [4]:
class Encoder(nn.Module):
    def __init__(self, input_shape, latent_dims):
        super(Encoder, self).__init__()
        self.input_shape = input_shape
        self.latent_dims = latent_dims

        self.unet = UNet(
            n_channels=input_shape[0],
            n_classes=input_shape[0],  # Ensure the output channels match the input channels
            bilinear=True
        )

        # Output layers for mean and log variance
        self.out_mean = nn.Conv2d(input_shape[0], input_shape[0], kernel_size=1)
        self.out_log_var = nn.Conv2d(input_shape[0], input_shape[0], kernel_size=1)

    def forward(self, x):
        x = self.unet(x)
        f_x = torch.tanh(self.out_mean(x))
        log_var = torch.sigmoid(self.out_log_var(x))
        return f_x, log_var


In [5]:
class ColdDiffusionModel(nn.Module):
    def __init__(self, encoder, input_shape, T=1000):
        super(ColdDiffusionModel, self).__init__()
        self.encoder = encoder
        self.input_shape = input_shape
        self.T = T
        self.decoder = UNetDecoder(n_channels=input_shape[0])

        # Define the noise schedule
        self.alpha = self.create_noise_schedule(T)

    def create_noise_schedule(self, T):
        beta_start = 0.0001
        beta_end = 0.02
        betas = torch.linspace(beta_start, beta_end, T)
        alphas = 1 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        return alphas_cumprod

    def forward(self, x):

        # Calculate the gaussian noise tensor
        batch_size = x.shape[0]
        f_x, sx_matrix = self.encoder(x)
        epsilon = echo_sample((f_x, sx_matrix)).detach()
        z = f_x + sx_matrix * epsilon
        
        del epsilon 
        
        torch.cuda.empty_cache()

        #sample a timestep t
        t = np.random.randint(0, self.T)
        # Retrieve noise scheduler alpha_T
        alpha_t = self.alpha[t]

        # Calculate square root alphas
        sqrt_alpha_t = torch.sqrt(alpha_t)
        sqrt_one_minus_alpha_t = torch.sqrt(1 - alpha_t)
        
        # Perform the weighted sum
        x_t = sqrt_alpha_t * x + sqrt_one_minus_alpha_t * z

        torch.cuda.empty_cache()

        #Calculate the timestep tensor
        t = torch.tensor([t] * x_t.size(0), dtype=torch.long).to(x_t.device)

        # Perform the reconstruction process 
        estimated_image = self.decoder(x_t,t)
        torch.cuda.empty_cache()
        return estimated_image

In [6]:
def freeze_module(module):
    for param in module.parameters():
        param.requires_grad = False

def unfreeze_module(module):
    for param in module.parameters():
        param.requires_grad = True

In [10]:
import os
import torch

def save_checkpoint(epoch, model, optimizer, filename="checkpoint.pth"):
    """Saves the model and optimizer state at the specified path."""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, filename)
    print(f"Checkpoint saved at epoch {epoch} to {filename}")

def load_checkpoint(model, optimizer, filename="checkpoint.pth", device='cpu'):
    """Loads the model and optimizer state from the specified path."""
    if os.path.isfile(filename):
        checkpoint = torch.load(filename, map_location=torch.device(device))
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        print(f"Checkpoint loaded from {filename}, resuming training from epoch {epoch}")
        return epoch
    else:
        print(f"No checkpoint found at {filename}, starting from scratch.")
        return -1

In [8]:
import time  # Importing time to log the duration
from tqdm import tqdm


def validate(model, val_loader, device):
    model.eval()  # Set the model to evaluation mode
    total_val_loss = 0
    with torch.no_grad():  # Disable gradient computation during validation
        for data, _ in val_loader:
            data = data.to(device)
            # fx, sx_matrix = model(data)
            # mi_loss = echo_loss(sx_matrix)
            # total_val_loss += mi_loss
            estimated_image = model(data)
            reconstruction_loss = nn.functional.l1_loss(data, estimated_image)
            total_val_loss += reconstruction_loss.item()  # Accumulate the validation loss
    
    avg_val_loss = total_val_loss / len(val_loader)  # Calculate average loss
    return avg_val_loss

def train(model, optimizer, train_loader, device,start_epoch, num_epochs, filename):
    model.train()
    freeze_module(model.encoder)

    for epoch in range(start_epoch+1, num_epochs):
        epoch_loss = 0.0
        epoch_start_time = time.time()  # Time tracking for the epoch

        print(f"Starting epoch {epoch+1}/{num_epochs}")
        for batch_idx, (data, _) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)):
            data = data.to(device)

            # Forward pass
            estimated_image = model(data)
            total_loss = torch.nn.functional.l1_loss(data, estimated_image)

            # Backward pass
            if not torch.isnan(total_loss).any():
                total_loss.backward()
            else:
                pass
                print(f"Warning: NaN detected in total_loss at batch {batch_idx+1}, skipping backward pass.")

            optimizer.step()  # Only step the optimizer every `accumulation_steps`
            optimizer.zero_grad()  # Reset gradients only after accumulation

            # Safe-guarding against NaN for epoch_loss
            if not torch.isnan(total_loss).any():
                epoch_loss += total_loss.item()
            else:
                epoch_loss += 0.0
                print(f"NaN detected, not adding to epoch_loss at batch {batch_idx+1}")

        # Save the model checkpoint
        save_checkpoint(epoch, model, optimizer, filename)
        
        # Average loss after training for an epoch
        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}] completed in {time.time() - epoch_start_time:.2f} seconds, Avg Loss: {avg_loss}")

        # Validation phase
        avg_val_loss = validate(model, val_loader, device)
        print(f"Epoch [{epoch+1}/{num_epochs}] validation completed, Avg Validation Loss: {avg_val_loss}")
        
    unfreeze_module(model.encoder)
    return model


## Train model to minimize mi loss

In [11]:
# Define the input shape
input_shape = (1, 28, 28)
latent_dims = latent_dims = [64, 128, 256, 512]

# Create the encoder model
encoder = Encoder(input_shape, latent_dims).to(device)

# Define the optimizer
optimizer = optim.Adam(encoder.parameters(), lr=1e-3)

# Define the filename
filename = "encoder.pth"

# Load the model checkpoint
start_epoch = load_checkpoint(encoder, optimizer, filename)
print(f"The training ended in epoch number: {start_epoch+1}")

Checkpoint loaded from encoder.pth, resuming training from epoch 38
The training ended in epoch number: 39


In [12]:
# Create the Diffuion Model
model = ColdDiffusionModel(encoder, input_shape).to(device)

# Define the optimizer 
optimizer = optim.Adam(model.parameters(), lr = 1e-3)

# Define the number of epochs and loss weights
num_epochs = 50

# Filename
filename = "mnist_echo_cold_l1.pth"

# Load the model training checkpoint
start_epoch = load_checkpoint(model, optimizer, filename)

print(f"The training ended in epoch number: {start_epoch+1}")

assert False

# Train the model
trained_model = train(model, optimizer, train_loader, device, start_epoch, num_epochs, filename)

Checkpoint loaded from mnist_echo_cold_l1.pth, resuming training from epoch 49
The training ended in epoch number: 50


AssertionError: 

## Sampling according to Algorithm 1

In [14]:
import torch
import torch.nn.functional as F
from tqdm import tqdm

# Define the necessary parameters and variables
T = 1000
beta_start = 0.0001
beta_end = 0.02
beta = torch.linspace(beta_start, beta_end, T)
alpha = 1 - beta
alpha = torch.cumprod(alpha, dim=0)

batch_size = 128
model.eval()
total_loss = 0.0

with torch.no_grad():
    for batch_idx, (data, _) in enumerate(tqdm(val_loader, desc="Sampling Progress:", leave=False)):
        data = data.to(device)
        f_x, sx_matrix = model.encoder(data)
        epsilon = echo_sample((f_x, sx_matrix))
        x = f_x + sx_matrix * epsilon
        x = torch.sqrt(alpha[-1]) * data + torch.sqrt(1-alpha[-1]) * x
        for s in range(T-1, 0, -1):
            t = torch.tensor([s] * data.size(0), dtype=torch.long).to(device)
            x_hat = model.decoder(x, t)
            z_hat = (1.0 / torch.sqrt(1-alpha[s])) * (x - torch.sqrt(alpha[s]) * x_hat)
            x = torch.sqrt(alpha[s-1]) * x_hat + torch.sqrt(1 - alpha[s-1]) * z_hat
        
        # Calculate the reconstruction error
        reconstruct_loss = torch.nn.functional.mse_loss(data, x)
        print(f"Loss for the batch no #{batch_idx}: {reconstruct_loss}") 
        total_loss += reconstruct_loss.item()

avg_loss = (total_loss) / len(val_loader)
print(f"Average loss for the val_loader: {avg_loss}")


Sampling Progress::   1%|▏                     | 1/94 [00:50<1:18:05, 50.38s/it]

Loss for the batch no #0: 0.09644512832164764


Sampling Progress::   2%|▍                     | 2/94 [01:38<1:14:59, 48.91s/it]

Loss for the batch no #1: 0.11078450828790665


Sampling Progress::   3%|▋                     | 3/94 [02:26<1:13:28, 48.45s/it]

Loss for the batch no #2: 0.10683412104845047


Sampling Progress::   4%|▉                     | 4/94 [03:14<1:12:21, 48.24s/it]

Loss for the batch no #3: 0.09600592404603958


Sampling Progress::   5%|█▏                    | 5/94 [04:02<1:11:24, 48.14s/it]

Loss for the batch no #4: 0.1028551533818245


Sampling Progress::   6%|█▍                    | 6/94 [04:51<1:11:11, 48.54s/it]

Loss for the batch no #5: 0.11183424293994904


Sampling Progress::   7%|█▋                    | 7/94 [05:42<1:11:25, 49.25s/it]

Loss for the batch no #6: 0.10976279526948929


Sampling Progress::   9%|█▊                    | 8/94 [06:32<1:11:18, 49.74s/it]

Loss for the batch no #7: 0.10266856849193573


Sampling Progress::  10%|██                    | 9/94 [07:21<1:09:47, 49.26s/it]

Loss for the batch no #8: 0.10842149704694748


Sampling Progress::  11%|██▏                  | 10/94 [08:09<1:08:28, 48.91s/it]

Loss for the batch no #9: 0.10337303578853607


Sampling Progress::  12%|██▍                  | 11/94 [08:57<1:07:19, 48.66s/it]

Loss for the batch no #10: 0.10676375776529312


Sampling Progress::  13%|██▋                  | 12/94 [09:47<1:07:16, 49.23s/it]

Loss for the batch no #11: 0.10271502286195755


Sampling Progress::  14%|██▉                  | 13/94 [10:35<1:05:58, 48.87s/it]

Loss for the batch no #12: 0.09922736138105392


Sampling Progress::  15%|███▏                 | 14/94 [11:26<1:05:49, 49.37s/it]

Loss for the batch no #13: 0.1040693148970604


Sampling Progress::  16%|███▎                 | 15/94 [12:14<1:04:27, 48.95s/it]

Loss for the batch no #14: 0.09888769686222076


Sampling Progress::  17%|███▌                 | 16/94 [13:05<1:04:19, 49.48s/it]

Loss for the batch no #15: 0.09626230597496033


Sampling Progress::  18%|███▊                 | 17/94 [13:55<1:03:57, 49.83s/it]

Loss for the batch no #16: 0.09709352254867554


Sampling Progress::  19%|████                 | 18/94 [14:45<1:03:10, 49.88s/it]

Loss for the batch no #17: 0.11574236303567886


Sampling Progress::  20%|████▏                | 19/94 [15:35<1:02:15, 49.81s/it]

Loss for the batch no #18: 0.1058899313211441


Sampling Progress::  21%|████▍                | 20/94 [16:23<1:00:45, 49.26s/it]

Loss for the batch no #19: 0.1065383031964302


Sampling Progress::  22%|████▋                | 21/94 [17:12<1:00:02, 49.35s/it]

Loss for the batch no #20: 0.10125280916690826


Sampling Progress::  23%|█████▍                 | 22/94 [18:01<58:45, 48.97s/it]

Loss for the batch no #21: 0.10542917251586914


Sampling Progress::  24%|█████▋                 | 23/94 [18:50<58:11, 49.17s/it]

Loss for the batch no #22: 0.09943728148937225


Sampling Progress::  26%|█████▊                 | 24/94 [19:40<57:30, 49.29s/it]

Loss for the batch no #23: 0.10771872103214264


Sampling Progress::  27%|██████                 | 25/94 [20:30<56:54, 49.48s/it]

Loss for the batch no #24: 0.1108691394329071


Sampling Progress::  28%|██████▎                | 26/94 [21:20<56:15, 49.65s/it]

Loss for the batch no #25: 0.11106160283088684


Sampling Progress::  29%|██████▌                | 27/94 [22:10<55:36, 49.80s/it]

Loss for the batch no #26: 0.10855427384376526


Sampling Progress::  30%|██████▊                | 28/94 [23:01<55:07, 50.11s/it]

Loss for the batch no #27: 0.1059296652674675


Sampling Progress::  31%|███████                | 29/94 [23:49<53:42, 49.57s/it]

Loss for the batch no #28: 0.1096615195274353


Sampling Progress::  32%|███████▎               | 30/94 [24:37<52:22, 49.10s/it]

Loss for the batch no #29: 0.10024868696928024


Sampling Progress::  33%|███████▌               | 31/94 [25:28<52:02, 49.57s/it]

Loss for the batch no #30: 0.10375369340181351


Sampling Progress::  34%|███████▊               | 32/94 [26:18<51:35, 49.93s/it]

Loss for the batch no #31: 0.10364905744791031


Sampling Progress::  35%|████████               | 33/94 [27:09<51:00, 50.17s/it]

Loss for the batch no #32: 0.11005517095327377


Sampling Progress::  36%|████████▎              | 34/94 [28:00<50:21, 50.36s/it]

Loss for the batch no #33: 0.11362864822149277


Sampling Progress::  37%|████████▌              | 35/94 [28:48<48:50, 49.67s/it]

Loss for the batch no #34: 0.10814471542835236


Sampling Progress::  38%|████████▊              | 36/94 [29:36<47:38, 49.29s/it]

Loss for the batch no #35: 0.10092265903949738


Sampling Progress::  39%|█████████              | 37/94 [30:27<47:13, 49.71s/it]

Loss for the batch no #36: 0.10378625988960266


Sampling Progress::  40%|█████████▎             | 38/94 [31:15<45:58, 49.26s/it]

Loss for the batch no #37: 0.10112452507019043


Sampling Progress::  41%|█████████▌             | 39/94 [32:06<45:37, 49.77s/it]

Loss for the batch no #38: 0.10093019902706146


Sampling Progress::  43%|█████████▊             | 40/94 [32:57<45:07, 50.14s/it]

Loss for the batch no #39: 0.11190170794725418


Sampling Progress::  44%|██████████             | 41/94 [33:48<44:30, 50.39s/it]

Loss for the batch no #40: 0.1047465056180954


Sampling Progress::  45%|██████████▎            | 42/94 [34:39<43:49, 50.58s/it]

Loss for the batch no #41: 0.1003771647810936


Sampling Progress::  46%|██████████▌            | 43/94 [35:30<43:04, 50.68s/it]

Loss for the batch no #42: 0.10482946783304214


Sampling Progress::  47%|██████████▊            | 44/94 [36:21<42:16, 50.73s/it]

Loss for the batch no #43: 0.09729932248592377


Sampling Progress::  48%|███████████            | 45/94 [37:09<40:47, 49.95s/it]

Loss for the batch no #44: 0.1057434231042862


Sampling Progress::  49%|███████████▎           | 46/94 [37:57<39:28, 49.35s/it]

Loss for the batch no #45: 0.10372905433177948


Sampling Progress::  50%|███████████▌           | 47/94 [38:45<38:20, 48.95s/it]

Loss for the batch no #46: 0.09949061274528503


Sampling Progress::  51%|███████████▋           | 48/94 [39:33<37:18, 48.65s/it]

Loss for the batch no #47: 0.1052006408572197


Sampling Progress::  52%|███████████▉           | 49/94 [40:24<36:57, 49.28s/it]

Loss for the batch no #48: 0.10966183245182037


Sampling Progress::  53%|████████████▏          | 50/94 [41:15<36:28, 49.74s/it]

Loss for the batch no #49: 0.10744540393352509


Sampling Progress::  54%|████████████▍          | 51/94 [42:05<35:51, 50.04s/it]

Loss for the batch no #50: 0.10718915611505508


Sampling Progress::  55%|████████████▋          | 52/94 [42:56<35:09, 50.23s/it]

Loss for the batch no #51: 0.10737992078065872


Sampling Progress::  56%|████████████▉          | 53/94 [43:47<34:28, 50.44s/it]

Loss for the batch no #52: 0.105152927339077


Sampling Progress::  57%|█████████████▏         | 54/94 [44:35<33:09, 49.73s/it]

Loss for the batch no #53: 0.10519362986087799


Sampling Progress::  59%|█████████████▍         | 55/94 [45:23<31:58, 49.20s/it]

Loss for the batch no #54: 0.09953133016824722


Sampling Progress::  60%|█████████████▋         | 56/94 [46:11<30:55, 48.83s/it]

Loss for the batch no #55: 0.10055584460496902


Sampling Progress::  61%|█████████████▉         | 57/94 [47:02<30:27, 49.40s/it]

Loss for the batch no #56: 0.10838373750448227


Sampling Progress::  62%|██████████████▏        | 58/94 [47:52<29:52, 49.80s/it]

Loss for the batch no #57: 0.11784079670906067


Sampling Progress::  63%|██████████████▍        | 59/94 [48:43<29:12, 50.09s/it]

Loss for the batch no #58: 0.1163589209318161


Sampling Progress::  64%|██████████████▋        | 60/94 [49:31<28:03, 49.53s/it]

Loss for the batch no #59: 0.1032753586769104


Sampling Progress::  65%|██████████████▉        | 61/94 [50:19<26:59, 49.08s/it]

Loss for the batch no #60: 0.11245069652795792


Sampling Progress::  66%|███████████████▏       | 62/94 [51:08<26:04, 48.89s/it]

Loss for the batch no #61: 0.11997876316308975


Sampling Progress::  67%|███████████████▍       | 63/94 [51:59<25:36, 49.56s/it]

Loss for the batch no #62: 0.10342323035001755


Sampling Progress::  68%|███████████████▋       | 64/94 [52:50<24:58, 49.94s/it]

Loss for the batch no #63: 0.10570650547742844


Sampling Progress::  69%|███████████████▉       | 65/94 [53:41<24:17, 50.26s/it]

Loss for the batch no #64: 0.10413028299808502


Sampling Progress::  70%|████████████████▏      | 66/94 [54:32<23:32, 50.44s/it]

Loss for the batch no #65: 0.10903480648994446


Sampling Progress::  71%|████████████████▍      | 67/94 [55:20<22:23, 49.77s/it]

Loss for the batch no #66: 0.1093294769525528


Sampling Progress::  72%|████████████████▋      | 68/94 [56:11<21:41, 50.05s/it]

Loss for the batch no #67: 0.11285581439733505


Sampling Progress::  73%|████████████████▉      | 69/94 [56:59<20:36, 49.44s/it]

Loss for the batch no #68: 0.10240580141544342


Sampling Progress::  74%|█████████████████▏     | 70/94 [57:49<19:54, 49.77s/it]

Loss for the batch no #69: 0.10144686698913574


Sampling Progress::  76%|█████████████████▎     | 71/94 [58:40<19:10, 50.02s/it]

Loss for the batch no #70: 0.10109712183475494


Sampling Progress::  77%|█████████████████▌     | 72/94 [59:30<18:24, 50.20s/it]

Loss for the batch no #71: 0.09945575147867203


Sampling Progress::  78%|████████████████▎    | 73/94 [1:00:21<17:36, 50.29s/it]

Loss for the batch no #72: 0.10500527173280716


Sampling Progress::  79%|████████████████▌    | 74/94 [1:01:09<16:31, 49.60s/it]

Loss for the batch no #73: 0.1014329269528389


Sampling Progress::  80%|████████████████▊    | 75/94 [1:01:59<15:47, 49.85s/it]

Loss for the batch no #74: 0.10328815877437592


Sampling Progress::  81%|████████████████▉    | 76/94 [1:02:47<14:47, 49.29s/it]

Loss for the batch no #75: 0.10634681582450867


Sampling Progress::  82%|█████████████████▏   | 77/94 [1:03:38<14:03, 49.64s/it]

Loss for the batch no #76: 0.1051219031214714


Sampling Progress::  83%|█████████████████▍   | 78/94 [1:04:28<13:18, 49.90s/it]

Loss for the batch no #77: 0.10551751405000687


Sampling Progress::  84%|█████████████████▋   | 79/94 [1:05:19<12:31, 50.13s/it]

Loss for the batch no #78: 0.10372797399759293


Sampling Progress::  85%|█████████████████▊   | 80/94 [1:06:09<11:43, 50.24s/it]

Loss for the batch no #79: 0.1082509234547615


Sampling Progress::  86%|██████████████████   | 81/94 [1:07:00<10:54, 50.34s/it]

Loss for the batch no #80: 0.09717614203691483


Sampling Progress::  87%|██████████████████▎  | 82/94 [1:07:48<09:55, 49.63s/it]

Loss for the batch no #81: 0.10055525600910187


Sampling Progress::  88%|██████████████████▌  | 83/94 [1:08:38<09:08, 49.88s/it]

Loss for the batch no #82: 0.10359656810760498


Sampling Progress::  89%|██████████████████▊  | 84/94 [1:09:29<08:20, 50.08s/it]

Loss for the batch no #83: 0.10483364015817642


Sampling Progress::  90%|██████████████████▉  | 85/94 [1:10:17<07:25, 49.45s/it]

Loss for the batch no #84: 0.10790012776851654


Sampling Progress::  91%|███████████████████▏ | 86/94 [1:11:05<06:31, 48.99s/it]

Loss for the batch no #85: 0.10785548388957977


Sampling Progress::  93%|███████████████████▍ | 87/94 [1:11:55<05:46, 49.45s/it]

Loss for the batch no #86: 0.10219015181064606


Sampling Progress::  94%|███████████████████▋ | 88/94 [1:12:43<04:54, 49.03s/it]

Loss for the batch no #87: 0.1045791357755661


Sampling Progress::  95%|███████████████████▉ | 89/94 [1:13:31<04:03, 48.71s/it]

Loss for the batch no #88: 0.09928957372903824


Sampling Progress::  96%|████████████████████ | 90/94 [1:14:19<03:13, 48.48s/it]

Loss for the batch no #89: 0.09673163294792175


Sampling Progress::  97%|████████████████████▎| 91/94 [1:15:10<02:27, 49.07s/it]

Loss for the batch no #90: 0.1014121025800705


Sampling Progress::  98%|████████████████████▌| 92/94 [1:16:00<01:38, 49.49s/it]

Loss for the batch no #91: 0.10446931421756744


Sampling Progress::  99%|████████████████████▊| 93/94 [1:16:51<00:49, 49.80s/it]

Loss for the batch no #92: 0.10841048508882523


                                                                                

Loss for the batch no #93: 0.09818530827760696
Average loss for the val_loader: 0.10496607066151943




In [None]:
filename = 'mnist_echo_cold_alg1.pt'
# Save the dictionary to a .pt file
torch.save(sampled_data, filename)

print(f"Sampled data saved to {filename}")


In [None]:
print(len(sampled_data))

In [None]:
import torch
import matplotlib.pyplot as plt

# Load the saved data
sampled_data = torch.load(filename)

# Access a specific entry
index = 120 # Replace with the index you want to check
original_image = sampled_data[index]['original_image']
sampled_image = sampled_data[index]['sampled']



# Convert the tensor to a numpy array
original_image = original_image.cpu().numpy().transpose(1, 2, 0)
sampled_image = sampled_image.cpu().numpy().transpose(1, 2, 0)

# Plot the images
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(original_image, cmap='gray')
ax[0].set_title('Original Image')
ax[0].axis('off')

ax[1].imshow(sampled_image, cmap='gray')
ax[1].set_title('Sampled Image')
ax[1].axis('off')

plt.show()

## Sampling according to Algorithm 2

In [15]:
import torch
import torch.nn.functional as F
from tqdm import tqdm

# Define the necessary parameters and variables
T = 1000
beta_start = 0.0001
beta_end = 0.02
beta = torch.linspace(beta_start, beta_end, T)
alpha = 1 - beta
alpha = torch.cumprod(alpha, dim=0)

batch_size = 128
model.eval()
total_loss = 0.0

with torch.no_grad():
    for batch_idx, (data, _) in enumerate(tqdm(val_loader, desc="Sampling Progress:", leave=False)):  # Corrected line
        data = data.to(device)
        f_x, sx_matrix = model.encoder(data)
        epsilon = echo_sample((f_x, sx_matrix))
        x = f_x + sx_matrix * epsilon
        x = torch.sqrt(alpha[-1]) * data + torch.sqrt(1-alpha[-1]) * x
        for s in range(T-1, -1, -1):
            t = torch.tensor([s] * data.size(0), dtype=torch.long).to(device)
            x_hat = model.decoder(x, t)
            z_hat = (1.0 / torch.sqrt(1-alpha[s])) * (x - torch.sqrt(alpha[s]) * x_hat)
            D_s = torch.sqrt(alpha[s]) * x_hat + torch.sqrt(1-alpha[s]) * z_hat
            D_s_minus_one = torch.sqrt(alpha[s-1]) * x_hat + torch.sqrt(1-alpha[s-1]) * z_hat
            x = x - D_s + D_s_minus_one
        
        # Calculate the reconstruction loss
        reconstruct_loss = torch.nn.functional.mse_loss(data, x)
        print(f"The reconstruction loss for the batch no #{batch_idx}: {reconstruct_loss}")
        total_loss += reconstruct_loss.item()

avg_loss = total_loss/len(val_loader)
print(f"The average loss for the val_loader: {avg_loss}")

Sampling Progress::   1%|▏                     | 1/94 [00:50<1:18:14, 50.48s/it]

The reconstruction loss for the batch no #0: 186.7362060546875


Sampling Progress::   2%|▍                     | 2/94 [01:40<1:17:26, 50.50s/it]

The reconstruction loss for the batch no #1: 231.85494995117188


Sampling Progress::   3%|▋                     | 3/94 [02:31<1:16:40, 50.55s/it]

The reconstruction loss for the batch no #2: 155.24732971191406


Sampling Progress::   4%|▉                     | 4/94 [03:22<1:15:50, 50.56s/it]

The reconstruction loss for the batch no #3: 117.20182800292969


Sampling Progress::   5%|█▏                    | 5/94 [04:12<1:15:01, 50.58s/it]

The reconstruction loss for the batch no #4: 194.0275115966797


Sampling Progress::   6%|█▍                    | 6/94 [05:03<1:14:10, 50.57s/it]

The reconstruction loss for the batch no #5: 212.83616638183594


Sampling Progress::   7%|█▋                    | 7/94 [05:53<1:13:20, 50.57s/it]

The reconstruction loss for the batch no #6: 192.78085327148438


Sampling Progress::   9%|█▊                    | 8/94 [06:44<1:12:30, 50.59s/it]

The reconstruction loss for the batch no #7: 205.5034637451172


Sampling Progress::  10%|██                    | 9/94 [07:35<1:11:47, 50.68s/it]

The reconstruction loss for the batch no #8: 202.681640625


Sampling Progress::  11%|██▏                  | 10/94 [08:26<1:11:00, 50.72s/it]

The reconstruction loss for the batch no #9: 157.88494873046875


Sampling Progress::  12%|██▍                  | 11/94 [09:17<1:10:13, 50.76s/it]

The reconstruction loss for the batch no #10: 182.56275939941406


Sampling Progress::  13%|██▋                  | 12/94 [10:07<1:09:23, 50.78s/it]

The reconstruction loss for the batch no #11: 195.79391479492188


Sampling Progress::  14%|██▉                  | 13/94 [10:58<1:08:33, 50.79s/it]

The reconstruction loss for the batch no #12: 146.22711181640625


Sampling Progress::  15%|███▏                 | 14/94 [11:49<1:07:43, 50.80s/it]

The reconstruction loss for the batch no #13: 146.32290649414062


Sampling Progress::  16%|███▎                 | 15/94 [12:40<1:06:52, 50.80s/it]

The reconstruction loss for the batch no #14: 121.00567626953125


Sampling Progress::  17%|███▌                 | 16/94 [13:31<1:06:02, 50.80s/it]

The reconstruction loss for the batch no #15: 168.2917938232422


Sampling Progress::  18%|███▊                 | 17/94 [14:21<1:05:11, 50.80s/it]

The reconstruction loss for the batch no #16: 161.49951171875


Sampling Progress::  19%|████                 | 18/94 [15:12<1:04:21, 50.81s/it]

The reconstruction loss for the batch no #17: 267.9104309082031


Sampling Progress::  20%|████▏                | 19/94 [16:03<1:03:30, 50.80s/it]

The reconstruction loss for the batch no #18: 220.66273498535156


Sampling Progress::  21%|████▍                | 20/94 [16:54<1:02:38, 50.79s/it]

The reconstruction loss for the batch no #19: 194.68609619140625


Sampling Progress::  22%|████▋                | 21/94 [17:45<1:01:48, 50.80s/it]

The reconstruction loss for the batch no #20: 148.8715362548828


Sampling Progress::  23%|████▉                | 22/94 [18:36<1:00:59, 50.82s/it]

The reconstruction loss for the batch no #21: 170.4435577392578


Sampling Progress::  24%|█████▏               | 23/94 [19:26<1:00:10, 50.85s/it]

The reconstruction loss for the batch no #22: 156.53189086914062


Sampling Progress::  26%|█████▊                 | 24/94 [20:17<59:19, 50.85s/it]

The reconstruction loss for the batch no #23: 171.6936798095703


Sampling Progress::  27%|██████                 | 25/94 [21:08<58:28, 50.85s/it]

The reconstruction loss for the batch no #24: 222.6040496826172


Sampling Progress::  28%|██████▎                | 26/94 [21:59<57:37, 50.85s/it]

The reconstruction loss for the batch no #25: 194.22763061523438


Sampling Progress::  29%|██████▌                | 27/94 [22:50<56:47, 50.85s/it]

The reconstruction loss for the batch no #26: 200.9840087890625


Sampling Progress::  30%|██████▊                | 28/94 [23:41<55:55, 50.85s/it]

The reconstruction loss for the batch no #27: 150.0196533203125


Sampling Progress::  31%|███████                | 29/94 [24:31<55:03, 50.82s/it]

The reconstruction loss for the batch no #28: 198.03977966308594


Sampling Progress::  32%|███████▎               | 30/94 [25:22<54:12, 50.83s/it]

The reconstruction loss for the batch no #29: 153.77435302734375


Sampling Progress::  33%|███████▌               | 31/94 [26:13<53:18, 50.76s/it]

The reconstruction loss for the batch no #30: 197.61627197265625


Sampling Progress::  34%|███████▊               | 32/94 [27:04<52:26, 50.75s/it]

The reconstruction loss for the batch no #31: 132.02674865722656


Sampling Progress::  35%|████████               | 33/94 [27:54<51:36, 50.77s/it]

The reconstruction loss for the batch no #32: 179.3961639404297


Sampling Progress::  36%|████████▎              | 34/94 [28:45<50:49, 50.83s/it]

The reconstruction loss for the batch no #33: 233.62562561035156


Sampling Progress::  37%|████████▌              | 35/94 [29:36<49:59, 50.84s/it]

The reconstruction loss for the batch no #34: 170.9837188720703


Sampling Progress::  38%|████████▊              | 36/94 [30:27<49:06, 50.80s/it]

The reconstruction loss for the batch no #35: 154.47869873046875


Sampling Progress::  39%|█████████              | 37/94 [31:18<48:11, 50.73s/it]

The reconstruction loss for the batch no #36: 184.02993774414062


Sampling Progress::  40%|█████████▎             | 38/94 [32:08<47:17, 50.66s/it]

The reconstruction loss for the batch no #37: 166.1663360595703


Sampling Progress::  41%|█████████▌             | 39/94 [32:59<46:24, 50.62s/it]

The reconstruction loss for the batch no #38: 185.67306518554688


Sampling Progress::  43%|█████████▊             | 40/94 [33:49<45:31, 50.59s/it]

The reconstruction loss for the batch no #39: 193.4693603515625


Sampling Progress::  44%|██████████             | 41/94 [34:40<44:40, 50.58s/it]

The reconstruction loss for the batch no #40: 186.6087188720703


Sampling Progress::  45%|██████████▎            | 42/94 [35:30<43:49, 50.56s/it]

The reconstruction loss for the batch no #41: 190.9784698486328


Sampling Progress::  46%|██████████▌            | 43/94 [36:21<42:58, 50.55s/it]

The reconstruction loss for the batch no #42: 190.3313751220703


Sampling Progress::  47%|██████████▊            | 44/94 [37:11<42:07, 50.54s/it]

The reconstruction loss for the batch no #43: 130.58038330078125


Sampling Progress::  48%|███████████            | 45/94 [38:02<41:16, 50.54s/it]

The reconstruction loss for the batch no #44: 205.69342041015625


Sampling Progress::  49%|███████████▎           | 46/94 [38:52<40:26, 50.55s/it]

The reconstruction loss for the batch no #45: 151.8600311279297


Sampling Progress::  50%|███████████▌           | 47/94 [39:43<39:35, 50.55s/it]

The reconstruction loss for the batch no #46: 125.61173248291016


Sampling Progress::  51%|███████████▋           | 48/94 [40:33<38:44, 50.54s/it]

The reconstruction loss for the batch no #47: 192.16871643066406


Sampling Progress::  52%|███████████▉           | 49/94 [41:24<37:53, 50.53s/it]

The reconstruction loss for the batch no #48: 208.54571533203125


Sampling Progress::  53%|████████████▏          | 50/94 [42:14<37:03, 50.53s/it]

The reconstruction loss for the batch no #49: 191.00880432128906


Sampling Progress::  54%|████████████▍          | 51/94 [43:05<36:12, 50.53s/it]

The reconstruction loss for the batch no #50: 220.54873657226562


Sampling Progress::  55%|████████████▋          | 52/94 [43:55<35:22, 50.53s/it]

The reconstruction loss for the batch no #51: 191.60227966308594


Sampling Progress::  56%|████████████▉          | 53/94 [44:46<34:31, 50.53s/it]

The reconstruction loss for the batch no #52: 200.0660400390625


Sampling Progress::  57%|█████████████▏         | 54/94 [45:37<33:41, 50.54s/it]

The reconstruction loss for the batch no #53: 163.29005432128906


Sampling Progress::  59%|█████████████▍         | 55/94 [46:27<32:51, 50.54s/it]

The reconstruction loss for the batch no #54: 188.45835876464844


Sampling Progress::  60%|█████████████▋         | 56/94 [47:18<32:00, 50.54s/it]

The reconstruction loss for the batch no #55: 123.52729034423828


Sampling Progress::  61%|█████████████▉         | 57/94 [48:06<30:42, 49.80s/it]

The reconstruction loss for the batch no #56: 178.25222778320312


Sampling Progress::  62%|██████████████▏        | 58/94 [48:56<30:01, 50.04s/it]

The reconstruction loss for the batch no #57: 265.7788391113281


Sampling Progress::  63%|██████████████▍        | 59/94 [49:47<29:16, 50.18s/it]

The reconstruction loss for the batch no #58: 232.8602294921875


Sampling Progress::  64%|██████████████▋        | 60/94 [50:37<28:30, 50.30s/it]

The reconstruction loss for the batch no #59: 166.16799926757812


Sampling Progress::  65%|██████████████▉        | 61/94 [51:28<27:43, 50.41s/it]

The reconstruction loss for the batch no #60: 208.3987579345703


Sampling Progress::  66%|███████████████▏       | 62/94 [52:19<26:54, 50.45s/it]

The reconstruction loss for the batch no #61: 211.71180725097656


Sampling Progress::  67%|███████████████▍       | 63/94 [53:09<26:04, 50.48s/it]

The reconstruction loss for the batch no #62: 177.88902282714844


Sampling Progress::  68%|███████████████▋       | 64/94 [54:00<25:14, 50.49s/it]

The reconstruction loss for the batch no #63: 160.4415283203125


Sampling Progress::  69%|███████████████▉       | 65/94 [54:50<24:24, 50.51s/it]

The reconstruction loss for the batch no #64: 182.21844482421875


Sampling Progress::  70%|████████████████▏      | 66/94 [55:41<23:34, 50.51s/it]

The reconstruction loss for the batch no #65: 173.4050750732422


Sampling Progress::  71%|████████████████▍      | 67/94 [56:31<22:43, 50.48s/it]

The reconstruction loss for the batch no #66: 176.44387817382812


Sampling Progress::  72%|████████████████▋      | 68/94 [57:19<21:32, 49.72s/it]

The reconstruction loss for the batch no #67: 251.88856506347656


Sampling Progress::  73%|████████████████▉      | 69/94 [58:09<20:47, 49.91s/it]

The reconstruction loss for the batch no #68: 155.03854370117188


Sampling Progress::  74%|█████████████████▏     | 70/94 [59:00<20:01, 50.06s/it]

The reconstruction loss for the batch no #69: 183.3903045654297


Sampling Progress::  76%|█████████████████▎     | 71/94 [59:49<19:08, 49.93s/it]

The reconstruction loss for the batch no #70: 175.4279327392578


Sampling Progress::  77%|████████████████     | 72/94 [1:00:38<18:06, 49.38s/it]

The reconstruction loss for the batch no #71: 125.4189453125


Sampling Progress::  78%|████████████████▎    | 73/94 [1:01:26<17:08, 48.97s/it]

The reconstruction loss for the batch no #72: 230.11264038085938


Sampling Progress::  79%|████████████████▌    | 74/94 [1:02:14<16:13, 48.67s/it]

The reconstruction loss for the batch no #73: 162.8058624267578


Sampling Progress::  80%|████████████████▊    | 75/94 [1:03:02<15:21, 48.48s/it]

The reconstruction loss for the batch no #74: 192.53993225097656


Sampling Progress::  81%|████████████████▉    | 76/94 [1:03:50<14:29, 48.32s/it]

The reconstruction loss for the batch no #75: 166.10035705566406


Sampling Progress::  82%|█████████████████▏   | 77/94 [1:04:38<13:40, 48.24s/it]

The reconstruction loss for the batch no #76: 138.43739318847656


Sampling Progress::  83%|█████████████████▍   | 78/94 [1:05:26<12:50, 48.16s/it]

The reconstruction loss for the batch no #77: 132.23902893066406


Sampling Progress::  84%|█████████████████▋   | 79/94 [1:06:14<12:01, 48.11s/it]

The reconstruction loss for the batch no #78: 172.03550720214844


Sampling Progress::  85%|█████████████████▊   | 80/94 [1:07:02<11:12, 48.07s/it]

The reconstruction loss for the batch no #79: 196.77377319335938


Sampling Progress::  86%|██████████████████   | 81/94 [1:07:50<10:24, 48.05s/it]

The reconstruction loss for the batch no #80: 126.64762115478516


Sampling Progress::  87%|██████████████████▎  | 82/94 [1:08:38<09:36, 48.08s/it]

The reconstruction loss for the batch no #81: 151.05870056152344


Sampling Progress::  88%|██████████████████▌  | 83/94 [1:09:26<08:48, 48.08s/it]

The reconstruction loss for the batch no #82: 162.36134338378906


Sampling Progress::  89%|██████████████████▊  | 84/94 [1:10:14<08:00, 48.10s/it]

The reconstruction loss for the batch no #83: 198.8800506591797


Sampling Progress::  90%|██████████████████▉  | 85/94 [1:11:02<07:13, 48.12s/it]

The reconstruction loss for the batch no #84: 210.96661376953125


Sampling Progress::  91%|███████████████████▏ | 86/94 [1:11:50<06:24, 48.10s/it]

The reconstruction loss for the batch no #85: 178.01353454589844


Sampling Progress::  93%|███████████████████▍ | 87/94 [1:12:38<05:36, 48.08s/it]

The reconstruction loss for the batch no #86: 170.0158233642578


Sampling Progress::  94%|███████████████████▋ | 88/94 [1:13:26<04:48, 48.09s/it]

The reconstruction loss for the batch no #87: 153.8350830078125


Sampling Progress::  95%|███████████████████▉ | 89/94 [1:14:14<04:00, 48.12s/it]

The reconstruction loss for the batch no #88: 144.33929443359375


Sampling Progress::  96%|████████████████████ | 90/94 [1:15:03<03:12, 48.10s/it]

The reconstruction loss for the batch no #89: 145.3271484375


Sampling Progress::  97%|████████████████████▎| 91/94 [1:15:51<02:24, 48.16s/it]

The reconstruction loss for the batch no #90: 148.20730590820312


Sampling Progress::  98%|████████████████████▌| 92/94 [1:16:39<01:36, 48.23s/it]

The reconstruction loss for the batch no #91: 171.45184326171875


Sampling Progress::  99%|████████████████████▊| 93/94 [1:17:28<00:48, 48.38s/it]

The reconstruction loss for the batch no #92: 201.347412109375


                                                                                

The reconstruction loss for the batch no #93: 135.44277954101562
The average loss for the val_loader: 178.73324658008332




In [None]:
filename = 'mnist_echo_cold_alg2.pt'
# Save the dictionary to a .pt file
torch.save(sampled_data, filename)

print(f"Sampled data saved to {filename}")


In [None]:
print(len(sampled_data))

In [None]:
import torch
import matplotlib.pyplot as plt

# Load the saved data
sampled_data = torch.load(filename)

# Access a specific entry
index = 120 # Replace with the index you want to check
original_image = sampled_data[index]['original_image']
sampled_image = sampled_data[index]['sampled']



# Convert the tensor to a numpy array
original_image = original_image.cpu().numpy().transpose(1, 2, 0)
sampled_image = sampled_image.cpu().numpy().transpose(1, 2, 0)

# Plot the images
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(original_image, cmap='gray')
ax[0].set_title('Original Image')
ax[0].axis('off')

ax[1].imshow(sampled_image, cmap='gray')
ax[1].set_title('Sampled Image')
ax[1].axis('off')

plt.show()