In [26]:
import matplotlib.pyplot as plt
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "3"  

import torch
import torch.nn.functional as F
from torchinfo import summary
from accelerate import Accelerator

import cv2
import numpy as np
from tqdm import tqdm
from dataclasses import dataclass

from diffusers import DDPMScheduler, UNet2DModel, get_cosine_schedule_with_warmup
from datasets import RAMDatasetDIDC, LazyDatasetDIDC
from mt_DIDC_config import GROUPING_RULES, NEW_LABELS

%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
DATA_PATH = 'New_dictionary'

In [6]:
TARGET_SIZE = (128, 128)
N_EPOCHS = 5


In [8]:
dataset = LazyDatasetDIDC(DATA_PATH, grouping_rules=GROUPING_RULES, new_labels=NEW_LABELS)
print(f"Dataset size: {len(dataset)}")

Lazy Dataset: File...


Indexing files and slices:  98%|█████████▊| 480/489 [00:11<00:00, 36.22it/s]

5645 samples indexed.
Dataset size: 5645


Indexing files and slices: 100%|██████████| 489/489 [00:11<00:00, 41.30it/s]


In [9]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
batch = next(iter(dataloader))
print(f"Batch keys: {batch.keys()}")
print(f"Batch fg shape: {batch['input_label'].shape}")
print(f"Batch mask shape: {batch['multiClassMask'].shape}")

Batch keys: dict_keys(['input_label', 'multiClassMask'])
Batch fg shape: torch.Size([2, 4, 384, 384])
Batch mask shape: torch.Size([2, 384, 384])


In [16]:
batch['input_label'].dtype, batch['multiClassMask'].dtype

(torch.float32, torch.int64)

In [15]:
def training_step(batch, model, num_classes, optimizer, noise_scheduler, accelerator):
    model
    model.train()

    clean_images = batch['multiClassMask']  # Shape: (B, C, H, W)
    clean_images = F.one_hot(clean_images.long(), num_classes=num_classes).permute(0, 3, 1, 2).float()  # Shape: (B, C, H, W)
    clean_images = clean_images * 2.0 - 1.0  # Scale to [-1, 1]

    batch_size = clean_images.size(0)
    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=accelerator.device).long()

    noise = torch.randn_like(clean_images.float())
    noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
    noise_pred = model(noisy_images, timesteps).sample
    loss = F.mse_loss(noise_pred, noise)

    with accelerator.accumulate(model):
        optimizer.zero_grad()
        accelerator.backward(loss)
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

    return loss.detach().item()


In [12]:
# accelerator = Accelerator()

# # STAMPE DI DEBUG (Il tuo Sanity Check)
# print(f"--- SANITY CHECK ---")
# print(f"Processo corrente index: {accelerator.process_index}")
# print(f"Device assegnato a questo processo: {accelerator.device}")
# print(f"Numero totale di processi (GPU) attivi: {accelerator.num_processes}")
# print(f"--------------------\n")

# model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
# batch = next(iter(dataloader))

# print(f"[Process {accelerator.process_index}] Il batch si trova su: {batch['multiClassMask'].device}")
# print(f"[Process {accelerator.process_index}] Il modello si trova su: {next(model.parameters()).device}")

# loss_value = training_step(batch, model, len(dataset.new_labels), optimizer, noise_scheduler, accelerator)

# accelerator.print(f"Step completato con successo! Loss: {loss_value:.4f}")

## Final code

In [27]:
@dataclass
class TrainingConfig:
    target_size = 128  # the generated image resolution
    train_batch_size = 16
    eval_batch_size = 16  # how many images to sample during evaluation
    num_epochs = 50
    batch_size_per_gpu = 2
    num_gpus = torch.cuda.device_count()
    gradient_accumulation_steps = max(1, train_batch_size // (batch_size_per_gpu * num_gpus))
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "ddpm-butterflies-128"  # the model name locally and on the HF Hub

    seed = 0

In [28]:
config = TrainingConfig()
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="sigmoid")
model = UNet2DModel(sample_size=384,  in_channels=22,   out_channels=22, layers_per_block=2,block_out_channels=(128, 256, 512, 512),
    down_block_types=("DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D"),
)

dataset = LazyDatasetDIDC(DATA_PATH, grouping_rules=GROUPING_RULES, new_labels=NEW_LABELS)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)

model_summary = summary(model, input_data=(torch.randn(1, 22, 384, 384), torch.tensor([0])))

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=config.lr_warmup_steps,num_training_steps=(len(train_dataloader) * config.num_epochs))

Lazy Dataset: File...



Indexing files and slices:   0%|          | 0/489 [00:00<?, ?it/s][A
Indexing files and slices:   1%|          | 4/489 [00:00<00:13, 34.66it/s][A
Indexing files and slices:   2%|▏         | 8/489 [00:00<00:17, 28.07it/s][A
Indexing files and slices:   2%|▏         | 12/489 [00:00<00:15, 29.96it/s][A
Indexing files and slices:   4%|▎         | 18/489 [00:00<00:12, 38.01it/s][A
Indexing files and slices:   5%|▍         | 24/489 [00:00<00:10, 42.60it/s][A
Indexing files and slices:   6%|▌         | 29/489 [00:00<00:11, 39.77it/s][A
Indexing files and slices:   7%|▋         | 34/489 [00:00<00:11, 41.28it/s][A
Indexing files and slices:   8%|▊         | 39/489 [00:00<00:10, 42.90it/s][A
Indexing files and slices:   9%|▉         | 44/489 [00:01<00:10, 43.72it/s][A
Indexing files and slices:  10%|█         | 49/489 [00:01<00:10, 42.41it/s][A
Indexing files and slices:  11%|█         | 54/489 [00:01<00:13, 32.46it/s][A
Indexing files and slices:  12%|█▏        | 58/489 [00:01<00:1

5645 samples indexed.


In [23]:
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs"),
    )

    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        accelerator.init_trackers("train_example")

    # Prepare accelerator model
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(model, optimizer, train_dataloader, lr_scheduler)

    global_step = 0

    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            loss = training_step(batch, model, len(train_dataloader.dataset.new_labels), optimizer, noise_scheduler, accelerator)

            progress_bar.update(1)
            logs = {"loss": loss, "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        if accelerator.is_main_process:
            # implement sampling pipeline

            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                # implement evaluate() to see how well your model is doing, and optionally save generated images
                pass

            if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                # implement saving logic
                pass



In [24]:
config = TrainingConfig()
train_loop(config, model, noise_scheduler, optimizer, dataloader, lr_scheduler=None)

Epoch 0:   0%|          | 1/2823 [00:04<3:29:37,  4.46s/it]

AttributeError: 'NoneType' object has no attribute 'get_last_lr'