## Load Dataset

In [6]:
from image_datasets.IXIdataset import IXIDataset
from sklearn.model_selection import train_test_split

root = "/lustre/fswork/projects/rech/krk/usy14zi/datasets/IXI-dataset/size64/"
dataset = IXIDataset(root, mode="train")

train_set, val_set = train_test_split(dataset_train, test_size=41)

channels, image_size, _ = dataset[0]['T1'].shape
print(f"Image of size {image_size}, with {channels} channel(s).")

Image of size 64, with 1 channel(s).


In [7]:
from torch.utils.data import DataLoader

batch_size = 20

train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=False,
)
val_loader = DataLoader(
    val_set,
    batch_size=batch_size,
    shuffle=False,
)

## Initialize denoising model

In [8]:
import torch 
from model.unet import Unet
from torchinfo import summary

device = "cuda" if torch.cuda.is_available() else "cpu"

model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)
model.to(device)

summary(model)
#print(sum(p.numel() for p in model.parameters() if p.requires_grad))

Layer (type:depth-idx)                                       Param #
Unet                                                         --
├─Conv2d: 1-1                                                128
├─Sequential: 1-2                                            --
│    └─SinusoidalPositionEmbeddings: 2-1                     --
│    └─Linear: 2-2                                           16,640
│    └─GELU: 2-3                                             --
│    └─Linear: 2-4                                           65,792
├─ModuleList: 1-3                                            --
│    └─ModuleList: 2-5                                       --
│    │    └─ResnetBlock: 3-1                                 107,008
│    │    └─ResnetBlock: 3-2                                 107,008
│    │    └─Residual: 3-3                                    33,088
│    │    └─Sequential: 3-4                                  16,448
│    └─ModuleList: 2-6                                       --
│    │  

## Initialize diffusion class

In [9]:
from diffusion.time_scheduler import quadratic_beta_schedule

timesteps = 600
betas = quadratic_beta_schedule(timesteps, beta_start=0.00001, beta_end=0.01)

In [10]:
from diffusion.diffusion import DiffusionModel

ddpm = DiffusionModel(model, timesteps, betas, device, loss_type='l2')

## Train the model

In [23]:
from torch.optim import Adam

epochs = 100
optimizer = Adam(model.parameters(), lr=3e-4)

In [None]:
ddpm.train(epochs, optimizer, train_loader, val_loader)

Epoch 0: 100%|##########| 24/24 [00:13<00:00,  1.84it/s]


Loss: 0.762143088504672


Epoch 1: 100%|##########| 24/24 [00:03<00:00,  7.85it/s]


Loss: 0.3397834428275625


Epoch 2: 100%|##########| 24/24 [00:02<00:00,  8.21it/s]


Loss: 0.263161089271307


Epoch 3: 100%|##########| 24/24 [00:02<00:00,  8.93it/s]


Loss: 0.20557330114146075


Epoch 4: 100%|##########| 24/24 [00:06<00:00,  3.47it/s]


Loss: 0.18417405057698488


Epoch 5: 100%|##########| 24/24 [00:04<00:00,  5.75it/s]


Loss: 0.16969719265277186


Epoch 6: 100%|##########| 24/24 [00:02<00:00,  8.70it/s]


Loss: 0.15497096876303354


Epoch 7: 100%|##########| 24/24 [00:02<00:00,  8.88it/s]


Loss: 0.13406257859120765


Epoch 8: 100%|##########| 24/24 [00:02<00:00,  8.29it/s]


Loss: 0.13333342876285315


Epoch 9: 100%|##########| 24/24 [00:02<00:00,  8.93it/s]


Loss: 0.13696973553548256


Epoch 10: 100%|##########| 24/24 [00:02<00:00,  8.77it/s]


Loss: 0.12214901794989903


Epoch 11: 100%|##########| 24/24 [00:02<00:00,  8.94it/s]


Loss: 0.11969653004780412


Epoch 12: 100%|##########| 24/24 [00:02<00:00,  8.95it/s]


Loss: 0.11212637081431846


Epoch 13: 100%|##########| 24/24 [00:02<00:00,  8.47it/s]


Loss: 0.1098571087544163


Epoch 14: 100%|##########| 24/24 [00:02<00:00,  8.93it/s]


Loss: 0.12708501672993103


Epoch 15: 100%|##########| 24/24 [00:02<00:00,  8.94it/s]


Loss: 0.12591202619175115


Epoch 16: 100%|##########| 24/24 [00:02<00:00,  8.95it/s]


Loss: 0.11380932449052732


Epoch 17: 100%|##########| 24/24 [00:02<00:00,  8.96it/s]


Loss: 0.11256256699562073


Epoch 18: 100%|##########| 24/24 [00:02<00:00,  8.42it/s]


Loss: 0.11997802276164293


Epoch 19: 100%|##########| 24/24 [00:02<00:00,  8.51it/s]


Loss: 0.10555162808547418


Epoch 20: 100%|##########| 24/24 [00:02<00:00,  8.98it/s]


Loss: 0.10047798782276611


Epoch 21: 100%|##########| 24/24 [00:02<00:00,  8.96it/s]


Loss: 0.10933972196653485


Epoch 22: 100%|##########| 24/24 [00:02<00:00,  8.95it/s]


Loss: 0.09490908356383443


Epoch 23: 100%|##########| 24/24 [00:02<00:00,  8.94it/s]


Loss: 0.10209509000803034


Epoch 24: 100%|##########| 24/24 [00:02<00:00,  8.98it/s]


Loss: 0.10459820770968993


Epoch 25: 100%|##########| 24/24 [00:02<00:00,  8.52it/s]


Loss: 0.10426626028493047


Epoch 26: 100%|##########| 24/24 [00:02<00:00,  8.85it/s]


Loss: 0.08738992960813145


Epoch 27: 100%|##########| 24/24 [00:02<00:00,  8.72it/s]


Loss: 0.08279940447149177


Epoch 28: 100%|##########| 24/24 [00:02<00:00,  8.91it/s]


Loss: 0.09846481719675164


Epoch 29: 100%|##########| 24/24 [00:02<00:00,  8.88it/s]


Loss: 0.08694962396596868


Epoch 30: 100%|##########| 24/24 [00:02<00:00,  8.84it/s]


Loss: 0.09911422757431865


Epoch 31: 100%|##########| 24/24 [00:02<00:00,  8.53it/s]


Loss: 0.09538753610104322


Epoch 32: 100%|##########| 24/24 [00:02<00:00,  9.00it/s]


Loss: 0.08463193802163005


Epoch 33: 100%|##########| 24/24 [00:02<00:00,  8.98it/s]


Loss: 0.0838304867502302


Epoch 34: 100%|##########| 24/24 [00:02<00:00,  8.94it/s]


Loss: 0.09142405989890297


Epoch 35: 100%|##########| 24/24 [00:02<00:00,  8.97it/s]


Loss: 0.09121485892683268


Epoch 36: 100%|##########| 24/24 [00:02<00:00,  8.95it/s]


Loss: 0.0855864838231355


Epoch 37: 100%|##########| 24/24 [00:02<00:00,  8.50it/s]


Loss: 0.08691422268748283


Epoch 38: 100%|##########| 24/24 [00:02<00:00,  8.99it/s]


Loss: 0.09875167502711217


Epoch 39: 100%|##########| 24/24 [00:02<00:00,  9.00it/s]


Loss: 0.08551171142607927


Epoch 40: 100%|##########| 24/24 [00:02<00:00,  8.84it/s]


Loss: 0.07622081336254875


Epoch 41: 100%|##########| 24/24 [00:02<00:00,  8.98it/s]


Loss: 0.08891753215963642


Epoch 42: 100%|##########| 24/24 [00:02<00:00,  8.51it/s]


Loss: 0.08875541125113766


Epoch 43: 100%|##########| 24/24 [00:02<00:00,  8.95it/s]


Loss: 0.08907305317309995


Epoch 44: 100%|##########| 24/24 [00:02<00:00,  8.91it/s]


Loss: 0.0814302380507191


Epoch 45: 100%|##########| 24/24 [00:02<00:00,  8.92it/s]


Loss: 0.09018850357582171


Epoch 46: 100%|##########| 24/24 [00:02<00:00,  8.75it/s]


Loss: 0.07894146861508489


Epoch 47: 100%|##########| 24/24 [00:02<00:00,  8.93it/s]


Loss: 0.08515355441098411


Epoch 48: 100%|##########| 24/24 [00:02<00:00,  8.33it/s]


Loss: 0.08781236553719889


Epoch 49: 100%|##########| 24/24 [00:02<00:00,  8.95it/s]


Loss: 0.0906438158514599


Epoch 50: 100%|##########| 24/24 [00:05<00:00,  4.47it/s]


Loss: 0.07980929951493938


Epoch 51: 100%|##########| 24/24 [00:15<00:00,  1.54it/s]


Loss: 0.08211805469666918


Epoch 52: 100%|##########| 24/24 [00:06<00:00,  3.66it/s]


Loss: 0.08989225110659997


Epoch 53: 100%|##########| 24/24 [00:05<00:00,  4.29it/s]


Loss: 0.0797460990337034


Epoch 54: 100%|##########| 24/24 [00:09<00:00,  2.54it/s]


Loss: 0.07904892476896445


Epoch 55: 100%|##########| 24/24 [00:03<00:00,  6.86it/s]


Loss: 0.08948487012336652


Epoch 56: 100%|##########| 24/24 [00:05<00:00,  4.04it/s]


Loss: 0.08720035835479696


Epoch 57: 100%|##########| 24/24 [00:05<00:00,  4.25it/s]


Loss: 0.08184156939387321


Epoch 58: 100%|##########| 24/24 [00:06<00:00,  3.92it/s]


Loss: 0.08294311786691348


Epoch 59: 100%|##########| 24/24 [00:05<00:00,  4.62it/s]


Loss: 0.07414915210877855


Epoch 60: 100%|##########| 24/24 [00:05<00:00,  4.14it/s]


Loss: 0.08270055195316672


Epoch 61: 100%|##########| 24/24 [00:06<00:00,  3.80it/s]


Loss: 0.08875709911808372


Epoch 62: 100%|##########| 24/24 [00:06<00:00,  3.77it/s]


Loss: 0.08041661302559078


Epoch 63:  83%|########3 | 20/24 [00:05<00:01,  2.87it/s]

## Sample

In [None]:
from diffusion.sampler import sample
import matplotlib.pyplot as plt

# sample 64 images
samples = sample(ddpm, image_size=image_size, batch_size=64, channels=channels)

# show a random one
random_indexes = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]

fig, axs = plt.subplots(1, 10, figsize=(20, 8))
for i in range(10):
    idx = random_indexes[i]
    axs[i].get_xaxis().set_visible(False)
    axs[i].get_yaxis().set_visible(False)
    axs[i].set_title(f"Image {idx}")
    axs[i].imshow(samples[-1][idx].reshape(image_size, image_size, channels), cmap="gray")
plt.show()

In [None]:
import matplotlib.animation as animation

random_index = 5

fig = plt.figure()
ims = []
for i in range(timesteps):
    plt.title(f"T = {i}")
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion64x64.gif')
plt.show()