Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extend corr_stop function to return a weighted mean of correlations #16

Merged
merged 6 commits into from
Nov 16, 2019

Conversation

KonstantinWilleke
Copy link
Collaborator

No description provided.

target, output = np.array([]), np.array([])
target, output = torch.empty(0), torch.empty(0)
for images, responses in loader[data_key]:
output = torch.cat((output, (model(images.to(device), data_key).detach().cpu())), dim=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when you pass in data_key into the model, be sure to pass it as keyword argument (e.g. data_key=data_key) rather than a positional argument.

"""
computes model predictions for a given dataloader and a model
Returns:
target: ground truth, i.e. neuronal firing rates of the neurons
output: responses as predicted by the network
"""
target, output = np.array([]), np.array([])
target, output = torch.empty(0), torch.empty(0)
for images, responses in loader[data_key]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes that the loader only return inputs and responses and will break if applied on a dataloader that has more parts (e.g. eye movements). Rather, use something like the following:

for data_batch in loader[data_key]:
    # put into a dictionary and also move to target device already
    data_batch = {f: getattr(data_batch, f).to(device) for f in data_batch._fields}
    # extract inputs and targets that are supposed to always exist
    inputs, targets = data_batch.pop('inputs'), data_batch.pop('targets')
    output = torch.cat((output, (model(inputs, data_key=data_key, **data_batch).detach().cpu())), dim=0)

target, output = torch.empty(0), torch.empty(0)
for images, responses in loader[data_key]:
output = torch.cat((output, (model(images.to(device), data_key).detach().cpu())), dim=0)
target = torch.cat((target, responses.detach().cpu()), dim=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you are going to use my codesnippet above, I called it targets instead of responses already.

all_correlations = np.append(all_correlations, ret)
else:
n_neurons[0,i] = output.shape[1]
correlations[i,0] = ret.mean()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't it make more sense too just get sum of correlations here and later just divide by the total number of neurons?

desc='Epoch {}'.format(epoch)):

loss = full_objective(model, data_key, **data)
loss = full_objective(model, data_key, *data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a dangerous assumption about the order of elements returned from the dataloader. You are assuming that it always comes back as inputs, targets, other1, other2,... but this is not guaranteed. Rather we should construct a dictionary to properly deal with this. Refer to my comments above on how to properly pass a dictionary constructed from the namedtuple.


model, epoch = run(model=model,
full_objective=full_objective,
optimizer=optimizer,
scheduler=scheduler,
stop_closure=stop_closure,
train_loader=train_loader,
train=train,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I think this argument was better off staying as train_loader. I just didn't want the dictiionary of data loaders' key to be train_loader.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants