## Load Dataset

In [1]:
import os
import diffusers
import torch

import matplotlib.pyplot as plt

import numpy as np
import matplotlib.pyplot as plt

from urllib.request import urlopen 
import json 

os.environ['HF_HOME'] = "/run/media/anton/hdd/hf"

In [2]:
from diffusers import DiffusionPipeline
import PIL.Image
import numpy as np

from datasets import load_dataset

dataset = load_dataset("uoft-cs/cifar10", split="train")

  deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)


In [8]:
from torchvision import transforms

batch_size=32

preprocess = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["img"]]
    return {"images": images}

dataset.set_transform(transform)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

## Define VAE Model

In [9]:
from diffusers.models import AutoencoderKL

model = AutoencoderKL(
    in_channels=3,  
    out_channels=3,  
    latent_channels=4,
    layers_per_block=2,  
    block_out_channels=(128, 256, 256, 256,),  # the number of output channels for each UNet block
    down_block_types=(
        "DownEncoderBlock2D",
        "DownEncoderBlock2D",
        "DownEncoderBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownEncoderBlock2D",
    ),
    up_block_types=(
        "UpDecoderBlock2D",  # a regular ResNet upsampling block
        "UpDecoderBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpDecoderBlock2D",
        "UpDecoderBlock2D",
    ),
)

In [12]:
from diffusers.optimization import get_constant_schedule
from accelerate import Accelerator

lr=1e-3
epochs=150


optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

lr_scheduler = get_constant_schedule(optimizer=optimizer)

accelerator = Accelerator(mixed_precision="fp16", gradient_accumulation_steps=1)

model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, lr_scheduler
)

In [15]:
batch = next(iter(train_dataloader))

In [38]:
x = batch['images']

q = model.encode(x).latent_dist

loss_kl = -0.5 * torch.sum(1 + q.logvar - q.mean.pow(2) - q.logvar.exp())

x_pred = model.decode(q.sample()).sample

loss_recon = torch.nn.functional.mse_loss(x_pred, x)

loss = loss_recon + loss_kl


In [None]:
import torch.nn.functional as F

for epoch in range(epochs):
    for batch_idx, batch in enumerate(train_dataloader):

        x = batch['images']
        
        with accelerator.accumulate(model):
            q = model.encode(x).latent_dist

            loss_kl = -0.5 * torch.sum(1 + q.logvar - q.mean.pow(2) - q.logvar.exp())
            
            x_pred = model.decode(q.sample()).sample
            
            loss_recon = F.mse_loss(x_pred, x)
            
            loss = loss_recon + loss_kl

            
            accelerator.backward(loss)

            # accelerator.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
        
        if batch_idx % 100 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Step [{batch_idx+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}')

Epoch [1/150], Step [1/1563], Loss: 1080.9672
Epoch [1/150], Step [101/1563], Loss: 0.4157
Epoch [1/150], Step [201/1563], Loss: 0.3153
Epoch [1/150], Step [301/1563], Loss: 0.2510
Epoch [1/150], Step [401/1563], Loss: 0.2995
Epoch [1/150], Step [501/1563], Loss: 0.2688
Epoch [1/150], Step [601/1563], Loss: 0.2866
Epoch [1/150], Step [701/1563], Loss: 1.3666
Epoch [1/150], Step [801/1563], Loss: 0.4619
Epoch [1/150], Step [901/1563], Loss: 0.2741
Epoch [1/150], Step [1001/1563], Loss: 0.2582
Epoch [1/150], Step [1101/1563], Loss: 0.4918
Epoch [1/150], Step [1201/1563], Loss: 0.3504
Epoch [1/150], Step [1301/1563], Loss: 0.2481
Epoch [1/150], Step [1401/1563], Loss: 0.2373
Epoch [1/150], Step [1501/1563], Loss: 0.2582
Epoch [2/150], Step [1/1563], Loss: 0.2623
Epoch [2/150], Step [101/1563], Loss: 0.2426
Epoch [2/150], Step [201/1563], Loss: 0.2919
Epoch [2/150], Step [301/1563], Loss: 0.2276
Epoch [2/150], Step [401/1563], Loss: 0.2221
Epoch [2/150], Step [501/1563], Loss: 0.2532
Epoch