In [None]:
import sys
import pathlib

sys.path.append(r"C:\Users\amrul\programming\deep_learning\dl_projects\Generative_Deep_Learning_2nd_Edition")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision import datasets, transforms
from tqdm import tqdm

IMAGE_SIZE = 64
BATCH_SIZE = 64
DATASET_REPETITIONS = 5
LOAD_MODEL = False

NOISE_EMBEDDING_SIZE = 32
PLOT_DIFFUSION_STEPS = 20

# optimization
EMA = 0.999
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
EPOCHS = 50


In [None]:
checkpoints_folder = pathlib.Path.cwd()/"torch_checkpoints"
print(f"checkpoints folder : {checkpoints_folder} and it exists {checkpoints_folder.exists()}")
images_folder = pathlib.Path.cwd()/"output"/"generated_images_torch"
print(f"images folder : {images_folder} and it exists : {images_folder.exists()}")

In [None]:
from notebooks.utils import display

In [None]:
import torch
import torch.nn as nn
from ddm_torch_model import DiffusionModel,get_flower_images_train_dataset,convert_images_torch_to_numpy_for_display,training_loop,Normalizer

In [None]:
train_data = get_flower_images_train_dataset()
print(f"Loaded flower train dataset : {len(train_data)}")

In [None]:
image,label = train_data[0]
print(f"image shape : {image.shape}")

In [None]:
from torch.utils.data.sampler import Sampler
import numpy as np

class RepeatSampler(Sampler):
    def __init__(self, data_source, repetitions):
        self.data_source = data_source
        self.repetitions = repetitions

    def __iter__(self):
        n = len(self.data_source)
        return iter(np.tile(np.arange(n), self.repetitions))

    def __len__(self):
        return len(self.data_source) * self.repetitions


In [None]:
from torch.utils.data import DataLoader

repeat_sampler = RepeatSampler(train_data, DATASET_REPETITIONS)

# train_loader = DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)
train_loader = DataLoader(
    train_data,
    batch_size=BATCH_SIZE,
    sampler=repeat_sampler,
    shuffle=False
)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

In [None]:
normalizer = Normalizer()

In [None]:
# let's denormalize some images and display them
for idx,(images,_) in enumerate(train_loader):
    if idx>0:
        break
    display(convert_images_torch_to_numpy_for_display(images))
    display(convert_images_torch_to_numpy_for_display(normalizer.denormalize(images)))


In [None]:

model = DiffusionModel(3,normalizer,device)

if LOAD_MODEL:
    state_dict = torch.load(str(checkpoints_folder/"ddm_torch_checkpoints_31.pt"))
    model.load_state_dict(state_dict)

model = model.to(device)

In [None]:
import torch.optim as optim
optimizer = optim.AdamW(model.parameters(),lr=LEARNING_RATE)

In [None]:
mae_loss = nn.L1Loss()

In [None]:
ret = training_loop(EPOCHS,optimizer,model,mae_loss,train_loader,device,checkpoints_folder,images_folder)

In [None]:
import matplotlib.pyplot as plt
plt.plot(ret)