### Define the model architecture using a self-attention transformer


In [1]:
import numpy as np
from PIL import Image
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
import torch.nn.functional as F
from tqdm import tqdm
from torch import Tensor
import math
import matplotlib.pyplot as plt

In [2]:
# define the diffusion model
class DiffusionModel(nn.Module):
    def __init__(self, n_steps, n_heads, n_dims, n_hidden, output_dim):
        super().__init__()

        self.n_steps = n_steps
        self.n_heads = n_heads
        self.n_dims = n_dims
        self.n_hidden = n_hidden

        self.transformer_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(n_dims, n_heads, n_hidden) 
            for _ in range(n_steps)
        ])

        self.to_params = nn.Linear(n_dims, 2 * n_dims)
        self.to_output = nn.Linear(n_dims, output_dim)

    def sample_noise(self, batch_size, device):
        return torch.randn(batch_size, self.n_dims, device=device)

    def forward(self, x, timesteps_left):
        # apply the diffusion process
        for i in reversed(range(self.n_steps)):
            # sample noise for this step
            noise = self.sample_noise(x.shape[0], x.device)

            # get the parameters for this step
            params = self.to_params(x)
            mean, log_std = params.chunk(2, dim=-1)

            # calculate the new state
            std = torch.exp(log_std)
            state = (x - mean) / std
            state = state + noise * std
            state = self.transformer_layers[i](state.transpose(0, 1), None).transpose(0, 1)

            # update x with the new state
            x = mean + state * std

            # apply the schedule
            x = x * (1 - timesteps_left[i]) + noise * timesteps_left[i].sqrt()

        # map the final state to RGB values
        output = self.to_output(x)
        return output

In [4]:
class ColoredPiDataset(Dataset):
    def __init__(self, image_path, xs_path, ys_path):
        self.xs = np.load(xs_path)
        self.ys = np.load(ys_path)
        self.image_array = np.array(Image.open(image_path))
        self.rgb_values = self.image_array[self.xs, self.ys]
        
        # Normalize xy values to be between 0 and 1
        self.xs, self.ys = self.xs / 299.0, self.ys / 299.0

        # Normalize rgb values to be between 0 and 1
        self.rgb_values = self.rgb_values / 255.0
        
        # # Normalize rgb values to be between -1 and 1
        # self.rgb_values = (self.rgb_values / 127.5) - 1

    def __len__(self):
        return len(self.xs)
        # return 30000

    def __getitem__(self, idx):
        if idx >= 5000:
            return torch.zeros((5)).to(torch.float32)
        return torch.tensor([self.xs[idx], self.ys[idx], self.rgb_values[idx][0], self.rgb_values[idx][1], self.rgb_values[idx][2]]).to(torch.float32)

# Define training function
def train_diffusion(model, optimizer, criterion, dataloader, device):
    model.train()
    running_loss = 0.0
    for batch in dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()
        
        noise = torch.randn(batch.shape[0], 128).to(device)
        
        # generate the timesteps_left schedule
        timesteps_left = torch.linspace(0, 1, model.n_steps, device=device)
        timesteps_left = timesteps_left.expand(batch_size, -1)
        print(timesteps_left.shape)
        # generate the RGB values using the model
        outputs = model(noise, timesteps_left)

        loss = criterion(outputs, batch)

        loss.backward()
        optimizer.step()
        running_loss += loss.item() * batch.size(0)
    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define hyperparameters
input_dim = 5 # XYRGB values
output_dim = 5 # XYRGB values
hidden_dim = 128
latent_dim = 16
num_layers = 2
num_heads = 4
dropout = 0.1

batch_size = 128
learning_rate = 3e-4
num_epochs = 10
num_samples = 500

# Load the dataset
dataset = ColoredPiDataset('sparse_pi_colored.jpg', 'pi_xs.npy', 'pi_ys.npy')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

# Initialize model, optimizer, and loss function
model = DiffusionModel(n_steps=1000, n_heads=4, n_dims=128, n_hidden=512, output_dim=output_dim).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

# Train model
iteration = tqdm(range(num_epochs))
for epoch in iteration:
    train_loss = train_diffusion(model, optimizer, criterion, dataloader, device)
    iteration.set_description('Epoch [{}/{}], Train Loss: {:.4f}'.format(epoch+1, num_epochs, train_loss))
    
# Generate some samples from the model
generated_image = np.zeros(dataset.image_array.shape)

# xy  = np.zeros((num_samples*batch_size, 2))
# rgb = np.zeros((num_samples*batch_size, 3))
# sample_iter = tqdm(range(num_samples))
# for sample_idx in sample_iter:

xy  = np.zeros((len(dataloader)*batch_size, 2))
rgb = np.zeros((len(dataloader)*batch_size, 3))
for sample_idx, batch in enumerate(dataloader):
    with torch.no_grad():
        # samples = model(torch.randn(batch_size, 5).to(device))
        # samples, _, _ = model(torch.randn(batch_size, 5).to(device))
        samples, _, _ = model(batch.to(device))
        # samples = model.decode(torch.randn(batch_size, latent_dim).to(device))

        # Denomarlizing samples
        # samples[:, :2] = (samples[:, :2] + 1) * 149.5
        samples[:, :2] = (samples[:, :2]) * 299
        
        # Denomarlizing samples
        # samples[:, 2:] = (samples[:, 2:] + 1) * 127.5
        samples[:, 2:] = (samples[:, 2:]) * 255
        
        xy[sample_idx*batch_size:(sample_idx+1)*batch_size, :] = samples[:, :2].cpu().numpy()
        rgb[sample_idx*batch_size:(sample_idx+1)*batch_size, :] = samples[:, 2:].cpu().numpy()

        samples = samples.cpu().numpy().astype(np.uint8)
        for i in range(batch_size):
            x, y, r, g, b = samples[i]
            generated_image[x, y] = [r, g, b]
            
print(f'xy mean: {np.mean(xy)}, xy std: {np.std(xy)}, xy max: {np.max(xy)}, xy min: {np.min(xy)}')
print(f'rgb mean: {np.mean(rgb)}, rgb std: {np.std(rgb)}, rgb max: {np.max(rgb)}, rgb min: {np.min(rgb)}')
print(f'Error: {np.mean(np.abs(generated_image - dataset.image_array))}')

# Save the output image
# Image.fromarray(generated_image).save('generated_pi_colored.jpg')
plt.imshow(generated_image)

  0%|          | 0/10 [00:00<?, ?it/s]

torch.Size([128, 1000])





ValueError: not enough values to unpack (expected 3, got 2)