Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

63 test results #78

Merged
merged 9 commits into from
Mar 3, 2021
2 changes: 1 addition & 1 deletion platalea/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def val_loss():
if 'epsilon_decay' in config.keys():
# Save full model for inference
torch.save(net, 'net.best.pt')

return {'validation loss': val_loss().item()}

def get_default_config(hidden_size_factor=1024):
fd = D.Flickr8KData
Expand Down
4 changes: 4 additions & 0 deletions platalea/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def val_loss(net):

debug_logging_active = logging.getLogger().isEnabledFor(logging.DEBUG)

loss_value = None
with open("result.json", "w") as out:
for epoch in range(1, config['epochs']+1):
cost = Counter()
Expand Down Expand Up @@ -170,6 +171,9 @@ def val_loss(net):
result["validation loss"] = validation_loss
wandb.log(result)

# Return loss of the final model for automated testing
return {'final_loss': loss_value}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we return different criteria (validation_loss, final_loss) for asr and basic experiments? I would make this consistent by returning the same criteria or both.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe better, do as in (e.g.) mtl.py and save and return all intermediate scores.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have a look if it can always return the result dict (that is written to json anyway).

Copy link
Contributor

@bhigy bhigy Feb 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might have to save the results at each epoch.

Little reminder: please also check experiments/flickr8k/pip_ind.py.



DEFAULT_CONFIG = dict(SpeechEncoder=dict(conv=dict(in_channels=39, out_channels=64, kernel_size=6, stride=2, padding=0,
bias=False),
Expand Down
2 changes: 1 addition & 1 deletion platalea/experiments/flickr8k/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@
l2_regularization=args.l2_regularization,)

logging.info('Training')
M.experiment(net, data, run_config, slt=data['train'].dataset.is_slt())
result = M.experiment(net, data, run_config, slt=data['train'].dataset.is_slt())
2 changes: 1 addition & 1 deletion platalea/experiments/flickr8k/basic_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@
l2_regularization=args.l2_regularization,)

logging.info('Training')
M.experiment(net, data, run_config)
result = M.experiment(net, data, run_config)
2 changes: 1 addition & 1 deletion platalea/experiments/flickr8k/mtl_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,4 @@
dict(name='ASR', net=net.SpeechTranscriber, data=data, eval=scorer)]

logging.info('Training')
M.experiment(net, tasks, run_config)
result = M.experiment(net, tasks, run_config)
2 changes: 1 addition & 1 deletion platalea/experiments/flickr8k/mtl_st.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@
dict(name='ST', net=net.SpeechText, data=data, eval=score_speech_text)]

logging.info('Training')
M.experiment(net, tasks, run_config)
result = M.experiment(net, tasks, run_config)
2 changes: 1 addition & 1 deletion platalea/experiments/flickr8k/pip_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,6 @@
l2_regularization=args.l2_regularization)

logging.info('Training text-image')
M2.experiment(net, data, run_config)
result = M2.experiment(net, data, run_config)
copyfile('result.json', 'result_text_image.json')
copy_best('.', 'result_text_image.json', 'ti.best.pt')
2 changes: 1 addition & 1 deletion platalea/experiments/flickr8k/text_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@
l2_regularization=args.l2_regularization,)

logging.info('Training')
M.experiment(net, data, run_config)
result = M.experiment(net, data, run_config)
2 changes: 1 addition & 1 deletion platalea/experiments/flickr8k/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,4 @@ def __new__(cls, value):
logged_config['encoder_config'].pop('SpeechEncoder') # Object info is redundant in log.

logging.info('Training')
M.experiment(net, data, run_config, wandb_project='platalea_transformer', wandb_log=logged_config)
result = M.experiment(net, data, run_config, wandb_project='platalea_transformer', wandb_log=logged_config)
2 changes: 2 additions & 0 deletions platalea/mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,5 @@ def experiment(net, tasks, config):
# Saving model
logging.info("Saving model in net.{}.pt".format(epoch))
torch.save(net, "net.{}.pt".format(epoch))

return result
2 changes: 2 additions & 0 deletions platalea/text_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def val_loss():
optimizer = create_optimizer(config, net_parameters)
scheduler = create_scheduler(config, optimizer, data)

result = None
with open("result.json", "w") as out:
for epoch in range(1, config['epochs']+1):
cost = Counter()
Expand All @@ -101,6 +102,7 @@ def val_loss():
print('', file=out, flush=True)
logging.info("Saving model in net.{}.pt".format(epoch))
torch.save(net, "net.{}.pt".format(epoch))
return result


def get_default_config(hidden_size_factor=1024):
Expand Down
38 changes: 37 additions & 1 deletion tests/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_transformer_experiment():
'--trafo_d_model=4',
'--trafo_feedforward_dim=4']):
import platalea.experiments.flickr8k.transformer
assert platalea.experiments.flickr8k.transformer.result == {'final_loss': 0.5153712034225464}


def test_basic_default_experiment():
Expand All @@ -38,6 +39,7 @@ def test_basic_default_experiment():
f'--flickr8k_root={flickr1d_path}',
'--hidden_size_factor=4']):
import platalea.experiments.flickr8k.basic_default
assert platalea.experiments.flickr8k.basic_default.result == {'final_loss': 0.41894787549972534}


def test_mtl_asr_experiment():
Expand All @@ -47,6 +49,23 @@ def test_mtl_asr_experiment():
f'--flickr8k_root={flickr1d_path}',
'--hidden_size_factor=4']):
import platalea.experiments.flickr8k.mtl_asr
assert platalea.experiments.flickr8k.mtl_asr.result == {
'ASR': {'cer': {'CER': 6.791171477079796,
'Cor': 0,
'Del': 0,
'Ins': 3411,
'Sub': 589},
'wer': {'Cor': 0,
'Del': 118,
'Ins': 0,
'Sub': 10,
'WER': 1.0}},
'SI': {'medr': 1.5,
'recall': {1: 0.5,
5: 1.0,
10: 1.0}},
'epoch': 1,
}


def test_mtl_st_experiment():
Expand All @@ -56,15 +75,20 @@ def test_mtl_st_experiment():
f'--flickr8k_root={flickr1d_path}',
'--hidden_size_factor=4']):
import platalea.experiments.flickr8k.mtl_st
assert platalea.experiments.flickr8k.mtl_st.result == {'SI': {'medr': 2.0, 'recall': {1: 0.4, 5: 1.0, 10: 1.0}},
'ST': {'medr': 6.0, 'recall': {1: 0.0, 5: 0.5, 10: 1.0}},
'epoch': 1}


def test_asr_experiment():
with unittest.mock.patch('sys.argv', ['[this gets ignored]',
'--epochs=1',
'-c', f'{flickr1d_path}/config.yml',
f'--flickr8k_root={flickr1d_path}',
'--hidden_size_factor=4']):
'--hidden_size_factor=4',
'--epsilon_decay=0.001']):
import platalea.experiments.flickr8k.asr
assert platalea.experiments.flickr8k.asr.result == {'validation loss': 4.364380836486816}
# save output of this experiment to serve as input for pip_ind and pip_seq


Expand All @@ -75,6 +99,10 @@ def test_text_image_experiment():
f'--flickr8k_root={flickr1d_path}',
'--hidden_size_factor=4']):
import platalea.experiments.flickr8k.text_image
assert platalea.experiments.flickr8k.text_image.result == {
'epoch': 1,
'medr': 1.5,
'recall': {1: 0.5, 5: 1.0, 10: 1.0}}
# save output of this experiment to serve as input for pip_ind


Expand All @@ -88,6 +116,12 @@ def test_pip_ind_experiment():
# '--text_image_model_dir={text_image_out_path}'
]):
import platalea.experiments.flickr8k.pip_ind
assert platalea.experiments.flickr8k.pip_ind.result == {
'ranks': [1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
'recall': {
1: [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
5: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
10: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}}


def test_pip_seq_experiment():
Expand All @@ -100,3 +134,5 @@ def test_pip_seq_experiment():
# '--asr_model_dir={asr_out_path}'
]):
import platalea.experiments.flickr8k.pip_seq
assert platalea.experiments.flickr8k.pip_seq.result == {'epoch': 1, 'medr': 1.5,
'recall': {1: 0.5, 5: 1.0, 10: 1.0}}