In [4]:
from google.colab import drive
drive.mount("/content/gdrive")

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [5]:
# dataset is stored in gdrive
Data = '/content/gdrive/My Drive/MVTec_dataset/mvtec_anomaly_detection.tar.xz'

In [8]:
import os
import lzma
import torch
import random
import shutil
import tarfile
import contextlib
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from collections import defaultdict
from torch.autograd import Variable
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Dataset




In [3]:
# since it is tar archive file --> using Lzma to extract this archive compressed file
with contextlib.closing(lzma.LZMAFile(Data)) as xz:
    with tarfile.open(fileobj=xz) as f:
        f.extractall('./MVTec-AD')

NameError: ignored

In [3]:

def rename_images_with_folder_name(folder_path, folder_name):
    if not os.path.exists(folder_path) or not os.path.isdir(folder_path):
        raise ValueError("The provided path is not a valid directory.")

    image_extensions = (".png")

    image_count = 0

    for file_name in os.listdir(folder_path):
        file_ext = os.path.splitext(file_name)[1].lower()

        if file_ext in image_extensions:
            image_count += 1
            new_file_name = f"{folder_name}_{image_count:03d}{file_ext}"
            old_file_path = os.path.join(folder_path, file_name)
            new_file_path = os.path.join(folder_path, new_file_name)
            os.rename(old_file_path, new_file_path)


In [4]:
def resize_image(image_path, target_size=(256, 256)):
    img = Image.open(image_path)
    img = img.resize(target_size, Image.ANTIALIAS)

    # If the image is not already in RGB mode, convert it to RGB
    if img.mode != "RGB":
        img = img.convert("RGB")

    return img

In [4]:
def preprocess_dataset(dataset_path, output_path, num_support_examples=5, num_query_examples=5):
    # Create output directories
    if not os.path.exists(output_path):
      os.makedirs(output_path)
    support_set_path = os.path.join(output_path, "support_set")
    query_set_path = os.path.join(output_path, "query_set")
    os.makedirs(support_set_path, exist_ok=True)
    os.makedirs(query_set_path, exist_ok=True)

    classes = os.listdir(dataset_path)
    for class_name in classes:
        if class_name == '.ipynb_checkpoints': continue
        class_path = os.path.join(dataset_path, class_name)

        print(class_path)
        # Create class directories in support set and query set
        support_class_dir = os.path.join(support_set_path, class_name)
        query_class_dir = os.path.join(query_set_path, class_name)
        os.makedirs(support_class_dir, exist_ok=True)
        os.makedirs(query_class_dir, exist_ok=True)

        # Get list of images in the 'good' folder (normal images) for this class
        good_images_dir = os.path.join(class_path, "train", "good")

        # rename images with it label
        rename_images_with_folder_name(good_images_dir, class_name)

        good_images = [os.path.join(good_images_dir, img) for img in os.listdir(good_images_dir) if img.endswith(".png")]

        # Randomly select support and query examples
        selected_support_examples = random.sample(good_images, num_support_examples)
        selected_query_examples = random.sample(good_images, num_query_examples)

        # Resize and move support examples to support set directory
        for img_path in selected_support_examples:
          img = resize_image(img_path)  # Resize the image
          img.save(os.path.join(support_class_dir, os.path.basename(img_path)))



        # Resize and move query examples (including defect images) to query set directory
        query_defect_dir = os.path.join(class_path, "test")


        for img_path in selected_query_examples:
          img = resize_image(img_path)  # Resize the image
          img.save(os.path.join(query_class_dir, os.path.basename(img_path)))


        defect_folders = os.listdir(query_defect_dir)
        for defect_folder in defect_folders:
            defect_images_dir = os.path.join(query_defect_dir, defect_folder)
            defect_folder = class_name if defect_folder == 'good' else defect_folder
            # rename images with it label
            rename_images_with_folder_name(defect_images_dir, defect_folder)
            defect_images = [os.path.join(defect_images_dir, img) for img in os.listdir(defect_images_dir) if img.endswith(".png")]
            for img_path in defect_images:
                img = resize_image(img_path)  # Resize the image
                img.save(os.path.join(query_class_dir, os.path.basename(img_path)))



In [9]:
dataset_path = "./MVTec-AD/"
output_path = "./MVTec-AD_preprocessed"

# Number of support examples per class
num_support_examples = 5

# Number of query examples per class
num_query_examples = 5

preprocess_dataset(dataset_path, output_path, num_support_examples, num_query_examples)

./MVTec-AD/hazelnut
./MVTec-AD/zipper
./MVTec-AD/transistor
./MVTec-AD/carpet
./MVTec-AD/wood
./MVTec-AD/leather
./MVTec-AD/grid
./MVTec-AD/bottle
./MVTec-AD/toothbrush
./MVTec-AD/metal_nut
./MVTec-AD/pill
./MVTec-AD/cable
./MVTec-AD/tile
./MVTec-AD/screw
./MVTec-AD/capsule


In [7]:
# Source path in Colab
source_folder_path = 'MVTec-AD_preprocessed'

# Destination path in Google Drive
destination_folder_path = '/content/gdrive/My Drive/'

# Copy the data folder to Google Drive
!cp -r "$source_folder_path" "$destination_folder_path"


In [9]:
# Replace 'folder_name' with the name of the folder you want to delete
folder_name = "MVTec-AD_preprocessed"

# Use shutil.rmtree() to delete the folder
shutil.rmtree(folder_name)

In [10]:
#neural network for image encoding
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Encoder, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x



class PrototypicalNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(PrototypicalNetwork, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_classes = num_classes

        # Flatten the input size for the first linear layer
        flattened_input_size = input_size * 256

        self.encoder = nn.Sequential(
            nn.Linear(flattened_input_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU()
        )

    def forward(self, x):
        # Flatten the input tensor for the first linear layer
        x = x.view(x.size(0), -1)
        embeddings = self.encoder(x)
        return embeddings


# load images from subfolders
class CustomImageFolder(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.classes = sorted(os.listdir(root))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.images = self._load_images()

    def _load_images(self):
        images = []
        for target_class in self.classes:
            class_dir = os.path.join(self.root, target_class)
            if not os.path.isdir(class_dir):
                continue
            for image_name in os.listdir(class_dir):
                image_path = os.path.join(class_dir, image_name)
                images.append((image_path, self.class_to_idx[target_class]))
        return images

    def __getitem__(self, index):
        image_path, target = self.images[index]
        image = Image.open(image_path).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)
        return image, target

    def __len__(self):
        return len(self.images)



In [11]:
support_set_path = '/content/gdrive/My Drive/support_set'
query_set_path = '/content/gdrive/My Drive/query_set'

In [None]:
# Set hyperparameters
input_size = 256 * 256 * 3  # Replace with the actual size of your resized image data
hidden_size = 256
num_classes = 14  # Number of classes
num_support_examples = 5  # Number of support examples per class
num_query_examples = 5  # Number of query examples per class
num_episodes = 5  # Number of episodes for training
learning_rate = 0.001

# Create an instance of the Prototypical Network model
model = PrototypicalNetwork(input_size, hidden_size, num_classes)

# Define the loss function (e.g., cross-entropy loss)
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Example data loaders for the support set and query set
support_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])
query_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

support_set = CustomImageFolder(support_set_path, transform=support_transform)
query_set = CustomImageFolder(query_set_path, transform=query_transform)

support_loader = DataLoader(support_set, batch_size=num_support_examples, shuffle=True)
query_loader = DataLoader(query_set, batch_size=num_query_examples, shuffle=True)


for episode in range(num_episodes):
    model.train()
    total_loss = 0

    for support_batch, support_labels in support_loader:
        for query_batch, query_labels in query_loader:
            optimizer.zero_grad()

            # Forward pass: compute embeddings for support set and query set
            support_embeddings = model(support_batch)
            query_embeddings = model(query_batch)

            # Calculate distances between class prototypes and query embeddings
            distances = torch.cdist(support_embeddings, query_embeddings)

            # Compute the loss
            loss = criterion(-distances, query_labels)  # Use negative distances for cross-entropy loss

            # Backpropagation
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

    print(f"Episode [{episode+1}/{num_episodes}]: Average Loss: {total_loss/len(support_loader)}")

# Save the trained model if needed
torch.save(model.state_dict(), "path/to/save/model.pth")
