# Static stuff

In [1]:
%cd src

/home/evry/Desktop/master-degree/repositories/vision-anomaly/src


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torchvision
from torchvision import transforms
import cv2

from torchsummary import summary
# from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam

from model.model import Model

from src.data_loader.data_loader import MVTec
from progressbar import Bar, DynamicMessage, ProgressBar, ETA

from torchmetrics.image import StructuralSimilarityIndexMeasure
from torch.utils.data import random_split

from sklearn.metrics import roc_auc_score

import matplotlib.cm as cm
from PIL import Image
from torch.utils.data import DataLoader

In [3]:
from tasad.data_loader import MVTecTrainDataset
from tasad.data_loader_test import MVTecTestDataset

dataset_root_path = "/home/evry/Desktop/master-degree/repositories/two-stage-coarse-to-fine-image-anomaly-segmentation-and-detection-model/data/images"
anomaly_path = "/home/evry/Desktop/master-degree/repositories/two-stage-coarse-to-fine-image-anomaly-segmentation-and-detection-model/data/anomaly/images"
# dataset_root_path = "/home/evry/Desktop/master-degree/dataset/BTech_Dataset_transformed"

def read_data(class_name: str, batch_size=1):
    train_dataset = MVTecTrainDataset(root_dir=dataset_root_path + f"/{class_name}/train/", anomaly_source_path=anomaly_path, resize_shape=[256, 256])
    test_dataset = MVTecTestDataset(root_dir=dataset_root_path + f"/{class_name}/test/", resize_shape=[256, 256])
    
    train_loader = DataLoader(dataset = train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset = test_dataset, batch_size=1, shuffle=True)
    
    return train_loader, test_loader

In [4]:
import os

def save_comparison(
    class_name: str, 
    file_name:str, 
    image, 
    mask, 
    reconstruction, 
    ssim_map, 
    fas_input, 
    fas_output, 
    processed_fas_output, 
    save_fig: bool, 
    summary_writer: SummaryWriter, 
    epoch,
    path: str = "../runs/tasad-vitcnn"
):
    with torch.no_grad():
        _input_image = image.cpu().numpy()[0].transpose(1, 2, 0)
        _mask = mask.cpu().numpy()[0].transpose(1, 2, 0)
        _fas_input = fas_input.cpu().numpy()[0].transpose(1, 2, 0)
        _fas_output = fas_output.cpu().numpy()[0].transpose(1, 2, 0)
        _ssim_map = ssim_map.cpu().numpy()[0][0]
        _processed_fas_output = processed_fas_output.cpu().numpy()[0].transpose(1, 2, 0)
        
        # Create a heatmap from the normalized SSIM map
        heatmap = cv2.applyColorMap((_ssim_map * 255).astype(np.uint8), cv2.COLORMAP_JET)

        # Convert _input_image to uint8
        _input_image_uint8 = (_input_image * 255).astype(np.uint8)

        # Overlay the heatmap on the original input image
        overlay = cv2.addWeighted(_input_image_uint8, 0.4, heatmap, 0.6, 0)

        # Plot the results

        fig = plt.figure(figsize=(15, 3))

        ax1 = fig.add_subplot(171)
        ax1.imshow(_input_image, cmap='gray')
        ax1.set_title('ViT-CNN entrada')
        ax1.axis("off")
        
        reconstruction_norm = (reconstruction - reconstruction.min()) / (reconstruction.max() - reconstruction.min())
        
        ax2 = fig.add_subplot(172)
        ax2.imshow(reconstruction_norm.cpu().numpy()[0].transpose(1, 2, 0))
        ax2.set_title('ViT-CNN saída')
        ax2.axis("off")
        
        ax3 = fig.add_subplot(173)
        ax3.imshow(overlay)
        ax3.set_title('ViT-CNN mapa SSIM')
        ax3.axis("off")
        
        ax4 = fig.add_subplot(174)
        ax4.imshow(_mask, cmap='gray')
        ax4.set_title('Padrão ouro')
        ax4.axis("off")
        
        ax5 = fig.add_subplot(175)
        ax5.imshow(_fas_input)
        ax5.set_title('Entrada FAS')
        ax5.axis("off")
        
        ax6 = fig.add_subplot(176)
        ax6.imshow(_fas_output, cmap='gray')
        ax6.set_title('Saída FAS')
        ax6.axis("off")
        
        ax6 = fig.add_subplot(177)
        ax6.imshow(_processed_fas_output, cmap='gray')
        ax6.set_title('Saída FAS binária')
        ax6.axis("off")
        
        plt.tight_layout()
        # if plot:
        #     plt.show()
        
        path += f"/{class_name}/plots/"
        
        if not os.path.exists(path):
            os.makedirs(path)
        
        if save_fig:
            fig.savefig(path + file_name)        

        summary_writer.add_figure('plot', fig, epoch)
        
        fig.clear()
        plt.close()
        plt.cla()
        plt.clf()

# Training

In [5]:
def get_vitcnn_output_mask(input_batch, vitcnn):
    _, reconstruction = vitcnn(input_batch)
    
    SSIM = StructuralSimilarityIndexMeasure(return_full_image=True).cpu()
    
    ssim_value, ssim_map = SSIM(input_batch.cpu(), reconstruction.cpu())
    
    norm_ssim_map = (ssim_map - ssim_map.min()) / (ssim_map.max() - ssim_map.min())
    
    mean_tensor = norm_ssim_map.mean(dim=1, keepdim=True)

    norm_ssim_map = mean_tensor.expand(-1, 3, -1, -1)
    
    binary_ssim_map = torch.where(norm_ssim_map > 0.9, torch.zeros_like(norm_ssim_map), torch.ones_like(norm_ssim_map))
    
    return input_batch * binary_ssim_map.cuda(), reconstruction, ssim_value, ssim_map, binary_ssim_map.cpu().float()

In [6]:
import gc
import torch
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
import matplotlib.pyplot as plt 
from tasad.tasad_model import TasadModel
from tasad.utils.utilts_custom_class import *
from tasad.utils.utilts_func import *
import cv2
from typing import Union
from datetime import datetime
from progressbar import Bar, DynamicMessage, ProgressBar, ETA

def test_segmentation_model(
    vitcnn: Model,
    class_name: str,
    dataloader,
    epoch: int,
    gpu_id,
    fas_model: TasadModel,
    visualizer: SummaryWriter = None,
    print_logs: bool = True
):
    """
    Test the segmentation model on the MVTec test dataset.
    
    Args:
    segmentation_model (TasadModel): The main segmentation model to be tested.
    class_name (str): The name of the class to be tested.
    data_path (str): Path to the dataset.
    epoch (int): Current epoch number.
    gpu_id: GPU ID for CUDA device.
    fas_model (TasadModel, optional): Additional model for further segmentation (if any).
    visualizer (TensorboardVisualizer, optional): Tensorboard visualizer for performance plotting.
    print_logs (bool, optional): Flag to print logs.
    
    Returns:
    tuple: AP, AP per pixel, AUROC, AUROC per pixel
    """
    
    cuda_device = torch.device(f'cuda:{gpu_id}')
    img_dimension = 256
    image_index = 0

    # Initialize the FAS model if provided
    fas_model.cuda(cuda_device)
    fas_model.eval()

    dataset = dataloader.dataset
    total_pixel_scores = np.zeros((img_dimension * img_dimension * len(dataset)))
    total_gt_pixel_scores = np.zeros((img_dimension * img_dimension * len(dataset)))
    mask_count = 0

    anomaly_score_gt = []
    anomaly_score_prediction = []
    
    widgets = [
        DynamicMessage('test'),
        Bar(marker='=', left='[', right=']'),
        ' ', ETA(),
    ]

    with ProgressBar(widgets=widgets, max_value=len(dataset)) as progress_bar:
        saved_in_this_epoch = False
        for i_batch, sample_batched in enumerate(dataloader):
            vitcnn_input = sample_batched["augmented_image"].cuda(cuda_device)
            original_image = plt.imread(dataloader.dataset.images[i_batch]) 
            resized_original_image = cv2.resize(original_image, (img_dimension, img_dimension))

            has_anomaly = sample_batched["has_anomaly"].detach().numpy()
            anomaly_score_gt.append(has_anomaly)
            ground_truth_mask = sample_batched["anomaly_mask"]
            
            ground_truth_mask_np = ground_truth_mask.detach().numpy()[0, :, :, :].transpose((1, 2, 0))
                
            fas_input, vitcnn_reconstruction, vitcnn_ssim_value, vitcnn_ssim_map, vitcnn_binary_ssim_map = \
                    get_vitcnn_output_mask(vitcnn_input, vitcnn)
                
            fas_output = fas_model(fas_input)
            
            vitcnn_binary_ssim_map = vitcnn_binary_ssim_map.to(fas_output.device)
            
            processed_fas_output = fas_output * vitcnn_binary_ssim_map.mean(dim=1, keepdim=True)
            
            processed_fas_output = torch.where(fas_output < 0.3, torch.zeros_like(processed_fas_output), torch.ones_like(processed_fas_output))
            
            if has_anomaly[0] and not saved_in_this_epoch:
                saved_in_this_epoch = True
                
                save_comparison(
                    class_name=class_name,
                    file_name=f"test_sample_epoch_{epoch}.jpg",
                    image=vitcnn_input,
                    mask=ground_truth_mask,
                    reconstruction=vitcnn_reconstruction,
                    ssim_map=vitcnn_ssim_map,
                    fas_input=fas_input,
                    fas_output=fas_output,
                    processed_fas_output=processed_fas_output,
                    save_fig=True,
                    summary_writer=visualizer,
                    epoch=epoch,
                    path=f"../runs/tasad-vitcnn/{class_name}/test")
                
            # output_mask_np = processed_fas_output.detach().cpu().numpy()
            # output = fas_output                                                                                                                                   
            
            # fas_input = torch.tensor(seg_module(vitcnn_input, cas_output, th_pix=0.95, th_val=30)).cuda(cuda_device)
            # fas_output = fas_model(fas_input)
            # combined_output = fas_output + cas_output

            output_mask_np = processed_fas_output[0, 0, :, :].detach().cpu().numpy()
            output = fas_output                                                                                                                                   
           
            try:
                fas_input_np = fas_input.detach().cpu().numpy()[0, :, :, :].transpose((1, 2, 0))   
                fas_input_rgb = cv2.cvtColor(fas_input_np, cv2.COLOR_BGR2RGB)
                query_image_rgb = cv2.cvtColor(resized_original_image, cv2.COLOR_BGR2RGB)
                out_mask_fas_np = abs(fas_output.detach().cpu().numpy()[0, :, :, :].transpose((1, 2, 0))[:,:,0])/torch.max(fas_output).item()
                # out_mask_cas_np = abs(cas_output.detach().cpu().numpy()[0, :, :, :].transpose((1, 2, 0))[:,:,0])/torch.max(cas_output).item()
                combined_output_normalized = abs(out_mask_fas_np)
                combined_output_normalized = combined_output_normalized/np.max(combined_output_normalized)
                
                all_images = [query_image_rgb,  ground_truth_mask_np, fas_input_rgb, out_mask_fas_np, combined_output_normalized] 
                # Image saving function can be implemented here if needed

            except Exception as e:
                pass
            
            image_index += 1
            averaged_output_mask = torch.nn.functional.avg_pool2d(output, 21, stride=1, padding=21 // 2).cpu().detach().numpy()
            image_score = np.max(averaged_output_mask)

            anomaly_score_prediction.append(image_score)

            flat_gt_mask = ground_truth_mask_np.flatten()
            flat_output_mask = output_mask_np.flatten()
            
            total_pixel_scores[mask_count * img_dimension * img_dimension:(mask_count + 1) * img_dimension * img_dimension] = flat_output_mask
            total_gt_pixel_scores[mask_count * img_dimension * img_dimension:(mask_count + 1) * img_dimension * img_dimension] = flat_gt_mask
            mask_count += 1
            
            progress_bar.update(
                i_batch,
                test=f"({i_batch}/{len(dataset)}) Class {class_name} ")

    anomaly_score_prediction = np.array(anomaly_score_prediction)
    anomaly_score_gt = np.array(anomaly_score_gt)
    auroc = roc_auc_score(anomaly_score_gt, anomaly_score_prediction)
    ap = average_precision_score(anomaly_score_gt, anomaly_score_prediction)

    total_gt_pixel_scores = total_gt_pixel_scores.astype(np.uint8)
    total_gt_pixel_scores = total_gt_pixel_scores[:img_dimension * img_dimension * mask_count]
    total_pixel_scores = total_pixel_scores[:img_dimension * img_dimension * mask_count]
    
    auroc_pixel = roc_auc_score(total_gt_pixel_scores, total_pixel_scores)
    ap_pixel = average_precision_score(total_gt_pixel_scores, total_pixel_scores)
        
    if visualizer:
        visualizer.add_scalar("test_AP_pixel", ap_pixel, epoch)
        visualizer.add_scalar("test_AUROC_pixel", auroc_pixel, epoch)
        visualizer.add_scalar("test_AUROC", auroc, epoch)
        visualizer.add_scalar("test_AP", ap, epoch)
    
    if print_logs:
        print(f"{datetime.now()} Test for epoch {epoch}: Class {class_name} Pixel AP {ap_pixel:.2f} Pixel AUC {auroc_pixel:.2f} Image AUC {auroc:.2f} Image AP {ap:.2f}")
    
    del total_gt_pixel_scores
    del total_pixel_scores
    del anomaly_score_gt
    del anomaly_score_prediction
    del progress_bar
    
    gc.collect()
    
    return ap, ap_pixel, auroc, auroc_pixel

In [7]:
from tasad.loss import SSIM
from tasad.tasad_model import TasadModel


progressbar_widgets = [
    DynamicMessage('log', format = '{formatted_value}'),
    Bar(marker='=', left='[', right=']'),
    ' ',  ETA(),
]

def train_tasad_and_vitcnn(
    class_name: str,
    vitcnn,
    epochs=500,
    learning_rate=0.0001,
    batch_size=1
):
    train_loader, test_loader = read_data(class_name)
    
    print(f"\n\nStarting training for class: {class_name}\n")
    print(f"Info: Found {len(train_loader.dataset)} sample for training")
    print(f"Info: Found {len(test_loader.dataset)} sample for test")    

    fas_model = TasadModel(in_channels=3, out_channels=1).cuda()
    fas_parameters_amount = TasadModel.get_n_params(fas_model)
    
    print(f"Info: Initializing fas with {fas_parameters_amount} parameters")
    
    optimizer = torch.optim.Adam([{"params": fas_model.parameters(), "lr": learning_rate}])
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [epochs*0.3, epochs*0.5, epochs*0.8], gamma=0.1, last_epoch=-1)

    best_epoch = -1
    best_loss = 1e10
    
    summary_writer = SummaryWriter(log_dir=f'../runs/tasad-vitcnn/{class_name}')
    
    loss_l2 = torch.nn.modules.loss.MSELoss()
    # SSIM = StructuralSimilarityIndexMeasure().cuda()
    loss_ssim = SSIM(0)

    for epoch in range(epochs):
        sum_loss = 0
        sum_ssim_loss = 0
        sum_l2_loss = 0
        saved_in_this_epoch = False
        with ProgressBar(widgets=progressbar_widgets, max_value=test_loader.__len__() + 1) as progress_bar:
            for batch_index, batch in enumerate(test_loader):
                input_batch = batch['augmented_image'].cuda()
                ground_truth_batch = batch['anomaly_mask'].cuda()
                has_anomaly = batch['has_anomaly'].cuda()
                
                optimizer.zero_grad()
                
                fas_input, vitcnn_reconstruction, vitcnn_ssim_value, vitcnn_ssim_map, vitcnn_binary_ssim_map = \
                    get_vitcnn_output_mask(input_batch, vitcnn)
                
                fas_output = fas_model(fas_input)
                
                vitcnn_binary_ssim_map = vitcnn_binary_ssim_map.to(fas_output.device)
                
                processed_fas_output = fas_output * vitcnn_binary_ssim_map.mean(dim=1, keepdim=True)
                
                processed_fas_output = torch.where(fas_output < 0.3, torch.zeros_like(processed_fas_output), torch.ones_like(processed_fas_output))
                
                if has_anomaly[0] and not saved_in_this_epoch:
                    saved_in_this_epoch = True
                    
                    save_comparison(
                        class_name=class_name,
                        file_name=f"train_sample_epoch_{epoch}.jpg",
                        image=input_batch,
                        mask=ground_truth_batch,
                        reconstruction=vitcnn_reconstruction,
                        ssim_map=vitcnn_ssim_map,
                        fas_input=fas_input,
                        fas_output=fas_output,
                        processed_fas_output=processed_fas_output,
                        save_fig=True,
                        summary_writer=summary_writer,
                        epoch=epoch,
                        path=f"../runs/tasad-vitcnn/{class_name}/train")
                
                l2_loss = loss_l2(fas_output, ground_truth_batch)
                    
                sum_l2_loss += l2_loss.cpu()
                
                ssim_loss = loss_ssim(fas_output, ground_truth_batch)
                
                sum_ssim_loss += ssim_loss.cpu()
                
                loss = l2_loss + ssim_loss
                
                sum_loss += loss.cpu()                
                
                loss.backward()
                
                optimizer.step()
                
                progress_bar.update(
                    batch_index, 
                    log=f"({epoch+1}) Class: {class_name} | L2 loss: {(sum_l2_loss / (batch_index + 1)):.2f} | SSIM loss: {(sum_ssim_loss / (batch_index + 1)):.2f} | L2 and SSIM loss: {(sum_loss / (batch_index + 1)):.2f} ")
        
        scheduler.step()
        
        del input_batch
        del ground_truth_batch
        del has_anomaly
        del batch
        del fas_input
        del vitcnn_reconstruction
        del vitcnn_ssim_value
        del vitcnn_ssim_map
        del vitcnn_binary_ssim_map
        
        gc.collect()
        
        ap, ap_pixel, auroc, auroc_pixel = test_segmentation_model(
            vitcnn=vitcnn,
            class_name=class_name,
            dataloader=test_loader,
            epoch=epoch,
            gpu_id=0,
            fas_model=fas_model,
            visualizer=summary_writer
        )
        
        fas_model.train()
            
        print(f"(test) AP: {ap:.2f} | AP pixel: {ap_pixel:.2f} | AUROC: {auroc:.2f} AUROC pixel: {auroc_pixel:.2f}")
        
        summary_writer.add_scalar('fas_l2_loss', sum_l2_loss / len(train_loader), epoch)
        summary_writer.add_scalar('fas_ssim_loss', sum_ssim_loss / len(train_loader), epoch)
        
        avg_loss = sum_loss / len(train_loader)
        if avg_loss < best_loss and best_loss - avg_loss >= 0.01:
            best_loss = avg_loss
            best_epoch = epoch + 1
            
            torch.save(fas_model.state_dict(), os.path.join(f"../tasad_models/tasad_{class_name}.pt"))
        elif (epoch + 1) - best_epoch >= 50:
            break

In [8]:
def get_vit_model(class_name: str):
    model = Model(patch_size=16, depth=32).cuda()
    model.load_state_dict(torch.load(f"../vit_models/reconstruction/vit_{class_name}.pt"))
    model.eval()

    return model

In [9]:
def train_class(class_name: str):
    vit_model = get_vit_model(class_name)

    train_tasad_and_vitcnn(class_name, vit_model, learning_rate=0.0001)

In [10]:
train_class("bottle")



Starting training for class: bottle

Info: Found 209 sample for training
Info: Found 83 sample for test
Info: Initializing fas with 11959681 parameters


  return F.conv_transpose2d(
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
(1) Class: bottle | L2 loss: 0.10 | SSIM loss: 0.91 | L2 and SSIM loss: 1.01 [] Time:  0:00:24


2024-06-12 08:08:38.227907 Test for epoch 0: Class bottle Pixel AP 0.06 Pixel AUC 0.39 Image AUC 0.58 Image AP 0.83
(test) AP: 0.83 | AP pixel: 0.06 | AUROC: 0.58 AUROC pixel: 0.39


(2) Class: bottle | L2 loss: 0.07 | SSIM loss: 0.54 | L2 and SSIM loss: 0.60 [] Time:  0:00:24


OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB. GPU 

<Figure size 640x480 with 0 Axes>