Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scan loss stuck, performs worse than pretext #113

Open
mazatov opened this issue Aug 30, 2022 · 14 comments
Open

scan loss stuck, performs worse than pretext #113

mazatov opened this issue Aug 30, 2022 · 14 comments

Comments

@mazatov
Copy link

mazatov commented Aug 30, 2022

Hello,

I'm trying to train the model on my own dataset. I successfully trained the pretext model with very good top20 accuracy (95%, the dataset is pretty simple). However, when I run scan.py the loss gets stuck without any improvement and the final performance is pretty bad (56%). I wonder what could go wrong in scan.py for the loss to get stuck like that? The only things I changed nin the config file were the number of clusters and the crop size.

I also wonder if I should be changing anything here.

update_cluster_head_only: False # Update full network in SCAN
num_heads: 1 # Only use one head

image

@wvangansbeke
Copy link
Owner

Since I don't know any details about your dataset (e.g. imbalance), I can't help much. We only provide results for CIFAR, STL and ImageNet. Try experimenting with these datasets first to know how the loss should behave. Play around with the entropy weight and also see if only updating the clustering head helps.

@mazatov
Copy link
Author

mazatov commented Aug 30, 2022

@wvangansbeke, really appreciate you keeping up with the repository and replying. Great paper and very clear code.

My dataset is quite imbalanced. ~ 250000 samples, 5 classes, the largest class has ~ 150000 samples, the smallest about 10,000 samples. The images are of athletes in different uniforms, so the main difference is the clothes they wear. I change the augmentations to not do color augmentation since it is crucial here. By your intuition, what should the entropy weight be for such an imbalanced dataset?

The pretext part worked like a charm. Really high accuracy pretty quickly. The scan part is where I'm struggling. I was looking at how the loss changes with STL and for my data, it pretty much stays the same after a few epochs with some variation from batch to batch. The accuracy is in the 50s, so it's worse than the pretext. I'm currently trying with entropy weight set to 1.

Good idea about about setting update_cluster_head_only to True. That should at least make it so scan part is not worse than the pretext.

I'll write an update here after I do some more experiments.

@wvangansbeke
Copy link
Owner

wvangansbeke commented Aug 30, 2022

You might want to reduce the entropy weight in this case. Set it to 1 or 2. Note that we applied stronger augmentations for the SCAN part (strong color augmentations etc., see randaugment.py). Not sure if it breaks things for your datasets. Also note that you can't really compare knn accuracy (top20) and classification accuracy directly. Due the heavy imbalance you will have to try out a few things. If the knn's are good, try overclustering.

@mazatov
Copy link
Author

mazatov commented Aug 31, 2022

@wvangansbeke , what do you mean by randaugment.py ? Are you referring to the data\augment.py ? I switched the augmentations for my data by changing this augment_list here.

Is there any other place where augmentations are happening?

@wvangansbeke
Copy link
Owner

wvangansbeke commented Aug 31, 2022

Yes, I meant data/augment.py. That should be the only place.

@mazatov
Copy link
Author

mazatov commented Aug 31, 2022

Thanks @wvangansbeke! One more question: I was doing some more debugging and found out the following.

Two of my classes are somewhat similar. So let's out of five classes, Class 1 is the biggest one, Class 4 is similar to Class 2, and Class 5 is similar to Class 3. What happens with the scan , is that it basically finds 3 clusters, and puts 4 and 5 together with 2 and 3 creating two superclusters. Since it must have 5 classes that I specify in the config, it divides the biggest Class 1 into 3 classes that don't really make sense.

So all of this is making sense to me now. I wonder how I can "encourage" the model to not create superclusters and try to separate the other two classes from similar ones. If you have any tips or suggestions, I would appreciate it. Thanks for all your help!

@wvangansbeke
Copy link
Owner

My suggestion is to overcluster (eg. go to 10 or 20 clusters) and then merge them manually.

@mazatov
Copy link
Author

mazatov commented Sep 1, 2022

Thanks @wvangansbeke! In your paper you mention the following:

Table 3 reports the results when we overestimate
the number of ground-truth classes by a factor of 2, e.g. we cluster CIFAR10 into
20 rather than 10 classes. The classification accuracy remains stable for CIFAR10
(87.6% to 86.2%) and STL10 (76.7% to 76.8%), and improves for CIFAR100-20
(45.9% to 55.1%). We conclude that the approach does not require knowledge
of the exact number of clusters.

I couldn't find more details on implementation, so I wanted to check on how you approached over clustering in CIFAR10 or STL10. Let's say you try to cluster CIFAR10 into 20 clusters. Do you remove all validation in this scenario, since you do not have true labels for 20 clusters? (In that case you wouldn't be able to pick the best model, so I'm surprised the performance doesn't decrease.) Or do you just randomly split every true cluster into 2 fake clusters? Like say you have 100 images of the cluster "dog". Do you just split it into 2 clusters "dog-1" and "dog-2" randomly 50/50 or just cluster without any validation checks?

Another question was at which point, did you merge the classes manually? To me, it seems like the best would be to manually merge them after scan and before the last selflabel step. However, the starting model of selflabel is the one obtained in scan (including the clustering head, right? ), so if you reduce the number of clusters before the selflabel step, you would need to restrain the last layer from scratch 🤔 So maybe, the best is to continue using 20 clusters and manually merge them after self-label? As you can see I'm not 100% sure how this would work haha. Any tips for how to implement this would be highly appreciated.

@mazatov
Copy link
Author

mazatov commented Sep 8, 2022

Hi @wvangansbeke ,

I've been experimenting with self-label part of the code and I ran into an interesting scenario. As self-label learns it converges from 10 classes to just 3, zeroing out the number of samples predicted for other classes. In your opinion, what could be the cause for such behaviour?

I printed out the number of high probability samples per class after every epoch:
print('High probs per class: ', np.sum(probs>0.99, axis = 0))

Epoch 2:
High probs per class: [1371 2819 2211 3704 2897 1163 2403 1157 2189 2634]
Epoch 3:
High probs per class: [1187 0 1397 4594 2455 3059 1984 1290 1874 3023]
Epoch 4:
High probs per class: [ 774 0 504 1526 847 2889 1482 4316 2427 3522]
Epoch 5 and etc.:

High probs per class:  [2607    0 1052 4139    0 3199  787    0 2845 3689]` 
High probs per class:  [1532    0 2619 3773    0 1151 2865    0 2004 3447]
High probs per class:  [1702    0 2095 5602    0 1699 2104    0 2697 3642]
High probs per class:  [2904    0 1972 4649    0 1824 3143    0  826 3561]
High probs per class:  [2386    0 2337 1689    0 1362 2886    0 4602 3589]
High probs per class:  [1335    0 2161 1559    0 1692 1765    0 4515 2706]
High probs per class:  [2296    0 2617  732    0  837 7086    0 4315  710]
High probs per class:  [1090    0 2879 1480    0 1137 1996    0 6035 1558]
High probs per class:  [1364    0 2188 2445    0 3214 1916    0 4653 1800]
High probs per class:  [1072    0 2022  715    0 4081 2368    0 5979 1707]
High probs per class:  [1201    0 2858 1170    0 4760 1746    0 5435 1661]
High probs per class:  [1111    0 2326 1776    0 4515 1956    0 5525 1314]
High probs per class:  [ 989    0 2732 3265    0 3947  948    0 6475 1446]
High probs per class:  [1050    0 2247 3654    0 3363 1081    0 6538 1361]
High probs per class:  [ 743    0 2510  880    0 3031 3387    0 6589 1642]
High probs per class:  [   0    0 4246    0    0    0 5566    0 7598 3427]
High probs per class:  [   0    0 4693    0    0    0 6366    0 6810 3770]
High probs per class:  [   0    0 2239    0    0    0 5227    0 8872 4242]
High probs per class:  [   0    0 5349    0    0    0 3407    0 7329 4919]
High probs per class:  [   0    0 3019    0    0    0 5101    0 5586 2773]
High probs per class:  [   0    0 4220    0    0    0 5747    0 7029 3240]
High probs per class:  [   0    0 8284    0    0    0 1854    0 3613 6357]
High probs per class:  [   0    0 8553    0    0    0 4588    0 3779 3157]
High probs per class:  [   0    0 5617    0    0    0 1895    0 9253    0]
High probs per class:  [   0    0 3942    0    0    0 6485    0 7002    0]
.
.
.
High probs per class:  [    0     0  9028     0     0     0 10017     0  4992     0]

@akshay-iyer
Copy link

Hey @mazatov , I'm curious to know what worked for you? I am facing a similar issue of SCAN loss plateauing and sometimes starting from a negative value and plateauing at 0. My dataset classes are balanced so not sure what the problem could be.

@akshay-iyer
Copy link

akshay-iyer commented Dec 14, 2022

Also, @wvangansbeke, I loved this work and loved your implementation equally! I feel it's done beautifully, took me quiet sometime to understand the work. I have a couple of questions related to the SCAN loss (including but not limited to the scope of OP's problem).

While we are trying to maximize the dot product of a sample and its neighbor, why not also try to minimize the dot product of an image and another image not in the neighbor set? Basically like contrastive learning, also trying to push apart images from different classes

Also another question about SCAN loss:
My understanding of the loss is the first term forces one-hot and consistent cluster predictions of an image and its neighbor - say a dog image and a neighbor image of another dog marked as cluster 1. But what is stopping the network from classifying a cat and its neighbor image of a cat also to be clustered as cluster 1? Since the dot product would still be high. How is homogeneity within clusters maintained?
The entropy term, I understand, will partly alleviate this problem since it will try and maintain equitable cluster sizes. But in case the class sizes are not balanced, entropy will have to be lowered then I'm not very clear on how homogeneity within clusters will still be maintained?

I'm not sure if this is a problem, so do let me know what you think.
And in case it is, will something like my first question of also including a dot product between an image and another image of a different class help in this case?

@mazatov
Copy link
Author

mazatov commented Dec 15, 2022

@akshay-iyer , I couldn't do self-labeling for my application because my classes are so unbalanced. I increased the number of classes so that the first two steps work well. But once you merge classes manually, self-labeling, at least the way it's written, doesn't work. It expects you to have the same amount of classes.

@itsMohitShah
Copy link

Hi @wvangansbeke ,

I've been experimenting with self-label part of the code and I ran into an interesting scenario. As self-label learns it converges from 10 classes to just 3, zeroing out the number of samples predicted for other classes. In your opinion, what could be the cause for such behaviour?

I printed out the number of high probability samples per class after every epoch:
print('High probs per class: ', np.sum(probs>0.99, axis = 0))

Epoch 2:
High probs per class: [1371 2819 2211 3704 2897 1163 2403 1157 2189 2634]
Epoch 3:
High probs per class: [1187 0 1397 4594 2455 3059 1984 1290 1874 3023]
Epoch 4:
High probs per class: [ 774 0 504 1526 847 2889 1482 4316 2427 3522]
Epoch 5 and etc.:

High probs per class:  [2607    0 1052 4139    0 3199  787    0 2845 3689]` 
High probs per class:  [1532    0 2619 3773    0 1151 2865    0 2004 3447]
High probs per class:  [1702    0 2095 5602    0 1699 2104    0 2697 3642]
High probs per class:  [2904    0 1972 4649    0 1824 3143    0  826 3561]
High probs per class:  [2386    0 2337 1689    0 1362 2886    0 4602 3589]
High probs per class:  [1335    0 2161 1559    0 1692 1765    0 4515 2706]
High probs per class:  [2296    0 2617  732    0  837 7086    0 4315  710]
High probs per class:  [1090    0 2879 1480    0 1137 1996    0 6035 1558]
High probs per class:  [1364    0 2188 2445    0 3214 1916    0 4653 1800]
High probs per class:  [1072    0 2022  715    0 4081 2368    0 5979 1707]
High probs per class:  [1201    0 2858 1170    0 4760 1746    0 5435 1661]
High probs per class:  [1111    0 2326 1776    0 4515 1956    0 5525 1314]
High probs per class:  [ 989    0 2732 3265    0 3947  948    0 6475 1446]
High probs per class:  [1050    0 2247 3654    0 3363 1081    0 6538 1361]
High probs per class:  [ 743    0 2510  880    0 3031 3387    0 6589 1642]
High probs per class:  [   0    0 4246    0    0    0 5566    0 7598 3427]
High probs per class:  [   0    0 4693    0    0    0 6366    0 6810 3770]
High probs per class:  [   0    0 2239    0    0    0 5227    0 8872 4242]
High probs per class:  [   0    0 5349    0    0    0 3407    0 7329 4919]
High probs per class:  [   0    0 3019    0    0    0 5101    0 5586 2773]
High probs per class:  [   0    0 4220    0    0    0 5747    0 7029 3240]
High probs per class:  [   0    0 8284    0    0    0 1854    0 3613 6357]
High probs per class:  [   0    0 8553    0    0    0 4588    0 3779 3157]
High probs per class:  [   0    0 5617    0    0    0 1895    0 9253    0]
High probs per class:  [   0    0 3942    0    0    0 6485    0 7002    0]
.
.
.
High probs per class:  [    0     0  9028     0     0     0 10017     0  4992     0]

Hi @mazatov !
I am roughly getting the same output
Most of my data points are being predicted to some select 4-5 classes (out of given 20)
Did you understand why this might be happening and how to resolve it?

@mazatov
Copy link
Author

mazatov commented Jun 7, 2023

Sorry not. For me, I was purposely clustering on more clusters to get the small clusters working.So before self-labeling step I needed to manually merge classes. Never figured out how to run the last step if the number of classes is changed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants