In [1]:
import argparse
import torch
import datetime
import json
import yaml
import os

from dataset_pm25 import get_dataloader
from main_model import CSDI_PM25
from utils import train, evaluate

In [3]:
parser = argparse.ArgumentParser(description="CSDI")
parser.add_argument("--config", type=str, default="base.yaml")
parser.add_argument('--device', default='cuda:0', help='Device for Attack')
parser.add_argument("--modelfolder", type=str, default="")
parser.add_argument(
    "--targetstrategy", type=str, default="mix", choices=["mix", "random", "historical"]
)
parser.add_argument(
    "--validationindex", type=int, default=0, help="index of month used for validation (value:[0-7])"
)
parser.add_argument("--nsample", type=int, default=100)
parser.add_argument("--unconditional", action="store_true")

args_str = '--nsample 16 --config small_try.yaml'

args = parser.parse_args(args_str.split())
print(args)

Namespace(config='small_try.yaml', device='cuda:0', modelfolder='', targetstrategy='mix', validationindex=0, nsample=16, unconditional=False)


In [4]:
path = "config/" + args.config
with open(path, "r") as f:
    config = yaml.safe_load(f)

config["model"]["is_unconditional"] = args.unconditional
config["model"]["target_strategy"] = args.targetstrategy

print(json.dumps(config, indent=4))

current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 
foldername = (
    "./save/pm25_validationindex" + str(args.validationindex) + "_" + current_time + "/"
)

print('model folder:', foldername)
os.makedirs(foldername, exist_ok=True)
with open(foldername + "config.json", "w") as f:
    json.dump(config, f, indent=4)

{
    "train": {
        "epochs": 100,
        "batch_size": 16,
        "lr": 0.001
    },
    "diffusion": {
        "layers": 4,
        "channels": 16,
        "nheads": 4,
        "diffusion_embedding_dim": 4,
        "beta_start": 0.0001,
        "beta_end": 0.5,
        "num_steps": 50,
        "schedule": "quad"
    },
    "model": {
        "is_unconditional": false,
        "timeemb": 12,
        "featureemb": 16,
        "target_strategy": "mix"
    }
}
model folder: ./save/pm25_validationindex0_20220617_135452/


In [5]:
train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader(
    config["train"]["batch_size"], device=args.device, validindex=args.validationindex
)
model = CSDI_PM25(config, args.device).to(args.device)

In [17]:
item = next(iter(valid_loader))
item["observed_data"].shape

torch.Size([16, 36, 36])

In [18]:
train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader(
    config["train"]["batch_size"], device=args.device, validindex=args.validationindex
)
model = CSDI_PM25(config, args.device).to(args.device)