Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/replay #108

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 72 additions & 33 deletions examples/vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,41 @@
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os

import flor


parser = argparse.ArgumentParser(description='VAE MNIST Example')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument(
'--batch-size',
type=int,
default=128,
metavar='N',
help='input batch size for training (default: 128)'
)
parser.add_argument(
'--epochs',
type=int,
default=10,
metavar='N',
help='number of epochs to train (default: 10)'
)
parser.add_argument(
'--no-cuda',
action='store_true',
default=False,
help='disables CUDA training'
)
parser.add_argument(
'--seed', type=int, default=1, metavar='S', help='random seed (default: 1)'
)
parser.add_argument(
'--log-interval',
type=int,
default=10,
metavar='N',
help='how many batches to wait before logging training status'
)
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

Expand All @@ -30,12 +50,19 @@

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.ToTensor()),
batch_size=args.batch_size, shuffle=True, **kwargs)
datasets.MNIST(
'../data', train=True, download=True, transform=transforms.ToTensor()
),
batch_size=args.batch_size,
shuffle=True,
**kwargs
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
batch_size=args.batch_size, shuffle=True, **kwargs)
batch_size=args.batch_size,
shuffle=True,
**kwargs
)


class VAE(nn.Module):
Expand All @@ -53,9 +80,9 @@ def encode(self, x):
return self.fc21(h1), self.fc22(h1)

def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps*std
return mu + eps * std

def decode(self, z):
h3 = F.relu(self.fc3(z))
Expand Down Expand Up @@ -88,7 +115,7 @@ def train(epoch):
model.train()
train_loss = 0

if flor.SkipBlock.step_into():
if flor.SkipBlock.step_into('train'):
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
Expand All @@ -98,37 +125,49 @@ def train(epoch):
train_loss += loss.item()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))
print(
'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)
)
)
_, train_loss = flor.SkipBlock.end(model, train_loss)

print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
print(
'====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)
)
)


def test(epoch):
model.eval()
test_loss = 0

if flor.SkipBlock.start():
if flor.SkipBlock.step_into('test'):
with torch.no_grad():
for i, (data, _) in enumerate(test_loader):
data = data.to(device)
recon_batch, mu, logvar = model(data)
test_loss += loss_function(recon_batch, data, mu, logvar).item()
if i == 0:
n = min(data.size(0), 8)
comparison = torch.cat([data[:n],
recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
save_image(comparison.cpu(),
'results/reconstruction_' + str(epoch) + '.png', nrow=n)
comparison = torch.cat([
data[:n],
recon_batch.view(args.batch_size, 1, 28, 28)[:n]
])
save_image(
comparison.cpu(),
'results/reconstruction_' + str(epoch) + '.png',
nrow=n
)
test_loss = flor.SkipBlock.end(test_loss)

test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))


if __name__ == "__main__":
if not os.path.exists('results'):
os.mkdir('results')
Expand All @@ -139,7 +178,7 @@ def test(epoch):
with torch.no_grad():
sample = torch.randn(64, 20).to(device)
sample = model.decode(sample).cpu()
save_image(sample.view(64, 1, 28, 28),
'results/sample_' + str(epoch) + '.png')

flor.flush()
save_image(
sample.view(64, 1, 28, 28),
'results/sample_' + str(epoch) + '.png'
)
3 changes: 3 additions & 0 deletions flor/interface/skipblock/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def step_into(block_name: str, probed=False):
def end(*args, values=None):
if flags.NAME is not None:
raise RuntimeError("SkipBlock missing dynamic linking")
if len(args) == 1:
return args[0]
return args

@staticmethod
def bind():
Expand Down
12 changes: 12 additions & 0 deletions flor/interface/skipblock/readblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from flor.journal.entry import Bracket, LBRACKET

from typing import List
import types
import weakref


class ReadBlock(SeemBlock):
Expand All @@ -25,10 +27,12 @@ def step_into(block_name: str, probed=False):
def end(*args, values=None):
lbracket = ReadBlock.pda.pop()
block = journal.as_tree()[lbracket.sk].blocks[lbracket.gk]
return_args = []
bobbyyyan marked this conversation as resolved.
Show resolved Hide resolved
if not lbracket.predicate:
for data_record, arg in zip(block.data_records, args):
data_record.make_val()
value_serialized = data_record.value
return_args.append(arg)
if hasattr(arg, 'load_state_dict'):
# PyTorch support
arg.load_state_dict(value_serialized)
Expand All @@ -44,8 +48,16 @@ def end(*args, values=None):
else:
assert type(arg) == type(value_serialized)
arg.__dict__.update(value_serialized.__dict__)
elif type(arg) in (type(None), int, float, bool, complex, str, tuple,
bytes, frozenset, type, range, slice, property,
types.BuiltinFunctionType, type(Ellipsis), type(NotImplemented),
types.FunctionType, weakref.ref):
return_args[-1] = value_serialized
else:
# TODO: ...
raise RuntimeError("TODO: add hooks for user-defined de-serialization")
# TODO: ...
assert values is None, "TODO: Add support for literals/atomics"
if len(return_args) == 1:
return return_args[0]
return return_args
3 changes: 3 additions & 0 deletions flor/interface/skipblock/writeblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def end(*args, values=None):
else:
rbracket = Bracket(lbracket.sk, lbracket.gk, RBRACKET)
journal.feed(rbracket)
if len(args) == 1:
return args[0]
return args

@staticmethod
def _should_materialize(block_group):
Expand Down