In [5]:
import torch 
import pandas as pd, numpy as np
import os
import sys
sys.path.append('/global/cfs/cdirs/m3443/usr/pmtuan/hadsim')
import yaml
from data.utils import *
from data.datamodule import CartesianDataModule

DATA_PATH = "/global/cfs/cdirs/m3443/usr/pmtuan/HadronicMCData/train_data_2_particles_processed/"
data_files = os.listdir(DATA_PATH)

In [6]:
hparams = {
    'n_particle': 1,
    'max_etot': 100000,
    'min_etot': 10000,
    'gen_hidden_activation': 'LeakyReLU',
    'dis_hidden_activation': 'LeakyReLU',
    'gen_output_activation': 'Tanh',
    'dis_output_activation': 'Sigmoid',
    'gen_batchnorm': True,
    'dis_batchnorm': True,
    'gen_dropout_rate': 0.5,
    'dis_dropout_rate': 0.,
    'nb_gen_layer': 10,
    'nb_dis_layer': 10,
    'gen_lr': 0.001,
    'dis_lr': 0.001,
    
    'sort_by': 0,
    'batch_size': 8096,
    'input_dir': '/global/cfs/cdirs/m3443/usr/pmtuan/HadronicMCData/2_particle_fstate',
    'hidden':  128,
    
    'noise_dim': 4,
    'cond_dim': 1,
    'gen_in': 4,
    'gen_dim': 4,
    'data_module': 'CartesianDataModule'
    
}

In [11]:
from turtle import forward
from pytorch_lightning import LightningModule
from torch import nn
from scipy.stats import wasserstein_distance

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class Generator(LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)

        self.network = make_mlp(
            input_size=self.hparams['noise_dim'] + self.hparams['cond_dim'],
            sizes=[self.hparams['hidden']] * self.hparams['nb_gen_layer'] + [self.hparams['gen_dim']],
            hidden_activation=hparams['gen_hidden_activation'],
            output_activation=hparams['gen_output_activation'],
            dropout_rate=self.hparams['gen_dropout_rate'],
            batch_norm=self.hparams['gen_batchnorm']
        )

    def forward(self, cond):
        noise = torch.distributions.normal.Normal(0., 1.).sample(torch.Size([cond.shape[0], self.hparams['noise_dim']])).to(self.device)
        z = torch.cat([cond, noise], dim=-1)
        return self.network(z)

class Discriminator(LightningModule):

    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)

        self.network = make_mlp(
            input_size=self.hparams['gen_dim'] + self.hparams['cond_dim'],
            sizes=[self.hparams['hidden']] * self.hparams['nb_dis_layer'] + [1],
            hidden_activation=hparams['dis_hidden_activation'],
            output_activation=hparams['dis_output_activation'],
            dropout_rate=self.hparams['dis_dropout_rate'],
            batch_norm=self.hparams['dis_batchnorm']
        )

    def forward(self, cond, x):
        z = torch.cat([cond, x], dim=-1)
        return self.network(z)

class GAN(LightningModule):
    def __init__(self, hparams) -> None:
        super().__init__()

        self.save_hyperparameters(hparams)

        self.gen, self.dis = Generator(self.hparams), Discriminator(self.hparams)

        self.data_module = eval(self.hparams['data_module'])(self.hparams)
    
    def setup(self, stage = 'fit') -> None:
        self.data_module.setup(stage)

    def train_dataloader(self):
        return self.data_module.train_dataloader()

    def val_dataloader(self):
        return self.data_module.val_dataloader()

    def test_dataloader(self):
        return self.data_module.test_dataloader()

    def forward(self, cond):
        return self.gen(cond)
    
    def adversarial_loss(self, y_hat, y):
        return nn.functional.binary_cross_entropy(y_hat, y)
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        cond, _, truth = batch

        # train generator
        if optimizer_idx == 0:
            # generate data
            fake = self(cond)

            # create fake label that looks like truth label
            y = torch.ones(cond.size(0), 1, device=self.device)

            y_hat = self.dis(cond, fake)

            g_loss = self.adversarial_loss(y_hat, y)

            self.log('gen_loss', g_loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=self.hparams['batch_size'])

            return g_loss
        
        if optimizer_idx == 1:
            # train discriminator
            truth_label = torch.ones(truth.size(0), 1, device=self.device)

            fake = self(cond).detach()

            # create actual fake label
            fake_label = torch.zeros(cond.size(0), 1, device=self.device)

            y_hat = self.dis(
                torch.cat([cond, cond], dim=0),
                torch.cat([truth, fake], dim=0)                
            )
            
            y = torch.cat([truth_label, fake_label], dim=0)

            d_loss = self.adversarial_loss( y_hat, y )

            self.log('dis_loss', d_loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=self.hparams['batch_size'])

            return d_loss

    def configure_optimizers(self):
        opt_g = torch.optim.Adam(self.gen.parameters(), lr=self.hparams['gen_lr'])
        opt_d = torch.optim.Adam(self.dis.parameters(), lr=self.hparams['dis_lr'])
        return [opt_g, opt_d], []

    def shared_evaluation(self, batch, batch_idx, log=True):
        cond, _, truth = batch
        
        fake = self(cond)

        y = torch.ones(cond.size(0), 1, device=self.device)

        y_hat = self.dis(cond, fake)

        loss = self.adversarial_loss(y_hat, y)

        self.log_metrics(cond, truth, fake)

        return {'loss': loss}


    def log_metrics(self, cond, truth, fake):

        nominal_cond = torch.round(cond, decimals=1).view(-1)
        unique_cond = nominal_cond.unique()

        w_dis = {}
        for u_cond in nominal_cond.unique():
            mask = (nominal_cond == u_cond)
            c_truth, c_fake = truth[mask], fake[mask]
            w_dis[u_cond] = np.mean(
                [
                    wasserstein_distance(c_truth.detach().cpu().numpy()[:, idx], c_fake.detach().cpu().numpy()[:, idx]) for idx in range(c_truth.size(1))
                ]
            )
        mean_wdis = np.mean(list(w_dis.values()))

        self.log_dict({
            "mean_wdis" : mean_wdis,

        },
        on_epoch=True, on_step=False, prog_bar=True, batch_size=self.hparams['batch_size'] )

    def validation_step(self, batch, batch_idx):
        outputs = self.shared_evaluation(batch, batch_idx)
        return outputs['loss']
    
    def test_step(self, batch, batch_idx):
        outputs = self.shared_evaluation(batch, batch_idx)
        return outputs['loss']



In [12]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar

gan = GAN(hparams)
trainer = Trainer(max_epochs=2, callbacks=[RichProgressBar(refresh_rate=10)], accelerator='auto', devices=1)

trainer.fit(gan)

`Trainer.fit` stopped: `max_epochs=2` reached.
