Skip to content

Commit

Permalink
Merge ea3e7c2 into 18d6b96
Browse files Browse the repository at this point in the history
  • Loading branch information
sgbaird committed Jun 10, 2022
2 parents 18d6b96 + ea3e7c2 commit fa79c17
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions notebooks/ddpm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from os import path
from uuid import uuid4

from denoising_diffusion_pytorch import GaussianDiffusion, Trainer, Unet
from mp_time_split.core import MPTimeSplit
Expand All @@ -13,37 +14,30 @@

data_path = path.join("data", "preprocessed", "mp-time-split")
xc = XtalConverter(save_dir=data_path)
xc.xtal2png(train_inputs.tolist())

model = Unet(dim=64, dim_mults=(1, 2, 4, 8), channels=1).cuda()

diffusion = GaussianDiffusion(
model,
channels=1,
image_size=64,
timesteps=10000, # number of steps
loss_type="l1", # L1 or L2
model, channels=1, image_size=64, timesteps=1000, loss_type="l1"
).cuda()

trainer = Trainer(
diffusion,
data_path,
image_size=64,
train_batch_size=2,
train_batch_size=32,
train_lr=2e-5,
train_num_steps=700000, # total training steps
gradient_accumulate_every=2, # gradient accumulation steps
ema_decay=0.995, # exponential moving average decay
amp=True, # turn on mixed precision
augment_horizontal_flip=False,
results_folder=path.join("results", str(uuid4())[0:4]),
)

trainer.train()

sampled_images = diffusion.sample(batch_size=100)

# import numpy as np
# from PIL import Image
# data = np.squeeze(sampled_images.cpu().numpy())
# imgs = []
# for d in data:
# img = Image.fromarray(d, mode="L")
# imgs.append(img)
1 + 1

0 comments on commit fa79c17

Please sign in to comment.