Skip to content

Commit

Permalink
fix spacing
Browse files Browse the repository at this point in the history
  • Loading branch information
9q9q committed Dec 19, 2023
1 parent b47cf89 commit 3559a83
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sparsecoding/data/datasets/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class FieldDataset(Dataset):
patch_size : int
Side length of patches for sparse dictionary learning.
stride : int, optional
Stride for sampling patches. If not specified, set to `patch_size`
Stride for sampling patches. If not specified, set to `patch_size`
(non-overlapping patches).
"""

Expand Down
57 changes: 57 additions & 0 deletions sparsecoding/data/datasets/vanhateren.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os

from scipy.io import loadmat
import torch
from torch.utils.data import Dataset

from sparsecoding.data.transforms.patch import patchify


class VanHaterenDataset(Dataset):
"""Dataset used in Olshausen & Field (1996).
Paper:
https://courses.cs.washington.edu/courses/cse528/11sp/Olshausen-nature-paper.pdf
Emergence of simple-cell receptive field properties
by learning a sparse code for natural images.
Parameters
----------
root : str
Location to download the dataset to.
patch_size : int
Side length of patches for sparse dictionary learning.
"""

B = 10
C = 1
H = 512
W = 512

def __init__(
self,
root: str,
patch_size: int = 8,
):
self.P = patch_size

root = os.path.expanduser(root)
os.system(f"mkdir -p {root}")
if not os.path.exists(f"{root}/field.mat"):
os.system("wget https://rctn.org/bruno/sparsenet/IMAGES.mat")
os.system(f"mv IMAGES.mat {root}/field.mat")

self.images = torch.tensor(loadmat(f"{root}/field.mat")["IMAGES"]) # [H, W, B]
assert self.images.shape == (self.H, self.W, self.B)

self.images = torch.permute(self.images, (2, 0, 1)) # [B, H, W]
self.images = torch.reshape(self.images, (self.B, self.C, self.H, self.W)) # [B, C, H, W]

self.patches = patchify(patch_size, self.images) # [B, N, C, P, P]
self.patches = torch.reshape(self.patches, (-1, self.C, self.P, self.P)) # [B*N, C, P, P]

def __len__(self):
return self.patches.shape[0]

def __getitem__(self, idx):
return self.patches[idx]
4 changes: 2 additions & 2 deletions sparsecoding/data/transforms/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def patchify(
C is the number of channels,
H is the image height,
W is the image width.
stride : int, optional
Stride between patches in pixel space. If not specified, set to
stride : int, optional
Stride between patches in pixel space. If not specified, set to
`patch_size` (non-overlapping patches).
Returns
Expand Down

0 comments on commit 3559a83

Please sign in to comment.