Skip to content

performance after fine-tuning #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
doheon114 opened this issue Mar 31, 2025 · 3 comments
Open

performance after fine-tuning #100

doheon114 opened this issue Mar 31, 2025 · 3 comments

Comments

@doheon114
Copy link

Hello,

This model was pre-trained on Protein, DNA, and RNA data collectively. Therefore, I initially expected that fine-tuning with RNA-only data would improve its performance specifically for RNA structure prediction. However, the results did not align with my expectations.

Despite adjusting parameters such as lowering the learning rate and tuning ema_decay, the more I fine-tuned the model, the worse its performance became on the test RNA dataset. In the end, the pretrained model remained the best-performing version.

Do you have any advice on why this might be happening or how I could improve the fine-tuning process?

@zhangyuxuann
Copy link
Collaborator

What is the learning rate set to? how do the RNA-only training and test data construct?

@doheon114
Copy link
Author

doheon114 commented Apr 1, 2025

What is the learning rate set to? how do the RNA-only training and test data construct?

learning rate was set to various rates (0.01, 0.001, 0.0001, 0.00001, 0.000001 ) RNA-only training set was consist of 351 cif files and corresponding msa files(.sto). test sets were only 5 sets, which I really don't care because I want to just see how the model fits into training set.

Image

Image
Opposed to what I expect, loss values are not decreasing.

@zhangyuxuann
Copy link
Collaborator

zhangyuxuann commented Apr 2, 2025

The scale of loss is related to noise-level. When the pre-trained model is already very good, if the log interval is set to 1, the loss will fluctuate more when the number of steps is small (you set 1 above). Below is the curve of setting the log interval 50, which is based on the given example of finetune_demo.sh about running 1000 steps. It can be seen that when fine-tuning on a small amount of data, at least the training set can be further overfitted. The theoretical minimum value of smooth lddt loss is 0.196 (1 - 0.25 * (torch.sigmoid(torch.tensor(0.5)) + torch.sigmoid(torch.tensor(1.0)) + torch.sigmoid(torch.tensor(2.0)) + torch.sigmoid(torch.tensor(4.0))))
Image

export LAYERNORM_TYPE=fast_layernorm
export USE_DEEPSPEED_EVO_ATTENTION=true
# wget -P /af3-dev/release_model/ https://af3-dev.tos-cn-beijing.volces.com/release_model/model_v0.2.0.pt
checkpoint_path="/af3-dev/release_model/model_v0.2.0.pt"

python3 ./runner/train.py \
--run_name protenix_finetune \
--seed 42 \
--base_dir ./output \
--dtype bf16 \
--project protenix \
--use_wandb true \
--diffusion_batch_size 48 \
--eval_interval 1000 \
--log_interval 50 \
--eval_first false \
--checkpoint_interval 2000 \
--ema_decay 0.999 \
--train_crop_size 384 \
--max_steps 100000 \
--warmup_steps 100 \
--lr 0.0001 \
--sample_diffusion.N_step 20 \
--load_checkpoint_path ${checkpoint_path} \
--load_ema_checkpoint_path ${checkpoint_path} \
--data.train_sets weightedPDB_before2109_wopb_nometalc_0925 \
--data.weightedPDB_before2109_wopb_nometalc_0925.base_info.pdb_list examples/finetune_subset.txt \
--data.test_sets recentPDB_1536_sample384_0925,posebusters_0925

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants