In [1]:
import os
import sys
import random
from pathlib import Path

from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import mean_absolute_error
from sklearn.preprocessing import MinMaxScaler, MaxAbsScaler
import statsmodels.api as sm
from statsmodels.tsa.vector_ar.var_model import VAR

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import xgboost as xgb

SEED = 2103
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def eval_fn(y_pred, y_target):
    return mean_absolute_error(y_pred, y_target)

# Load Train/Test Data

In [2]:
train_df = pd.read_pickle('./data_postprocessing/train_df.pkl')
test_df = pd.read_pickle('./data_postprocessing/test_df.pkl')

In [3]:
train_df.head(2)

Unnamed: 0,window_id,ask_price_0,ask_size_0,bid_price_0,bid_size_0,far_price_0,imbalance_buy_sell_flag_0,imbalance_size_0,matched_size_0,near_price_0,...,bid_size_198,far_price_198,imbalance_buy_sell_flag_198,imbalance_size_198,matched_size_198,near_price_198,reference_price_198,target_198,wap_198,last_5_mins
0,0,1.000026,8493.03,0.999812,60651.5,0.0,1,3180602.69,13380276.64,0.0,...,54300.05,0.0,1,15249373.9,26134518.94,0.0,0.999775,-0.510216,1.0,0
1,0,1.000026,23519.16,0.999812,13996.5,0.0,1,1299772.7,15261106.63,0.0,...,153691.34,0.0,1,13496480.93,27604966.3,0.0,1.000288,-1.419783,1.000222,0


In [4]:
test_df.head(2)

Unnamed: 0,window_id,ask_price_0,ask_size_0,bid_price_0,bid_size_0,far_price_0,imbalance_buy_sell_flag_0,imbalance_size_0,matched_size_0,near_price_0,...,bid_size_198,far_price_198,imbalance_buy_sell_flag_198,imbalance_size_198,matched_size_198,near_price_198,reference_price_198,target_198,wap_198,last_5_mins
0,433,1.000066,2765.73,0.999697,12685.14,0.0,-1,5128680.68,10543243.05,0.0,...,5753.28,0.0,-1,26806621.39,83364970.7,0.0,1.000486,5.459785,1.0,0
1,433,0.999882,78944.32,0.999697,2331.03,0.0,-1,3901807.9,11733675.67,0.0,...,77927.44,0.0,-1,26782684.93,83388907.16,0.0,0.999944,3.089905,1.000051,0


In [5]:
feature_cols = [col for col in train_df if not col.startswith("target") and not col.startswith("window_id")]
scaler = MinMaxScaler(feature_range=(-1, 1))
scaler.fit(train_df[feature_cols].values)

In [6]:
class StockDataset(Dataset):
    """
    To preserve original window structure, do not shuffle, and set batch size to be 55.
    """
    def __init__ (self, df):
        self.target_cols = [col for col in df if col.startswith("target")] # we have to change this when we add lags
        self.feature_cols = [col for col in df if not col.startswith("target") and not col.startswith("window_id")]
        self.labels = df[self.target_cols] # labels for all stocks at all times
        self.features = df[self.feature_cols]
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        return self.features.iloc[idx].values, self.labels.iloc[idx].values

In [7]:
train_dataset = torch.load('./data_postprocessing/train.pt')
test_dataset = torch.load('./data_postprocessing/test.pt')
train_dataloader = DataLoader(train_dataset, batch_size=55, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=55, shuffle=False)

In [8]:
x, y = train_dataset[0]
print(x.shape, y.shape)

(2036,) (185,)


# XGBoost

In [None]:
temp_train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
X, y = next(iter(temp_train_loader))
X, y = X.numpy(), y.numpy()

In [None]:
print(X.shape, y.shape)

(23815, 2036) (23815, 185)


In [None]:
temp_test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))
test_X, test_y = next(iter(temp_test_loader))
test_X, test_y = test_X.numpy(), test_y.numpy()

In [None]:
print(test_X.shape, test_y.shape)

(2640, 2036) (2640, 185)


In [None]:
xgb_model = xgb.XGBRegressor(device="cuda")
xgb_model.fit(X, y)

In [None]:
train_y_hat = xgb_model.predict(X)
print(f'Train MAE: {eval_fn(train_y_hat, y)})')

Potential solutions:
- Use a data structure that matches the device ordinal in the booster.
- Set the device for booster before call to inplace_predict.




Train MAE: 2.33017348998886)


In [None]:
test_y_hat = xgb_model.predict(test_X)
print(f'Test MAE: {eval_fn(test_y_hat, test_y)})')

Test MAE: 6.913688415179943)


# VARMAX

In this part, we do a rolling prediction using VARMAX as follows (the `|` operator is a concatenation along the 0th axis, and each `train/test[i]` refers to a window of 55 samples): \
`(WAP | Target)_train1 | .. | (WAP | Target)_trainN | WAP_test1 -> Pred_Target_test1` \
`(WAP | Target)_train1 | .. | (WAP | Target)_trainN | (WAP | Pred_Target)_test1 | WAP_test2 -> Pred_Target_test2` \
and so on... \
\
We then just use the last 100 pairs of (WAP | Target) from this big array, plus the current WAP_test appended at the end, to predict the current Target_test \
\
This may perform poorly, but serves as a good baseline for other models to beat

In [None]:
wap_cols = [col for col in train_df if col.startswith("wap")]
target_cols = [col for col in train_df if col.startswith("target")]
exo_cols = [col for col in train_df if not col.startswith("wap") and not col.startswith("target")]

num_train_samples = len(train_df)
wap_indices = [i * 55 + j for i in range(0, (num_train_samples * 2) // 55, 2) for j in range(55)]
target_indices = [55 + i * 55 + j for i in range(0, (num_train_samples * 2) // 55, 2) for j in range(55)]
concat_train_arr = np.empty((num_train_samples * 2, len(target_cols)))
concat_train_arr[wap_indices, :] = train_df[wap_cols].values
concat_train_arr[target_indices, :] = train_df[target_cols].values

# exogenous variables are just duplicated
exo_train_arr = np.empty((num_train_samples * 2, len(exo_cols)))
exo_train_arr[wap_indices, :] = train_df[exo_cols].values
exo_train_arr[target_indices, :] = train_df[exo_cols].values

In [None]:
num_test_samples = len(test_df)

curr_test_wap_and_target_arr = test_df[wap_cols][:55]
curr_test_exo_arr = test_df[exo_cols][:55]

total_MAE = 0.0

for i in tqdm(range(num_test_samples // 55)):
    # train inputs
    curr_concat_train_arr = np.concatenate([concat_train_arr, curr_test_wap_and_target_arr])
    curr_exo_train_arr = np.concatenate([exo_train_arr, curr_test_exo_arr])

    # scale inputs
    curr_scaler = MinMaxScaler()
    curr_scaler.fit(curr_concat_train_arr)
    curr_concat_train_arr = curr_scaler.transform(curr_concat_train_arr)

    # Fit the VAR model
    model = VAR(endog=curr_concat_train_arr[-55*100-55*101:])
    model_fit = model.fit()

    # Make prediction on validation
    curr_pred = model_fit.forecast(model.endog, steps=55)

    # evaluate
    curr_true_target = test_df[target_cols][i*55:(i+1)*55].values
    curr_MAE = mean_absolute_error(curr_pred, curr_true_target)
    total_MAE += curr_MAE

    # concat new inputs
    if i != num_test_samples // 55 - 1:
        curr_test_wap_and_target_arr = np.concatenate([
            curr_test_wap_and_target_arr,
            curr_pred,
            test_df[wap_cols][(i+1)*55:(i+2)*55]
        ])
        curr_test_exo_arr = np.concatenate([curr_test_exo_arr, test_df[exo_cols][(i+1)*55:(i+2)*55], test_df[exo_cols][(i+1)*55:(i+2)*55]])

    print(curr_MAE)

print(f'Test MAE:', total_MAE / (num_test_samples // 55))

  2%|▏         | 1/48 [00:00<00:12,  3.67it/s]

5.626262182701441


  4%|▍         | 2/48 [00:00<00:12,  3.75it/s]

6.199536437138377


  6%|▋         | 3/48 [00:00<00:12,  3.71it/s]

6.017054281416677


  8%|▊         | 4/48 [00:01<00:11,  3.73it/s]

6.1177923941989025


 10%|█         | 5/48 [00:01<00:11,  3.72it/s]

6.028102483441858


 12%|█▎        | 6/48 [00:01<00:11,  3.73it/s]

6.229354085890746


 15%|█▍        | 7/48 [00:01<00:11,  3.67it/s]

7.416029570724428


 17%|█▋        | 8/48 [00:02<00:11,  3.62it/s]

6.213474132797025


 19%|█▉        | 9/48 [00:02<00:10,  3.65it/s]

5.919268747283808


 21%|██        | 10/48 [00:02<00:10,  3.68it/s]

5.649721362770189


 23%|██▎       | 11/48 [00:02<00:10,  3.70it/s]

5.883322399699134


 25%|██▌       | 12/48 [00:03<00:09,  3.72it/s]

6.248745724406667


 27%|██▋       | 13/48 [00:03<00:09,  3.75it/s]

5.685401908647022


 29%|██▉       | 14/48 [00:03<00:08,  3.78it/s]

5.318636627176094


 31%|███▏      | 15/48 [00:04<00:08,  3.75it/s]

5.372311657808444


 33%|███▎      | 16/48 [00:04<00:08,  3.75it/s]

5.258559826693388


 35%|███▌      | 17/48 [00:04<00:08,  3.75it/s]

5.590861889870388


 38%|███▊      | 18/48 [00:04<00:08,  3.66it/s]

5.754411445697447


 40%|███▉      | 19/48 [00:05<00:07,  3.67it/s]

5.262738301294624


 42%|████▏     | 20/48 [00:05<00:07,  3.68it/s]

5.2854818312367335


 44%|████▍     | 21/48 [00:05<00:07,  3.71it/s]

5.493222665531559


 46%|████▌     | 22/48 [00:05<00:07,  3.62it/s]

5.934616690198129


 48%|████▊     | 23/48 [00:06<00:07,  3.39it/s]

6.4125353092475965


 50%|█████     | 24/48 [00:06<00:06,  3.45it/s]

5.757635257674721


 52%|█████▏    | 25/48 [00:06<00:06,  3.50it/s]

5.923644343597231


 54%|█████▍    | 26/48 [00:07<00:06,  3.54it/s]

9.359717009623


 56%|█████▋    | 27/48 [00:07<00:05,  3.57it/s]

6.269975450432604


 58%|█████▊    | 28/48 [00:07<00:05,  3.59it/s]

5.672876660039966


 60%|██████    | 29/48 [00:07<00:05,  3.57it/s]

5.828086856501665


 62%|██████▎   | 30/48 [00:08<00:05,  3.57it/s]

5.829092129637158


 65%|██████▍   | 31/48 [00:08<00:04,  3.58it/s]

6.025820050863868


 67%|██████▋   | 32/48 [00:08<00:04,  3.55it/s]

5.581841985011598


 69%|██████▉   | 33/48 [00:09<00:04,  3.55it/s]

5.129937306706472


 71%|███████   | 34/48 [00:09<00:03,  3.54it/s]

5.004912432849947


 73%|███████▎  | 35/48 [00:09<00:03,  3.54it/s]

5.56809162807789


 75%|███████▌  | 36/48 [00:09<00:03,  3.53it/s]

6.35659884898354


 77%|███████▋  | 37/48 [00:10<00:03,  3.52it/s]

5.683609448930777


 79%|███████▉  | 38/48 [00:10<00:02,  3.51it/s]

8.172022366636876


 81%|████████▏ | 39/48 [00:10<00:02,  3.49it/s]

5.1809255557209255


 83%|████████▎ | 40/48 [00:11<00:02,  3.49it/s]

5.567672432617637


 85%|████████▌ | 41/48 [00:11<00:02,  3.48it/s]

4.855197323266491


 88%|████████▊ | 42/48 [00:11<00:01,  3.47it/s]

8.135826620625844


 90%|████████▉ | 43/48 [00:11<00:01,  3.44it/s]

5.703407243200001


 92%|█████████▏| 44/48 [00:12<00:01,  3.42it/s]

5.268876048371801


 94%|█████████▍| 45/48 [00:12<00:00,  3.42it/s]

5.203573569168554


 96%|█████████▌| 46/48 [00:12<00:00,  3.41it/s]

5.117107175166288


 98%|█████████▊| 47/48 [00:13<00:00,  3.39it/s]

5.592916167068363


100%|██████████| 48/48 [00:13<00:00,  3.58it/s]

4.798367229914929
Test MAE: 5.8855244395116415





# LSTM
In this part we train a simple bi-directional LSTM

In [None]:
class LSTM(nn.Module):
    def __init__(self, hidden=256, out_dim=185):
        super().__init__()
        self.hidden = hidden
        self.out_dim = out_dim
        self.lstm = nn.LSTM(2036, hidden_size=self.hidden, num_layers=2, bidirectional=True, batch_first=True)
        self.linear = nn.Linear(self.hidden * 2, self.out_dim)
    
    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.linear(x)
        return x

In [None]:
lstm_model = LSTM().to(device)
lstm_optimizer = torch.optim.Adam(lstm_model.parameters())
loss_fn = nn.L1Loss()

for epoch in range(30):
    print(f"--- Epoch [{epoch + 1}/{30}] ---")
    total_epoch_loss = 0.
    for step, batch in enumerate(train_dataloader):
        features, targets = batch
        features = torch.from_numpy(scaler.transform(features))
        features, targets = features.float(), targets.float()
        features, targets = features.to(device), targets.to(device)
        targets = targets.unsqueeze(0)
        features = features.unsqueeze(0)

        preds = lstm_model(features)
        loss = loss_fn(targets, preds)
        total_epoch_loss += loss.detach().cpu()

        if step % 100 == 0:
            print("Loss:", loss.item())

        lstm_optimizer.zero_grad()
        loss.backward()
        lstm_optimizer.step()
    print(f"Epoch Loss: {total_epoch_loss / len(train_dataloader)}")

--- Epoch [1/30] ---
Loss: 5.549416542053223
Loss: 5.237961292266846
Loss: 8.095149040222168
Loss: 6.798412799835205
Loss: 5.570059776306152
Epoch Loss: 6.318953037261963
--- Epoch [2/30] ---
Loss: 5.544959545135498
Loss: 5.238502025604248
Loss: 8.09648609161377
Loss: 6.800079345703125
Loss: 5.575272560119629
Epoch Loss: 6.317306041717529
--- Epoch [3/30] ---
Loss: 5.541306495666504
Loss: 5.239346504211426
Loss: 8.089181900024414
Loss: 6.8014750480651855
Loss: 5.572729587554932
Epoch Loss: 6.3166728019714355
--- Epoch [4/30] ---
Loss: 5.542489528656006
Loss: 5.236428737640381
Loss: 8.09494400024414
Loss: 6.800201416015625
Loss: 5.568076133728027
Epoch Loss: 6.315825462341309
--- Epoch [5/30] ---
Loss: 5.5440497398376465
Loss: 5.2335333824157715
Loss: 8.132625579833984
Loss: 6.76423978805542
Loss: 5.565145492553711
Epoch Loss: 6.310324668884277
--- Epoch [6/30] ---
Loss: 5.535722732543945
Loss: 5.228562355041504
Loss: 8.079418182373047
Loss: 6.7209882736206055
Loss: 5.56702184677124
Epo

In [None]:
lstm_model.eval()
total_loss = 0.
total_elems = 0
for step, batch in enumerate(test_dataloader):
    features, targets = batch
    features = torch.from_numpy(scaler.transform(features))
    features, targets = features.float(), targets.float()
    features, targets = features.to(device), targets.to(device)
    targets = targets.unsqueeze(0)
    features = features.unsqueeze(0)
    with torch.inference_mode():
        preds = lstm_model(features)
        total_loss += torch.sum(torch.abs(preds - targets)).item()
        total_elems += torch.numel(targets)
print(f"MAE: {total_loss / total_elems}")

MAE: 5.760513931011978


# Conditional DDPM
In this part we treat the inputs and outputs as 1-channel images and apply conditional DDPM

In [8]:
from diffusion_model import Unet, p_losses, sample
U = Unet(dim=64, channels=1, dim_mults=(1, 2, 4), self_condition=True)
if torch.cuda.is_available():
    U.cuda()
    print('Models moved to GPU.')
u_optimizer = torch.optim.Adam(U.parameters(), 0.0002, [0.5, 0.999])

  from .autonotebook import tqdm as notebook_tqdm


Models moved to GPU.


In [None]:
num_epc = 5
for epoch in range(num_epc):
    print(f"--- Epoch [{epoch+1}/{num_epc}] ---")

    total_epoch_loss = 0.

    for step, batch in enumerate(train_dataloader):
        cond, targets = batch
        cond = torch.from_numpy(scaler.transform(cond))
        cond, targets = cond.float(), targets.float()
        cond, targets = cond.to(device), targets.to(device)
        targets = targets.unsqueeze(0).unsqueeze(0)
        cond = cond.unsqueeze(0)
        targets = torch.nn.functional.pad(targets, (0, 71, 0, 9), "constant", 0)

        # 1. Sample t uniformally for every example in the batch
        t = torch.randint(low=0, high=500, size=(1,), device=device).long()

        # 2. Get l1 loss
        loss = p_losses(U, targets, t, loss_type='l1', time_cond=cond)

        if step % 100 == 0:
            print("Loss:", loss.item())
    
        total_epoch_loss += loss.detach().item()

        u_optimizer.zero_grad()
        loss.backward()
        u_optimizer.step()
    
    print("Epoch Loss:", total_epoch_loss / len(train_dataloader))

    # Save every epoch
    print("Saving...")
    torch.save(U.state_dict(), "diffusion.pth")

--- Epoch [1/5] ---
Loss: 0.8664166331291199
Loss: 0.342109739780426
Loss: 0.5751746892929077
Loss: 0.36596542596817017
Loss: 0.6054816246032715
Epoch Loss: 0.5488351859479248
Saving...
--- Epoch [2/5] ---
Loss: 0.48535293340682983
Loss: 0.5322597026824951
Loss: 0.36875081062316895
Loss: 0.5387202501296997
Loss: 0.27861088514328003
Epoch Loss: 0.4736612992528955
Saving...
--- Epoch [3/5] ---
Loss: 0.2649272382259369
Loss: 0.243866428732872
Loss: 0.5191173553466797
Loss: 0.293118953704834
Loss: 0.5198401212692261
Epoch Loss: 0.4562738733946864
Saving...
--- Epoch [4/5] ---
Loss: 0.7090773582458496
Loss: 0.5887994766235352
Loss: 0.47108373045921326
Loss: 0.5159890055656433
Loss: 0.32491758465766907
Epoch Loss: 0.44573564989186876
Saving...
--- Epoch [5/5] ---
Loss: 0.4543801248073578
Loss: 0.4828573763370514
Loss: 0.42861083149909973
Loss: 0.48692965507507324
Loss: 0.3912566900253296
Epoch Loss: 0.44426095021101397
Saving...


In [None]:
# sampling loop
U = Unet(dim=64, channels=1, dim_mults=(1, 2, 4), self_condition=True)
if torch.cuda.is_available():
    U.cuda()
U.load_state_dict(torch.load('diffusion.pth'))

U.eval()
total_loss = 0.
total_elems = 0
for step, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
    cond, targets = batch
    cond = torch.from_numpy(scaler.transform(cond))
    cond, targets = cond.float(), targets.float()
    cond, targets = cond.to(device), targets.to(device)
    targets = targets.unsqueeze(0).unsqueeze(0)
    cond = cond.unsqueeze(0)
    targets = torch.nn.functional.pad(targets, (0, 71, 0, 9), "constant", 0)
    # torch inference_mode already annotated for sample() function
    samples = sample(U, (64, 256), batch_size=1, channels=1, time_cond=cond)
    total_loss += torch.sum(torch.abs(torch.from_numpy(samples[-1]) - targets.detach().cpu())).item()
    total_elems += torch.numel(targets)
    print(total_loss / total_elems)
print(f"MAE: {total_loss / total_elems}")

  0%|          | 0/48 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

  2%|▏         | 1/48 [01:11<56:12, 71.76s/it]

4.201510429382324


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

  4%|▍         | 2/48 [02:24<55:18, 72.13s/it]

4.367911338806152


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

  6%|▋         | 3/48 [03:36<54:11, 72.25s/it]

4.455802122751872


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

  8%|▊         | 4/48 [04:48<52:54, 72.14s/it]

4.494023323059082


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 10%|█         | 5/48 [06:00<51:40, 72.10s/it]

4.475032806396484


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 12%|█▎        | 6/48 [07:12<50:25, 72.03s/it]

4.495413939158122


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 15%|█▍        | 7/48 [08:24<49:11, 72.00s/it]

4.604225022452218


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 17%|█▋        | 8/48 [09:36<47:58, 71.97s/it]

4.6070767641067505


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 19%|█▉        | 9/48 [10:48<46:45, 71.95s/it]

4.579028765360515


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 21%|██        | 10/48 [12:00<45:33, 71.95s/it]

4.528861236572266


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 23%|██▎       | 11/48 [13:12<44:23, 71.99s/it]

4.523338491266424


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 25%|██▌       | 12/48 [14:24<43:10, 71.97s/it]

4.542020678520203


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 27%|██▋       | 13/48 [15:36<41:58, 71.95s/it]

4.533690856053279


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 29%|██▉       | 14/48 [16:47<40:45, 71.94s/it]

4.5020498888833185


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 31%|███▏      | 15/48 [17:59<39:33, 71.93s/it]

4.477288659413656


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 33%|███▎      | 16/48 [19:11<38:22, 71.94s/it]

4.447746425867081


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 35%|███▌      | 17/48 [20:23<37:10, 71.94s/it]

4.434576567481546


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 38%|███▊      | 18/48 [21:35<35:58, 71.94s/it]

4.418250693215264


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 40%|███▉      | 19/48 [22:47<34:47, 71.97s/it]

4.397729823463841


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 42%|████▏     | 20/48 [23:59<33:34, 71.95s/it]

4.382056307792664


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 44%|████▍     | 21/48 [25:11<32:22, 71.95s/it]

4.371080920809791


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 46%|████▌     | 22/48 [26:23<31:10, 71.94s/it]

4.367576534097845


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 48%|████▊     | 23/48 [27:35<29:58, 71.93s/it]

4.389462533204452


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 50%|█████     | 24/48 [28:47<28:46, 71.93s/it]

4.385650336742401


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 52%|█████▏    | 25/48 [29:59<27:34, 71.95s/it]

4.382399520874023


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 54%|█████▍    | 26/48 [31:11<26:22, 71.93s/it]

4.449407980992244


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 56%|█████▋    | 27/48 [32:23<25:10, 71.93s/it]

4.458234769326669


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 58%|█████▊    | 28/48 [33:35<23:58, 71.93s/it]

4.451662080628531


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 60%|██████    | 29/48 [34:47<22:46, 71.93s/it]

4.447116868249301


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 62%|██████▎   | 30/48 [35:58<21:34, 71.93s/it]

4.448305098215739


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 65%|██████▍   | 31/48 [37:10<20:22, 71.93s/it]

4.445637072286298


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 67%|██████▋   | 32/48 [38:22<19:10, 71.93s/it]

4.434021577239037


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 69%|██████▉   | 33/48 [39:34<17:58, 71.93s/it]

4.41458933281176


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 71%|███████   | 34/48 [40:46<16:47, 71.93s/it]

4.402802467346191


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 73%|███████▎  | 35/48 [41:58<15:35, 71.93s/it]

4.392192963191441


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 75%|███████▌  | 36/48 [43:10<14:23, 71.94s/it]

4.40048282676273


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 77%|███████▋  | 37/48 [44:22<13:11, 71.95s/it]

4.390925587834539


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 79%|███████▉  | 38/48 [45:34<11:59, 71.95s/it]

4.423069376694529


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 81%|████████▏ | 39/48 [46:46<10:47, 71.96s/it]

4.408555727738601


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 83%|████████▎ | 40/48 [47:58<09:35, 71.94s/it]

4.403669440746308


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 85%|████████▌ | 41/48 [49:10<08:23, 71.95s/it]

4.387494308192555


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 88%|████████▊ | 42/48 [50:22<07:11, 71.94s/it]

4.414678380602882


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 90%|████████▉ | 43/48 [51:34<05:59, 71.94s/it]

4.416277574938397


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 92%|█████████▏| 44/48 [52:46<04:47, 71.94s/it]

4.409924279559743


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 94%|█████████▍| 45/48 [53:58<03:35, 71.94s/it]

4.399718639585707


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 96%|█████████▌| 46/48 [55:09<02:23, 71.93s/it]

4.390979839407879


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

 98%|█████████▊| 47/48 [56:21<01:11, 71.94s/it]

4.384790430677698


sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

100%|██████████| 48/48 [57:33<00:00, 71.96s/it]

4.376183301210403
MAE: 4.376183301210403





# TimeGrad

In this part we implement TimeGrad based on the original paper, and apply to our time series, trying to transform the WAP from our data into the target

In [8]:
CONTEXT_LENGTH = 55
HIDDEN_SIZE = 256
NUM_LAYERS = 2
NUM_CELLS = 256
RESIDUAL_LAYERS = 8
RESIDUAL_CHANNELS = 256
DILATION_CYCLE_LENGTH = 2
RESIDUAL_HIDDEN = 256
CONDITIONING_LENGTH = HIDDEN_SIZE

In [9]:
x, y = train_dataset[0]
num_features, num_target_features = x.shape[0], y.shape[0]
feature_cols = [col for col in train_df if not col.startswith("target") and not col.startswith("window_id")]
wap_idx = [i for i, col in enumerate(feature_cols) if col.startswith("wap")]
assert len(wap_idx) == num_target_features

time_grad_model = TimeGrad(
    num_features=num_features,
    num_target_features=num_target_features,
    context_length=CONTEXT_LENGTH,
    non_covariate_col_idx=wap_idx,
    hidden_size=HIDDEN_SIZE,
    num_layers=NUM_LAYERS,
    num_cells=NUM_CELLS,
    residual_layers=RESIDUAL_LAYERS,
    residual_channels=RESIDUAL_CHANNELS,
    dilation_cycle_length=DILATION_CYCLE_LENGTH,
    residual_hidden=RESIDUAL_HIDDEN,
    conditioning_length=CONDITIONING_LENGTH
).to(device)

In [10]:
time_grad_optimizer = torch.optim.Adam(time_grad_model.parameters())

for epoch in range(30):
    print(f"--- Epoch [{epoch + 1}/{30}] ---")
    total_epoch_loss = 0.
    for step, batch in enumerate(train_dataloader):
        features, targets = batch
        features, targets = features.float(), targets.float()
        features, targets = features.to(device), targets.to(device)
        features, targets = features.unsqueeze(0), targets.unsqueeze(0)

        loss, _, _ = time_grad_model(features, targets)
        total_epoch_loss += loss.detach().cpu()

        if step % 100 == 0:
            print("Loss:", loss.item())

        time_grad_optimizer.zero_grad()
        loss.backward()
        time_grad_optimizer.step()
    print(f"Epoch Loss: {total_epoch_loss / len(train_dataloader)}")
    
    # Save every epoch
    print("Saving...")
    torch.save(time_grad_model.state_dict(), "timegrad.pth")

--- Epoch [1/30] ---
Loss: 0.9918442964553833
Loss: 0.764417827129364
Loss: 0.9035812616348267
Loss: 0.8236040472984314
Loss: 0.8140358328819275
Epoch Loss: 0.8518611788749695
Saving...
--- Epoch [2/30] ---
Loss: 0.8248192667961121
Loss: 0.8697466254234314
Loss: 0.8730795383453369
Loss: 0.90981525182724
Loss: 0.859169602394104
Epoch Loss: 0.8389438986778259
Saving...
--- Epoch [3/30] ---
Loss: 0.8506718277931213
Loss: 0.8634642958641052
Loss: 0.9100229740142822
Loss: 0.9128495454788208
Loss: 0.8314695954322815
Epoch Loss: 0.8396666646003723
Saving...
--- Epoch [4/30] ---
Loss: 0.8246649503707886
Loss: 0.8713688254356384
Loss: 0.9252825379371643
Loss: 0.9046648740768433
Loss: 0.7982223033905029
Epoch Loss: 0.8388842344284058
Saving...
--- Epoch [5/30] ---
Loss: 0.78267902135849
Loss: 0.8183034062385559
Loss: 0.8853309154510498
Loss: 0.8482010364532471
Loss: 0.7952440977096558
Epoch Loss: 0.8380935788154602
Saving...
--- Epoch [6/30] ---
Loss: 0.8352630138397217
Loss: 0.804884672164917
L

In [12]:
x, y = train_dataset[0]
num_features, num_target_features = x.shape[0], y.shape[0]
feature_cols = [col for col in train_df if not col.startswith("target") and not col.startswith("window_id")]
wap_idx = [i for i, col in enumerate(feature_cols) if col.startswith("wap")]
assert len(wap_idx) == num_target_features

time_grad_model = TimeGradPredictionNetwork(
    num_parallel_samples=1,
    prediction_length=CONTEXT_LENGTH,
    num_features=num_features,
    num_target_features=num_target_features,
    context_length=CONTEXT_LENGTH,
    non_covariate_col_idx=wap_idx,
    hidden_size=HIDDEN_SIZE,
    num_layers=NUM_LAYERS,
    num_cells=NUM_CELLS,
    residual_layers=RESIDUAL_LAYERS,
    residual_channels=RESIDUAL_CHANNELS,
    dilation_cycle_length=DILATION_CYCLE_LENGTH,
    residual_hidden=RESIDUAL_HIDDEN,
    conditioning_length=CONDITIONING_LENGTH
).to(device)

In [13]:
# sampling loop
time_grad_model.load_state_dict(torch.load('timegrad.pth'))

total_loss = 0.
total_elems = 0
for step, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
    features, targets = batch
    features, targets = features.float(), targets.float()
    features, targets = features.to(device), targets.to(device)
    features, targets = features.unsqueeze(0), targets.unsqueeze(0)

    with torch.inference_mode():
        samples = time_grad_model(features)
    total_loss += torch.sum(torch.abs(samples[0] - targets)).item()
    total_elems += torch.numel(targets)
    print(total_loss / total_elems)
print(f"MAE: {total_loss / total_elems}")

  2%|▏         | 1/48 [00:00<00:44,  1.05it/s]

7.694186885749386


  4%|▍         | 2/48 [00:01<00:43,  1.07it/s]

7.911542920761671


  6%|▋         | 3/48 [00:02<00:41,  1.07it/s]

7.9253173628173625


  8%|▊         | 4/48 [00:03<00:40,  1.08it/s]

7.85518082002457


 10%|█         | 5/48 [00:04<00:39,  1.08it/s]

7.8722515356265355


 12%|█▎        | 6/48 [00:05<00:38,  1.08it/s]

7.8828391175266175


 15%|█▍        | 7/48 [00:06<00:38,  1.08it/s]

8.412434187434187


 17%|█▋        | 8/48 [00:07<00:37,  1.08it/s]

8.43403485872236


 19%|█▉        | 9/48 [00:08<00:36,  1.08it/s]

8.379474303849303


 21%|██        | 10/48 [00:09<00:35,  1.08it/s]

8.302979806511056


 23%|██▎       | 11/48 [00:10<00:34,  1.08it/s]

8.233068530824212


 25%|██▌       | 12/48 [00:11<00:33,  1.08it/s]

8.239568680896806


 27%|██▋       | 13/48 [00:12<00:32,  1.08it/s]

8.208337448024949


 29%|██▉       | 14/48 [00:12<00:31,  1.08it/s]

8.142971163127413


 31%|███▏      | 15/48 [00:13<00:30,  1.08it/s]

8.085344748157247


 33%|███▎      | 16/48 [00:14<00:29,  1.08it/s]

8.026081033092751


 35%|███▌      | 17/48 [00:15<00:28,  1.07it/s]

8.003442106879607


 38%|███▊      | 18/48 [00:16<00:27,  1.07it/s]

8.016361546205296


 40%|███▉      | 19/48 [00:17<00:27,  1.07it/s]

7.991005633324712


 42%|████▏     | 20/48 [00:18<00:26,  1.07it/s]

7.974786317567568


 44%|████▍     | 21/48 [00:19<00:25,  1.07it/s]

7.942720033345033


 46%|████▌     | 22/48 [00:20<00:24,  1.07it/s]

7.92932812849006


 48%|████▊     | 23/48 [00:21<00:23,  1.07it/s]

7.954429381209272


 50%|█████     | 24/48 [00:22<00:22,  1.07it/s]

7.944922322891073


 52%|█████▏    | 25/48 [00:23<00:21,  1.07it/s]

7.940217997542997


 54%|█████▍    | 26/48 [00:24<00:20,  1.06it/s]

8.03528420903421


 56%|█████▋    | 27/48 [00:25<00:19,  1.06it/s]

8.0276397988898


 58%|█████▊    | 28/48 [00:26<00:18,  1.06it/s]

8.012578947657072


 60%|██████    | 29/48 [00:27<00:17,  1.06it/s]

7.98606980216894


 62%|██████▎   | 30/48 [00:28<00:16,  1.06it/s]

7.984210892710893


 65%|██████▍   | 31/48 [00:28<00:16,  1.06it/s]

7.998219019973052


 67%|██████▋   | 32/48 [00:29<00:15,  1.06it/s]

7.994692346821253


 69%|██████▉   | 33/48 [00:30<00:14,  1.06it/s]

7.987559121621621


 71%|███████   | 34/48 [00:31<00:13,  1.06it/s]

7.973795594558462


 73%|███████▎  | 35/48 [00:32<00:12,  1.06it/s]

7.965572108634609


 75%|███████▌  | 36/48 [00:33<00:11,  1.06it/s]

7.965196581524706


 77%|███████▋  | 37/48 [00:34<00:10,  1.06it/s]

7.951715730626203


 79%|███████▉  | 38/48 [00:35<00:09,  1.06it/s]

7.994062702056123


 81%|████████▏ | 39/48 [00:36<00:08,  1.06it/s]

7.980420486045486


 83%|████████▎ | 40/48 [00:37<00:07,  1.06it/s]

7.96212046990172


 85%|████████▌ | 41/48 [00:38<00:06,  1.06it/s]

7.944544480733505


 88%|████████▊ | 42/48 [00:39<00:05,  1.06it/s]

8.026688677313677


 90%|████████▉ | 43/48 [00:40<00:04,  1.06it/s]

8.032369864579167


 92%|█████████▏| 44/48 [00:41<00:03,  1.06it/s]

8.027864624050704


 94%|█████████▍| 45/48 [00:42<00:02,  1.06it/s]

8.019674003549003


 96%|█████████▌| 46/48 [00:43<00:01,  1.06it/s]

8.00874849775665


 98%|█████████▊| 47/48 [00:44<00:00,  1.06it/s]

8.007391068534686


100%|██████████| 48/48 [00:45<00:00,  1.07it/s]

7.995133343570844
MAE: 7.995133343570844





# TimeGrad-Adapted Conditional Diffusion

In this part we do our own version of TimeGrad, where we run the input time series through an LSTM and also condition on its hidden states to do our diffusion.

In [8]:
from cond_diffusion_model import Unet, p_losses, sample
U = Unet(dim=128, channels=1, dim_mults=(1, 2, 4, 8), self_condition=True)
if torch.cuda.is_available():
    U.cuda()
    print('Models moved to GPU.')
u_optimizer = torch.optim.Adam(U.parameters(), 0.0001, [0.5, 0.999])

  from .autonotebook import tqdm as notebook_tqdm


Models moved to GPU.


In [9]:
num_epc = 5
for epoch in range(num_epc):
    print(f"--- Epoch [{epoch+1}/{num_epc}] ---")

    total_epoch_loss = 0.

    for step, batch in enumerate(train_dataloader):
        cond, targets = batch
        cond = torch.from_numpy(scaler.transform(cond))
        cond, targets = cond.float(), targets.float()
        cond, targets = cond.to(device), targets.to(device)
        targets = targets.unsqueeze(0).unsqueeze(0)
        cond = cond.unsqueeze(0)
        targets = torch.nn.functional.pad(targets, (0, 71, 0, 9), "constant", 0)

        # 1. Sample t uniformally for every example in the batch
        t = torch.randint(low=0, high=500, size=(1,), device=device).long()

        # 2. Get l1 loss
        loss = p_losses(U, targets, t, loss_type='l2', time_cond=cond)

        if step % 100 == 0:
            print("Loss:", loss.item())
    
        total_epoch_loss += loss.detach().item()

        u_optimizer.zero_grad()
        loss.backward()
        u_optimizer.step()
    
    print("Epoch Loss:", total_epoch_loss / len(train_dataloader))

    # Save every epoch
    print("Saving...")
    torch.save(U.state_dict(), "cond_diffusion.pth")

--- Epoch [1/5] ---
Loss: 1.354584813117981
Loss: 0.2350679486989975
Loss: 0.6375114917755127
Loss: 0.28499627113342285
Loss: 0.7179057002067566
Epoch Loss: 0.6123815074253303
Saving...
--- Epoch [2/5] ---
Loss: 0.4805217683315277
Loss: 0.6244086027145386
Loss: 0.2914869785308838
Loss: 0.6305547952651978
Loss: 0.17532482743263245
Epoch Loss: 0.48986321815029304
Saving...
--- Epoch [3/5] ---
Loss: 0.16021402180194855
Loss: 0.12643244862556458
Loss: 0.605862021446228
Loss: 0.1947467029094696
Loss: 0.5922020673751831
Epoch Loss: 0.468313440508435
Saving...
--- Epoch [4/5] ---
Loss: 0.5997691750526428
Loss: 0.6881994605064392
Loss: 0.5229768753051758
Loss: 0.6171546578407288
Loss: 0.24648571014404297
Epoch Loss: 0.4588907350706449
Saving...
--- Epoch [5/5] ---
Loss: 0.4747392535209656
Loss: 0.5189821720123291
Loss: 0.42935848236083984
Loss: 0.5588717460632324
Loss: 0.351311057806015
Epoch Loss: 0.4622343526357316
Saving...


In [10]:
# sampling loop
U = Unet(dim=128, channels=1, dim_mults=(1, 2, 4, 8), self_condition=True)
if torch.cuda.is_available():
    U.cuda()
U.load_state_dict(torch.load('cond_diffusion.pth'))

U.eval()
total_loss = 0.
total_elems = 0
for step, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
    cond, targets = batch
    cond = torch.from_numpy(scaler.transform(cond))
    cond, targets = cond.float(), targets.float()
    cond, targets = cond.to(device), targets.to(device)
    targets = targets.unsqueeze(0).unsqueeze(0)
    cond = cond.unsqueeze(0)
    targets = torch.nn.functional.pad(targets, (0, 71, 0, 9), "constant", 0)
    # torch inference_mode already annotated for sample() function
    samples = sample(U, (64, 256), batch_size=1, channels=1, time_cond=cond)
    total_loss += torch.sum(torch.abs(torch.from_numpy(samples[-1]) - targets.detach().cpu())).item()
    total_elems += torch.numel(targets)
    print(total_loss / total_elems)
print(f"MAE: {total_loss / total_elems}")

sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.33it/s]
  2%|▏         | 1/48 [01:18<1:01:52, 78.99s/it]

4.113003730773926


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
  4%|▍         | 2/48 [02:37<1:00:33, 78.98s/it]

4.315007448196411


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
  6%|▋         | 3/48 [03:56<59:13, 78.97s/it]  

4.344380855560303


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
  8%|▊         | 4/48 [05:15<57:54, 78.96s/it]

4.341216683387756


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 10%|█         | 5/48 [06:34<56:34, 78.95s/it]

4.341076850891113


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 12%|█▎        | 6/48 [07:53<55:15, 78.95s/it]

4.371092875798543


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.33it/s]
 15%|█▍        | 7/48 [09:12<53:57, 78.96s/it]

4.485645907265799


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 17%|█▋        | 8/48 [10:31<52:38, 78.97s/it]

4.491291701793671


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 19%|█▉        | 9/48 [11:50<51:19, 78.96s/it]

4.461538526746962


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 21%|██        | 10/48 [13:09<50:00, 78.95s/it]

4.432057332992554


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 23%|██▎       | 11/48 [14:28<48:41, 78.95s/it]

4.418404925953258


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 25%|██▌       | 12/48 [15:47<47:22, 78.95s/it]

4.426957607269287


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 27%|██▋       | 13/48 [17:06<46:03, 78.96s/it]

4.418472106640156


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 29%|██▉       | 14/48 [18:25<44:44, 78.95s/it]

4.377064074788775


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 31%|███▏      | 15/48 [19:44<43:25, 78.94s/it]

4.356523434321086


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 33%|███▎      | 16/48 [21:03<42:06, 78.95s/it]

4.328624039888382


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 35%|███▌      | 17/48 [22:22<40:47, 78.95s/it]

4.318182047675638


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 38%|███▊      | 18/48 [23:41<39:28, 78.94s/it]

4.309334145651923


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 40%|███▉      | 19/48 [25:00<38:09, 78.94s/it]

4.291256653635125


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 42%|████▏     | 20/48 [26:19<36:50, 78.94s/it]

4.274409139156342


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 44%|████▍     | 21/48 [27:37<35:31, 78.94s/it]

4.262370870226905


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 46%|████▌     | 22/48 [28:56<34:12, 78.94s/it]

4.267400275577199


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 48%|████▊     | 23/48 [30:15<32:53, 78.92s/it]

4.285747289657593


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 50%|█████     | 24/48 [31:34<31:34, 78.92s/it]

4.280434479316075


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 52%|█████▏    | 25/48 [32:53<30:15, 78.93s/it]

4.278882894515991


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 54%|█████▍    | 26/48 [34:12<28:56, 78.93s/it]

4.3549306851166945


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 56%|█████▋    | 27/48 [35:31<27:37, 78.92s/it]

4.361088708594993


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 58%|█████▊    | 28/48 [36:50<26:18, 78.93s/it]

4.349992215633392


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 60%|██████    | 29/48 [38:09<24:59, 78.93s/it]

4.344578192151826


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 62%|██████▎   | 30/48 [39:28<23:40, 78.93s/it]

4.339537723859151


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 65%|██████▍   | 31/48 [40:47<22:21, 78.93s/it]

4.341339180546422


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 67%|██████▋   | 32/48 [42:06<21:03, 78.94s/it]

4.329653985798359


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 69%|██████▉   | 33/48 [43:25<19:44, 78.94s/it]

4.316612930008859


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 71%|███████   | 34/48 [44:44<18:25, 78.94s/it]

4.299175374648151


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 73%|███████▎  | 35/48 [46:03<17:06, 78.94s/it]

4.295046438489641


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 75%|███████▌  | 36/48 [47:21<15:47, 78.94s/it]

4.298680702845256


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 77%|███████▋  | 37/48 [48:40<14:28, 78.94s/it]

4.2928874557082715


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 79%|███████▉  | 38/48 [49:59<13:09, 78.94s/it]

4.330372320978265


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 81%|████████▏ | 39/48 [51:18<11:50, 78.94s/it]

4.322590815715301


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 83%|████████▎ | 40/48 [52:37<10:31, 78.94s/it]

4.319117915630341


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 85%|████████▌ | 41/48 [53:56<09:12, 78.94s/it]

4.304647637576592


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 88%|████████▊ | 42/48 [55:15<07:53, 78.94s/it]

4.333165140379043


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 90%|████████▉ | 43/48 [56:34<06:34, 78.94s/it]

4.331406765205916


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 92%|█████████▏| 44/48 [57:53<05:15, 78.94s/it]

4.319239681417292


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 94%|█████████▍| 45/48 [59:12<03:56, 78.94s/it]

4.3122179455227325


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 96%|█████████▌| 46/48 [1:00:31<02:37, 78.93s/it]

4.303508618603582


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
 98%|█████████▊| 47/48 [1:01:50<01:18, 78.92s/it]

4.297000174826764


sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.34it/s]
100%|██████████| 48/48 [1:03:09<00:00, 78.94s/it]

4.28546083966891
MAE: 4.28546083966891





# TimeGrad-Adapted Conditional Diffusion with FRFT

Here we perform our modified TimeGrad and add on conditioning on the Fractional Fourier transform.

In [8]:
from fft_cond_diffusion_model import Unet, p_losses, sample
U = Unet(dim=128, channels=1, dim_mults=(1, 2, 4, 8), self_condition=True)
if torch.cuda.is_available():
    U.cuda()
    print('Models moved to GPU.')
u_optimizer = torch.optim.Adam(U.parameters(), 0.0001, [0.5, 0.999])

  from .autonotebook import tqdm as notebook_tqdm


Models moved to GPU.


In [9]:
num_epc = 5
for epoch in range(num_epc):
    print(f"--- Epoch [{epoch+1}/{num_epc}] ---")

    total_epoch_loss = 0.

    for step, batch in enumerate(train_dataloader):
        cond, targets = batch
        cond = torch.from_numpy(scaler.transform(cond))
        cond, targets = cond.float(), targets.float()
        cond, targets = cond.to(device), targets.to(device)
        targets = targets.unsqueeze(0).unsqueeze(0)
        cond = cond.unsqueeze(0)
        targets = torch.nn.functional.pad(targets, (0, 71, 0, 9), "constant", 0)

        # 1. Sample t uniformally for every example in the batch
        t = torch.randint(low=0, high=500, size=(1,), device=device).long()

        # 2. Get l1 loss
        loss = p_losses(U, targets, t, loss_type='l2', time_cond=cond)

        if step % 100 == 0:
            print("Loss:", loss.item())
    
        total_epoch_loss += loss.detach().item()

        u_optimizer.zero_grad()
        loss.backward()
        u_optimizer.step()
    
    print("Epoch Loss:", total_epoch_loss / len(train_dataloader))

    # Save every epoch
    print("Saving...")
    torch.save(U.state_dict(), "frft_cond_diffusion.pth")

--- Epoch [1/5] ---
Loss: 1.3692975044250488
Loss: 0.22155994176864624
Loss: 0.606671929359436
Loss: 0.29038989543914795
Loss: 0.7157468795776367
Epoch Loss: 0.5967875802351751
Saving...
--- Epoch [2/5] ---
Loss: 0.48445069789886475
Loss: 0.624477207660675
Loss: 0.2906903326511383
Loss: 0.6239463686943054
Loss: 0.19061361253261566
Epoch Loss: 0.48953074664221624
Saving...
--- Epoch [3/5] ---
Loss: 0.16034150123596191
Loss: 0.12739458680152893
Loss: 0.6061078906059265
Loss: 0.1953548789024353
Loss: 0.5962360501289368
Epoch Loss: 0.4697800937687277
Saving...
--- Epoch [4/5] ---
Loss: 0.6069482564926147
Loss: 0.6854190826416016
Loss: 0.5220898389816284
Loss: 0.6204911470413208
Loss: 0.24741126596927643
Epoch Loss: 0.45765033223199514
Saving...
--- Epoch [5/5] ---
Loss: 0.47280484437942505
Loss: 0.515881359577179
Loss: 0.42626556754112244
Loss: 0.5580533742904663
Loss: 0.3658735454082489
Epoch Loss: 0.46285366972448094
Saving...


In [11]:
# sampling loop
U = Unet(dim=128, channels=1, dim_mults=(1, 2, 4, 8), self_condition=True)
if torch.cuda.is_available():
    U.cuda()
U.load_state_dict(torch.load('frft_cond_diffusion.pth'))

U.eval()
total_loss = 0.
total_elems = 0
for step, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
    cond, targets = batch
    cond = torch.from_numpy(scaler.transform(cond))
    cond, targets = cond.float(), targets.float()
    cond, targets = cond.to(device), targets.to(device)
    targets = targets.unsqueeze(0).unsqueeze(0)
    cond = cond.unsqueeze(0)
    targets = torch.nn.functional.pad(targets, (0, 71, 0, 9), "constant", 0)
    # torch inference_mode already annotated for sample() function
    samples = sample(U, (64, 256), batch_size=1, channels=1, time_cond=cond)
    total_loss += torch.sum(torch.abs(torch.from_numpy(samples[-1]) - targets.detach().cpu())).item()
    total_elems += torch.numel(targets)
    print(total_loss / total_elems)
print(f"MAE: {total_loss / total_elems}")

sampling loop time step: 100%|██████████| 500/500 [01:22<00:00,  6.09it/s]
  2%|▏         | 1/48 [01:22<1:04:20, 82.13s/it]

4.11014461517334


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.10it/s]
  4%|▍         | 2/48 [02:44<1:02:55, 82.07s/it]

4.281827926635742


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.10it/s]
  6%|▋         | 3/48 [04:06<1:01:32, 82.06s/it]

4.312503973642985


sampling loop time step: 100%|██████████| 500/500 [01:22<00:00,  6.08it/s]
  8%|▊         | 4/48 [05:28<1:00:14, 82.14s/it]

4.281505703926086


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 10%|█         | 5/48 [06:50<58:48, 82.06s/it]  

4.310191059112549


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 12%|█▎        | 6/48 [08:12<57:24, 82.01s/it]

4.339882691701253


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.10it/s]
 15%|█▍        | 7/48 [09:34<56:03, 82.03s/it]

4.443565164293561


sampling loop time step: 100%|██████████| 500/500 [01:22<00:00,  6.09it/s]
 17%|█▋        | 8/48 [10:56<54:43, 82.09s/it]

4.46360445022583


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.10it/s]
 19%|█▉        | 9/48 [12:18<53:20, 82.06s/it]

4.434910085466173


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 21%|██        | 10/48 [13:40<51:55, 82.00s/it]

4.392377042770386


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.10it/s]
 23%|██▎       | 11/48 [15:02<50:34, 82.01s/it]

4.368496374650435


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 25%|██▌       | 12/48 [16:24<49:11, 81.98s/it]

4.370772322018941


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 27%|██▋       | 13/48 [17:46<47:48, 81.96s/it]

4.365747121664194


sampling loop time step: 100%|██████████| 500/500 [01:22<00:00,  6.09it/s]
 29%|██▉       | 14/48 [19:08<46:28, 82.00s/it]

4.328195895467486


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 31%|███▏      | 15/48 [20:30<45:05, 81.98s/it]

4.303321154912313


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 33%|███▎      | 16/48 [21:52<43:42, 81.96s/it]

4.285396412014961


sampling loop time step: 100%|██████████| 500/500 [01:22<00:00,  6.09it/s]
 35%|███▌      | 17/48 [23:14<42:22, 82.00s/it]

4.2741837361279655


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 38%|███▊      | 18/48 [24:36<40:59, 81.97s/it]

4.264785779847039


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 40%|███▉      | 19/48 [25:58<39:36, 81.95s/it]

4.240985456265901


sampling loop time step: 100%|██████████| 500/500 [01:22<00:00,  6.09it/s]
 42%|████▏     | 20/48 [27:20<38:16, 82.00s/it]

4.227737259864807


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 44%|████▍     | 21/48 [28:42<36:53, 81.97s/it]

4.215077774865287


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 46%|████▌     | 22/48 [30:04<35:30, 81.94s/it]

4.221634030342102


sampling loop time step: 100%|██████████| 500/500 [01:22<00:00,  6.10it/s]
 48%|████▊     | 23/48 [31:26<34:09, 81.99s/it]

4.240104084429533


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 50%|█████     | 24/48 [32:47<32:46, 81.94s/it]

4.239701141913732


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 52%|█████▏    | 25/48 [34:09<31:24, 81.92s/it]

4.241749601364136


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 54%|█████▍    | 26/48 [35:31<30:02, 81.93s/it]

4.318916091552148


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.12it/s]
 56%|█████▋    | 27/48 [36:53<28:39, 81.88s/it]

4.319468560042204


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.12it/s]
 58%|█████▊    | 28/48 [38:15<27:16, 81.84s/it]

4.310112450804029


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 60%|██████    | 29/48 [39:37<25:55, 81.88s/it]

4.307964875780303


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.12it/s]
 62%|██████▎   | 30/48 [40:59<24:33, 81.84s/it]

4.298997807502746


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.12it/s]
 65%|██████▍   | 31/48 [42:20<23:10, 81.81s/it]

4.30323520014363


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 67%|██████▋   | 32/48 [43:42<21:49, 81.85s/it]

4.295509420335293


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.12it/s]
 69%|██████▉   | 33/48 [45:04<20:27, 81.81s/it]

4.280568491328847


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.12it/s]
 71%|███████   | 34/48 [46:26<19:05, 81.79s/it]

4.265242681783788


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 73%|███████▎  | 35/48 [47:48<17:43, 81.83s/it]

4.257589197158813


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.12it/s]
 75%|███████▌  | 36/48 [49:09<16:21, 81.80s/it]

4.258629262447357


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.12it/s]
 77%|███████▋  | 37/48 [50:31<14:59, 81.78s/it]

4.250288847330454


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 79%|███████▉  | 38/48 [51:53<13:38, 81.83s/it]

4.290512059864245


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.12it/s]
 81%|████████▏ | 39/48 [53:15<12:16, 81.79s/it]

4.278892761621719


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.12it/s]
 83%|████████▎ | 40/48 [54:36<10:54, 81.77s/it]

4.274121737480163


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 85%|████████▌ | 41/48 [55:58<09:32, 81.81s/it]

4.25708652705681


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.12it/s]
 88%|████████▊ | 42/48 [57:20<08:10, 81.78s/it]

4.2854203553426835


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.12it/s]
 90%|████████▉ | 43/48 [58:42<06:48, 81.76s/it]

4.2820365650709284


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.10it/s]
 92%|█████████▏| 44/48 [1:00:04<05:27, 81.82s/it]

4.272343418814919


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.12it/s]
 94%|█████████▍| 45/48 [1:01:25<04:05, 81.80s/it]

4.263354645835029


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.12it/s]
 96%|█████████▌| 46/48 [1:02:47<02:43, 81.78s/it]

4.254121687101281


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.11it/s]
 98%|█████████▊| 47/48 [1:04:09<01:21, 81.82s/it]

4.2482460356773215


sampling loop time step: 100%|██████████| 500/500 [01:21<00:00,  6.12it/s]
100%|██████████| 48/48 [1:05:31<00:00, 81.90s/it]

4.234675755103429
MAE: 4.234675755103429



