Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

computing alignment and uniformity #120

Closed
lephong opened this issue Nov 17, 2021 · 6 comments
Closed

computing alignment and uniformity #120

lephong opened this issue Nov 17, 2021 · 6 comments

Comments

@lephong
Copy link

lephong commented Nov 17, 2021

I'm following Wang and Isola to compute alignment and uniformity (using their given code in Fig 5, http://proceedings.mlr.press/v119/wang20k/wang20k.pdf) to reproduce Fig 2 in your paper but fail. What I saw is that the alignment decreases whereas the uniformity is almost unchanged, which is completely different from Fig 2. Details are below.

To compute alignment and uniformity, I changed line 66-79 file SimCSE/blob/main/SentEval/senteval/sts.py by adding the code from Wang and Isola:

            ...
            input1, input2, gs_scores = self.data[dataset]
            all_enc1 = []
            all_enc2 = []
            for ii in range(0, len(gs_scores), params.batch_size):
                batch1 = input1[ii:ii + params.batch_size]
                batch2 = input2[ii:ii + params.batch_size]

                # we assume get_batch already throws out the faulty ones
                if len(batch1) == len(batch2) and len(batch1) > 0:
                    enc1 = batcher(params, batch1)
                    enc2 = batcher(params, batch2)

                    all_enc1.append(enc1.detach())
                    all_enc2.append(enc2.detach())
                    ...
            
             def _norm(x, eps=1e-8): 
                xnorm = torch.linalg.norm(x, dim=-1)
                xnorm = torch.max(xnorm, torch.ones_like(xnorm) * eps)
                return x / xnorm.unsqueeze(dim=-1)

            # from Wang and Isola (with a bit of modification)
            # only consider pairs with gs > 4 (from footnote 3)
            def _lalign(x, y, ok, alpha=2):
                return ((_norm(x) - _norm(y)).norm(dim=1).pow(alpha) * ok).sum() / ok.sum()
            
            def _lunif(x, t=2):
                sq_pdist = torch.pdist(_norm(x), p=2).pow(2)
                return sq_pdist.mul(-t).exp().mean().log()

            ok = (torch.Tensor(gs_scores) > 4).int()
            align = _lalign(
                torch.cat(all_enc1), 
                torch.cat(all_enc2), 
                ok).item()

            # consider all sentences (from footnote 3)
            unif = _lunif(torch.cat(all_enc1 + all_enc2)).item()
            logging.info(f'align {align}\t\t uniform {unif}')

The output (which also shows spearman on stsb dev set) is

align 0.2672557830810547 uniform -2.5320491790771484 'eval_stsb_spearman': 0.6410360622426501, 'epoch': 0.01
align 0.2519586384296417 uniform -2.629746913909912 'eval_stsb_spearman': 0.6859433315879646, 'epoch': 0.02
align 0.2449202835559845 uniform -2.5870673656463623 'eval_stsb_spearman': 0.7198291431689111, 'epoch': 0.02
align 0.22248655557632446 uniform -2.557053565979004 'eval_stsb_spearman': 0.7538674335025006, 'epoch': 0.03
align 0.22624073922634125 uniform -2.6622540950775146 'eval_stsb_spearman': 0.7739112284380941, 'epoch': 0.04
align 0.22583454847335815 uniform -2.5768041610717773 'eval_stsb_spearman': 0.7459814500897265, 'epoch': 0.05
align 0.22845414280891418 uniform -2.5601420402526855 'eval_stsb_spearman': 0.7683573046863201, 'epoch': 0.06
align 0.22689573466777802 uniform -2.560364007949829 'eval_stsb_spearman': 0.7766837072148098, 'epoch': 0.06
align 0.22807720303535461 uniform -2.5539987087249756 'eval_stsb_spearman': 0.7692866256106997, 'epoch': 0.07
align 0.20026598870754242 uniform -2.50628399848938 'eval_stsb_spearman': 0.7939010002048291, 'epoch': 0.08
align 0.20466476678848267 uniform -2.535121440887451 'eval_stsb_spearman': 0.8011027122797894, 'epoch': 0.09
align 0.2030458152294159 uniform -2.5547776222229004 'eval_stsb_spearman': 0.8044623693996088, 'epoch': 0.1
align 0.20119303464889526 uniform -2.5325350761413574 'eval_stsb_spearman': 0.8070404405714893, 'epoch': 0.1
align 0.19329915940761566 uniform -2.488903522491455 'eval_stsb_spearman': 0.8220311448535872, 'epoch': 0.11
align 0.19556573033332825 uniform -2.5273373126983643 'eval_stsb_spearman': 0.8183500898254208, 'epoch': 0.12
align 0.19112755358219147 uniform -2.4959402084350586 'eval_stsb_spearman': 0.8146496522216178, 'epoch': 0.13
align 0.18491695821285248 uniform -2.4762508869171143 'eval_stsb_spearman': 0.8088527080054781, 'epoch': 0.14
align 0.19815796613693237 uniform -2.5905373096466064 'eval_stsb_spearman': 0.8333401056438776, 'epoch': 0.14
align 0.1950838416814804 uniform -2.4894299507141113 'eval_stsb_spearman': 0.8293951990138778, 'epoch': 0.15
align 0.19777807593345642 uniform -2.5985066890716553 'eval_stsb_spearman': 0.8268435050866446, 'epoch': 0.16
align 0.2016373723745346 uniform -2.616013765335083 'eval_stsb_spearman': 0.8199602019842832, 'epoch': 0.17
align 0.19906719028949738 uniform -2.57528018951416 'eval_stsb_spearman': 0.8094202934650283, 'epoch': 0.18
align 0.18731220066547394 uniform -2.517271041870117 'eval_stsb_spearman': 0.8231122818777513, 'epoch': 0.18
align 0.18802008032798767 uniform -2.508246421813965 'eval_stsb_spearman': 0.8248523275594679, 'epoch': 0.19
align 0.20015984773635864 uniform -2.4563515186309814 'eval_stsb_spearman': 0.8061084765791668, 'epoch': 0.2
align 0.2015877515077591 uniform -2.5121841430664062 'eval_stsb_spearman': 0.8113328705761889, 'epoch': 0.21
align 0.20187602937221527 uniform -2.5167288780212402 'eval_stsb_spearman': 0.8124173161634701, 'epoch': 0.22
align 0.20096932351589203 uniform -2.5201926231384277 'eval_stsb_spearman': 0.8127754107163266, 'epoch': 0.22
align 0.19966433942317963 uniform -2.5182201862335205 'eval_stsb_spearman': 0.8152261579570365, 'epoch': 0.23
align 0.19897222518920898 uniform -2.557129383087158 'eval_stsb_spearman': 0.8169452712415308, 'epoch': 0.24
...

We can see that alignment drops from 0.26 to less than 0.20 whereas uniformity is still around -2.55. It means that reducing alignment is key, not uniformity. This trend is completely different from Fig 2.

Did you also use the code from Wang and Isola like I did? If possible, could you please provide the code for reproducing alignment and uniformity?

@lephong lephong closed this as completed Nov 17, 2021
@lephong lephong reopened this Nov 17, 2021
@gaotianyu1350
Copy link
Member

The uniformity will drop very fast from the beginning. Can you specify what is your initialization and what's the stride to calculate the uniformity?

@lephong
Copy link
Author

lephong commented Nov 24, 2021

I didn't change anything else except adding some lines to calculate the alignment and uniformity (as mentioned before). More specifically, from run_unsup_example.sh


python train.py \
    --model_name_or_path bert-base-uncased \
    --train_file data/wiki1m_for_simcse.txt \
    --output_dir result/my-unsup-simcse-bert-base-uncased \
    --num_train_epochs 1 \
    --per_device_train_batch_size 64 \
    --learning_rate 3e-5 \
    --max_seq_length 32 \
    --evaluation_strategy steps \
    --metric_for_best_model stsb_spearman \
    --load_best_model_at_end \
    --eval_steps 125 \
    --pooler_type cls \
    --mlp_only_train \
    --overwrite_output_dir \
    --temp 0.05 \
    --do_train \
    --do_eval \
    --fp16 \

For initialisation, I didn't change random seed. So I guess it's 42 from huggingface (don't know, maybe wrong).

@gaotianyu1350
Copy link
Member

If I understand correctly, you calculate the alignment/uniformity every 125 step (the same as validation). In the original paper, we calculate every 10 step, because as I mentioned, the uniformity drops very fast at the beginning of the training.

@lephong
Copy link
Author

lephong commented Nov 24, 2021

ah, so you mean every 10 update steps / batches? I thought it was every 10 * 125 batches.

But even if that's the case, I'm not sure if figure 2 provides a good explanation here because after 125 steps (or 12 little red stars in figure 2), the accuracy (on STSB dev) is only around 60%, which is much lower than 82.5% in the paper. So, I think you can use fig 2 to explain what happens in the very first training phase, but then, the gap of 82.5 - 60 = 22.5% is not explained.

@gaotianyu1350
Copy link
Member

You can use Figure 3 as a reference (although it's not a rigorous comparison because we didn't put CLS BERT representation, which is the initialization for SimCSE into the figure), and it's the uniformity that makes a huge difference.

@lephong
Copy link
Author

lephong commented Nov 24, 2021

that makes sense. thanks

@lephong lephong closed this as completed Nov 24, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants