# imports and static stuff

In [1]:
%cd src

/home/evry/Desktop/master-degree/repositories/vision-anomaly/src


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

import cv2

from torchsummary import summary
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam

from model.VT_AE import VT_AE

from src.data_loader.data_loader import MVTec
from progressbar import Bar, DynamicMessage, ProgressBar, ETA

from torchmetrics.image import StructuralSimilarityIndexMeasure

In [3]:
print(torch.__version__)

2.3.0


In [4]:
import os

def save_comparison(class_name: str, file_name:str, image, mask, reconstruction, ssim_map):
    with torch.no_grad():
        _input_image = image.cpu().numpy()[0].transpose(1, 2, 0)
        _mask = mask.cpu().numpy()[0].transpose(1, 2, 0)
        _ssim_map = ssim_map.cpu().numpy()[0][0]

        # Normalize the SSIM map
        ssim_map_norm = (_ssim_map - _ssim_map.min()) / (_ssim_map.max() - _ssim_map.min())
        
        # Create a heatmap from the normalized SSIM map
        heatmap = cv2.applyColorMap((ssim_map_norm * 255).astype(np.uint8), cv2.COLORMAP_JET)

        # Convert _input_image to uint8
        _input_image_uint8 = _input_image_uint8 = (_input_image * 255).astype(np.uint8)

        # Overlay the heatmap on the original input image
        overlay = cv2.addWeighted(_input_image_uint8, 0.7, heatmap, 0.3, 0)

        # Plot the results

        fig = plt.figure(figsize=(10, 3))

        ax1 = fig.add_subplot(141)
        ax1.imshow(_input_image, cmap='gray')
        ax1.set_title('Input')
        
        reconstruction_norm = (reconstruction - reconstruction.min()) / (reconstruction.max() - reconstruction.min())
        
        ax2 = fig.add_subplot(142)
        ax2.imshow(reconstruction_norm.cpu().numpy()[0].transpose(1, 2, 0))
        ax2.set_title('Reconstructed')
        
        ax3 = fig.add_subplot(143)
        ax3.imshow(_mask, cmap='gray')
        ax3.set_title('GT')
        
        ax4 = fig.add_subplot(144)
        ax4.imshow(overlay)
        ax4.set_title('Pred')
        
        plt.tight_layout()
        # plt.show()
        
        path = f"../runs/{class_name}/plots/"
        
        if not os.path.exists(path):
            os.makedirs(path)
        
        fig.savefig(path + file_name)
        
        fig.clear()
        plt.close()
        plt.cla()
        plt.clf()

# Loading data

In [5]:
dataset_root_path = "/home/evry/Desktop/master-degree/repositories/two-stage-coarse-to-fine-image-anomaly-segmentation-and-detection-model/data/images"

def read_data(class_name: str):
    train_dataset = MVTec(class_name=class_name, root_dir=dataset_root_path + f"/{class_name}", test=False, resize_shape=[256, 256])
    test_dataset = MVTec(class_name=class_name,root_dir=dataset_root_path + f"/{class_name}", test=True, resize_shape=[256, 256])

    train_loader = DataLoader(dataset = train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(dataset = test_dataset, batch_size=1, shuffle=False)
    
    return train_loader, test_loader

# Model training

- [ ] Test model after each epoch (remember to put model in train mode)
- [ ] Perform early stopping based on the model test

In [6]:
def test(class_name: str, epoch: int, model: VT_AE, plot: bool, test_loader: DataLoader):
    mse_sum = 0
    ssim_sum = 0
    loss_sum = 0

    anomaly_samples_saved = 0
    
    SSIM = StructuralSimilarityIndexMeasure(return_full_image=True).cpu()
    
    model.eval()
    
    for i, (input_batch, mask_batch, has_anomaly_batch) in enumerate(test_loader):
        _, reconstruction = model(input_batch.cuda())
        
        mse = F.mse_loss(reconstruction.cpu(), input_batch.cpu(), reduction='mean')
            
        mse_sum += mse.item()
        
        ssim_value, ssim_map = SSIM(input_batch.cpu(), reconstruction.cpu())
        
        ssim_sum += ssim_value.item()
        
        loss_sum += mse.item() + (1.0 - ssim_value.item())
        
        if has_anomaly_batch[0] and anomaly_samples_saved < 1 and plot:
            save_comparison(
                class_name,
                f"anomaly_sample_epoch_{epoch}.jpg",
                input_batch,
                mask_batch,
                reconstruction,
                ssim_map
            )
            
            anomaly_samples_saved += 1
            
    model.train()
    
    batches = test_loader.__len__()
    
    return mse_sum / batches, ssim_sum / batches, loss_sum / batches
            
def train(class_name: str):
    print(f"\n\nStart training for object \"{class_name}\"\n\n\n")
    
    progressbar_widgets = [
        DynamicMessage('log', format = '{formatted_value}'),
        Bar(marker='=', left='[', right=']'),
        ' ',  ETA(),
    ]

    SSIM = StructuralSimilarityIndexMeasure().cuda()
    
    best_epoch = -1
    best_loss = 1e10
    epochs = 400
    
    train_loader, test_loader = read_data(class_name)
    
    model = VT_AE(patch_size=32, depth=16).cuda()
    model.train()

    # print(summary(model, (3, 256, 256)))
    
    summary_writer = SummaryWriter(log_dir=f'../runs/{class_name}')

    optimizer = Adam(list(model.parameters()), lr=0.001, weight_decay=0.0001)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,[epochs*0.05, epochs*0.1,epochs*0.5, epochs*0.7],gamma=0.1, last_epoch=-1)

    for epoch in range(epochs):
        train_mse_losses = 0.0
        train_ssim_losses = 0.0
        train_losses = 0.0
        
        with ProgressBar(widgets=progressbar_widgets, max_value=train_loader.__len__() + 1) as progress_bar:
            for sample_i, (input_batch, mask_batch, _) in enumerate(train_loader):
                model.zero_grad()
                
                input_batch = input_batch.cuda()
                
                _, reconstruction = model(input_batch)
                
                reconstruction = reconstruction.cuda()
                
                mse = F.mse_loss(reconstruction, input_batch, reduction='mean')
                
                train_mse_losses += mse.item()
                
                ssim_value = SSIM(input_batch.cuda(), reconstruction)
                
                train_ssim_losses += ssim_value.item()
                
                loss = mse.item() + (1.0 - ssim_value.item())
                
                train_losses += loss

                mse.backward()
                
                optimizer.step()                
                
                progress_bar.update(
                                sample_i,
                                log=f"({epoch+1}/{epochs}) MSE: {mse:.4f} SSIM: {ssim_value:.4f} Loss: {loss:.4f}")

            batches = train_loader.__len__()
            
            train_mse = train_mse_losses / batches
            train_ssim = train_ssim_losses / batches
            train_loss = train_losses / batches

            summary_writer.add_scalar('train_mse', train_mse, epoch)
            summary_writer.add_scalar('train_ssim', train_ssim, epoch)
            summary_writer.add_scalar('train_loss', train_loss, epoch)
            
            log = f"({epoch+1}/{epochs}) MSE: {train_mse:.4f} SSIM: {train_ssim:.4f} Loss: {train_loss:.4f}"
            
            test_mse, test_ssim, test_loss = test(class_name, epoch, model, epoch % 5 == 0, test_loader)
            
            summary_writer.add_scalar('test_mse', test_mse, epoch)
            summary_writer.add_scalar('test_ssim', test_ssim, epoch)
            summary_writer.add_scalar('test_loss', test_loss, epoch)
            
            log += f" | Test MSE: {test_mse:.4f} Test SSIM: {test_ssim:.4f} Test Loss: {test_loss:.4f} Best loss: {best_loss:.4f} ({best_epoch})"      
                    
            progress_bar.update(batches, log=log)

        scheduler.step()

        if test_loss < best_loss and best_loss - test_loss >= 0.01:
            best_epoch = epoch
            best_loss = test_loss
            
            torch.save(model.state_dict(), f'../vit_models/vit_{class_name}.pt')
        elif epoch - best_epoch >= 30:
            print("\n\n==========================================================================================")
            print(f"Stopping training for object {class_name}. No improvements since epoch {best_epoch}")
            print("==========================================================================================")
            
            break
        
    del model
    del optimizer
                

In [7]:
# 'bottle','cable','capsule','carpet','grid','hazelnut','leather','metal_nut','pill','screw','tile','toothbrush','transistor','wood','zipper'
classes = ['hazelnut']
for class_name in classes:
    train(class_name)



Start training for object "hazelnut"





  return F.conv_transpose2d(
(1/400) MSE: 0.1706 SSIM: -0.0589 Loss: 1.2295 | Test MSE: 0.0794 Test SSIM: -0.2880 Test Loss: 1.3674 Best loss: 10000000000.0000 (-1)[] Time:  0:00:18
(2/400) MSE: 0.0576 SSIM: 0.1674 Loss: 0.8902 | Test MSE: 0.0349 Test SSIM: 0.1553 Test Loss: 0.8796 Best loss: 1.3674 (0)[] Time:  0:00:17
(3/400) MSE: 0.0305 SSIM: 0.3404 Loss: 0.6901 | Test MSE: 0.0213 Test SSIM: 0.1818 Test Loss: 0.8395 Best loss: 0.8796 (1)[] Time:  0:00:17
(4/400) MSE: 0.0210 SSIM: 0.4913 Loss: 0.5297 | Test MSE: 0.0192 Test SSIM: 0.3697 Test Loss: 0.6495 Best loss: 0.8395 (2)[] Time:  0:00:17
(5/400) MSE: 0.0173 SSIM: 0.5550 Loss: 0.4623 | Test MSE: 0.0189 Test SSIM: 0.5282 Test Loss: 0.4907 Best loss: 0.6495 (3)[] Time:  0:00:17
(6/400) MSE: 0.0147 SSIM: 0.6272 Loss: 0.3875 | Test MSE: 0.0158 Test SSIM: 0.6270 Test Loss: 0.3889 Best loss: 0.4907 (4)[] Time:  0:00:17
(7/400) MSE: 0.0125 SSIM: 0.6431 Loss: 0.3694 | Test MSE: 0.0131 Test SSIM: 0.6509 Test Loss: 0.3621 Best loss: 0.3889

OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 

<Figure size 640x480 with 0 Axes>

# Loading saved model and testing it

In [None]:
# class_name = ""

# model = VT_AE(patch_size=32).cuda()

# model.load_state_dict(torch.load(f'../vit_models/vit_{class_name}.pt'))

# train_loader, test_loader = read_data(class_name)

# test_image, mask = test_loader.dataset.__getitem__(0)

# test_image = test_image.unsqueeze(0).cuda()

# plt.imshow(test_image.cpu().numpy()[0].transpose(1, 2, 0))

In [None]:
# mask = model.mask

# output = model.vt.transformer(test_image.cuda()[0])

# with torch.no_grad():
#     plt.imshow(output.cpu().numpy().transpose(1, 2, 0))

In [None]:
# SSIM = StructuralSimilarityIndexMeasure(return_full_image=True).cpu()

# encoded, reconstruction = model(test_image)

# mse = F.mse_loss(reconstruction, test_image, reduction='mean')
# ssim_value, ssim_map = SSIM(reconstruction.cpu(), test_image.cpu())

# mse.item(), ssim_value.item()

In [None]:
# with torch.no_grad():
#     _input_image = test_image.cpu().numpy()[0].transpose(1, 2, 0)
#     _mask = mask.cpu().numpy().transpose(1, 2, 0)
#     _ssim_map = ssim_map.cpu().numpy()[0][0]

#     # Normalize the SSIM map
#     ssim_map_norm = (_ssim_map - _ssim_map.min()) / (_ssim_map.max() - _ssim_map.min())
    
#     ssim_map_norm = np.where(ssim_map_norm > 0.5, 0.0, ssim_map_norm + 0.5)
    
#     # Create a heatmap from the normalized SSIM map
#     heatmap = cv2.applyColorMap((ssim_map_norm * 255).astype(np.uint8), cv2.COLORMAP_JET)

#     # Convert _input_image to uint8
#     _input_image_uint8 = _input_image_uint8 = (_input_image * 255).astype(np.uint8)

#     # Overlay the heatmap on the original input image
#     overlay = cv2.addWeighted(_input_image_uint8, 0.7, heatmap, 0.3, 0)

#     # Plot the results

#     fig = plt.figure(figsize=(15, 5))

#     ax1 = fig.add_subplot(141)
#     ax1.imshow(_input_image, cmap='gray')
#     ax1.set_title('Input')
    
#     reconstruction_norm = (reconstruction - reconstruction.min()) / (reconstruction.max() - reconstruction.min())
    
#     ax2 = fig.add_subplot(142)
#     ax2.imshow(reconstruction_norm.cpu().numpy()[0].transpose(1, 2, 0))
#     ax2.set_title('Reconstructed')
    
#     ax3 = fig.add_subplot(143)
#     ax3.imshow(_mask, cmap='gray')
#     ax3.set_title('GT')
    
#     ax4 = fig.add_subplot(144)
#     ax4.imshow(overlay)
#     ax4.set_title('Pred')
    
#     plt.tight_layout()
#     plt.show()

![image.png](attachment:image.png)