Skip to content

Commit

Permalink
rebased to 1.2
Browse files Browse the repository at this point in the history
  • Loading branch information
mansishr committed Nov 11, 2021
1 parent 4d5f8da commit 50f1c64
Show file tree
Hide file tree
Showing 11 changed files with 373 additions and 216 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@ cd <openfl_folder>/openfl-tutorials/interactive_api/PyTorch_MVTec_PatchSVDD

1. Run director:
```sh
cd director_folder
./start_director.sh
cd director
bash start_director.sh
```

2. Run envoy:
```sh
cd envoy_folder
./start_envoy.sh env_one shard_config.yaml
cd envoy
bash start_envoy.sh env_one envoy_config.yaml
```

Optional: start second envoy:
- Copy `envoy_folder` to another place and run from there:
- Copy `envoy` to another place and run from there:
```sh
./start_envoy.sh env_two shard_config_two.yaml
bash start_envoy.sh env_two envoy_config_two.yaml
```

3. Run `PatchSVDD_with_Director.ipynb` jupyter notebook:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
params:
cuda_devices: []

optional_plugin_components: {}

shard_descriptor:
template: mvtec_shard_descriptor.MVTecShardDescriptor
params:
data_folder: MVTec_data
rank_worldsize: 1,2
obj: bottle
Original file line number Diff line number Diff line change
Expand Up @@ -10,97 +10,46 @@
from imageio import imread
from PIL import Image

from openfl.interface.interactive_api.shard_descriptor import ShardDataset
from openfl.interface.interactive_api.shard_descriptor import ShardDescriptor


class MVTecShardDescriptor(ShardDescriptor):
"""MVTec Shard descriptor class."""
class MVTecShardDataset(ShardDataset):
"""MVTec Shard dataset class."""

def __init__(self, data_folder: str = 'MVTec_data',
rank_worldsize: str = '1,1',
enforce_image_hw: str = None,
obj: str = '',
mode: str = 'train') -> None:
"""Initialize MVTecShardDescriptor."""
super().__init__()

self.dataset_path = Path.cwd() / data_folder
self.download_data(self.dataset_path)
# Settings for resizing data
self.enforce_image_hw = None
if enforce_image_hw is not None:
self.enforce_image_hw = tuple(int(size) for size in enforce_image_hw.split(','))
# Settings for sharding the dataset
self.rank_worldsize = tuple(int(num) for num in rank_worldsize.split(','))
self.mode = mode
# Train dataset
fpattern = os.path.join(self.dataset_path, f'{obj}/train/*/*.png')
fpaths = sorted(glob(fpattern))
self.train_path = list(fpaths)[self.rank_worldsize[0] - 1::self.rank_worldsize[1]]
# Test dataset
fpattern = os.path.join(self.dataset_path, f'{obj}/test/*/*.png')
fpaths = sorted(glob(fpattern))
fpaths_anom = list(
filter(lambda fpath: os.path.basename(os.path.dirname(fpath)) != 'good', fpaths))
fpaths_good = list(
filter(lambda fpath: os.path.basename(os.path.dirname(fpath)) == 'good', fpaths))
fpaths = fpaths_anom + fpaths_good
self.test_path = fpaths[self.rank_worldsize[0] - 1::self.rank_worldsize[1]]
# Sharding the labels
self.labels = np.zeros(len(fpaths_anom) + len(fpaths_good), dtype=np.int32)
self.labels[:len(fpaths_anom)] = 1 # anomalies
self.labels = self.labels[self.rank_worldsize[0] - 1::self.rank_worldsize[1]]
# Masks
fpattern_mask = os.path.join(self.dataset_path, f'{obj}/ground_truth/*/*.png')
self.mask_path = sorted(glob(fpattern_mask))

def set_mode(self, mode='train'):
"""Set mode for getitem."""
self.mode = mode
if self.mode == 'train':
self.imgs_path = self.train_path
elif self.mode == 'test':
self.imgs_path = self.test_path
else:
raise Exception(f'Wrong mode: {mode}')

@staticmethod
def download_data(data_folder):
"""Download data."""
zip_file_path = data_folder / 'mvtec_anomaly_detection.tar.xz'
if not Path(zip_file_path).exists():
os.makedirs(data_folder, exist_ok=True)
print('Downloading MVTec Dataset...this might take a while')
os.system('wget -nc'
" 'https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094/mvtec_anomaly_detection.tar.xz'" # noqa
f' -O {zip_file_path.relative_to(Path.cwd())}')
print('Downloaded MVTec dataset, untar-ring now')
os.system(f'tar -xvf {zip_file_path.relative_to(Path.cwd())}'
f' -C {data_folder.relative_to(Path.cwd())}')
def __init__(self, images_path,
mask_path, labels,
rank=1,
worldsize=1):
"""Initialize MVTecShardDataset."""
self.rank = rank
self.worldsize = worldsize
self.images_path = images_path[self.rank - 1::self.worldsize]
self.mask_path = mask_path
self.labels = labels[self.rank - 1::self.worldsize]

def __getitem__(self, index):
"""Return a item by the index."""
img = np.asarray(imread(self.imgs_path[index]))
img = np.asarray(imread(self.images_path[index]))
if img.shape[-1] != 3:
img = self.gray2rgb(img)

img = self.resize(img)
img = np.asarray(img)
if self.mode == 'train':
mask = np.full(img.shape, None)
label = 0
if self.mask_path[index]:
mask = np.asarray(imread(self.mask_path[index]))
mask = self.resize(mask)
mask = np.asarray(mask)
else:
if self.mask_path[index]:
mask = np.asarray(imread(self.mask_path[index]))
mask = self.resize(mask)
mask = np.asarray(mask)
label = self.labels[index]
else:
mask = np.full(img.shape, None)
label = self.labels[index]
mask = np.full(img.shape, None)
label = self.labels[index]

return img, mask, label

def __len__(self):
"""Return the len of the dataset."""
return len(self.images_path)

def resize(self, image, shape=(256, 256)):
"""Resize image."""
return np.array(Image.fromarray(image).resize(shape))
Expand All @@ -113,25 +62,89 @@ def gray2rgb(self, images):
images = np.tile(np.expand_dims(images, axis=-1), tile_shape)
return images

def __len__(self):
"""Return the len of the dataset."""
if self.mode == 'train':
return len(self.train_path)
if self.mode == 'test':
return len(self.test_path)

class MVTecShardDescriptor(ShardDescriptor):
"""MVTec Shard descriptor class."""

def __init__(self, data_folder: str = 'MVTec_data',
rank_worldsize: str = '1,1',
obj: str = 'bottle'):
"""Initialize MVTecShardDescriptor."""
super().__init__()

self.dataset_path = Path.cwd() / data_folder
self.download_data()
self.rank, self.worldsize = tuple(int(num) for num in rank_worldsize.split(','))
self.obj = obj

# Calculating data and target shapes
ds = self.get_dataset()
sample, masks, target = ds[0]
self._sample_shape = [str(dim) for dim in sample.shape]
self._target_shape = [str(dim) for dim in target.shape]

def download_data(self):
"""Download data."""
zip_file_path = self.dataset_path / 'mvtec_anomaly_detection.tar.xz'
if not Path(zip_file_path).exists():
os.makedirs(self.dataset_path, exist_ok=True)
print('Downloading MVTec Dataset...this might take a while')
os.system('wget -nc'
" 'https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094/mvtec_anomaly_detection.tar.xz'" # noqa
f' -O {zip_file_path.relative_to(Path.cwd())}')
print('Downloaded MVTec dataset, untar-ring now')
os.system(f'tar -xvf {zip_file_path.relative_to(Path.cwd())}'
f' -C {self.dataset_path.relative_to(Path.cwd())}')

def get_dataset(self, dataset_type='train'):
"""Return a shard dataset by type."""
# Train dataset
if dataset_type == 'train':
fpattern = os.path.join(self.dataset_path, f'{self.obj}/train/*/*.png')
fpaths = sorted(glob(fpattern))
self.images_path = list(fpaths)
self.labels = np.zeros(len(fpaths), dtype=np.int32)
# Test dataset
elif dataset_type == 'test':
fpattern = os.path.join(self.dataset_path, f'{self.obj}/test/*/*.png')
fpaths = sorted(glob(fpattern))
fpaths_anom = list(
filter(lambda fpath: os.path.basename(os.path.dirname(fpath)) != 'good', fpaths))
fpaths_good = list(
filter(lambda fpath: os.path.basename(os.path.dirname(fpath)) == 'good', fpaths))
fpaths = fpaths_anom + fpaths_good
self.images_path = fpaths
self.labels = np.zeros(len(fpaths_anom) + len(fpaths_good), dtype=np.int32)
self.labels[:len(fpaths_anom)] = 1 # anomalies
else:
raise Exception(f'Wrong dataset type: {dataset_type}.'
f'Choose from the list: [train, test]')
# Masks
fpattern_mask = os.path.join(self.dataset_path, f'{self.obj}/ground_truth/*/*.png')
self.mask_path = sorted(glob(fpattern_mask))

return MVTecShardDataset(
images_path=self.images_path,
mask_path=self.mask_path,
labels=self.labels,
rank=self.rank,
worldsize=self.worldsize,
)

@property
def sample_shape(self):
"""Return the sample shape info."""
return ['256', '256', '3']
# return self._sample_shape

@property
def target_shape(self):
"""Return the target shape info."""
return ['256', '256']
# return self._target_shape

@property
def dataset_description(self) -> str:
"""Return the dataset description."""
return (f'MVTec dataset, shard number {self.rank_worldsize[0]}'
f' out of {self.rank_worldsize[1]}')
"""Return the shard dataset description."""
return (f'MVTec dataset, shard number {self.rank}'
f' out of {self.worldsize}')
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
numpy
pillow
pillow
imageio

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash
set -e

fx envoy start -n env_one --disable-tls --shard-config-path shard_config.yaml -dh localhost -dp 50050
fx envoy start -n env_one --disable-tls --envoy-config-path envoy_config.yaml -dh localhost -dp 50050
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ set -e
ENVOY_NAME=$1
DIRECTOR_FQDN=$2

fx envoy start -n "$ENVOY_NAME" --shard-config-path shard_config.yaml -dh"$DIRECTOR_FQDN" -dp 50051 -rc cert/root_ca.crt -pk cert/"$ENVOY_NAME".key -oc cert/"$ENVOY_NAME".crt
fx envoy start -n "$ENVOY_NAME" --envoy-config-path envoy_config.yaml -dh"$DIRECTOR_FQDN" -dp 50050 -rc cert/root_ca.crt -pk cert/"$ENVOY_NAME".key -oc cert/"$ENVOY_NAME".crt

0 comments on commit 50f1c64

Please sign in to comment.