Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About layer regularization. #39

Closed
Tyler-D opened this issue Jan 23, 2019 · 12 comments
Closed

About layer regularization. #39

Tyler-D opened this issue Jan 23, 2019 · 12 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@Tyler-D
Copy link
Contributor

Tyler-D commented Jan 23, 2019

I'm curious why there is no regularizer option for layer? Is that a trick that training without regularization ?

@qubvel
Copy link
Owner

qubvel commented Jan 23, 2019

Hi, @Tyler-D
Did you mean a possibility to add regularisation for all convolution layers of the model?

@Tyler-D
Copy link
Contributor Author

Tyler-D commented Jan 24, 2019

Well, I think it would be better if there is a function that adding specific regularizer to all layers.

@qubvel qubvel added the enhancement New feature or request label Jan 24, 2019
@qubvel
Copy link
Owner

qubvel commented Jan 24, 2019

According to this and this issues it can be implemented as follows:

def set_regularization(model, 
                       kernel_regularizer=None, 
                       bias_regularizer=None, 
                       activity_regularizer=None):
    
    for layer in model.layers:
        
        # set kernel_regularizer
        if kernel_regularizer is not None and hasattr(layer, 'kernel_regularizer'):
            layer.kernel_regularizer = kernel_regularizer

        # set bias_regularizer
        if bias_regularizer is not None and hasattr(layer, 'bias_regularizer'):
            layer.bias_regularizer = bias_regularizer

        # set activity_regularizer
        if activity_regularizer is not None and hasattr(layer, 'activity_regularizer'):
            layer.activity_regularizer = activity_regularizer

# exmaple
set_regularization(model, kernel_regularizer=keras.regularizers.l2(0.0001))
model.compile(...)  # you have to recompile model if regularization is changed

I did not test this code, if it works it can be added as utils function.

@qubvel qubvel added the help wanted Extra attention is needed label Jan 24, 2019
@Tyler-D
Copy link
Contributor Author

Tyler-D commented Jan 25, 2019

Cool, that's exactly the function I want. I could help to add it, what kind of test you needed?

@Tyler-D
Copy link
Contributor Author

Tyler-D commented Jan 25, 2019

Actually, I'm thinking if there is possibility to build a segmentation task pipeline upon your repo including: train, evaluation, some data-loader for public dataset (e.g. pascal-voc, coco) and even an export tool to export the keras model to inference framework (e.g TensorRT). Then I'm sure this repository can be extremely appealing.

@qubvel
Copy link
Owner

qubvel commented Jan 25, 2019

Just test that it works as expected:

Regularization appears in conv/dense layers and applied during training.
Saved/loaded model has regularization.

@qubvel
Copy link
Owner

qubvel commented Jan 25, 2019

Segmentation pipeline is a cool idea, however I think it should be build in other repo or written as an example part here.
If you can recommend any cool repos with such kind of pipeline it would be extremly helpful! 😄

@Tyler-D
Copy link
Contributor Author

Tyler-D commented Jan 25, 2019

I've tried the code you offered in my train scripts and thing is that only the model config is changed. And after investigation, I found this. And a workround can be found here:

def create_model():
    model = your_model()
    model.save_weights("tmp.h5")

    # optionally do some other modifications (freezing layers, adding convolutions etc.)
    ....

    regularizer = l2(WEIGHT_DECAY / 2)
    for layer in model.layers:
        for attr in ['kernel_regularizer', 'bias_regularizer']:
            if hasattr(layer, attr) and layer.trainable:
                setattr(layer, attr, regularizer)

    out = model_from_json(model.to_json())
    out.load_weights("tmp.h5", by_name=True)

    return  out

It seems not an elegant way to do the things. I'm thinking how to refactor it.

@qubvel
Copy link
Owner

qubvel commented Jan 25, 2019

Yes, I agree. this is not elegant way..
Another not elegant way, but at least do not require model saving:

def set_regularization(model, 
                       kernel_regularizer=None, 
                       bias_regularizer=None, 
                       activity_regularizer=None):
    
    for layer in model.layers:
        
        # set kernel_regularizer
        if kernel_regularizer is not None and hasattr(layer, 'kernel_regularizer'):
            layer.kernel_regularizer = kernel_regularizer

        # set bias_regularizer
        if bias_regularizer is not None and hasattr(layer, 'bias_regularizer'):
            layer.bias_regularizer = bias_regularizer

        # set activity_regularizer
        if activity_regularizer is not None and hasattr(layer, 'activity_regularizer'):
            layer.activity_regularizer = activity_regularizer

    out = model_from_json(model.to_json())
    out.set_weights(model.get_weights())

    return out

new_model = set_regularization(model, kernel_regularizer=keras.regularizers.l2(0.0001))
new_model.compile(...) 

@Tyler-D
Copy link
Contributor Author

Tyler-D commented Feb 18, 2019

Hi @qubvel . I've tested the new implementation, and it works well! You can add it #54 .

qubvel added a commit that referenced this issue Feb 21, 2019
@qubvel qubvel closed this as completed Feb 21, 2019
@qubvel
Copy link
Owner

qubvel commented Feb 22, 2019

Hi @Tyler-D, ok, no problem

@mathmanu
Copy link

Try this:

# a utility function to add weight decay after the model is defined.
def add_weight_decay(model, weight_decay):
	if (weight_decay is None) or (weight_decay == 0.0):
		return

	# recursion inside the model
	def add_decay_loss(m, factor):
		if isinstance(m, tf.keras.Model):
			for layer in m.layers:
				add_decay_loss(layer, factor)
		else:
			for param in m.trainable_weights:
				with tf.keras.backend.name_scope('weight_regularizer'):
					regularizer = lambda: tf.keras.regularizers.l2(factor)(param)
					m.add_loss(regularizer)

	# weight decay and l2 regularization differs by a factor of 2
	add_decay_loss(model, weight_decay/2.0)
	return

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants