Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/replay #108

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 72 additions & 33 deletions examples/vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,41 @@
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os

import flor


parser = argparse.ArgumentParser(description='VAE MNIST Example')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument(
'--batch-size',
type=int,
default=128,
metavar='N',
help='input batch size for training (default: 128)'
)
parser.add_argument(
'--epochs',
type=int,
default=10,
metavar='N',
help='number of epochs to train (default: 10)'
)
parser.add_argument(
'--no-cuda',
action='store_true',
default=False,
help='disables CUDA training'
)
parser.add_argument(
'--seed', type=int, default=1, metavar='S', help='random seed (default: 1)'
)
parser.add_argument(
'--log-interval',
type=int,
default=10,
metavar='N',
help='how many batches to wait before logging training status'
)
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

Expand All @@ -30,12 +50,19 @@

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.ToTensor()),
batch_size=args.batch_size, shuffle=True, **kwargs)
datasets.MNIST(
'../data', train=True, download=True, transform=transforms.ToTensor()
),
batch_size=args.batch_size,
shuffle=True,
**kwargs
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
batch_size=args.batch_size, shuffle=True, **kwargs)
batch_size=args.batch_size,
shuffle=True,
**kwargs
)


class VAE(nn.Module):
Expand All @@ -53,9 +80,9 @@ def encode(self, x):
return self.fc21(h1), self.fc22(h1)

def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps*std
return mu + eps * std

def decode(self, z):
h3 = F.relu(self.fc3(z))
Expand Down Expand Up @@ -88,7 +115,7 @@ def train(epoch):
model.train()
train_loss = 0

if flor.SkipBlock.step_into():
if flor.SkipBlock.step_into('train'):
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
Expand All @@ -98,37 +125,49 @@ def train(epoch):
train_loss += loss.item()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))
print(
'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)
)
)
_, train_loss = flor.SkipBlock.end(model, train_loss)

print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
print(
'====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)
)
)


def test(epoch):
model.eval()
test_loss = 0

if flor.SkipBlock.start():
if flor.SkipBlock.step_into('test'):
with torch.no_grad():
for i, (data, _) in enumerate(test_loader):
data = data.to(device)
recon_batch, mu, logvar = model(data)
test_loss += loss_function(recon_batch, data, mu, logvar).item()
if i == 0:
n = min(data.size(0), 8)
comparison = torch.cat([data[:n],
recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
save_image(comparison.cpu(),
'results/reconstruction_' + str(epoch) + '.png', nrow=n)
comparison = torch.cat([
data[:n],
recon_batch.view(args.batch_size, 1, 28, 28)[:n]
])
save_image(
comparison.cpu(),
'results/reconstruction_' + str(epoch) + '.png',
nrow=n
)
test_loss = flor.SkipBlock.end(test_loss)

test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))


if __name__ == "__main__":
if not os.path.exists('results'):
os.mkdir('results')
Expand All @@ -139,7 +178,7 @@ def test(epoch):
with torch.no_grad():
sample = torch.randn(64, 20).to(device)
sample = model.decode(sample).cpu()
save_image(sample.view(64, 1, 28, 28),
'results/sample_' + str(epoch) + '.png')

flor.flush()
save_image(
sample.view(64, 1, 28, 28),
'results/sample_' + str(epoch) + '.png'
)
1 change: 1 addition & 0 deletions flor/interface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .iterator import it
from .skipblock import SkipBlock
from ..logger import Logger
3 changes: 3 additions & 0 deletions flor/interface/skipblock/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def step_into(block_name: str, probed=False):
def end(*args, values=None):
if flags.NAME is not None:
raise RuntimeError("SkipBlock missing dynamic linking")
if len(args) == 1:
return args[0]
return args

@staticmethod
def bind():
Expand Down
12 changes: 12 additions & 0 deletions flor/interface/skipblock/readblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from flor.journal.entry import Bracket, LBRACKET

from typing import List
import types
import weakref


class ReadBlock(SeemBlock):
Expand All @@ -25,10 +27,12 @@ def step_into(block_name: str, probed=False):
def end(*args, values=None):
lbracket = ReadBlock.pda.pop()
block = journal.as_tree()[lbracket.sk].blocks[lbracket.gk]
return_args = []
bobbyyyan marked this conversation as resolved.
Show resolved Hide resolved
if not lbracket.predicate:
for data_record, arg in zip(block.data_records, args):
data_record.make_val()
value_serialized = data_record.value
return_args.append(arg)
if hasattr(arg, 'load_state_dict'):
# PyTorch support
arg.load_state_dict(value_serialized)
Expand All @@ -44,8 +48,16 @@ def end(*args, values=None):
else:
assert type(arg) == type(value_serialized)
arg.__dict__.update(value_serialized.__dict__)
elif type(arg) in (type(None), int, float, bool, complex, str, tuple,
bytes, frozenset, type, range, slice, property,
types.BuiltinFunctionType, type(Ellipsis), type(NotImplemented),
types.FunctionType, weakref.ref):
return_args[-1] = value_serialized
else:
# TODO: ...
raise RuntimeError("TODO: add hooks for user-defined de-serialization")
# TODO: ...
assert values is None, "TODO: Add support for literals/atomics"
if len(return_args) == 1:
return return_args[0]
return return_args
3 changes: 3 additions & 0 deletions flor/interface/skipblock/writeblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def end(*args, values=None):
else:
rbracket = Bracket(lbracket.sk, lbracket.gk, RBRACKET)
journal.feed(rbracket)
if len(args) == 1:
return args[0]
return args

@staticmethod
def _should_materialize(block_group):
Expand Down
20 changes: 12 additions & 8 deletions flor/journal/file.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from .entry import DataVal, DataRef, Bracket, EOF, make_entry

from flor import shelf
from ..logger import Logger

import json
import pathlib
from typing import Union, List

entries: List[Union[DataRef, DataVal, Bracket, EOF]] = []

def _proc_log_record(log_record):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this function already defined elsewhere in the code? Can we avoid code duplication here? Have the code block be defined in a single spot and used in multiple spots.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find where it's duplicated. Mind pointing me to it?

if isinstance(log_record, DataRef):
log_record.set_ref_and_dump(shelf.get_pkl_ref())
return json.dumps(log_record.jsonify()) + pathlib.os.linesep

entries_logger = Logger()
entries_logger.register_pre(_proc_log_record)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I see what you're doing here. You're being clear about the code that gets run on pre. But the user won't be changing the Flor code base, so this show-and-tell format is better off in a readme or Docs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not too sure what you mean by "show-and-tell format." It's registering a preprocessing step particular to entries_logger but not all loggers.


def read():
with open(shelf.get_index(), 'r') as f:
Expand All @@ -17,16 +25,12 @@ def read():


def feed(journal_entry: Union[DataRef, DataVal, Bracket, EOF]):
entries.append(journal_entry)

if entries_logger.path is None:
entries_logger.set_path(shelf.get_index())
entries_logger.append(journal_entry)

def write():
with open(shelf.get_index(), 'w') as f:
for log_record in entries:
if isinstance(log_record, DataRef):
log_record.set_ref_and_dump(shelf.get_pkl_ref())
f.write(json.dumps(log_record.jsonify()) + pathlib.os.linesep)
entries[:] = []
entries_logger.flush()


def merge():
Expand Down
39 changes: 39 additions & 0 deletions flor/logger/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,40 @@
from .copy import deepcopy
import os
from flor import shelf

class Logger:
def __init__(self, path=None, buf_size=1024):
self.path = path
self.buf_size = buf_size
self.buffer = []
self.preprocess_f = None

def set_path(self, path):
self.path = path

def append(self, *args):
assert self.path is not None, "Logger path not set."
for e in args:
self.buffer.append(e)
if len(self.buffer) >= self.buf_size:
self.flush()

def flush(self):
pid = os.fork()
if not pid:
self._flush_buffer()
else:
self.buffer = []

def _flush_buffer(self):
if self.preprocess_f is not None:
self.buffer = list(map(self.preprocess_f, self.buffer))
with open(self.path, 'w') as f:
for e in self.buffer:
f.write(e)

def force(self):
bobbyyyan marked this conversation as resolved.
Show resolved Hide resolved
self._flush_buffer()

def register_pre(self, func):
self.preprocess_f = func