In [None]:
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets,transforms

In [None]:
sys.path.append(r"C:\Users\amrul\programming\deep_learning\dl_projects\Generative_Deep_Learning_2nd_Edition\notebooks")

In [None]:
from utils import display

In [None]:
# define constants
IMAGE_SIZE=32
CHANNELS=1
STEP_SIZE=10
STEPS=60
NOISE=0.005
ALPHA=0.1
GRADIENT_CLIP=0.03
BATCH_SIZE=128
BUFFER_SIZE=8192
LEARNING_RATE = 1e-4
EPOCHS=60

In [None]:
import pathlib
datapath=pathlib.Path(r"C:\Users\amrul\programming\deep_learning\dl_projects\Generative_Deep_Learning_2nd_Edition\data")

transform = transforms.Compose([transforms.Pad(2),transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])])

train_mnist = datasets.MNIST(str(datapath),train=True,download=True,transform=transform)
test_mnist = datasets.MNIST(str(datapath),train=False,download=True, transform=transform)

print(f"train mnist size : {len(train_mnist)}, test mnist size : {len(test_mnist)}")

In [None]:
import random
idx = random.choice(range(len(train_mnist)))
image,label = train_mnist[idx]
print(f"picked index : {idx}")
print(image.shape,label)
print(f"image min : {image.min()}, image max : {image.max()}")

In [None]:
import matplotlib.pyplot as plt

plt.hist(image.flatten().numpy())
plt.figure()
plt.imshow(image.squeeze(0))

In [None]:
def calc_out_height(height,kernel_size,stride,padding):
    return (height+2*padding-kernel_size)/stride+1

In [None]:
print(f"kernel 5, stride 2, padding 1 : {calc_out_height(32,5,2,1)}")
print(f"height 15, kernel 3, stride 2, padding 1 : {calc_out_height(15,3,2,1)}")
print(f"height 8, kernel 3, stride 2, padding 1 : {calc_out_height(8,3,2,1)}")
print(f"height 4, kernel 3, stride 2, padding 1 : {calc_out_height(4,3,2,1)}")


In [None]:
def swish(x,beta=1.0):
    return x * F.sigmoid(beta*x)

In [None]:
ret=swish(image.flatten().squeeze(0))

In [None]:
x=torch.linspace(-1,-0.9,steps=100)
y=swish(x)
plt.plot(x,y)

In [None]:
images = torch.normal(0,1,(100,1,32,32))
flat_images = images.flatten(start_dim=1)
print(f"images shape : {images.shape}, flat_images shape : {flat_images.shape}")

In [None]:
class EnergyFunction(nn.Module):
    def __init__(self,out_size=2, out_channels=64) -> None:
        super(EnergyFunction,self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=2, padding=1)
        self.conv2=nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)
        self.dense = nn.Linear(out_size*out_size*out_channels,64)
        self.dense2 = nn.Linear(64,1)
    
    def forward(self,x):
        x = swish(self.conv1(x))
        x = swish(self.conv2(x))
        x = swish(self.conv3(x))
        x = swish(self.conv4(x))
        x = x.flatten(start_dim=1)
        x = swish(self.dense(x))
        return self.dense2(x)


In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_mnist,batch_size=64, shuffle=True)



In [None]:
idx = random.choice(range(len(train_loader)))
print(f"picked idx {idx} from {len(train_loader)} batches")
for idx,(images,labels) in enumerate(train_loader):
    if idx == idx:
        break

In [None]:
model = EnergyFunction()

In [None]:
out = model(images)
print(f"out shape : {out.shape}")
print(f"its mean is : {torch.mean(out,dim=0)}")

In [None]:
images1 = torch.rand(64,1,32,32)
images2 = torch.rand(64,1,32,32)
images = torch.cat([images1,images2],dim=0)
print(f"images shape : {images.shape}")

In [None]:
out = model(images)
out1,out2 = torch.split(out,out.size(0)//2)
print(f"out1 shape : {out1.shape}, out2 shape : {out2.shape}")

In [None]:
GRAD_CLIP = 0.03
import ipdb

def generate_samples(model, inp_images, steps, step_size, noise, return_imgs_per_step=False):
    imgs_per_step = []     

    for _ in range(steps):
        inp_images = inp_images.detach()
        inp_images.requires_grad_(True)
        noised_inp_images = inp_images + torch.normal(0, noise, size = inp_images.size())
        noised_inp_images = torch.clamp(noised_inp_images, -1.0, 1.0)               

        model.zero_grad()
        outscore = model(noised_inp_images)
        mean_outscore = torch.mean(outscore,dim=0)
        mean_outscore.backward()
        #ipdb.set_trace()
        grads = torch.clamp(inp_images.grad,-1*GRAD_CLIP,GRAD_CLIP)
        #ipdb.set_trace()
        inp_images = inp_images + step_size * grads
        inp_images = torch.clamp(inp_images, -1.0, 1.0)

        if return_imgs_per_step:
            imgs_per_step.append(inp_images)
    if return_imgs_per_step:
        return torch.stack(imgs_per_step,dim=0)
    return inp_images

In [None]:
#chatgpt version of generate_samples
import torch

# Function to generate samples using Langevin Dynamics in PyTorch
import torch

def generate_samples(
    model, inp_imgs, steps, step_size, noise, return_img_per_step=False, return_energy_scores_per_step=False,
    gradient_clip=None
):
    imgs_per_step = []
    energy_scores_per_step = []

    for _ in range(steps):
        # Ensure inp_imgs is a float tensor and requires grad
        inp_imgs = inp_imgs.float().requires_grad_(True)

        # Add noise and clip
        inp_imgs = inp_imgs + torch.randn_like(inp_imgs) * noise
        inp_imgs = torch.clamp(inp_imgs, min=-1.0, max=1.0)

        out_score = model(inp_imgs)

        # Zero gradients of the model and inp_imgs
        model.zero_grad()
        if inp_imgs.grad is not None:
            inp_imgs.grad.data.zero_()

        # Backward pass to get gradients
        out_score.sum().backward()

        if inp_imgs.grad is None:
            raise ValueError("No gradients were computed for the input. Check the model's forward pass.")

        grads = inp_imgs.grad.data

        # Clipping gradients if a gradient clip value is provided
        if gradient_clip is not None:
            grads = grads.clamp(min=-gradient_clip, max=gradient_clip)

        # Detach inp_imgs from the current graph and update
        inp_imgs = inp_imgs.detach() + step_size * grads
        inp_imgs = torch.clamp(inp_imgs, min=-1.0, max=1.0)

        if return_img_per_step:
            imgs_per_step.append(inp_imgs.clone().detach())

        if return_energy_scores_per_step:
            energy_scores_per_step.append(out_score.clone().detach())

    if return_img_per_step:
        return torch.stack(imgs_per_step), torch.stack(energy_scores_per_step)
    else:
        return inp_imgs


In [None]:
steps=60
step_size = 10
noise = 0.005

out_images = generate_samples(model,images,steps,step_size,noise)

In [None]:
import random
import numpy as np

class Buffer:
    def __init__(self,model) -> None:
        self.model=model
        self.examples = [torch.rand(1,CHANNELS,IMAGE_SIZE,IMAGE_SIZE) for _ in range(BATCH_SIZE)]
    
    def sample_new_examples(self, steps, step_size, noise):
        # number of successes out of running an experiment with binomial outcomes BATCH_SIZE times with 5% probability of success rate
        n_new = np.random.binomial(BATCH_SIZE,0.05)
        
        # we are making values fall in -1.0 to 1.0 range
        rand_images = torch.rand(n_new, CHANNELS, IMAGE_SIZE, IMAGE_SIZE) * 2 -1 
        
        # we are randomly choosing k examples and then concatenating them along batch dimension
        old_images = torch.cat(random.choices(self.examples, k=BATCH_SIZE-n_new), dim=0)

        # concatenate newly randomly generated samples with picked up existing samples along batch dimension
        inp_images = torch.cat([old_images,rand_images], dim=0)

        # pass inp_images through langevin dynamics
        inp_images = generate_samples(self.model, inp_images, steps, step_size, noise)

        # append inp_images to the front of examples
        # pay attention to torch.split, in the second argument we are specifying what should be axis 0 size after splitting
        # which should equal 1 in this case
        self.examples = list(torch.split(inp_images, inp_images.size(0)//BATCH_SIZE, dim=0)) + self.examples

        # throw away examples that exceed buffer size
        self.examples = self.examples[:BUFFER_SIZE]

        return inp_images

In [None]:
buffer = Buffer(model)

In [None]:
inp_images = buffer.sample_new_examples(STEPS,STEP_SIZE,NOISE)
print(f"inp_images shape : {inp_images.shape}, examples size : {len(buffer.examples)}")

In [None]:
display(inp_images.detach().squeeze(1).numpy())