In [None]:
import random
from functools import partial
import inspect
import numpy as np
import pandas as pd
import torch
from fastai.optimizer import Adam, RMSProp
from fastai.tabular.model import TabularModel
from torch import nn

from matplotlib import pyplot as plt

# from data.synthetic_wind_data import get_synthetic_wind_data
# from utils_plots import data_to_plot, plot_hists, wind_vs_power
from fastai.callback.all import *

# from gan import GANLearner
from fastai.vision.gan import *

rng_seed = 42
torch.manual_seed(rng_seed)
random.seed(rng_seed)
np.random.seed(rng_seed)

from fastai.tabular.all import TabularDataLoaders
from fastai.tabular.all import TabularProc, Tabular

In [None]:
class NormalizePerTask(TabularProc):
    "Normalize per TaskId"
    order = 1

    def __init__(self, cols_to_ignore=[]):
        self.cols_to_ignore = cols_to_ignore

    def setups(self, to: Tabular):
        self.rel_cols = [c for c in to.cont_names if c not in self.cols_to_ignore]
        self.means = getattr(to, "train", to)[self.rel_cols].mean()
        self.stds = getattr(to, "train", to)[self.rel_cols].std(ddof=0) + 1e-7

    def encodes(self, to):
        to.loc[:, self.rel_cols] = (to.loc[:, self.rel_cols] - self.means) / self.stds

    def decodes(self, to):
        to.loc[:, self.rel_cols] = to.loc[:, self.rel_cols] * self.stds + self.means

        
def get_synthetic_wind_data(data_path, file_name, config):

    df = pd.concat([pd.read_hdf(data_path + f, key="powerdata") for f in file_name], axis=0)
    drop_list = ["TestFlag"]
    cat_names = []
    cont_names = [column for column in df.columns if column not in drop_list]
    # cont_names = [
    #     "WindSpeed58m",
    #     "WindSpeed60m",
    #     "PowerGeneration",
    # ]
    cont_names = [
        "T_HAG_2_M",
        "RELHUM_HAG_2_M",
        "PS_SFC_0_M",
        "ASWDIFDS_SFC_0_M",
        "ASWDIRS_SFC_0_M",
        "WindSpeed58m",
        "WindSpeed60m",
#         "PowerGeneration",
    ]
    y_names = ["PowerGeneration"]
    # y_names = []

    all_names = cont_names + y_names

    dls = TabularDataLoaders.from_df(
        df=df,
        cat_names=cat_names,
        cont_names=all_names,
        y_names=all_names,
        # procs=[Normalize],
        bs=config["batch_size"],
        procs=[NormalizePerTask(cols_to_ignore=["PowerGeneration"])],
#         procs=[NormalizePerTask()],
    )

    return dls

In [None]:
data_path = "/home/scribbler/data/DAF_ICON_Synthetic_Wind_Power_processed/"
file_name = ["00011.h5","01303.h5","02559.h5"]

In [None]:
ls {data_path}

00011.h5  01303.h5  02559.h5  03651.h5  05347.h5  06344.h5
00090.h5  01346.h5  02564.h5  03668.h5  05349.h5  07341.h5
00161.h5  01357.h5  02573.h5  03730.h5  05371.h5  07351.h5
00164.h5  01358.h5  02597.h5  03761.h5  05397.h5  07367.h5
00183.h5  01379.h5  02601.h5  03811.h5  05404.h5  07368.h5
00197.h5  01420.h5  02638.h5  03821.h5  05412.h5  07369.h5
00198.h5  01443.h5  02667.h5  03897.h5  05426.h5  07370.h5
00232.h5  01468.h5  02712.h5  03925.h5  05440.h5  07374.h5
00282.h5  01490.h5  02794.h5  03946.h5  05480.h5  07389.h5
00298.h5  01503.h5  02812.h5  03987.h5  05490.h5  07391.h5
00303.h5  01544.h5  02856.h5  04024.h5  05516.h5  07392.h5
00342.h5  01550.h5  02878.h5  04032.h5  05538.h5  07393.h5
00427.h5  01580.h5  02897.h5  04036.h5  05546.h5  07394.h5
00430.h5  01587.h5  02907.h5  04039.h5  05629.h5  07395.h5
00433.h5  01605.h5  02925.h5  04094.h5  05705.h5  07396.h5
00460.h5  01612.h5  02928.h5  04104.h5  05779.h5  07403.h5
00591.h5  01639.h5  02932.h5  04177.h5  

In [None]:
config = {
    "n_samples": 100000,
    # "n_features": 24,
#     "n_targets": 1,
    "batch_size": 1024,
    "n_noise_samples": 100,
    "lr": 1 * 1e-5,
    "epochs": 100,
    "structure": [2 ** n for n in range(11, 5, -1)],
}

In [None]:
dls = get_synthetic_wind_data(data_path, file_name, config)
config["n_features"] = len(dls.cont_names)

In [None]:
cat,x,y = dls.one_batch()

In [None]:
x.shape

torch.Size([1024, 8])

In [None]:
(x==y).sum()==x.shape[0]*x.shape[1]

tensor(True)

In [None]:
from fastai.callback.core import Callback, TrainEvalCallback
from fastai.callback.progress import CSVLogger
from fastai.learner import Learner, Metric
from fastcore.basics import class2attr
from fastcore.foundation import L
from torch import nn
import torch

from fastrenewables.tabular.model import MultiLayerPerceptron
from torch.nn import Tanh
from torch.nn import Sigmoid
from enum import Enum

In [None]:
class TrainMode(Enum):
    DISC_REAL=0
    DISC_FAKE=1
    GEN=2
    

In [None]:
TrainMode.DISC_REAL

<TrainMode.DISC_REAL: 0>

In [None]:
class TabularGANModule(nn.Module):
    def __init__(self, generator, critic, noise_size=100):
        super(TabularGANModule, self).__init__()
        self.generator = generator
        self.critic = critic
        self.noise_size = noise_size
        self.gen_mode = True  # for forward-fct
        self.train_gen = True  # for optimizer handling

    def _input_noise(self, bs):
        # generate random values, used as input for generator
        with torch.no_grad():
            return torch.randn(bs, self.noise_size).cuda()

    def generate_samples(self, x_cat, x_cont, device="cpu"):
        bs = x_cont.shape[0]
        noise = self._input_noise(bs).to(device)
        gen_data = self.generator(x_cat, noise)
        return gen_data

    def _requires_grad(self, model, freeze):
        for p in model.parameters():
            p.requires_grad = freeze
            
    def update_train_mode(self, train_mode):
        self.train_mode = train_mode

    def forward(self, x_cat, x_cont):
        # TODO: should not be in the model
        if self.train_mode in (TrainMode.DISC_REAL, TrainMode.DISC_FAKE):
            self._requires_grad(self.generator, False)
            self._requires_grad(self.critic, True)
        else:
            self._requires_grad(self.generator, True)
            self._requires_grad(self.critic, False)
            
        
        if self.train_mode == TrainMode.DISC_REAL:
            # take real data
            crit_out = self.critic(x_cat, x_cont)
        elif self.train_mode in (TrainMode.DISC_FAKE, TrainMode.GEN):
            # take synthetic data
            gen_data = self.generate_samples(x_cat, x_cont)
            crit_out = self.critic(x_cat, gen_data)
            
#         print("crit_out", crit_out.shape, crit_out[0,0])

        return crit_out

In [None]:
class RealLossMetric(Metric):
    def reset(self):
        pass

    def accumulate(self, learn):
        self.real_loss = learn.real_loss

    @property
    def value(self):
        return f"{self.real_loss:.6f}"

    @property
    def name(self):
        return class2attr(self, "Metric")


class FakeLossMetric(Metric):
    def reset(self):
        pass

    def accumulate(self, learn):
        self.fake_loss = learn.fake_loss

    @property
    def value(self):
        return f"{self.fake_loss:.6f}"

    @property
    def name(self):
        return class2attr(self, "Metric")

In [None]:
a = nn.BCELoss()
a

BCELoss()

In [None]:
class GANLoss(nn.Module):
    def __init__(self, model, criterion=nn.BCELoss()):
        super(GANLoss, self).__init__()
        self.model = model
        self.real_label = 1.0
        self.fake_label = 0.0
        self.criterion = criterion

    @property
    def train_mode(self):
        return self.model.train_mode

    def forward(self, y, t):
        # change t w.r.t train_mode
        loss = 0
        device = "cuda:0"
        b_size = y.shape[0]
        a_size = y.shape[1]
#         print("*******")
#         print(y.shape, y[0,0])
#         print("*******")
#         print(t.shape, t[0,0])
#         print("*******")
        
        
        if self.train_mode == TrainMode.DISC_REAL:
            label = torch.full((b_size,a_size), self.real_label, dtype=torch.float, device=device)
        elif self.train_mode == TrainMode.DISC_FAKE:
            label = torch.full((b_size,a_size), self.fake_label, dtype=torch.float, device=device)
        elif self.train_mode.GEN == TrainMode.GEN:
            label = torch.full((b_size,a_size), self.real_label, dtype=torch.float, device=device)
        else:
            raise ValueError
            
        return self.criterion(y.reshape(-1), label.reshape(-1).detach())


In [None]:
from fastai.callback.core import *

class GANTrainer(Callback):
    run_after = TrainEvalCallback

    def __init__(self, n_gen=1, n_crit=1, clip=None):
        super(GANTrainer, self).__init__()
        self.n_gen = n_gen
        self.n_crit = n_crit
        self.clip = clip
        

    def after_create(self):
        self.learn.gen_opt = self.opt_func(self.model.generator.parameters(), lr=self.lr)
        self.learn.crit_opt = self.opt_func(self.model.critic.parameters(), lr=self.lr)
#         self.c_gen = 0
#         self.c_crit = 0
        self._set_train_mode(TrainMode.DISC_REAL)
        

    def before_batch(self):
        if self.train_mode.GEN == TrainMode.GEN:
            self.learn.opt = self.learn.gen_opt
        else:
            self.learn.opt = self.learn.crit_opt
        # zero grad only after DISC_FAKE? is done in _do_one_batch anyways
        self.learn.opt.zero_grad()

    def before_step(self):
        if not self.train_mode.GEN == TrainMode.GEN and self.clip is not None:
            nn.utils.clip_grad_value_(self.learn.model.critic.parameters(), self.clip)
            
    def _set_train_mode(self, new_mode):
        self.train_mode = new_mode
        self.model.update_train_mode(self.train_mode)

    def after_batch(self):
        if not self.training:
            return

        # TODO: IS THIS DIRECTLY ACCESSING THE MODELS PARAMS?
        if self.train_mode == TrainMode.DISC_REAL:
            self._set_train_mode(TrainMode.DISC_FAKE)
        elif self.train_mode == TrainMode.DISC_FAKE:
            self._set_train_mode(TrainMode.GEN)
            self.learn.fake_loss = self.learn.loss.item()
        elif self.train_mode.GEN == TrainMode.GEN:
            self._set_train_mode(TrainMode.DISC_REAL)
            self.learn.real_loss = self.learn.loss.item()
        else:
            raise ValueError
            


    def after_fit(self):
        self.learn.opt = self.learn.gen_opt

In [None]:
class GANLearner(Learner):
    def __init__(
        self,
        dls,
        generator,
        critic,
        criterion=nn.BCELoss(),
        lr=1e-3,
        noise_size=100,
        cbs=[CSVLogger()],
        metrics=[RealLossMetric, FakeLossMetric],
        opt_func=None,
        clip=0.01,
        n_gen=1,
        n_crit=5,
    ):
        gan_model = TabularGANModule(
            generator=generator,
            critic=critic,
            noise_size=noise_size,
        )

        gan_loss = GANLoss(gan_model, criterion=criterion)
        trainer = GANTrainer(n_gen, n_crit, clip)
        cbs = L(cbs) + L(trainer)
        metrics = L(metrics)
        
        super(GANLearner, self).__init__(
            dls,
            gan_model,
            cbs=cbs,
            lr=lr,
            metrics=metrics,
            opt_func=opt_func,
            loss_func=gan_loss,
        )
        

In [None]:
# Learning rate for optimizers
lr = 1e-2

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5



# learner = GANLearner(dls, generator=generator_model, 
#                      critic=critic_model, 
# #                      opt_func=partial(RMSProp, mom=0.5, sqr_mom=0.99),)
#                     opt_func=partial(Adam, lr=lr),)

In [None]:
# learner.fit(50, lr=lr)

In [None]:
# from fastai.torch_core import to_np

# def data_to_plot(dls, learner):

#     # dls.cuda()
#     learner.model.cuda()
#     real_cat_data = dls.train_ds.cats
#     real_cont_data = dls.train_ds.conts
#     fake_data = to_np(learner.model.generate_samples(real_cat_data, real_cont_data))
#     # fake_data = (
#     #         learner.model.generate_samples(real_cat_data, real_cont_data).detach().cpu().numpy()
#     # )
#     real_data = real_cont_data.to_numpy()
#     return real_data, fake_data

In [None]:
# real, fake = data_to_plot(dls, learner)
# for i in range(real.shape[1]):
    
#     plt.hist(fake[:,i], label="fake")
#     plt.hist(real[:,i], label="real")
    
#     plt.title(dls.train_ds.cont_names[i])
#     plt.legend()
#     plt.show()

In [None]:
config

{'n_samples': 100000,
 'batch_size': 1024,
 'n_noise_samples': 100,
 'lr': 1e-05,
 'epochs': 100,
 'structure': [2048, 1024, 512, 256, 128, 64],
 'n_features': 8}

In [None]:
# gan_model = TabularGANModule(
#             generator=generator_model,
#             critic=critic_model,
#             noise_size=config["n_noise_samples"],
#         )

In [None]:
cat,x,y = dls.one_batch()

In [None]:
# critic = MultiLayerPerceptron(
#             ann_structure=[config["n_features"], 100, 50, 1],
#             act_cls=nn.LeakyReLU(),
#             final_activation=Sigmoid,
#             bn_cont=False
#         )

# generator = MultiLayerPerceptron(
#     [config["n_noise_samples"], 400, 200, 100, 50, config["n_features"]],
#     act_cls=nn.LeakyReLU(),
# #     final_activation=Sigmoid,
#     bn_cont=False
# )

# import torch.optim as optim# Initialize BCELoss function
# criterion = nn.BCELoss()

# # Create batch of latent vectors that we will use to visualize
# #  the progression of the generator
# # fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# # Establish convention for real and fake labels during training
# real_label = 1.
# fake_label = 0.

# # Setup Adam optimizers for both G and D
# optimizerD = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
# optimizerG = optim.Adam(critic.parameters(), lr=lr, betas=(beta1, 0.999))


# critic.zero_grad()
# batch_size = x.shape[0]
# label = torch.full((batch_size,), real_label, dtype=torch.float)
# output = critic(cat, y).view(-1)

# errD_real = criterion(output, label)
# errD_real.backward()

# D_x = output.mean().item()

# # fake = gan_model.generate_samples(cat,x, "cpu")
# bs = x.shape[0]
# noise = gan_model._input_noise(bs).to("cpu")
# fake = generator(cat, noise)
# # # noise
# label.fill_(fake_label)
# output = critic(cat, fake.detach()).view(-1)
# errD_fake = criterion(output, label)
# errD_fake.backward()
# D_G_z1 = output.mean().item()
# errD = errD_real + errD_fake

# optimizerD.step()
# critic.zero_grad()

# generator.zero_grad()
# label.fill_(real_label)
# fake = generator(cat, noise)
# ouput = critic(cat, fake).view(-1)
# errG = criterion(output,label)
# errG.backward()

In [None]:
# class GANLearner(Learner):
#     def __init__(
#         self,
#         dls,
#         generator,
#         critic,
#         criterion=nn.BCELoss(),
#         lr=1e-3,
#         noise_size=100,
#         cbs=[CSVLogger()],
#         metrics=[RealLossMetric, FakeLossMetric],
#         opt_func=None,
#         clip=0.01,
#         n_gen=1,
#         n_crit=5,
#     ):
#         gan_model = TabularGANModule(
#             generator=generator,
#             critic=critic,
#             noise_size=noise_size,
#         )

#         gan_loss = GANLoss(gan_model, criterion=criterion)
# #         trainer = GANTrainer(n_gen, n_crit, clip)
# #         cbs = L(cbs) + L(trainer)
# #         metrics = L(metrics)
        
#         # Setup Adam optimizers for both G and D
#         self.optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
#         self.optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

        
#         super(GANLearner, self).__init__(
#             dls,
#             gan_model,
#             cbs=cbs,
#             lr=lr,
#             metrics=metrics,
#             opt_func=opt_func,
#             loss_func=gan_loss,
#         )
        
#     def _do_one_batch(self):
#         self.pred = self.model(*self.xb)
#         self('after_pred')
#         if len(self.yb):
#             self.loss_grad = self.loss_func(self.pred, *self.yb)
#             self.loss = self.loss_grad.clone()
#         self('after_loss')
#         if not self.training or not len(self.yb): return
#         self('before_backward')
#         self.loss_grad.backward()
#         self._with_events(self.opt.step, 'step', CancelStepException)
#         self.opt.zero_grad()
        