In [8]:
import torch
import torch.optim as optim
from models.DiffLoad.diffusion.layers import CondModel_v2
from matplotlib import pyplot as plt
from utils.helper import make_beta_schedule, EMA, ObjectView
from utils.plots import hdr_plot_style
hdr_plot_style()
from tqdm import tqdm 
from models.DiffLoad.ddpm import DDPM1d
from utils.config import config_dataset, config_ddpm, config_nn

In [9]:
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

In [10]:
configs = {
    'epoch': 80000,
    'batch_size': 3000,
    'learning_rate': 1e-4,
    'lr_decay': 0.99,
    'lr_decay_step': 200,
    'mode': None, # ["checkpoint", None]
    'use_MLP': True,
    'use_solar':True
}
configs.update(config_dataset)
configs.update(config_ddpm)
configs.update(config_nn)
args = ObjectView(configs)
args.save_name = "diff_phy" if args.use_solar else "diff_base"
print(args.save_name)

diff_phy


# Dataset

In [11]:
X_train = torch.load("../data/Pecan Street Smart Meter Data (large) (tensor)/X_train.pt")
X_test = torch.load("../data/Pecan Street Smart Meter Data (large) (tensor)/X_test.pt")
cond_train = torch.load("../data/Pecan Street Smart Meter Data (large) (tensor)/cond_train.pt")
cond_test = torch.load("../data/Pecan Street Smart Meter Data (large) (tensor)/cond_test.pt")
PV_base_train = torch.load("../data/Pecan Street Smart Meter Data (large) (tensor)/PV_base_train.pt")
PV_base_test = torch.load("../data/Pecan Street Smart Meter Data (large) (tensor)/PV_base_test.pt")

In [12]:
X_train.min()

tensor(-1.)

In [13]:
X_train.shape

torch.Size([8760, 96])

In [14]:
# Select betas
n_steps = args.n_steps

betas = make_beta_schedule(schedule='linear', n_timesteps=n_steps, start=args.beta_start, end=args.beta_end)
betas = betas.to(device)
model = CondModel_v2(args)
if args.mode == "checkpoint":
    print(f"load chechpoint")
    checkpoint = torch.load(f'../result/pecan/models/{args.save_name}_80000.pth')
    model = checkpoint['ddpm'].model
    args.learning_rate = args.learning_rate * args.lr_decay ** (args.epoch/args.lr_decay_step)

model = model.to(device)
X_train = X_train.to(device)
cond_train = cond_train.to(device)
PV_base_train = PV_base_train.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_step, gamma=args.lr_decay)
ddpm = DDPM1d(model, betas, n_steps, (args.input_dim,), loss_type='l2')
# Create EMA model
ema = EMA(args.ema_decay)
ema.register(model)

Loss = []
for j in tqdm(range(args.epoch)):
    # X is a torch Variable
    permutation = torch.randperm(X_train.size()[0])
    for i in range(0, X_train.size()[0], args.batch_size):
        # Retrieve current batch 
        indices = permutation[i:i+args.batch_size]
        batch_x = X_train[indices]
        batch_x = batch_x + 0.05 * torch.randn_like(batch_x)
        batch_cond = cond_train[indices]
        # Compute the loss.
        if args.use_solar == True:
            batch_PV_base = PV_base_train[indices]
            loss = ddpm(batch_x, batch_cond, batch_PV_base)
        else:
            loss = ddpm(batch_x, batch_cond)
        # Before the backward pass, zero all of the network gradients
        optimizer.zero_grad()
        # Backward pass: compute gradient of the loss with respect to parameters
        loss.backward()
        # Perform gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
        # Calling the step function to update the parameters
        optimizer.step()
        scheduler.step()
        # Update the exponential moving average
        ema.update(model)
        ddpm.model = model
    if (j+1) % 100 == 0:
        Loss.append(loss.item())
    if (j+1) % 10000 == 0:
        print("loss: ", loss.item())
    if (j+1) % 10000 == 0 or (j+1) == args.epoch:
        checkpoint = {
            'config': configs,
            'ddpm': ddpm,
            'Loss': Loss
        }
        torch.save(checkpoint, f"../result/models/pecan/{args.save_name}_{j+1}.pth")

  0%|          | 0/80000 [00:21<?, ?it/s]


KeyboardInterrupt: 

In [None]:
plt.plot(Loss)

In [None]:
cond_test = cond_test.to(device)
PV_base_test = PV_base_test.to(device)
X_test_hat = ddpm.sample_seq(batch_size=len(cond_test), cond=cond_test, PV_base=PV_base_test)[-1]
X_test_hat = X_test_hat.to("cpu")
X_test_hat = X_test_hat.reshape(args.num_class, -1, 96)
X_test = X_test.reshape(args.num_class, -1, 96)
for j in range(config_dataset["num_class"])[:10]:
    plt.figure(figsize=(36,6), dpi=300)
    plt.subplot(1,4,1)
    for i in range(len(X_test[j])):
        plt.plot(X_test[j][i])
    plt.title("actual data")
    plt.subplot(1,4,2)
    for i in range(len(X_test_hat[j])):
        plt.plot(X_test_hat[j][i])
    plt.title("generated load profile (MLP0_Solar0)")

    plt.subplot(1,4,3)
    plt.plot(X_test_hat[j].mean(dim=0), label = "mean of generated data")
    plt.plot(X_test[j].mean(dim=0), label = "mean of actual data")
    plt.legend(fontsize=10)
    plt.subplot(1,4,4)
    plt.plot(X_test_hat[j].var(dim=0), label = "var of generated data")
    plt.plot(X_test[j].var(dim=0), label = "var of actual data")
    plt.legend(fontsize=10)
    plt.tight_layout()
