-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
extend corr_stop function to return a weighted mean of correlations #16
Conversation
nnfabrik/training/trainers.py
Outdated
target, output = np.array([]), np.array([]) | ||
target, output = torch.empty(0), torch.empty(0) | ||
for images, responses in loader[data_key]: | ||
output = torch.cat((output, (model(images.to(device), data_key).detach().cpu())), dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when you pass in data_key
into the model, be sure to pass it as keyword argument (e.g. data_key=data_key
) rather than a positional argument.
""" | ||
computes model predictions for a given dataloader and a model | ||
Returns: | ||
target: ground truth, i.e. neuronal firing rates of the neurons | ||
output: responses as predicted by the network | ||
""" | ||
target, output = np.array([]), np.array([]) | ||
target, output = torch.empty(0), torch.empty(0) | ||
for images, responses in loader[data_key]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes that the loader only return inputs and responses and will break if applied on a dataloader that has more parts (e.g. eye movements). Rather, use something like the following:
for data_batch in loader[data_key]:
# put into a dictionary and also move to target device already
data_batch = {f: getattr(data_batch, f).to(device) for f in data_batch._fields}
# extract inputs and targets that are supposed to always exist
inputs, targets = data_batch.pop('inputs'), data_batch.pop('targets')
output = torch.cat((output, (model(inputs, data_key=data_key, **data_batch).detach().cpu())), dim=0)
target, output = torch.empty(0), torch.empty(0) | ||
for images, responses in loader[data_key]: | ||
output = torch.cat((output, (model(images.to(device), data_key).detach().cpu())), dim=0) | ||
target = torch.cat((target, responses.detach().cpu()), dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you are going to use my codesnippet above, I called it targets
instead of responses
already.
all_correlations = np.append(all_correlations, ret) | ||
else: | ||
n_neurons[0,i] = output.shape[1] | ||
correlations[i,0] = ret.mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't it make more sense too just get sum of correlations here and later just divide by the total number of neurons?
desc='Epoch {}'.format(epoch)): | ||
|
||
loss = full_objective(model, data_key, **data) | ||
loss = full_objective(model, data_key, *data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a dangerous assumption about the order of elements returned from the dataloader. You are assuming that it always comes back as inputs, targets, other1, other2,...
but this is not guaranteed. Rather we should construct a dictionary to properly deal with this. Refer to my comments above on how to properly pass a dictionary constructed from the namedtuple.
|
||
model, epoch = run(model=model, | ||
full_objective=full_objective, | ||
optimizer=optimizer, | ||
scheduler=scheduler, | ||
stop_closure=stop_closure, | ||
train_loader=train_loader, | ||
train=train, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I think this argument was better off staying as train_loader
. I just didn't want the dictiionary of data loaders' key to be train_loader
.
No description provided.