In [3]:

import argparse
import numpy as np
import random
import torch
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from torch.cuda.amp import GradScaler, autocast
import torchvision.transforms as transforms
from datasets import load_dataset
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm
# from torchinfo import summary

from simple_diffusion.scheduler import DDIMScheduler
from simple_diffusion.model import UNet
from simple_diffusion.utils import save_images, normalize_to_neg_one_to_one
from simple_diffusion.dataset import CustomDataset, get_dataset
import pandas as pd
import webdataset as wds

from simple_diffusion.ema import EMA

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import os
data_path = '/Users/Sandhanakrishnan/Image_Synthesis_Diffusion/Train'

image_paths = [os.path.join(data_path, filename) for filename in os.listdir(data_path) if filename.endswith(('.png', '.jpg', '.jpeg'))]

# Create a DataFrame
data_df = pd.DataFrame(image_paths, columns=['image_path'])
csv_file_path = 'train.csv'
data_df.to_csv(csv_file_path, index=False)


In [6]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

n_timesteps = 1000
n_inference_timesteps = 250

def _grayscale_to_rgb(img):
    if img.mode != "RGB":
        return img.convert("RGB")
    return img

In [7]:
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("mps")
    model = UNet(3, image_size=args.resolution, hidden_dims=[64, 128, 256, 512])
    noise_scheduler = DDIMScheduler(num_train_timesteps=n_timesteps,
                                    beta_schedule="cosine")
    model = model.to(device)

    if args.pretrained_model_path:
        pretrained = torch.load(args.pretrained_model_path)["model_state"]
        model.load_state_dict(pretrained)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
    )

    tfms = transforms.Compose([
        transforms.Resize((args.resolution, args.resolution)),
        transforms.ToTensor()
    ])



    # df = pd.read_pickle(args.dataset_path)
    df = pd.read_csv(args.dataset_path)
    dataset = CustomDataset(df)


    train_dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=args.train_batch_size, shuffle=True)
    steps_per_epcoch = len(train_dataloader)

    total_num_steps = (steps_per_epcoch * args.num_epochs) // args.gradient_accumulation_steps
    total_num_steps += int(total_num_steps * 10/100)
    gamma = args.gamma
    ema = EMA(model, gamma, total_num_steps)

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps,
        num_training_steps=total_num_steps,
    )

    # summary(model, [(1, 3, args.resolution, args.resolution), (1,)], verbose=1)

    scaler = GradScaler(enabled=args.fp16_precision)
    global_step = 0
    losses = []
    for epoch in range(args.num_epochs):
        progress_bar = tqdm(total=steps_per_epcoch)
        progress_bar.set_description(f"Epoch {epoch}")
        losses_log = 0
        for step, batch in enumerate(train_dataloader):
            # print(batch["image"])
            clean_images = batch["image"].to(device)
            clean_images = normalize_to_neg_one_to_one(clean_images)

            batch_size = clean_images.shape[0]
            noise = torch.randn(clean_images.shape).to(device)
            timesteps = torch.randint(0,
                                      noise_scheduler.num_train_timesteps,
                                      (batch_size,),
                                      device=device).long()
            noisy_images = noise_scheduler.add_noise(clean_images, noise,
                                                     timesteps)

            optimizer.zero_grad()
            with autocast(enabled=args.fp16_precision):
                noise_pred = model(noisy_images, timesteps)["sample"]
                loss = F.l1_loss(noise_pred, noise)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            ema.update_params(gamma)
            gamma = ema.update_gamma(global_step)

            if args.use_clip_grad:
                clip_grad_norm_(model.parameters(), 1.0)

            lr_scheduler.step()

            progress_bar.update(1)
            losses_log += loss.detach().item()
            logs = {
                "loss_avg": losses_log / (step + 1),
                "loss": loss.detach().item(),
                "lr": lr_scheduler.get_last_lr()[0],
                "step": global_step,
                "gamma": gamma
            }

            progress_bar.set_postfix(**logs)
            global_step += 1

            # Generate sample images for visual inspection
            if global_step % args.save_model_steps == 0:
                ema.ema_model.eval()
                with torch.no_grad():
                    # has to be instantiated every time, because of reproducibility
                    generator = torch.manual_seed(0)
                    generated_images = noise_scheduler.generate(
                        ema.ema_model,
                        num_inference_steps=n_inference_timesteps,
                        generator=generator,
                        eta=1.0,
                        use_clipped_model_output=True,
                        batch_size=args.eval_batch_size,
                        output_type="numpy")

                    save_images(generated_images, epoch, args)

                    torch.save(
                        {
                            'model_state': model.state_dict(),
                            'ema_model_state': ema.ema_model.state_dict(),
                            'optimizer_state': optimizer.state_dict(),
                        }, args.output_dir)

        progress_bar.close()
        losses.append(losses_log / (step + 1))




In [13]:
class SimulationArgs:
    def __init__(self,
                 dataset_name='',
                 dataset_path='/Users/Sandhanakrishnan/Image_Synthesis_Diffusion/train.csv',
                 dataset_config_name=None,
                 output_dir="trained_models/ddpm-model-64.pth",
                 samples_dir="test_samples/",
                 loss_logs_dir="training_logs",
                 cache_dir=None,
                 resolution=64,
                 train_batch_size=1,
                 eval_batch_size=1,
                 num_epochs=1,
                 save_model_steps=100,
                 gradient_accumulation_steps=1,
                 learning_rate=1e-4,
                 lr_scheduler="cosine",
                 lr_warmup_steps=100,
                 adam_beta1=0.9,
                 adam_beta2=0.99,
                 adam_weight_decay=0.0,
                 use_clip_grad=False,
                 logging_dir="logs",
                 pretrained_model_path=None,
                 fp16_precision=False,
                 gamma=0.996):
        self.dataset_name = dataset_name
        self.dataset_path = dataset_path
        self.dataset_config_name = dataset_config_name
        self.output_dir = output_dir
        self.samples_dir = samples_dir
        self.loss_logs_dir = loss_logs_dir
        self.cache_dir = cache_dir
        self.resolution = resolution
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.num_epochs = num_epochs
        self.save_model_steps = save_model_steps
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.learning_rate = learning_rate
        self.lr_scheduler = lr_scheduler
        self.lr_warmup_steps = lr_warmup_steps
        self.adam_beta1 = adam_beta1
        self.adam_beta2 = adam_beta2
        self.adam_weight_decay = adam_weight_decay
        self.use_clip_grad = use_clip_grad
        self.logging_dir = logging_dir
        self.pretrained_model_path = pretrained_model_path
        self.fp16_precision = fp16_precision
        self.gamma = gamma

    # Example validation method
    def validate_args(self):
        if self.dataset_name is None and self.dataset_path is None:
            raise ValueError("You must specify either a dataset name or a dataset path.")


In [14]:
args = SimulationArgs(dataset_name="yfcc7m", train_batch_size=32, learning_rate=2e-4)

# Access arguments like so
print(args.dataset_name)
print(args.learning_rate)

# Don't forget to validate your arguments if needed
args.validate_args()

yfcc7m
0.0002


In [15]:
main(args)

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Epoch 0:  10%|█         | 1/10 [01:40<15:07, 100.86s/it, gamma=0.996, loss=0.852, loss_avg=0.852, lr=2e-6, step=0]

KeyboardInterrupt: 