In [7]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    image_size = 128  # the generated image resolution
    train_batch_size = 256
    eval_batch_size = 256  # how many images to sample during evaluation
    num_epochs = 500
    patience = 20  # early stopping patience
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 30
    output_dir = "test-model-1"  # the model name locally and on the HF Hub



config = TrainingConfig()

In [8]:
import torch
from torch.utils.data import Dataset, ConcatDataset
import numpy as np
from tqdm import tqdm

class GaussianDataset(Dataset):
    def __init__(self, num_samples, input_dim, cov, samples_per_obs):
        self.num_samples = num_samples
        self.input_dim = input_dim
        self.cov = cov

        self.o = []
        self.a = []

        if (samples_per_obs * num_samples) >= 1000000:
            # Generate input data (o)
            for _ in tqdm(range(num_samples)):
                obs = np.random.normal(size=(input_dim))

                for _ in range(samples_per_obs):

                    # Generate output data (a) using the Gaussian distribution
                    act = np.random.multivariate_normal(mean=obs, cov=self.cov)

                    self.o.append(obs)
                    self.a.append(act)
        else:
            # Generate input data (o)
            for _ in range(num_samples):
                obs = np.random.normal(size=(input_dim))

                for _ in range(samples_per_obs):

                    # Generate output data (a) using the Gaussian distribution
                    act = np.random.multivariate_normal(mean=obs, cov=self.cov)

                    self.o.append(obs)
                    self.a.append(act)

        print(f"Generated {len(self.o)} samples")


    def __len__(self):
        return len(self.o)
    
    def get_observations(self):
        return self.o

    def __getitem__(self, idx):
        o = torch.tensor(self.o[idx], dtype=torch.float32)
        a = torch.tensor(self.a[idx], dtype=torch.float32)
        return o, a

In [9]:
# Set parameters
real_samples = 100
sim_saples = 10000
action_dim = 2
cov = np.array([[0.01, 0.005], [0.005, 0.01]])
sim_cov = cov + np.random.normal(0, 0.001, cov.shape)
samples_per_obs = 10

# Create dataset
real_dataset = GaussianDataset(real_samples, action_dim, cov, samples_per_obs)

sim_dataset = GaussianDataset(sim_saples, action_dim, sim_cov, samples_per_obs)

val_dataset = GaussianDataset(100, action_dim, cov, samples_per_obs)

test_dataset = GaussianDataset(1000, action_dim, cov, samples_per_obs)

# Concatenate datasets
dataset = ConcatDataset([real_dataset, sim_dataset])

# Create data loader
batch_size = config.train_batch_size
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

for step, batch in enumerate(data_loader):
    o, a = batch
    print(o.shape, a.shape)
    break

# Construct validation dataset


Generated 1000 samples


  act = np.random.multivariate_normal(mean=obs, cov=self.cov)


Generated 100000 samples
Generated 1000 samples
Generated 10000 samples
torch.Size([256, 2]) torch.Size([256, 2])


In [10]:
import torch.nn as nn
from diffusers import UNet2DModel

import torch
import torch.nn as nn
import torch.nn.functional as F

class ClassConditionedMLP(nn.Module):
    def __init__(self, action_dim=2, class_emb_size=2, hidden_dim=64, num_layers=3):
        super().__init__()
        
        # Input dimension: (action_dim + class_emb_size) + 1 (timestep)
        input_dim = action_dim * 1 + class_emb_size * 1 + 1
        output_dim = action_dim * 1  # Matches the output shape of the UNet
        
        # Define a simple MLP architecture
        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
        for _ in range(num_layers - 1):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()])
        layers.append(nn.Linear(hidden_dim, output_dim))
        
        self.mlp = nn.Sequential(*layers)
    
    def forward(self, obs, act, t):
        # Flatten input dimensions and prepare input for MLP
        act = torch.tensor(act).unsqueeze(1).unsqueeze(-1)  # Convert to tensor and shape (bs, 1, 2, 1)
        bs, ch, w, h = act.shape
        act_flat = act.view(bs, -1)  # Flatten to (bs, action_dim)
        
        class_cond = torch.tensor(obs)  # Convert to tensor (bs, class_emb_size)
        t = torch.tensor(t).unsqueeze(-1)  # Convert timestep to tensor (bs, 1)
        
        # Concatenate action, class conditioning, and timestep
        mlp_input = torch.cat([act_flat, class_cond, t], dim=1)  # Shape: (bs, input_dim)
        
        # Pass through MLP
        output = self.mlp(mlp_input)  # Shape: (bs, output_dim)
        
        # Reshape to match UNet's output: (bs, 1, action_dim, 1)
        return output.view(bs, 1, w, h)


class ClassConditionedUnet(nn.Module):
    def __init__(self, action_dim=2, class_emb_size=2):
        super().__init__()

        # Self.model is an unconditional UNet with extra input channels to accept the conditioning information (the class embedding)
        self.model = UNet2DModel(
            sample_size=(action_dim, 1),  # the target image resolution
            in_channels=1 + class_emb_size,  # Additional input channels for class cond.
            out_channels=1,  # the number of output channels
            block_out_channels=(32,),
            down_block_types=(
                "DownBlock2D",  # a regular ResNet downsampling block
            ),
            up_block_types=(
                "UpBlock2D",  # a regular ResNet upsampling block
            ),
        )

    # Our forward method now takes the class labels as an additional argument
    def forward(self, obs, act, t):
        act = torch.Tensor(act).unsqueeze(1).unsqueeze(-1)  # Convert to tensor and move to device
        # Shape of x:
        bs, ch, w, h = act.shape

        # class conditioning in right shape to add as additional input channels
        class_cond = torch.Tensor(obs) # Convert to tensor and move to device
        class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
        # x is shape (bs, 1, 2, 1) and class_cond is now (bs, 2, 2, 1)

        # Net input is now x and class cond concatenated together along dimension 1
        net_input = torch.cat((act, class_cond), 1)  # (bs, 3, 2, 1)

        # Feed this to the UNet alongside the timestep and return the prediction
        return self.model(net_input, t).sample  # (bs, 1, 2, 1)

# Initialize the MLP with the desired action dimension
model = ClassConditionedMLP(action_dim=action_dim)

# Print the MLP architecture
print(model)

ClassConditionedMLP(
  (mlp): Sequential(
    (0): Linear(in_features=5, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=64, bias=True)
    (5): ReLU()
    (6): Linear(in_features=64, out_features=2, bias=True)
  )
)


In [41]:
o, a = real_dataset[0]
sample_inp = torch.cat([o, a, torch.tensor([0])])
print("Input shape:", sample_inp.shape)

num_weights = sum(p.numel() for p in model.parameters())
print(f"Number of weights: {num_weights}")

# Forward pass
#output = model(o, a, torch.tensor([0]))
#print("Output shape:", output.shape)


Input shape: torch.Size([5])
Number of weights: 8834


In [42]:
from diffusers import DDPMScheduler
import torch.nn.functional as F

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
noise = torch.randn(a.shape)
timesteps = torch.LongTensor([50])
noisy_image = noise_scheduler.add_noise(a, noise, timesteps)

observation = torch.tensor([0.0, 0.0])

print("Noisy image shape:", noisy_image.shape)
print("Noisy image:", noisy_image)


#noise_pred = model(noisy_image, observation, timesteps)
#loss = F.mse_loss(noise_pred, noise)

Noisy image shape: torch.Size([2])
Noisy image: tensor([ 0.8982, -0.5868])


In [43]:
from diffusers.optimization import get_cosine_schedule_with_warmup

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(data_loader) * config.num_epochs),
)

In [44]:
from accelerate import Accelerator
from huggingface_hub import create_repo, upload_folder
from tqdm.auto import tqdm
from pathlib import Path
import os
from diffusers import DDPMPipeline
from diffusers.utils import make_image_grid
import os
import wandb
import torch
import torch.nn.functional as F


wandb.init(project="diffusion-toy", entity="rohanb27-csail")


def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, val_dataloader, lr_scheduler):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="wandb",
        project_dir=os.path.join(config.output_dir, "logs"),
    )
    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)

    # Prepare everything
    # There is no specific order to remember, you just need to unpack the
    # objects in the same order you gave them to the prepare method.
    model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, val_dataloader, lr_scheduler
    )

    global_step = 0

    epochs_since_improvement = 0

    prev_val_loss = float("inf")

    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            # Unpack the batch
            obs, act = batch
            # Concatenate observation and action
            inputs = torch.cat((obs, act), dim=1)
            bs = inputs.shape[0]

            # Sample noise to add to the actions
            noise = torch.randn(act.shape, device=act.device)

            # Sample a random timestep for each action
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bs,), device=act.device,
                dtype=torch.int64
            )

            # Add noise to the clean actions according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_actions = noise_scheduler.add_noise(act, noise, timesteps)
            
            with accelerator.accumulate(model):
                # Predict the noise residual
                noise_pred = model(obs, noisy_actions, timesteps).squeeze()

                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        # Validate after each epoch
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_dataloader:
                obs, act = batch
                inputs = torch.cat((obs, act), dim=1)
                bs = inputs.shape[0]

                noise = torch.randn(act.shape, device=act.device)
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps, (bs,), device=act.device,
                    dtype=torch.int64
                )
                noisy_actions = noise_scheduler.add_noise(act, noise, timesteps)
                noise_pred = model(obs, noisy_actions, timesteps).squeeze()
                loss = F.mse_loss(noise_pred, noise)
                val_loss += loss.item()

        val_loss /= len(val_dataloader)
        accelerator.log({"val_loss": val_loss}, step=global_step)

        if val_loss < prev_val_loss:
            prev_val_loss = val_loss
            epochs_since_improvement = 0
        else:
            epochs_since_improvement += 1

        if epochs_since_improvement >= config.patience:
            print(f"Early stopping at epoch {epoch}")
            break

        # Save the model after each epoch
        if accelerator.is_main_process and epoch % config.save_model_epochs == 0:
            ...
            #pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model.model), scheduler=noise_scheduler)
            #pipeline.save_pretrained(config.output_dir)

        model.train()

In [45]:
from accelerate import notebook_launcher
args = (config, model, noise_scheduler, optimizer, data_loader, val_data_loader, lr_scheduler)
notebook_launcher(train_loop, args, num_processes=1)

Launching training on one GPU.


  act = torch.tensor(act).unsqueeze(1).unsqueeze(-1)  # Convert to tensor and shape (bs, 1, 2, 1)
  class_cond = torch.tensor(obs)  # Convert to tensor (bs, class_emb_size)
  t = torch.tensor(t).unsqueeze(-1)  # Convert timestep to tensor (bs, 1)
Epoch 0: 100%|██████████| 395/395 [00:02<00:00, 140.44it/s, loss=0.895, lr=7.9e-5, step=394] 
Epoch 1: 100%|██████████| 395/395 [00:03<00:00, 121.04it/s, loss=1.1, lr=0.0001, step=789]
Epoch 2: 100%|██████████| 395/395 [00:03<00:00, 129.60it/s, loss=0.782, lr=0.0001, step=1184]
Epoch 3: 100%|██████████| 395/395 [00:03<00:00, 116.60it/s, loss=0.793, lr=0.0001, step=1579]
Epoch 4: 100%|██████████| 395/395 [00:02<00:00, 139.32it/s, loss=0.832, lr=0.0001, step=1974]
Epoch 5: 100%|██████████| 395/395 [00:03<00:00, 113.43it/s, loss=0.682, lr=0.0001, step=2369]
Epoch 6: 100%|██████████| 395/395 [00:02<00:00, 138.03it/s, loss=0.557, lr=0.0001, step=2764]
Epoch 7: 100%|██████████| 395/395 [00:03<00:00, 120.14it/s, loss=0.478, lr=0.0001, step=3159]
Epoc

Early stopping at epoch 93





In [46]:
torch.save(model, "saved_models/MLP-model.pth")

In [58]:
print(obs)
print(model)
print(act_desired)

tensor([[-2.8121,  1.9047],
        [-2.8121,  1.9047]], device='cuda:0')
ClassConditionedMLP(
  (mlp): Sequential(
    (0): Linear(in_features=5, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=64, bias=True)
    (5): ReLU()
    (6): Linear(in_features=64, out_features=2, bias=True)
  )
)
tensor([[-2.8110,  1.8979],
        [-2.6461,  1.8752]])


In [39]:
# @markdown Sampling some different digits:
import torch
from diffusers import DDPMScheduler

device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
model = torch.load("saved_models/MLP-model.pth").to(device)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
model.eval()

# Prepare random x to start from, plus some desired labels y
# obs, act_desired = sim_dataset[0:2]

obs = torch.tensor([[-2.8121,  1.9047],
        [-2.8121,  1.9047], [-2.8121,  1.9047]])
act_desired = torch.tensor([[-2.8110,  1.8979],
        [-2.6461,  1.8752]])

act = torch.randn(3, 2)
#act = torch.tensor([[-2.8121,  1.9047],
#        [-2.8121,  1.9047]])
act = act.to(device)
obs = obs.to(device)

        
# Sampling loop
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
    if i % 100 == 0:
        print(act)
    # Get model pred
    with torch.no_grad():
        residual = model(obs, act, [t for _ in range(3)]).squeeze() # Again, note that we pass in our labels y

    # Update sample with step
    act = noise_scheduler.step(residual, t, act).prev_sample

    if (t < 10):
        print(act)


# Show the results
#fig, ax = plt.subplots(1, 1, figsize=(12, 12))
#ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], cmap="Greys")

  model = torch.load("saved_models/MLP-model.pth").to(device)
  act = torch.tensor(act).unsqueeze(1).unsqueeze(-1)  # Convert to tensor and shape (bs, 1, 2, 1)
  class_cond = torch.tensor(obs)  # Convert to tensor (bs, class_emb_size)
542it [00:00, 2718.63it/s]

tensor([[ 0.1077,  0.6939],
        [ 0.5206,  0.8794],
        [ 1.1180, -0.1546]])
tensor([[-0.7261, -1.2146],
        [ 0.0039,  1.3934],
        [-0.3356, -1.5278]])
tensor([[ 0.5718, -0.2566],
        [ 0.6911, -0.4354],
        [ 0.2615, -0.2394]])
tensor([[-1.8840, -0.3613],
        [ 0.6994,  0.5392],
        [ 0.1358, -0.2755]])
tensor([[-0.9662, -1.2247],
        [ 0.3034,  1.4216],
        [-1.0925, -0.2926]])
tensor([[-0.6319, -0.9908],
        [-0.7299,  0.5565],
        [-0.4612,  0.3663]])


1000it [00:00, 2678.26it/s]

tensor([[-0.4744, -0.6845],
        [-1.1193,  0.5861],
        [-0.1344, -0.3410]])
tensor([[-1.4067, -0.0784],
        [-0.6382,  1.0888],
        [ 0.0040,  1.1440]])
tensor([[-1.1687,  0.2444],
        [-0.8814,  1.3923],
        [-0.9857,  1.5158]])
tensor([[-0.8136,  0.7452],
        [-1.1298,  0.7762],
        [-0.8542,  0.6980]])
tensor([[-0.9941,  0.9996],
        [-1.0484,  1.0286],
        [-1.0304,  0.9995]])
tensor([[-0.9994,  1.0139],
        [-1.0210,  1.0223],
        [-1.0504,  0.9778]])
tensor([[-1.0060,  1.0267],
        [-1.0009,  1.0178],
        [-1.0260,  0.9753]])
tensor([[-1.0075,  1.0082],
        [-0.9995,  1.0133],
        [-1.0180,  0.9897]])
tensor([[-1.0090,  0.9854],
        [-0.9817,  1.0153],
        [-1.0133,  1.0068]])
tensor([[-1.0079,  0.9926],
        [-1.0052,  1.0110],
        [-0.9992,  0.9982]])
tensor([[-1.0160,  0.9852],
        [-1.0215,  1.0346],
        [-1.0061,  0.9887]])
tensor([[-1.0173,  0.9922],
        [-1.0253,  1.0205],
        [




In [31]:
print(act_desired)
print(obs)
print(act)

tensor([[-2.8110,  1.8979],
        [-2.6461,  1.8752]])
tensor([[-2.8121,  1.9047],
        [-2.8121,  1.9047]])
tensor([[-5.2368,  4.7500],
        [-5.2368,  4.7500]])
