Skip to content

Commit

Permalink
Cosmoflow (#87)
Browse files Browse the repository at this point in the history
* Add stub cosmoflow

* Add argparse to cosmoflow

* Fix arg parse of input shape

* Add model compilation

* Guess get_data()

* Refactor into build_model

* Finish cosmoflow (see full commit message!)

It seems cosmoflow is finished but it still doesn't work, because ATM
`do_tensorflow.py` does only classification (see line 77, where it
applies `to_categorical` to `y_train`).  So to make it work - you have
to comment out that line.

* Remove to_categorical()

* Change resnet, vggs, xception to use sparse CCE

* Add cosmoflow test

* Add cosmoflow/pytorch.py

* Move `proc_params()` to separate file

* Add support for channels_first

* Finish pytorch inference for cosmoflow

* Remove unused cosmoflow model

* Add `Regression` class to `helpers_torch.py`

* Make pytorch cosmoflow a regression

* Add test/pytorch/test_cosmoflow.py

Co-authored-by: Emil VATAI <vatai@x1carbon.subliminal>
  • Loading branch information
vatai and Emil VATAI committed Jul 23, 2020
1 parent 11f7e3d commit 96ff510
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 22 deletions.
17 changes: 10 additions & 7 deletions benchmarker/modules/problems/cosmoflow/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@


def get_data(params):
shape = (params["problem"]["cnt_batches_per_epoch"], params["batch_size"])
shape += params["input_shape"]
cnt_batches = params["problem"]["cnt_batches_per_epoch"]
batch_size = params["batch_size"]

shape = params["input_shape"]
if params["channels_first"]:
shape = shape[-1:] + shape[:-1]
# params["input_shape"] = shape # should this be updated?

shape = (cnt_batches, batch_size) + shape
X = np.random.random(shape).astype(np.float32)

shape = (
params["problem"]["cnt_batches_per_epoch"],
params["batch_size"],
params["target_size"],
)
shape = (cnt_batches, batch_size, params["target_size"])
Y = np.random.random(shape).astype(np.float32)

return X, Y
19 changes: 19 additions & 0 deletions benchmarker/modules/problems/cosmoflow/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import argparse
import ast


def proc_params(params, unparsed_args):
"""Process args for CosmoFlow in (D1, D2, D3, Channels) tuple."""

parser = argparse.ArgumentParser()
parser.add_argument("--input_shape", default="128, 128, 128, 4")
parser.add_argument("--target_size", default=4)
parser.add_argument("--dropout", default=0)

args, unparsed = parser.parse_known_args(unparsed_args)

params["input_shape"] = ast.literal_eval(args.input_shape)
params["target_size"] = args.target_size
params["dropout"] = args.dropout

assert unparsed == []
64 changes: 64 additions & 0 deletions benchmarker/modules/problems/cosmoflow/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import numpy as np
import torch.nn as nn

from ..helpers_torch import Regression
from .params import proc_params


class scale_1p2(nn.Module):
def forward(self, input):
return 1.2 * input


def build_model(input_shape, target_size, dropout=0):
shape = np.array(input_shape[:-1], dtype=np.int)
conv_args = {
"in_channels": input_shape[-1],
"out_channels": 16,
"kernel_size": 2,
}
maxpool_args = dict(kernel_size=2)

layers = [
nn.Conv3d(**conv_args),
nn.LeakyReLU(),
nn.MaxPool3d(**maxpool_args),
]
shape = (shape - 1) // 2

conv_args["in_channels"] = 16
for _ in range(4):
layers += [
nn.Conv3d(**conv_args),
nn.LeakyReLU(),
nn.MaxPool3d(**maxpool_args),
]
shape = (shape - 1) // 2

flat_shape = np.prod(shape) * 16
layers += [
nn.Flatten(),
nn.Dropout(dropout),
#
nn.Linear(flat_shape, 128),
nn.LeakyReLU(),
nn.Dropout(),
#
nn.Linear(128, 64),
nn.LeakyReLU(),
nn.Dropout(),
#
nn.Linear(64, target_size),
nn.Tanh(),
scale_1p2(),
]

return nn.Sequential(*layers)


def get_kernel(params, unparsed_args):
"""Construct the CosmoFlow 3D CNN model"""

proc_params(params, unparsed_args)
net = build_model(params["input_shape"], params["target_size"], params["dropout"])
return Regression(params, net)
17 changes: 2 additions & 15 deletions benchmarker/modules/problems/cosmoflow/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
"""

import argparse
import ast

import tensorflow as tf
import tensorflow.keras.layers as layers

from .params import proc_params


def scale_1p2(x):
"""Simple scaling function for Lambda layers.
Expand All @@ -23,18 +22,6 @@ def scale_1p2(x):
return x * 1.2


def proc_params(params, unparsed_args):
parser = argparse.ArgumentParser()
parser.add_argument("--input_shape", default="128, 128, 128, 4")
parser.add_argument("--target_size", default=4)
parser.add_argument("--dropout", default=0)
args, unparsed = parser.parse_known_args(unparsed_args)
params["input_shape"] = ast.literal_eval(args.input_shape)
params["target_size"] = args.target_size
params["dropout"] = args.dropout
assert unparsed == []


def build_model(input_shape, target_size, dropout=0):
conv_args = dict(kernel_size=2, padding="valid")

Expand Down
10 changes: 10 additions & 0 deletions benchmarker/modules/problems/helpers_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,13 @@ def __init__(self, net):
def Recommender(params, net):
"""Returns an inference or training recommender."""
return Net4Both(params, net, RecommenderInference, RecommenderTraining)


class RegressionTraining(Net4Train):
def __init__(self, net_and_loss):
super().__init__(*net_and_loss)


def Regression(params, net, loss=nn.MSELoss()):
"""Returns an inference or training recommender."""
return Net4Both(params, (net, loss), lambda t: t[0], RegressionTraining)
25 changes: 25 additions & 0 deletions test/pytorch/test_cosmoflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import logging
import os
import unittest

from benchmarker.benchmarker import run

logging.basicConfig(level=logging.DEBUG)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"


class CosmoflowTests(unittest.TestCase):
def setUp(self):
self.args = [
"--problem=cosmoflow",
"--framework=pytorch",
"--problem_size=1",
"--batch_size=1",
"--nb_epoch=1",
]

def test_cosmoflow(self):
run(self.args)

def test_cosmoflow_inference(self):
run(self.args + ["--mode=inference"])

0 comments on commit 96ff510

Please sign in to comment.