In [99]:
# load custom packages from src dir
import sys
sys.path.insert(0, '..')

# python packages
import logging
import os
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from tqdm import tqdm

# custom packages
import src.commons.dataset as ds
import src.commons.constants as cons

# Define the logging level
logging.getLogger().setLevel(logging.INFO)

# Vision Transformer approach

The ViT is an architecture that leverages transformers to image-based tasks. 

The main idea is to convert input image into a sequence of patches. We embed this sequence with a linear layer, plus positional encoding. Then, we apply a multi-head transformer layer. The output features are then used to solve the given task, e.g. classification.

The idea is to frame our anomaly detection as a weakly supervised problem, where weak binary labels $y$ are given. $y = 0$ means the image is normal, $y=1$ means the image is anomalous. This is called *weak* supervision because no information about *where* the anomaly is (i.e. ground truth mask) is supplied.

Then, we train a ViT in the task of predicting $\hat{y}$. By extracting information from the attention layers inside the transformer, we should be able to obtain anomaly masks. This should be the case because the attention should learn to prioritize anomalous regions in order to perform the classification task.

In [2]:
# Load dataset
cat = "capsule"
data = ds.MVTECTestDataset(os.path.join(ds.current_dir(),'../', cons.DATA_PATH), cat)

# Split data into train/test/val
lengths = torch.floor(torch.tensor([0.8, 0.1, 0.1])*len(data))
diff = torch.abs(lengths.sum() - len(data))
if diff > 0:
    lengths[0] += diff

train_data, test_data, val_data = torch.utils.data.random_split(data, lengths.to(int).tolist())

In [94]:
PATCH_SIZE = 50

def collate_fn(x):
    ''' 
    x: list - batch_size
    
    Converts list of input tensors (C, W, H) to batch.
    Normalizes each image from 0-255 to 0-1.
    For each image, convert it into sequence of W/P vectors of dimension C*P**2,
    where P is a fixed patch size.

    output: Tensor - (batch_size, W/P, C*P**2)
    '''   
    # Extract input images and labels
    imgs = [data["test"] for data in x]

    targets = [1 if torch.any(data["ground_truth"] > 0) else 0
               for data in x]

    C, W, H = imgs[0].shape

    # Convert images to 0-1
    imgs = [img / 255.0 for img in imgs]

    # Extract patches 
    patch_size = PATCH_SIZE
    temp = []
    for id, img in enumerate(imgs):
        # Cut up image into flatenned sequence of patches
        patches = [img[:, i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size].flatten()
                   for i in range(W//patch_size) for j in range(H//patch_size)]
        patches = torch.stack(patches)
        temp.append(patches)
    
    # Convert to tensor
    data, targets = torch.stack(temp), torch.tensor(targets)
    return data, targets

In [100]:
# Define DataLoaders for batching input
BATCH_SIZE = 4

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=False)
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=False)

# Check if loaders are working properly
def test_dataloader(data_loader):
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    for i, batch in pbar:   
        data, targets = batch
    return 

test_dataloader(train_dataloader)

100%|██████████| 27/27 [00:04<00:00,  5.87it/s]


In [32]:
class VisualTransformer(torch.nn.Module):
    def __init__(self):
        pass

    def forward(self, x):
        pass


In [None]:
def evaluate(model, data_loader, **kwargs):
    loss_fn = kwargs.get('loss_fn', torch.nn.functional.mse_loss)
    device = kwargs.get('device', torch.device('cpu'))
    
    model.eval() # set model to evaluation mode
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    avg_loss = 0.
    for i, batch in pbar:
        ## YOUR CODE HERE ##
        batch["inputs"] = batch["inputs"].to(device)
        batch["targets"] = batch["targets"].to(device)
        num_steps = batch["targets"].shape[1] # batch_size, time, dim
        with torch.no_grad(): # no need to compute gradients
            preds = model(batch["inputs"], num_steps)
        loss = loss_fn(preds, batch["targets"])
        avg_loss += loss.item()
        ## *** ##
        pbar.set_description(f'loss = {loss:.3f}')
    avg_loss /= len(val_loader)
    return avg_loss

def fit(model, train_loader, val_loader, optimizer, **kwargs):
    num_epochs = kwargs.get('num_epochs', 100)
    loss_fn = kwargs.get('loss_fn', torch.nn.functional.mse_loss)
    device = kwargs.get('device', torch.device('cpu'))
    
    train_loss_hist, val_loss_hist = [], []
    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        
        print('Training phase...')
        model.train() # set model to training mode
        train_loss = 0.
        pbar = tqdm(enumerate(train_loader), total=len(train_loader))
        for i, batch in pbar:
            ## YOUR CODE HERE ##
            batch["inputs"] = batch["inputs"].to(device)
            batch["targets"] = batch["targets"].to(device)
            model.zero_grad() # initialize gradients to zero
            num_steps = batch["targets"].shape[1] 
            preds = model(batch["inputs"], num_steps) # forward pass
            loss = loss_fn(preds, batch["targets"]) # loss computation
            loss.backward() # computing gradients (backward pass)
            optimizer.step() # updating the parameters of the model
            ## *** ##
            train_loss += loss.item()
            pbar.set_description(f'loss = {loss:.3f}')
        train_loss /= len(train_loader)
        print(f'train loss = {train_loss:.3f}')
        train_loss_hist.append(train_loss)
        
        print('Validation phase...')
        val_loss = evaluate(model, val_loader, loss_fn=loss_fn, device=device)
        print(f'validation loss = {val_loss:.3f}')
        val_loss_hist.append(val_loss)
        
    return train_loss_hist, val_loss_hist
