Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in dataloader ? #34

Closed
Brandonnogithub opened this issue Jan 7, 2022 · 3 comments
Closed

Bug in dataloader ? #34

Brandonnogithub opened this issue Jan 7, 2022 · 3 comments

Comments

@Brandonnogithub
Copy link

Hi guys, I am trying to reproducing your work. In the dataloader, I found this code:

for sample_idx in range(self.num_sample):
    for query_idx in range(len(self.query_examples)):
        # If training, exclude the current example. Else keep all.
        if self.use_demo and args.demo_filter:
            # Demonstration filtering
            candidate = [support_idx for support_idx in support_indices
                           if support_idx != query_idx or mode != "train"]
            sim_score = []
            for support_idx in candidate:
                sim_score.append((support_idx, util.pytorch_cos_sim(self.support_emb[support_idx], self.query_emb[query_idx])))
            sim_score.sort(key=lambda x: x[1], reverse=True)
            if self.num_labels == 1:
                # Regression task
                limit_each_label = int(len(sim_score) // 2 * args.demo_filter_rate)
                count_each_label = {'0': 0, '1': 0}
                context_indices = []

                if args.debug_mode:
                    print("Query %s: %s" % (self.query_examples[query_idx].label, self.query_examples[query_idx].text_a)) # debug
                for support_idx, score in sim_score:
                    if count_each_label['0' if float(self.support_examples[support_idx].label) <= median_mapping[args.task_name] else '1'] < limit_each_label:
                        count_each_label['0' if float(self.support_examples[support_idx].label) <= median_mapping[args.task_name] else '1'] += 1
                        context_indices.append(support_idx)
                        if args.debug_mode:
                            print("    %.4f %s | %s" % (score, self.support_examples[support_idx].label, self.support_examples[support_idx].text_a)) # debug
            else:
                limit_each_label = int(len(sim_score) // self.num_labels * args.demo_filter_rate)
                count_each_label = {label: 0 for label in self.label_list}
                context_indices = []

                if args.debug_mode:
                    print("Query %s: %s" % (self.query_examples[query_idx].label, self.query_examples[query_idx].text_a)) # debug
                for support_idx, score in sim_score:
                    if count_each_label[self.support_examples[support_idx].label] < limit_each_label:
                        count_each_label[self.support_examples[support_idx].label] += 1
                        context_indices.append(support_idx)
                        if args.debug_mode:
                            print("    %.4f %s | %s" % (score, self.support_examples[support_idx].label, self.support_examples[support_idx].text_a)) # debug
        else:
            # Using demonstrations without filtering
            context_indices = [support_idx for support_idx in support_indices
                       if support_idx != query_idx or mode != "train"]

        # We'll subsample context_indices further later.
        self.example_idx.append((query_idx, context_indices, sample_idx))

Here it is calculating the similarity.
But I don't know why you use this loop: for sample_idx in range(self.num_sample) at outermost, the sample_idx is only used when you add the result into self.sample_idx

This codes is really slow, since you set the num_sample=16

I think you can remove for sample_idx in range(self.num_sample) and change the last line as

for query_idx in range(len(self.query_examples)):
    ....
    # We'll subsample context_indices further later.
    for sample_idx in range(self.num_sample):
        self.example_idx.append((query_idx, context_indices, sample_idx))

I don't know whether am I right.

@Brandonnogithub
Copy link
Author

In my test, I found after changing, the result is different.

@gaotianyu1350
Copy link
Member

Hi,

I think the two implementations are essentially equivalent. The difference is due to different orders having different random sampling results.

@Brandonnogithub
Copy link
Author

Hi,

I think the two implementations are essentially equivalent. The difference is due to different orders having different random sampling results.

Thanks for your reply.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants