In [None]:
from Noise import NoiseScheduler
from DatasetLoader import ImageDataset
from torch.utils.data import DataLoader
import torch
import torchvision
from model import U_net
from utilites import count_parameters
import os
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#data
image_size = 28
data = ImageDataset("dataset/train", image_size=image_size)
data_loader = DataLoader(data, batch_size=64, shuffle=True)

#scheduler
num_timesteps = 1000
noise_scheduler = NoiseScheduler(num_timesteps=num_timesteps, beta_start=0.0001, beta_end=0.02, device=device)

#model
model = U_net(device)
model.to(device)

count_parameters(model)

Module                                                         Parameters
time_mlp                                                          525,568
Down_blocks                                                       316,800
bottleneck                                                        328,064
Up_blocks                                                         434,880
final_conv                                                             65
Total Trainable Parameters                                      1,605,377


1605377

In [None]:
#parameters
num_epochs = 100
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.MSELoss()

#training loop
for epoch in range(num_epochs):
    losses = []
    for images in tqdm(data_loader):
        images = images.to(device)

        # Sample random timesteps for each image
        batch_size = images.size(0)
        noise = torch.randn_like(images).to(device)

        # Add noise to the images
        timestep = torch.randint(0, num_timesteps, (images.shape[0],), device=device).long()

        #Add noise to images according to the timestep
        noisy_images = noise_scheduler.add_noise(images, noise, timestep)

        # Predict the noise using the model
        noise_pred = model(noisy_images, timestep)

        # Compute the loss
        loss = criterion(noise_pred, noise)
        losses.append(loss.item())

        # Backpropagation and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

    # Save the model checkpoint
    torch.save(model.state_dict(), os.path.join("save_model", f"model_epoch_{epoch+1}.pth"))

  2%|▏         | 17/938 [00:01<00:56, 16.40it/s]


KeyboardInterrupt: 

In [None]:
#sample
xt = torch.randn(1, 1, image_size, image_size).to(device)

for i in tqdm(reversed(range(num_timesteps))):
    noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))

    xt, _ = noise_scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))

    # denormalizace z [-1, 1] → [0, 1]
    img = torch.clamp(xt, -1., 1.).detach().cpu()
    img = img * 0.5 + 0.5

    # tensor -> PIL image
    to_pil = torchvision.transforms.ToPILImage()
    img_pil = to_pil(img[0]) 

    # cesta + přípona
    save_path = os.path.join(os.getcwd(), "generated_imgs", f"x_{i}.png")

    # uložení
    img_pil.save(save_path)
    img_pil.close()

1000it [00:06, 158.41it/s]
