In [22]:
# load custom packages from src dir
import sys
sys.path.insert(0, '..')

# python packages
import logging
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from tqdm import tqdm

# custom packages
import src.commons.dataset as ds
import src.commons.constants as cons

# Define the logging level
logging.getLogger().setLevel(logging.DEBUG)

# Vision Transformer approach

The ViT is an architecture that leverages transformers to image-based tasks. 

The main idea is to convert input image into a sequence of patches. We embed this sequence with a linear layer, plus positional encoding. Then, we apply a multi-head transformer layer. The output features are then used to solve the given task, e.g. classification.

The idea is to frame our anomaly detection as a weakly supervised problem, where weak binary labels $y$ are given. $y = 0$ means the image is normal, $y=1$ means the image is anomalous. This is called *weak* supervision because no information about *where* the anomaly is (i.e. ground truth mask) is supplied.

Then, we train a ViT in the task of predicting $\hat{y}$. By extracting information from the attention layers inside the transformer, we should be able to obtain anomaly masks. This should be the case because the attention should learn to prioritize anomalous regions in order to perform the classification task.

## DataLoader and data pre-processing

In [2]:
# Load dataset
cat = "capsule"
data = ds.MVTECTestDataset(os.path.join(ds.current_dir(),'../', cons.DATA_PATH), cat)

# Split data into train/test/val
lengths = torch.floor(torch.tensor([0.8, 0.1, 0.1])*len(data))
diff = torch.abs(lengths.sum() - len(data))
if diff > 0:
    lengths[0] += diff

train_data, test_data, val_data = torch.utils.data.random_split(data, lengths.to(int).tolist())

In [3]:
PATCH_SIZE = 50

def collate_fn(x):
    ''' 
    x: list - batch_size
    
    Converts list of input tensors (C, W, H) to batch.
    Normalizes each image from 0-255 to 0-1.
    For each image, convert it into sequence of W/P vectors of dimension C*P**2,
    where P is a fixed patch size.

    output: Tensor - (batch_size, W/P, C*P**2)
    '''   
    # Extract input images and labels
    imgs = [data["test"] for data in x]

    targets = [1 if torch.any(data["ground_truth"] > 0) else 0
               for data in x]

    C, W, H = imgs[0].shape

    # Convert images to 0-1
    imgs = [img / 255.0 for img in imgs]

    # Extract patches 
    patch_size = PATCH_SIZE
    temp = []
    for id, img in enumerate(imgs):
        # Cut up image into flatenned sequence of patches
        patches = [img[:, i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size].flatten()
                   for i in range(W//patch_size) for j in range(H//patch_size)]
        patches = torch.stack(patches)
        temp.append(patches)
    
    # Convert to tensor
    data, targets = torch.stack(temp), torch.tensor(targets)
    return data, targets

In [4]:
# Define DataLoaders for batching input
BATCH_SIZE = 4

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=False)
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=False)

# Check if loaders are working properly
def test_dataloader(data_loader):
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    for i, batch in pbar:   
        data, targets = batch
    return 

test_dataloader(train_dataloader)

100%|██████████| 27/27 [00:07<00:00,  3.56it/s]


## Defining the model architecture

In [57]:
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html

# Define model
class VisionTransformer(nn.Module):
    def __init__(self, seq_len, input_features, num_encoder_layers, d_model, nhead):
        super().__init__()
        
        self.linear = nn.Linear(in_features=input_features, out_features=d_model)
        self.pe_encoder = PositionalEncoding(d_model)
        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, nhead),
                                                 num_layers=num_encoder_layers)
        self.output = nn.Linear(in_features=d_model*seq_len, out_features=1)
        
        pass

    def forward(self, x):
        logging.debug(f"DTypes: raw input {type(x)}")

        x = self.linear(x)

        logging.debug(f"DTypes: linear output {type(x)}")
        logging.debug(f"PE Encoder: input (after Linear) {x.shape}, permuted {torch.swapdims(x, 0, 1).shape}")

        u = self.pe_encoder(torch.swapdims(x, 0, 1))
        logging.debug(f"PE Encoder: output {u.shape}, permuted: {torch.swapdims(u, 0, 1).shape}")
        x += torch.swapdims(u, 0, 1)
        x = self.encoder(x)
        logging.debug(f"Pre logits: {x.shape}")
        logits = self.output(torch.flatten(x, start_dim=1, end_dim=2))
        logging.debug(f"Logits: {logits.shape}")
        return logits.squeeze()

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        logging.debug(f"DTypes: PE before droput output {type(x)}")
        return self.dropout(x)

## Training and inference routines

In [58]:
def evaluate(model, data_loader, **kwargs):
    loss_fn = kwargs.get('loss_fn', torch.nn.functional.binary_cross_entropy_with_logits)
    device = kwargs.get('device', torch.device('cpu'))
    
    model.eval() # set model to evaluation mode
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    avg_loss = 0.
    for i, batch in pbar:
        data, targets = batch
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad(): # no need to compute gradients
            logits = model(data)
        loss = loss_fn(logits, targets)
        avg_loss += loss.item()
        pbar.set_description(f'loss = {loss:.3f}')
    avg_loss /= len(data_loader)
    return avg_loss

def fit(model, train_loader, val_loader, optimizer, **kwargs):
    num_epochs = kwargs.get('num_epochs', 100)
    loss_fn = kwargs.get('loss_fn', torch.nn.functional.mse_loss)
    device = kwargs.get('device', torch.device('cpu'))
    
    train_loss_hist, val_loss_hist = [], []
    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        
        print('Training phase...')
        model.train() # set model to training mode
        train_loss = 0.
        pbar = tqdm(enumerate(train_loader), total=len(train_loader))
        for i, batch in pbar:
            data, targets = batch
            data, targets = data.to(device), targets.to(device)
            model.zero_grad() # initialize gradients to zero
            logits = model(data) # forward pass
            logging.debug(f"Before loss computation (training), logits: {logits.shape} targets: {targets.shape}")
            logging.debug(f"DTypes: logits {type(logits)}, targets {type(targets)}")
            loss = loss_fn(logits, targets) # loss computation
            loss.backward() # computing gradients (backward pass)
            optimizer.step() # updating the parameters of the model
            
            train_loss += loss.item()
            pbar.set_description(f'loss = {loss:.3f}')
        train_loss /= len(train_loader)
        print(f'train loss = {train_loss:.3f}')
        train_loss_hist.append(train_loss)
        
        print('Validation phase...')
        val_loss = evaluate(model, val_loader, loss_fn=loss_fn, device=device)
        print(f'validation loss = {val_loss:.3f}')
        val_loss_hist.append(val_loss)
        
    return train_loss_hist, val_loss_hist


## Model training

In [59]:
LEARNING_RATE = 1e-3
NUM_EPOCHS = 100

SEQUENCE_LENGTH = 400
INPUT_FEATURES = 7500
ENCODER_LAYERS = 3
ENCODER_DIM = 128
ENCODER_HEADS = 8

if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
# elif torch.backends.mps.is_available():
#     DEVICE = torch.device('mps')
else:
    DEVICE = torch.device('cpu')
print('DEVICE:', DEVICE)

model = VisionTransformer(SEQUENCE_LENGTH, INPUT_FEATURES, ENCODER_LAYERS, ENCODER_DIM, ENCODER_HEADS).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

DEVICE: cuda


In [60]:
train_loss, val_loss = fit(model, train_dataloader, val_dataloader, optimizer, num_epochs=NUM_EPOCHS, device=DEVICE)

Epoch 1/100
Training phase...


  0%|          | 0/27 [00:00<?, ?it/s]DEBUG:root:DTypes: raw input <class 'torch.Tensor'>
DEBUG:root:DTypes: linear output <class 'torch.Tensor'>
DEBUG:root:PE Encoder: input (after Linear) torch.Size([4, 400, 128]), permuted torch.Size([400, 4, 128])
DEBUG:root:DTypes: PE before droput output <class 'torch.Tensor'>
DEBUG:root:PE Encoder: output torch.Size([400, 4, 128]), permuted: torch.Size([4, 400, 128])
DEBUG:root:Pre logits: torch.Size([4, 400, 128])
DEBUG:root:Logits: torch.Size([4, 1])
DEBUG:root:Before loss computation (training), logits: torch.Size([4]) targets: torch.Size([4])
DEBUG:root:DTypes: logits <class 'torch.Tensor'>, targets <class 'torch.Tensor'>
  0%|          | 0/27 [00:00<?, ?it/s]


RuntimeError: Found dtype Long but expected Float