Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Too many training images, memory overflow #8

Closed
x12901 opened this issue Aug 20, 2021 · 2 comments
Closed

Too many training images, memory overflow #8

x12901 opened this issue Aug 20, 2021 · 2 comments
Assignees
Labels
good first issue Good for newcomers

Comments

@x12901
Copy link

x12901 commented Aug 20, 2021

Hi, great project! I have 8000 images, and I found that the memory increased a lot during training. My computer has 60G RAM but it is still not enough.

model = SPADE(k=42)  # , backbone_name="hypernet")
train_dataset = StreamingDataset()
app_custom_train_images = "D:\\train\\good"
# train images
for root, dirs, files in os.walk(app_custom_train_images):
    for file in files:
        train_dataset.add_pil_image(Image.open(os.path.join(root, file)))
model.fit(train_dataset)
PATH = "test.pth"
torch.save(model.state_dict(), PATH)

pil_img = plt.open("20210115_raw.png")
img_pil_1 = np.array(pil_img) 
tensor_pil = torch.from_numpy(np.transpose(img_pil_1, (2, 0, 1)))
img = tensor_pil.type(torch.float32)  
img = img.unsqueeze(0) 
pil_to_tensor = img 
img_lvl_anom_score, pxl_lvl_anom_score = model.predict(pil_to_tensor)

The test picture does not seem to need to be transformed. Does it support pictures in other formats?

@rvorias
Copy link
Owner

rvorias commented Aug 20, 2021

Hi x1201, thanks for your interest!

Off the bat, 8000 images is a lot and way more than I've benchmarked. Here is a short overview of where the bulk of the memory goes for each method:

GENERAL

for each backbone, the more feature maps it returns, the more heavy the computations will be.


SPADE

self.zlib <-- stack of feature vectors, float32
self.fmaps <-- list of stacks of feature maps, float32. This will become the bulk of the memory.

solution: use something like mmap (Memory-mapped file support) to build it outside RAM. faiss could also be used. You will likely sacrifice some inference speed.


PADIM

self.patch_lib <-- stack of 2D patches, float32. This will become the main bulk of the memory.
torch.linalg.inv(self.E) <--- this could cause memory issues if your 2D grid is large

solution: online calculation of mean and covar matrix when the samples are added to the training set


PatchCore

self.patch_lib <-- collection of patches, float32. This will become the main bulk of the memory.
coreset selection <-- can also eat quite some memory as you will need to calculate distances between vectors

solution: the authors of the paper use faiss in their implementation, likely because it solves a couple of memory issues as well.


In a nutshell:

  1. implement one of these solutions (and make a PR :) ) or
  2. select a smaller backbone (resnet18, efficientnet_b0) or
  3. reduce your dataset

Second, I see you are using Streamingdataset.
For this to work you'd need to make a train instance and a test instance:

train_dataset = StreamingDataset()
test_dataset = StreamingDataset()

Then you can add samples like this (they are automatically transformed correctly):

for path in train_paths:
  train_dataset.add_pil_image(
      Image.open(path )
  )
for path in test_paths:
  test_dataset.add_pil_image(
      Image.open(path )
  )

For inference on test images, you then call:

test_idx = 0
sample, *_ = test_dataset[test_idx ]
img_lvl_anom_score, pxl_lvl_anom_score = model.predict(sample.unsqueeze(0))

Let me know if it works out!

@rvorias rvorias added the good first issue Good for newcomers label Aug 20, 2021
@rvorias rvorias self-assigned this Aug 20, 2021
@x12901
Copy link
Author

x12901 commented Aug 23, 2021

It worked, thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants