In [2]:
import sys
sys.path.append('/app/github/STAD')
sys.path.append('/app/github/FDFE')

In [1]:
import cv2
import torch
import numpy as np
import albumentations as albu

In [3]:
from stad.datasets import MVTecDataset
from fdfe.modules import MultiMaxPool2d

In [4]:
from torch.utils.data import DataLoader
from torchvision import models
from albumentations.pytorch import ToTensorV2
from pathlib import Path
from tqdm import tqdm

In [5]:
DEVICE = 'cuda:0'
EPOCHS = 100

In [6]:
pretrained_vgg = models.vgg19(pretrained=True)
teacher = pretrained_vgg.features[:36]
teacher = teacher.to(DEVICE)

vgg = models.vgg19(pretrained=False)
student = vgg.features[:36]
student = student.to(DEVICE)

pretrained_vgg = models.vgg19(pretrained=True)
teacher_copy = pretrained_vgg.features[:36]
teacher_copy.load_state_dict(teacher.state_dict())
teacher_copy = teacher_copy.to(DEVICE)

In [7]:
train_augs = albu.Compose([
    albu.HorizontalFlip(p=0.5),
    albu.RandomCrop(height=128, width=128, always_apply=True, p=1),
    albu.Normalize(always_apply=True, p=1),
    ToTensorV2()
    ])

In [8]:
mvtec = MVTecDataset(img_dir=Path('/app/github/data/bottle/train/good'),
                     augs=train_augs)

train_loader = DataLoader(dataset=mvtec,
                          batch_size=1,
                          shuffle=True)

In [9]:
criterion = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(student.parameters(), lr=0.0002, weight_decay=0.00001)

In [22]:
teacher.eval()

for epoch in tqdm(range(EPOCHS)):
    for img in train_loader:
         
        img = img.to(DEVICE)
        with torch.no_grad():
            surrogate_label = teacher(img)
        optimizer.zero_grad()
        pred = student(img)
        loss = criterion(pred, surrogate_label)
        loss.backward()
        optimizer.step()

10%|█         | 10/100 [01:26<13:00,  8.67s/it]


KeyboardInterrupt: 

In [10]:
test_augs = albu.Compose([
    albu.Normalize(always_apply=True, p=1),
    ToTensorV2()
    ])

In [11]:
mvtec = MVTecDataset(img_dir=Path('/app/github/data/bottle/test/broken_small'),
                     augs=test_augs)

test_loader = DataLoader(dataset=mvtec,
                         batch_size=1,
                         shuffle=True)

In [12]:
img = cv2.imread('/app/github/data/bottle/test/broken_small/000.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img[0:256, 0:256, :]

sample = test_augs(image=img)
img = sample['image']
img = img.unsqueeze(0)
img = img.to(DEVICE)

In [13]:
for i, module in enumerate(student.modules()):
    if module._get_name() == 'MaxPool2d':
        student[i-1] = MultiMaxPool2d(module)

for i, module in enumerate(teacher.modules()):
    if module._get_name() == 'MaxPool2d':
        teacher[i-1] = MultiMaxPool2d(module)

In [14]:
teacher.eval()
student.eval()

surrogate_label = teacher(img)
# pred = student(img)

In [15]:
x = surrogate_label
x = x.view(x.shape[0], -1).contiguous()
x = x.transpose(0, 1).contiguous()

x = x.view(512, 16, 16, 2, 2, -1).contiguous()
x = x.transpose(2, 3).contiguous()

x = x.view(512, 32, 32, 2, 2, -1)
x = x.transpose(2, 3).contiguous()

x = x.view(512, 64, 64, 2, 2, -1).contiguous()
x = x.transpose(2, 3).contiguous()

x = x.view(512, 128, 128, -1).contiguous()
x.shape

torch.Size([512, 128, 128, 4])

In [22]:
p = 16
fdfe_map = np.zeros((256-p, 256-p))

for h in tqdm(range(256-p)):
    for w in range(256-p):
        
        if   h % 2 == 0 and w % 2 == 0:
            fdfe_map[h, w] = float(x[:, h//2:h//2+p, w//2:w//2+p, 0].mean())
        elif h % 2 == 0 and w % 2 == 1:
            fdfe_map[h, w] = float(x[:, h//2:h//2+p, w//2:w//2+p, 1].mean())
        elif h % 2 == 1 and w % 2 == 0:
            fdfe_map[h, w] = float(x[:, h//2:h//2+p, w//2:w//2+p, 2].mean())
        elif h % 2 == 1 and w % 2 == 1:
            fdfe_map[h, w] = float(x[:, h//2:h//2+p, w//2:w//2+p, 3].mean())


100%|██████████| 240/240 [00:03<00:00, 68.58it/s]


In [19]:
p = 32
map = np.zeros((256-p, 256-p))

for h in tqdm(range(256-p)):
    for w in range(256-p):
        patch = img[:, :, h:h+p, w:w+p]
        map[h, w] = teacher_copy(patch).mean()

100%|██████████| 224/224 [02:54<00:00,  1.28it/s]


In [20]:
map

array([[0.07538726, 0.07538726, 0.07538726, ..., 0.07538726, 0.07538726,
        0.07538726],
       [0.07538726, 0.07538726, 0.07538726, ..., 0.07538726, 0.07538726,
        0.07538726],
       [0.07538726, 0.07538726, 0.07538726, ..., 0.07538726, 0.07538726,
        0.07538726],
       ...,
       [0.07538726, 0.07538726, 0.07538726, ..., 0.20150873, 0.21962997,
        0.20382014],
       [0.07538726, 0.07538726, 0.07538726, ..., 0.20688003, 0.20472288,
        0.21367201],
       [0.07538726, 0.07538726, 0.07538726, ..., 0.20820379, 0.2067745 ,
        0.19261743]])

In [23]:
fdfe_map

array([[0.0192838 , 0.01921315, 0.01885455, ..., 0.03513261, 0.02748283,
        0.03610511],
       [0.01917665, 0.01914162, 0.01874706, ..., 0.03721201, 0.02720499,
        0.03828237],
       [0.01787147, 0.01782434, 0.01749247, ..., 0.03336752, 0.02598732,
        0.03433357],
       ...,
       [0.03206759, 0.03206943, 0.03114278, ..., 0.14543083, 0.1850986 ,
        0.14502591],
       [0.0249638 , 0.0248473 , 0.02397401, ..., 0.17411387, 0.22492272,
        0.17345975],
       [0.03282092, 0.03283486, 0.03191298, ..., 0.14748918, 0.18687105,
        0.14712675]])