Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom BERT Model outputs #22

Closed
ashdtu opened this issue Feb 28, 2024 · 5 comments · Fixed by #31
Closed

Custom BERT Model outputs #22

ashdtu opened this issue Feb 28, 2024 · 5 comments · Fixed by #31
Assignees

Comments

@ashdtu
Copy link
Contributor

ashdtu commented Feb 28, 2024

For flexibility to fine-tune in downstream tasks, we should have the following options in the BERT family model outputs:

  1. Return output of a pooling layer on top of [CLS] token embedding via a user configurable flag.(add_pooling_layer=True) like transformers API.
  2. Return the full last hidden states of encoder layer : [Batch, Sequence, Embedding Dim] instead of just the [CLS] token embedding through a user configurable flag.
@ashdtu ashdtu self-assigned this Feb 28, 2024
@bkonkle
Copy link
Contributor

bkonkle commented Apr 11, 2024

I have a quick first pass at point 1 in a fork, based on how rust-bert handles it: https://github.com/bkonkle/burn-models/blob/707153f5ef1f1f2e8478cebf45e3ca58247d8348/bert-burn/src/model.rs#L168-L171

How would I approach point 2?

@bkonkle
Copy link
Contributor

bkonkle commented Apr 20, 2024

I'm making more progress on the very beginnings of a transformers-style library for Burn using traits for pipeline implementations, but in my WIP testing so far I'm having trouble with learning not working correctly. It doesn't seem to be using the pre-trained weights form bert-base-uncased correctly, so accuracy fluctuates around 25% to 50%.

https://github.com/bkonkle/burn-transformers

This is using my branch with pooled Bert output. The branch doesn't currently build, but I plan to do more work on it this week to fix that and get a good example in place for feedback.

@nathanielsimard
Copy link
Member

Awesome @bkonkle! I think the current implementation is using RoBERTa weights instead of BERT, so maybe this isn't compatible with the BERT weights for the classification head. Not sure if this helps, but if you find something not working, make sure to test multiple backends and report a bug if there are differences.

@bkonkle
Copy link
Contributor

bkonkle commented May 2, 2024

Okay, I believe I understand goal 2 better now. I was thinking this meant a flag for all_hidden_states, like this flag in Huggingface's transformers library. I now believe that this means just the full Tensor from the last hidden_states value, like this property in Huggingface's transformers. This would correspond with the final x value in Burn's Transformer Encoder, here.

If my interpretation is correct, I believe the approach in my fork here addresses this by returning both the last hidden states and the optional pooled output if available.

Update: Solved - see the next comment below.

Previous troubleshooting details

From what I can tell, the Bert model here should support `bert-base-uncased` without any issues using the additional pooler layer (which is also loaded from the safetensors file) despite being originally written for RoBERTa. Unfortunately, I'm still getting really poor accuracy when loading from safetensors and then fine-tuning on the Snips dataset.

The training loop I use is defined here, in my early-stage port of transformers in the text-classification pipeline.

======================== Learner Summary ========================
Model: Model[num_params=109489161]
Total Epochs: 10


| Split | Metric        | Min.     | Epoch    | Max.     | Epoch    |
|-------|---------------|----------|----------|----------|----------|
| Train | Loss          | 2.004    | 10       | 2.269    | 1        |
| Train | Learning Rate | 0.000    | 10       | 0.000    | 1        |
| Train | Accuracy      | 12.050   | 1        | 31.490   | 10       |
| Valid | Loss          | 1.955    | 10       | 2.205    | 1        |
| Valid | Accuracy      | 18.000   | 1        | 40.000   | 10       |

The learning rate seems like a clue - it shouldn't be zero, right?

By comparison, the training implementation for the Snips dataset in the JointBERT repo (which just combines text and token classification into one) hits accuracy in the 90% range within the first few epochs.

I'll probably move on to token classification for now and come back to text classification once I get some feedback and try to figure out what's going wrong. 😅

Thank you for any guidance you can provide!

@bkonkle
Copy link
Contributor

bkonkle commented May 3, 2024

The learning rate was indeed a hint. I had it set way too low, based on the default value in the JointBERT repo I was learning from. 😅 Setting the learning rate to 1e-2 solves my problem, so I think my branch is ready for some review to see if this is the right approach to enabling custom BERT model outputs. 👍

======================== Learner Summary ========================
Model: Model[num_params=109489161]
Total Epochs: 10

Split Metric Min. Epoch Max. Epoch
Train Loss 0.003 10 0.246 1
Train Accuracy 92.540 1 99.900 10
Train Learning Rate 0.000 10 0.000 1
Valid Loss 0.060 10 0.109 8
Valid Accuracy 96.900 8 98.100 10

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants