Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNISTLoader类中的get_batch方法取出的数据存在重复项 #42

Open
opensourcedigest opened this issue Feb 7, 2020 · 1 comment

Comments

@opensourcedigest
Copy link

将代码:

    def get_batch(self, batch_size):
        # 从数据集中随机取出batch_size个元素并返回
        index = np.random.randint(0, np.shape(self.train_data)[0], batch_size)
        return self.train_data[index, :], self.train_label[index]

改为:

    def get_batch(self, batch_size):
        # 从数据集中随机取出batch_size个元素并返回
        # index = np.random.randint(0, np.shape(self.train_data)[0], batch_size)
        index = np.random.choice(np.shape(self.train_data)[0], batch_size, replace=False)
        return self.train_data[index, :], self.train_label[index]

可避免每次获取的数据中不存在重复项。

@huan
Copy link
Collaborator

huan commented Mar 29, 2020

Thank you for the suggestion, could you please kindly submit a Pull Request to fix that?

Appreciate!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants