In [1]:
!pip install diffusers[torch]

Collecting diffusers[torch]
  Downloading diffusers-0.31.0-py3-none-any.whl.metadata (18 kB)
Collecting huggingface-hub>=0.23.2 (from diffusers[torch])
  Downloading huggingface_hub-0.26.3-py3-none-any.whl.metadata (13 kB)
Collecting regex!=2019.12.17 (from diffusers[torch])
  Downloading regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting safetensors>=0.3.1 (from diffusers[torch])
  Downloading safetensors-0.4.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Collecting accelerate>=0.31.0 (from diffusers[torch])
  Downloading accelerate-1.1.1-py3-none-any.whl.metadata (19 kB)
Downloading accelerate-1.1.1-py3-none-any.whl (333 kB)
Downloading huggingface_hub-0.26.3-py3-none-any.whl (447 kB)
Downloading regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (781 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m781.7/781.7 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision
from diffusers import UNet2DModel, DDPMScheduler
from PIL import Image
import os
from tqdm.auto import tqdm
import numpy as np

In [3]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [4]:
!nvidia-smi

Mon Dec  2 02:42:03 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla V100-SXM2-16GB           On  |   00000000:00:1B.0 Off |                    0 |
| N/A   30C    P0             34W /  300W |       1MiB /  16384MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla V100-SXM2-16GB           On  |   00

In [6]:
def get_device():
    """
    Get the appropriate device (CUDA, MPS, or CPU)
    """
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    return "cpu"


class CustomNormalization:
    def __call__(self, tensor):
        normalized = tensor * 2 - 1
        rounded = normalized.round()

        return rounded + (rounded == 0).to(torch.float32) * 0.0


class PianoRollDataset(Dataset):
    def __init__(self, image_dir):
        self.image_dir = image_dir
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            CustomNormalization()
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        rgb_image = Image.open(image_path).convert('RGB')
        image = self.transform(rgb_image)
        return image

In [7]:
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, device):
    def calculate_weights(batch):
        unique, counts = torch.unique(batch, return_counts=True)
        total_pixels = batch.numel()
        weights = total_pixels / (len(unique) * counts)
        weights = weights / weights.sum()
        return {val.item(): weight.item() for val, weight in zip(unique, weights)}

    progress_bar = tqdm(total=config["num_epochs"] * len(train_dataloader))
    global_step = 0

    loss_record = []
    for epoch in range(config["num_epochs"]):
        model.train()
        for batch in train_dataloader:
            clean_images = batch.to(device)
            batch_size = clean_images.shape[0]
            class_weights = calculate_weights(clean_images)

            # Sample noise and add to images
            noise = torch.randn(clean_images.shape).to(device)
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (batch_size,),
                device=device
            ).long()
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            # Get model prediction
            noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

            # Calculate loss
            loss = F.mse_loss(noise_pred, noise, reduction='none')
            # print(f"loss:{loss}")
            pixel_weights = torch.ones_like(loss)
            for val, weight in class_weights.items():
                mask = (clean_images == val)
                pixel_weights[mask] = weight

            # 应用权重
            weighted_loss = (loss * pixel_weights).mean()
            loss_record.append(weighted_loss)

            print(f"loss:{weighted_loss}")

            # Backpropagation
            optimizer.zero_grad()
            weighted_loss.backward()
            optimizer.step()

            progress_bar.update(1)
            global_step += 1

            if global_step % config["save_interval"] == 0:
                # Save checkpoint
                torch.save({
                    'step': global_step,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                }, f"checkpoint_{global_step}.pt")

        # Save model after each epoch
        torch.save(model.state_dict(), f"model_epoch_{epoch}.pt")
        with open("loss.txt", "w") as file:
            for item in loss:
                file.write(f"{item}\n")

In [8]:
def main():
    # Configuration
    config = {
        "image_height": 768,
        "image_width": 512,
        "batch_size": 2,
        "num_epochs": 5,
        "learning_rate": 1e-4,
        "save_interval": 500,
        "sample_interval": 1000,  # Interval for generating sample images
        "data_dir": "piano_roll_images",  # Your image directory
        "sample_dir": "samples"  # Directory to save generated samples
    }

    # Initialize device
    device = get_device()
    print(f"Using device: {device}")

    # Create dataset and dataloader
    dataset = PianoRollDataset(config["data_dir"])
    dataloader = DataLoader(
        dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=4
    )

    # Initialize model with 3 input/output channels for RGB
    model = UNet2DModel(
        sample_size=(config["image_height"], config["image_width"]),
        in_channels=3,  # RGB input
        out_channels=3,  # RGB output
        layers_per_block=3,
        block_out_channels=(32, 64, 128),  # Further reduced channels
        down_block_types=(
            "DownBlock2D",
            "DownBlock2D", 
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    ).to(device)

    # Initialize noise scheduler
    noise_scheduler = DDPMScheduler(
        num_train_timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear"
    )

    # Initialize optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"])

    # Train model
    train_loop(config, model, noise_scheduler, optimizer, dataloader, device)

In [9]:
 main()

Using device: cuda


  0%|          | 0/1990 [00:00<?, ?it/s]

loss:0.04017787054181099
loss:0.0526822991669178
loss:0.046404678374528885
loss:0.043543148785829544
loss:0.05463096871972084
loss:0.05601043999195099
loss:0.040390584617853165
loss:0.043204713612794876
loss:0.04385123774409294
loss:0.041011303663253784
loss:0.05031537264585495
loss:0.03435595706105232
loss:0.033751413226127625
loss:0.024198532104492188
loss:0.03635222092270851
loss:0.03783315420150757
loss:0.04369577765464783
loss:0.04939572885632515
loss:0.03165976330637932
loss:0.02894187718629837
loss:0.02820318378508091
loss:0.0331728458404541
loss:0.03773244097828865
loss:0.03654574602842331
loss:0.027767380699515343
loss:0.03592529520392418
loss:0.027708517387509346
loss:0.03313636779785156
loss:0.032084401696920395
loss:0.025182465091347694
loss:0.02848581224679947
loss:0.017751876264810562
loss:0.03169655427336693
loss:0.02112816646695137
loss:0.022709695622324944
loss:0.02313653565943241
loss:0.02545454353094101
loss:0.017870189622044563
loss:0.021791033446788788
loss:0.01939

loss:0.005438117776066065
loss:0.003949799109250307
loss:0.006492826621979475
loss:0.0025630812160670757
loss:0.007184735499322414
loss:0.003609722713008523
loss:0.003693987848237157
loss:0.0048707015812397
loss:0.005393119994550943
loss:0.003029213286936283
loss:0.006033487617969513
loss:0.0038835108280181885
loss:0.003630556631833315
loss:0.0034503419883549213
loss:0.004685811698436737
loss:0.009554793126881123
loss:0.004866846837103367
loss:0.0027110064402222633
loss:0.004631513729691505
loss:0.0030694985762238503
loss:0.0029792285058647394
loss:0.0034407461062073708
loss:0.007347329054027796
loss:0.004180348012596369
loss:0.002332722069695592
loss:0.001641064533032477
loss:0.0035375787410885096
loss:0.003260734723880887
loss:0.004012526944279671
loss:0.004437626339495182
loss:0.007329680025577545
loss:0.004301406443119049
loss:0.004462577402591705
loss:0.003417940577492118
loss:0.004432103130966425
loss:0.006178099196404219
loss:0.0054225013591349125
loss:0.0030060235876590014
loss

loss:0.001382871181704104
loss:0.0021862422581762075
loss:0.0028648637235164642
loss:0.0017929611494764686
loss:0.0023808451369404793
loss:0.0018235124880447984
loss:0.003389816964045167
loss:0.0021886113099753857
loss:0.0024433049838989973
loss:0.0021373352501541376
loss:0.0011925348080694675
loss:0.0022535324096679688
loss:0.003769150236621499
loss:0.002428896725177765
loss:0.002707269974052906
loss:0.002512099454179406
loss:0.002292082877829671
loss:0.0022350321523845196
loss:0.0031595106702297926
loss:0.0019867566879838705
loss:0.004491033963859081
loss:0.007555072661489248
loss:0.00370453717187047
loss:0.007067806087434292
loss:0.002514579566195607
loss:0.002112130867317319
loss:0.01016545481979847
loss:0.0015167695237323642
loss:0.010333436541259289
loss:0.0020509324967861176
loss:0.0021563994232565165
loss:0.003886539489030838
loss:0.002452262444421649
loss:0.004088092129677534
loss:0.003049281658604741
loss:0.0019842716865241528
loss:0.0021565649658441544
loss:0.001441567670553

loss:0.001496578217484057
loss:0.0026207726914435625
loss:0.0017320887418463826
loss:0.0008711975533515215
loss:0.0032124698627740145
loss:0.0018279001815244555
loss:0.0013055403251200914
loss:0.0016252920031547546
loss:0.00780549505725503
loss:0.01073262095451355
loss:0.0038926468696445227
loss:0.0013778667198494077
loss:0.0015731065068393946
loss:0.002606237307190895
loss:0.001681793131865561
loss:0.005747344810515642
loss:0.0020206861663609743
loss:0.0019803205505013466
loss:0.0024005745071917772
loss:0.0023464004043489695
loss:0.0010602136608213186
loss:0.0014348957920446992
loss:0.0014511261833831668
loss:0.005551953334361315
loss:0.0016399013111367822
loss:0.0016103758243843913
loss:0.006661290768533945
loss:0.0014444672269746661
loss:0.0012691776501014829
loss:0.002596786478534341
loss:0.0012993077980354428
loss:0.0011191994417458773
loss:0.0018755982164293528
loss:0.0021167797967791557
loss:0.010473226197063923
loss:0.0015633353032171726
loss:0.0019679118413478136
loss:0.006468

loss:0.0010976173216477036
loss:0.0010152362519875169
loss:0.002047415357083082
loss:0.005107903853058815
loss:0.0012332737678661942
loss:0.001391288242302835
loss:0.00140715844463557
loss:0.0015616065356880426
loss:0.0031853935215622187
loss:0.0017243872862309217
loss:0.001441733562387526
loss:0.001798121491447091
loss:0.0015903324820101261
loss:0.0017734031425788999
loss:0.0018059518188238144
loss:0.0010232666973024607
loss:0.0011798582272604108
loss:0.0012615652522072196
loss:0.0012113532284274697
loss:0.0013480059569701552
loss:0.001530898385681212
loss:0.0037465915083885193
loss:0.0013395188143476844
loss:0.0038429934065788984
loss:0.0010831912513822317
loss:0.0010001116897910833
loss:0.0013186174910515547
loss:0.0008825738332234323
loss:0.003016931703314185
loss:0.001971364486962557
loss:0.0011807932751253247
loss:0.0014966018497943878
loss:0.0011715609580278397
loss:0.003954519983381033
loss:0.001010481035336852
loss:0.002290082164108753
loss:0.005512558855116367
loss:0.00131095

loss:0.0013283273437991738
loss:0.001207290799356997
loss:0.0015318907098844647
loss:0.005238141864538193
loss:0.0026916752103716135
loss:0.0019603653345257044
loss:0.001035416149534285
loss:0.003423114074394107
loss:0.0010267546167597175
loss:0.0010639457032084465
loss:0.004871281795203686
loss:0.0018714043544605374
loss:0.0007815437857061625
loss:0.00183896126691252
loss:0.004140304401516914
loss:0.0012083137407898903
loss:0.0009428334305994213
loss:0.0009706239798106253
loss:0.0011498458916321397
loss:0.0009253875468857586
loss:0.0059862444177269936
loss:0.0010324511677026749
loss:0.0011056290240958333
loss:0.0014222523896023631
loss:0.0011212987592443824
loss:0.0011229208903387189
loss:0.0008112986688502133
loss:0.002945284591987729
loss:0.0011487791780382395
loss:0.0024052683729678392
loss:0.00105563597753644
loss:0.0011182860471308231
loss:0.0007727149059064686
loss:0.0011751505080610514
loss:0.0013282239669933915
loss:0.0014599727001041174
loss:0.0008531954372301698
loss:0.00125

loss:0.0010074006859213114
loss:0.0007770118536427617
loss:0.0011404359247535467
loss:0.0006487775244750082
loss:0.0012916993582621217
loss:0.0009035997209139168
loss:0.0010782243916764855
loss:0.0018264768877997994
loss:0.0026090615428984165
loss:0.0006503884796984494
loss:0.0006142064812593162
loss:0.000724201905541122
loss:0.0008428048458881676
loss:0.0007943356758914888
loss:0.001735297730192542
loss:0.0012872519437223673
loss:0.001113921171054244
loss:0.003080161055549979
loss:0.006069476250559092
loss:0.005221264436841011
loss:0.0016519486671313643
loss:0.008726025931537151
loss:0.0017065288266167045
loss:0.001431935583241284
loss:0.0009484494221396744
loss:0.0028662688564509153
loss:0.0008904169080778956
loss:0.0015621198108419776
loss:0.000991626176983118
loss:0.0011994438245892525
loss:0.0011067779269069433
loss:0.0019846868235617876
loss:0.0020617519039660692
loss:0.004133055452257395
loss:0.0012058725114911795
loss:0.0008695280994288623
loss:0.0014786765677854419
loss:0.0010

In [5]:
def save_images(images, path, index):
    """Save a batch of images during training for monitoring."""
    images = images.mean(dim=1, keepdim=True)
    # Convert to binary
    images = (images >= -0.8).float()
    grid = torchvision.utils.make_grid(images)
    # Convert to PIL image
    grid_image = torchvision.transforms.ToPILImage()(grid)
    os.makedirs(path, exist_ok=True)
    grid_image.save(f"{path}/sample_{index}.png")


def generate_images(
        checkpoint_path,
        image_height=768,
        image_width=512,
        output_dir="generated_images"
):
    config = {
        "image_height": image_height,
        "image_width": image_width,
        "sample_dir": output_dir
    }

    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")
    print(f"Using device: {device}")

    model = UNet2DModel(
        sample_size=[image_height, image_width],
        in_channels=3,
        out_channels=3,
        layers_per_block=3,
        block_out_channels=(32, 64, 128),
        down_block_types=(
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    ).to(device)

    noise_scheduler = DDPMScheduler(
        num_train_timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear"
    )

    # 加载checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    model.eval()

    with torch.no_grad():
        for i in range(100):
            # Generate sample images
            sample = torch.randn(1, 3, config["image_height"], config["image_width"]).to(device)
            timesteps = torch.linspace(999, 0, 50).long().to(device)
            for t in timesteps:
                residual = model(sample, t.repeat(1), return_dict=False)[0]
                sample = noise_scheduler.step(residual, t, sample).prev_sample
            save_images(sample, config["sample_dir"], i)

In [6]:
generate_images(
        checkpoint_path="model_epoch_4.pt",  
        image_height=768,
        image_width=512,
        output_dir="generated_images"
    )

Using device: cuda
