Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pretrained Swin-Transformer for multiple output #9

Closed
imanuelroz opened this issue Oct 16, 2021 · 2 comments
Closed

Pretrained Swin-Transformer for multiple output #9

imanuelroz opened this issue Oct 16, 2021 · 2 comments

Comments

@imanuelroz
Copy link

imanuelroz commented Oct 16, 2021

Hi rishigami,

Thank you for the implementation in Tensorflow. I am trying to use the Swin Transformer for a classification problem with multiple outputs. In your guide on how to use a pertained model you put it in a Sequential mode, but in this way I am not able to stack multiple dense layer for the multiple classification, could you help me understand how can I adapt your TF code to my problem, using it in a Functional API way maybe?

@orilifs
Copy link

orilifs commented Oct 19, 2021

Hello,

Thank you for the implementation :)
I too am interested in the above question please.

@rishigami
Copy link
Owner

@imanuelroz @orilifs
Hi, you can implement by using Subclassing API.

class CustomModel(tf.keras.Model):
    def __init__(self, output_num1: int, output_num2: int):
        super().__init__()
        self.preprocess = tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3])
        self.backbone = SwinTransformer('swin_tiny_224', include_top=False, pretrained=True)
        self.dense1 = tf.keras.layers.Dense(output_num1)
        self.dense2 = tf.keras.layers.Dense(output_num2)

    def call(self, inputs):
        x = self.preprocess(inputs)
        x = self.backbone(x)
        output1 = self.dense1(x)
        output2 = self.dense2(x)
        return output1, output2

Colab example here
https://colab.research.google.com/drive/1v1yrlaQUDluwJvBsgOBfJPlmmDW7-fYk?usp=sharing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants