# Federated PatchSVDD algorithm with Director example

# PatchSVDD algorithm
Anomaly detection involves making a binary decision as to whether an input image contains an anomaly, and anomaly segmentation aims to locate the anomaly on the pixel level. The deep learning variant of Support vector data description (SVDD: a long-standing algorithm used for anomaly detection) is used to the patch-based method using self-supervised learning. This extension enables anomaly segmentation and improves detection performances which are measured in AUROC on MVTec AD dataset.

![alt text](https://media.arxiv-vanity.com/render-output/5520416/x4.png "Patch Level SVDD for Anomaly Detection")

* Original paper: https://arxiv.org/abs/2006.16067
* Original Github code: https://github.com/nuclearboy95/Anomaly-Detection-PatchSVDD-PyTorch/tree/934d6238e5e0ad511e2a0e7fc4f4899010e7d892
* MVTec ad dataset download link: https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094/mvtec_anomaly_detection.tar.xz

In [None]:
# Install dependencies if not already installed
!pip install torchvision==0.8.1 matplotlib numpy scikit-image scikit-learn torch tqdm Pillow imageio opencv-python ngt

# Connect to the Federation

In [None]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# 1) Run with API layer - Director mTLS 
# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface
# cert_chain = 'cert/root_ca.crt'
# API_certificate = 'cert/frontend.crt'
# API_private_key = 'cert/frontend.key'

# federation = Federation(client_id='frontend', director_node_fqdn='localhost', director_port='50051',
#                        cert_chain=cert_chain, api_cert=API_certificate, api_private_key=API_private_key)

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(client_id='frontend', director_node_fqdn='localhost', director_port='50050', tls=False)


In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

In [None]:
federation.target_shape

## Creating a FL experiment using Interactive API

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

### Register dataset

In [None]:
#Arguments
args = {
'obj' : 'bottle',
'lambda_value': '1e-3',
'D' : 64,
'lr' : '1e-4',
}

In [None]:
import argparse
import torch
from functools import reduce
from torch.utils.data import DataLoader,Dataset
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from PIL import Image
from imageio import imread
from glob import glob
from sklearn.metrics import roc_auc_score
import os, shutil
import _pickle as p
from contextlib import contextmanager
import PIL
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms as tsf
from utils import to_device, task, DictionaryConcatDataset, crop_chw, cnn_output_size, crop_image_chw
from functools import reduce

In [None]:
def generate_coords(H, W, K):
    h = np.random.randint(0, H - K + 1)
    w = np.random.randint(0, W - K + 1)
    return h, w


def generate_coords_position(H, W, K):
    with task('P1'):
        p1 = generate_coords(H, W, K)
        h1, w1 = p1

    pos = np.random.randint(8)

    with task('P2'):
        J = K // 4

        K3_4 = 3 * K // 4
        h_dir, w_dir = pos_to_diff[pos]
        h_del, w_del = np.random.randint(J, size=2)

        h_diff = h_dir * (h_del + K3_4)
        w_diff = w_dir * (w_del + K3_4)

        h2 = h1 + h_diff
        w2 = w1 + w_diff

        h2 = np.clip(h2, 0, H - K)
        w2 = np.clip(w2, 0, W - K)

        p2 = (h2, w2)

    return p1, p2, pos


def generate_coords_svdd(H, W, K):
    with task('P1'):
        p1 = generate_coords(H, W, K)
        h1, w1 = p1

    with task('P2'):
        J = K // 32

        h_jit, w_jit = 0, 0

        while h_jit == 0 and w_jit == 0:
            h_jit = np.random.randint(-J, J + 1)
            w_jit = np.random.randint(-J, J + 1)

        h2 = h1 + h_jit
        w2 = w1 + w_jit

        h2 = np.clip(h2, 0, H - K)
        w2 = np.clip(w2, 0, W - K)

        p2 = (h2, w2)

    return p1, p2


pos_to_diff = {
    0: (-1, -1),
    1: (-1, 0),
    2: (-1, 1),
    3: (0, -1),
    4: (0, 1),
    5: (1, -1),
    6: (1, 0),
    7: (1, 1)
}


In [None]:
class SVDD_Dataset(Dataset):
    def __init__(self, memmap, K=64, repeat=1):
        super().__init__()
        self.arr = np.asarray(memmap)
        self.K = K
        self.repeat = repeat
        

    def __len__(self):
        N = self.arr.shape[0]
        return N * self.repeat

    def __getitem__(self, idx):
        N = self.arr.shape[0]
        K = self.K
        n = idx % N

        p1, p2 = generate_coords_svdd(256, 256, K)

        image = self.arr[n]

        patch1 = crop_image_chw(image, p1, K)
        patch2 = crop_image_chw(image, p2, K)

        return patch1, patch2

In [None]:
class PositionDataset(Dataset):
    def __init__(self, x, K=64, repeat=1):
        super(PositionDataset, self).__init__()
        self.x = np.asarray(x)
        self.K = K
        self.repeat = repeat

    def __len__(self):
        N = self.x.shape[0]
        return N * self.repeat

    def __getitem__(self, idx):
        N = self.x.shape[0]
        K = self.K
        n = idx % N

        image = self.x[n]
        p1, p2, pos = generate_coords_position(256, 256, K)

        patch1 = crop_image_chw(image, p1, K).copy()
        patch2 = crop_image_chw(image, p2, K).copy()

        # perturb RGB
        rgbshift1 = np.random.normal(scale=0.02, size=(3, 1, 1))
        rgbshift2 = np.random.normal(scale=0.02, size=(3, 1, 1))

        patch1 += rgbshift1
        patch2 += rgbshift2

        # additive noise
        noise1 = np.random.normal(scale=0.02, size=(3, K, K))
        noise2 = np.random.normal(scale=0.02, size=(3, K, K))

        patch1 += noise1
        patch2 += noise2

        return patch1, patch2, pos


In [None]:
class PatchDataset_NCHW(Dataset):
    def __init__(self, memmap, tfs=None, K=32, S=1):
        super().__init__()
        self.arr = memmap
        self.tfs = tfs
        self.S = S
        self.K = K
        self.N = self.arr.shape[0]
    
    def __len__(self):
        return self.N * self.row_num * self.col_num

    @property
    def row_num(self):
        N, C, H, W = self.arr.shape
        K = self.K
        S = self.S
        I = cnn_output_size(H, k=K, s=S)
        return I

    @property
    def col_num(self):
        N, C, H, W = self.arr.shape
        K = self.K
        S = self.S
        J = cnn_output_size(W, k=K, s=S)
        return J

    def __getitem__(self, idx):
        N = self.N
        n, i, j = np.unravel_index(idx, (N, self.row_num, self.col_num))
        K = self.K
        S = self.S
        image = self.arr[n]
        patch = crop_chw(image, i, j, K, S)

        if self.tfs:
            patch = self.tfs(patch)

        return patch, n, i, j


In [None]:
"""
ShardDataset class
"""
class MVTecShardDataset(Dataset):
    
    def __init__(self, dataset):
        self._dataset = dataset
        
    def __getitem__(self, index):
        img, mask, label = self._dataset[index]
        return img, mask, label
    
    def __len__(self):
        return len(self._dataset)
    
class MVTecSD(DataInterface):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    
    @property
    def shard_descriptor(self):
        return self._shard_descriptor
        
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        
        self.train_set = MVTecShardDataset(shard_descriptor.get_dataset('train'))

        self.test_set = MVTecShardDataset(shard_descriptor.get_dataset('test'))    
    
    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        train_x = np.stack([image for image, mask, label in self.train_set]).astype(np.float32)
        mean = train_x.astype(np.float32).mean(axis=0)
        train_x = (train_x.astype(np.float32) - mean) / 255
        train_x = np.transpose(train_x, [0, 3, 1, 2])
    
        if self.kwargs['train_bs']:
            batch_size = self.kwargs['train_bs']
        else:
            batch_size = 64
            
        loader = DataLoader(self.get_train_dataset_dict(train_x), batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
        return loader

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        # We need both train and test data for obtaining embeddings
        train_x = np.stack([image for image, mask, label in self.train_set]).astype(np.float32)
        mean = train_x.astype(np.float32).mean(axis=0)
        train_x = (train_x.astype(np.float32) - mean) / 255
        train_x = np.transpose(train_x, [0, 3, 1, 2])
        
        #getting val loader
        test_x = np.stack([image for image, mask, label in self.test_set]).astype(np.float32)
        mean = test_x.astype(np.float32).mean(axis=0)
        test_x = (test_x.astype(np.float32) - mean) / 255
        test_x = np.transpose(test_x, [0, 3, 1, 2])
        
        masks = np.stack([mask for image, mask, label in self.test_set]).astype(np.int32)
        masks[masks <= 128] = 0
        masks[masks > 128] = 255
        labels = np.stack([label for image, mask, label in self.test_set]).astype(np.int32)

        return (train_x, test_x, masks, labels, mean)


    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)
        

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.test_set)
    
    def get_train_dataset_dict(self,inp_x):
        rep = 100
        datasets = dict()
        datasets[f'pos_64'] = PositionDataset(inp_x, K=64, repeat=rep)
        datasets[f'pos_32'] = PositionDataset(inp_x, K=32, repeat=rep)

        datasets[f'svdd_64'] = SVDD_Dataset(inp_x, K=64, repeat=rep)
        datasets[f'svdd_32'] = SVDD_Dataset(inp_x, K=32, repeat=rep)
        dataset = DictionaryConcatDataset(datasets)
        return dataset


In [None]:
fed_dataset = MVTecSD(train_bs=64, val_bs=64)

### Describe a model and optimizer

In [None]:
import torch.optim as optim
import torch.nn as nn
import torch
import torch.nn.functional as F
import math
from utils import makedirpath

In [None]:
class Encoder(nn.Module):
    def __init__(self, K, D=64, bias=True):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 64, 5, 2, 0, bias=bias)
        self.conv2 = nn.Conv2d(64, 64, 5, 2, 0, bias=bias)
        self.conv3 = nn.Conv2d(64, 128, 5, 2, 0, bias=bias)
        self.conv4 = nn.Conv2d(128, D, 5, 1, 0, bias=bias)

        self.K = K
        self.D = D
        self.bias = bias

    def forward(self, x):
        h = self.conv1(x)
        h = F.leaky_relu(h, 0.1)

        h = self.conv2(h)
        h = F.leaky_relu(h, 0.1)

        h = self.conv3(h)

        if self.K == 64:
            h = F.leaky_relu(h, 0.1)
            h = self.conv4(h)

        h = torch.tanh(h)

        return h

def forward_hier(x, emb_small, K):
    K_2 = K // 2
    n = x.size(0)
    x1 = x[..., :K_2, :K_2]
    x2 = x[..., :K_2, K_2:]
    x3 = x[..., K_2:, :K_2]
    x4 = x[..., K_2:, K_2:]
    xx = torch.cat([x1, x2, x3, x4], dim=0)
    hh = emb_small(xx)

    h1 = hh[:n]
    h2 = hh[n: 2 * n]
    h3 = hh[2 * n: 3 * n]
    h4 = hh[3 * n:]

    h12 = torch.cat([h1, h2], dim=3)
    h34 = torch.cat([h3, h4], dim=3)
    h = torch.cat([h12, h34], dim=2)
    return h


In [None]:
class EncoderDeep(nn.Module):
    def __init__(self, K, D=64, bias=True):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=bias)
        self.conv2 = nn.Conv2d(32, 64, 3, 1, 0, bias=bias)
        self.conv3 = nn.Conv2d(64, 128, 3, 1, 0, bias=bias)
        self.conv4 = nn.Conv2d(128, 128, 3, 1, 0, bias=bias)
        self.conv5 = nn.Conv2d(128, 64, 3, 1, 0, bias=bias)
        self.conv6 = nn.Conv2d(64, 32, 3, 1, 0, bias=bias)
        self.conv7 = nn.Conv2d(32, 32, 3, 1, 0, bias=bias)
        self.conv8 = nn.Conv2d(32, D, 3, 1, 0, bias=bias)

        self.K = K
        self.D = D

    def forward(self, x):
        h = self.conv1(x)
        h = F.leaky_relu(h, 0.1)

        h = self.conv2(h)
        h = F.leaky_relu(h, 0.1)

        h = self.conv3(h)
        h = F.leaky_relu(h, 0.1)

        h = self.conv4(h)
        h = F.leaky_relu(h, 0.1)

        h = self.conv5(h)
        h = F.leaky_relu(h, 0.1)

        h = self.conv6(h)
        h = F.leaky_relu(h, 0.1)

        h = self.conv7(h)
        h = F.leaky_relu(h, 0.1)

        h = self.conv8(h)
        h = torch.tanh(h)

        return h


In [None]:
class EncoderHier(nn.Module):
    def __init__(self, K, D=64, bias=True):
        super().__init__()

        if K > 64:
            self.enc = EncoderHier(K // 2, D, bias=bias)

        elif K == 64:
            self.enc = EncoderDeep(K // 2, D, bias=bias)

        else:
            raise ValueError()

        self.conv1 = nn.Conv2d(D, 128, 2, 1, 0, bias=bias)
        self.conv2 = nn.Conv2d(128, D, 1, 1, 0, bias=bias)

        self.K = K
        self.D = D

    def forward(self, x):
        h = forward_hier(x, self.enc, K=self.K)

        h = self.conv1(h)
        h = F.leaky_relu(h, 0.1)

        h = self.conv2(h)
        h = torch.tanh(h)

        return h


In [None]:
xent = nn.CrossEntropyLoss()

class NormalizedLinear(nn.Module):
    __constants__ = ['bias', 'in_features', 'out_features']

    def __init__(self, in_features, out_features, bias=True):
        super(NormalizedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        with torch.no_grad():
            w = self.weight / self.weight.data.norm(keepdim=True, dim=0)
        return F.linear(x, w, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )


In [None]:
class PositionClassifier(nn.Module):
    def __init__(self, K, D, class_num=8):
        super().__init__()
        self.D = D

        self.fc1 = nn.Linear(D, 128)
        self.act1 = nn.LeakyReLU(0.1)

        self.fc2 = nn.Linear(128, 128)
        self.act2 = nn.LeakyReLU(0.1)

        self.fc3 = NormalizedLinear(128, class_num)
        self.fc3.requires_grad_(False)

        self.K = K

    def forward(self, h1, h2):
        h1 = h1.view(-1, self.D)
        h2 = h2.view(-1, self.D)

        h = h1 - h2

        h = self.fc1(h)
        h = self.act1(h)

        h = self.fc2(h)
        h = self.act2(h)

        h = self.fc3(h)
        return h


In [None]:
"""
Model definition (ensembled)
"""
class MyEnsembledModel(nn.Module):
    def __init__(self, enc, cls_64, cls_32):
        super().__init__()
        self._enc = enc
        self._cls_64 = cls_64
        self._cls_32 = cls_32
        
    def forward(self):
        pass


enc = EncoderHier(64, args['D'])
cls_64 = PositionClassifier(64, args['D'])
cls_32 = PositionClassifier(32, args['D'])

model = MyEnsembledModel(enc, cls_64, cls_32)

params_to_update = []
for p in model.parameters():
    if p.requires_grad:
        params_to_update.append(p)
optimizer_adam = torch.optim.Adam(params=params_to_update , lr=float(args['lr']))


#### Register model

In [None]:
from copy import deepcopy

framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model, optimizer=optimizer_adam, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model)

### Define and register FL tasks

In [None]:
TI = TaskInterface()
import torch
import tqdm
from utils import cnn_output_size
from inspection import eval_embeddings_nn_multik

# The Interactive API supports registering functions definied in main module or imported.
def function_defined_in_notebook(some_parameter):
    print(f'Also I accept a parameter and it is {some_parameter}')

# Task interface currently supports only standalone functions.
@TI.add_kwargs(**{'some_parameter': 42})
@TI.register_fl_task(model='model', data_loader='train_loader', \
                     device='device', optimizer='optimizer')     

def train(model, train_loader, optimizer, device, some_parameter=None):
    print(f'\n\n TASK TRAIN GOT DEVICE {device}\n\n')
    
    function_defined_in_notebook(some_parameter)
    
    train_loader = tqdm.tqdm(train_loader, desc="train")

    model.train()
    model.to(device)
    losses = []

    for d in train_loader:
        d = to_device(d, device, non_blocking=True)
        optimizer.zero_grad()
        loss_pos_64 = PositionClassifier_infer(model._cls_64, model._enc, d['pos_64'])
        loss_pos_32 = PositionClassifier_infer(model._cls_32, model._enc.enc, d['pos_32'])
        loss_svdd_64 = SVDD_Dataset_infer(model._enc, d['svdd_64'])
        loss_svdd_32 = SVDD_Dataset_infer(model._enc.enc, d['svdd_32'])

        loss = loss_pos_64 + loss_pos_32 + float(args['lambda_value']) * (loss_svdd_64 + loss_svdd_32)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
    return {'train_loss': np.mean(losses),}
    

@TI.register_fl_task(model='model', data_loader='val_loader', device='device')     
def validate(model, val_loader, device):
    print(f'\n\n TASK VALIDATE GOT DEVICE {device}\n\n')
    
    model._enc.eval()
    model._enc.to(device)
    
    x_tr, x_te, masks, labels, mean = val_loader

    embs64_tr = infer_(x_tr, model._enc, K=64, S=16, device=device)
    embs64_te = infer_(x_te, model._enc, K=64, S=16, device=device)
    embs32_tr = infer_(x_tr, model._enc.enc, K=32, S=4, device=device)
    embs32_te = infer_(x_te, model._enc.enc, K=32, S=4, device=device)

    embs64 = embs64_tr, embs64_te
    embs32 = embs32_tr, embs32_te

    results = eval_embeddings_nn_multik(args['obj'], embs64, embs32, masks, labels)
    
    maps = results['maps_mult']
    obj = args['obj']

    print("| K64 | Det: {:.3f} Seg:{:.3f} BA: {:.3f}".format(results['det_64'],results['seg_64'],results['bal_acc_64']))
    print("| K32 | Det: {:.3f} Seg:{:.3f} BA: {:.3f}".format(results['det_32'],results['seg_32'],results['bal_acc_32']))
    print("| sum | Det: {:.3f} Seg:{:.3f} BA: {:.3f}".format(results['det_sum'],results['seg_sum'],results['bal_acc_sum']))
    print("| mult | Det: {:.3f} Seg:{:.3f} BA: {:.3f}".format(results['det_mult'],results['seg_mult'],results['bal_acc_mult']))

    return {'detection_score_sum': results['det_sum'], 'segmentation_score_sum': results['seg_sum'], 'balanced_accuracy_score_sum': results['bal_acc_sum'], 'detection_score_mult': results['det_mult'], 'segmentation_score_mult': results['seg_mult'], 'balanced_accuracy_score_mult': results['bal_acc_mult']}

    

In [None]:
# Infer functions
def PositionClassifier_infer(c, enc, batch):
    x1s, x2s, ys = batch
    h1 = enc(x1s)
    h2 = enc(x2s)
    logits = c(h1, h2)
    loss = xent(logits, ys)
    return loss

def SVDD_Dataset_infer(enc, batch):
    x1s, x2s, = batch
    h1s = enc(x1s)
    h2s = enc(x2s)
    diff = h1s - h2s
    l2 = diff.norm(dim=1)
    loss = l2.mean()
    return loss

def infer_(x, enc, K, S, device):
    dataset = PatchDataset_NCHW(x, K=K, S=S)
    loader = DataLoader(dataset, batch_size=64, shuffle=False, pin_memory=True)
    embs = np.empty((dataset.N, dataset.row_num, dataset.col_num, args['D']), dtype=np.float32)  # [-1, I, J, D]
    enc = enc.eval()
    with torch.no_grad():
        for xs, ns, iis, js in loader:
            xs = xs.to(device)
            embedding = enc(xs)
            embedding = embedding.detach().cpu().numpy()

            for embed, n, i, j in zip(embedding, ns, iis, js):
                embs[n, i, j] = np.squeeze(embed)
    return embs

## Time to start a federated learning experiment

In [None]:
# create an experimnet in federation
experiment_name = 'MVTec_test_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
# If I use autoreload I got a pickling error

# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(model_provider=MI, 
                    task_keeper=TI,
                    data_loader=fed_dataset,
                    rounds_to_train=10,
                    opt_treatment='CONTINUE_GLOBAL',
                    device_assignment_policy='CUDA_PREFERRED')


In [None]:
# If user wants to stop IPython session, then reconnect and check how experiment is going 
# fl_experiment.restore_experiment_state(MI)

fl_experiment.stream_metrics()

## Now we validate the best model and print anomaly maps!

In [None]:
!pip install -r ../envoy/sd_requirements.txt
import sys
sys.path.insert(1, '../envoy')
from mvtec_shard_descriptor import MVTecShardDescriptor
from inspection import measure_emb_nn, eval_embeddings_nn_maps
import matplotlib.pyplot as plt
from PIL import Image
from skimage.segmentation import mark_boundaries
from utils import makedirpath, distribute_scores
import pickle

In [None]:
def obtain_maps(model, val_loader, device):
    print(f'\n\n OBTAIN MAPS GOT DEVICE {device}\n\n')
    
    model._enc.eval()
    model._enc.to(device)
    
    x_tr, x_te, masks, labels, mean = val_loader

    embs64_tr = infer_(x_tr, model._enc, K=64, S=16, device=device)
    embs64_te = infer_(x_te, model._enc, K=64, S=16, device=device)
    embs32_tr = infer_(x_tr, model._enc.enc, K=32, S=4, device=device)
    embs32_te = infer_(x_te, model._enc.enc, K=32, S=4, device=device)

    embs64 = embs64_tr, embs64_te
    embs32 = embs32_tr, embs32_te

    maps = eval_embeddings_nn_maps(args['obj'], embs64, embs32, masks, labels)    
    print_anomaly_maps(args['obj'], maps, x_te, masks, mean)
    

def infer_(x, enc, K, S, device):
    dataset = PatchDataset_NCHW(x, K=K, S=S)
    loader = DataLoader(dataset, batch_size=64, shuffle=False, pin_memory=True)
    embs = np.empty((dataset.N, dataset.row_num, dataset.col_num, args['D']), dtype=np.float32)  # [-1, I, J, D]
    enc = enc.eval()
    with torch.no_grad():
        for xs, ns, iis, js in loader:
            xs = xs.to(device)
            embedding = enc(xs)
            embedding = embedding.detach().cpu().numpy()

            for embed, n, i, j in zip(embedding, ns, iis, js):
                embs[n, i, j] = np.squeeze(embed)
    return embs

def print_anomaly_maps(obj, maps, images, masks, mean):
    """Print generated anomaly maps."""
    mshape = maps.shape[0]
    images = np.transpose(images, [0, 3, 2, 1])
    images = (images.astype(np.float32) * 255 + mean)

    for n in range(10):
        fig, axes = plt.subplots(ncols=2)
        fig.set_size_inches(6, 3)

        shape = (128, 128)
        image = np.array(Image.fromarray((images[n] * 255).astype(np.uint8)).resize(shape[::-1]))
        mask = np.array(Image.fromarray(masks[n]).resize(shape[::-1]))
        image = mark_boundaries(image, mask, color=(1, 0, 0), mode='thick')

        axes[0].imshow(image)
        axes[0].set_axis_off()

        axes[1].imshow(maps[n], vmax=maps[n].max(), cmap='Reds')
        axes[1].set_axis_off()

        plt.tight_layout()
        plt.show()
        plt.close()

In [None]:
fed_dataset = MVTecSD(train_bs=64, val_bs=64)
fed_dataset.shard_descriptor = MVTecShardDescriptor(obj=args['obj'], data_folder='MVTec_data',rank_worldsize='1,1')

last_model = fl_experiment.get_last_model()
obtain_maps(last_model, fed_dataset.get_valid_loader(), 'cuda')