Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN probabilities for step_sample #57

Closed
xiaotingxuan opened this issue Sep 23, 2023 · 4 comments
Closed

NaN probabilities for step_sample #57

xiaotingxuan opened this issue Sep 23, 2023 · 4 comments

Comments

@xiaotingxuan
Copy link

xiaotingxuan commented Sep 23, 2023

When I run experiment on my dataset, sometimes I got the following error .

  File "train.py", line 136, in <module>
    main()
  File "train.py", line 113, in main
    TrainLoop(
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-mtcv/xiaotingxuan/DiffuProj/text-DiffuSeq/train_util.py", line 183, in run_loop
    self.run_step(batch, cond)
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-mtcv/xiaotingxuan/DiffuProj/text-DiffuSeq/train_util.py", line 202, in run_step
    self.forward_backward(batch, cond)
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-mtcv/xiaotingxuan/DiffuProj/text-DiffuSeq/train_util.py", line 249, in forward_backward
    t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-mtcv/xiaotingxuan/DiffuProj/text-DiffuSeq/diffuseq/step_sample.py", line 56, in sample
    indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
  File "mtrand.pyx", line 935, in numpy.random.mtrand.RandomState.choice
ValueError: probabilities contain NaN

the original code is

def sample(self, batch_size, device):
        w = self.weights()
        p = w / np.sum(w)
        indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
  

I add some code to avoid NaN error, but I am not sure if it‘s ok to do that. Has anyone also met this problem, I am really confused, Hope someone can give me some advice

        p = [0 if np.isnan(i) else i for i in p]
@summmeer
Copy link
Collaborator

summmeer commented Oct 9, 2023

Usually the code won't get NaN, but it's ok to add this line.

@zkzhou126
Copy link

Hello!I met the same error, and I use your method. But It makes the sum of p is not 1, there is another error, because of np.random.choice . so I change the code to:

        p = w / np.sum(w)
        if np.sum(np.isnan(p)) > 0:
            p[np.isnan(p)] = 0
            p = p / np.sum(p)
        indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
        indices = th.from_numpy(indices_np).long().to(device)

But I got the NAN error again.
image

This is my train.sh, I just use my dataset and change the dim
image
Hope you can give me some advice

@xiaotingxuan
Copy link
Author

Hello, I also meet Nan error when using my own dataset.
One advice is to set you batch size smaller, I set --bsz 64.(It works for me)
Another advice is to change the code,

        if hasNan:
            print("has Nan prob p=",p)
            size = len(p)
            p = np.ones(p.shape) * (1/size)
            print("new p =",p)

@zkzhou126
Copy link

Ok, I'll try your method, thank you!!!!

commented

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants