<a href="https://colab.research.google.com/github/pig8pig/ViT-Lithography-Hotspot-Detection/blob/main/ICCAD_1_ViT_(2).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip /content/drive/MyDrive/Yuzhong_Luo/iccad-official.zip -d /content/drive/MyDrive/Yuzhong_Luo/Data/

In [None]:
!pwd

In [None]:
import tensorflow as tf
import os
import numpy as np
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)
import warnings
warnings.filterwarnings('ignore')

In [None]:
folder = 'iccad1'

In [None]:
import zipfile

# Define the path where you want to unzip the file
zip_file_path = '/content/drive/MyDrive/Yuzhong_Luo/iccad-official.zip'
extract_path = '/content/drive/My Drive/Yuzhong_Luo/'

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print(f"Unzipped files are in: {extract_path}")

In [None]:
base_dir = os.path.join('/content/iccad-official/', folder)
print(base_dir)

In [None]:
# base_dir = os.path.join(os.path.dirname(folder), 'iccad1')
train_dir = '/content/iccad-official/iccad1/train'
validation_dir = os.path.join(base_dir, 'validation')
# directory with our training hotspot pictures
train_hotspot_dir = os.path.join(train_dir, 'Hotspot')

# directory with our training non-hotspot pictures
train_not_hotspot_dir = os.path.join(train_dir, 'Not_Hotspot')

# directory with our validation hotspot pictures
validation_hotspot_dir = os.path.join(validation_dir, 'Hotspot')

 # directory with our non-hotspot pictures
validation_not_hotspot_dir = os.path.join(validation_dir, 'Not_Hotspot')
num_hs_tr = len(os.listdir(train_hotspot_dir))
num_nhs_tr = len(os.listdir(train_not_hotspot_dir))

num_hs_val = len(os.listdir(validation_hotspot_dir))
num_nhs_val = len(os.listdir(validation_not_hotspot_dir))

total_train = num_hs_tr + num_nhs_tr
total_val = num_hs_val + num_nhs_val

print('The dataset contains:')
print('\u2022 {:,} training images'.format(total_train))
print('\u2022 {:,} validation images'.format(total_val))

print('\nThe training set contains:')
print('\u2022 {:,} images with hotspot'.format(num_hs_tr))
print('\u2022 {:,} images without hotspot'.format(num_nhs_tr))

print('\nThe validation set contains:')
print('\u2022 {:,} images with hotspot'.format(num_hs_val))
print('\u2022 {:,} images without hotspot'.format(num_nhs_val))

In [None]:
!pip install timm # torch image models library

In [None]:
import torch
import timm
import tqdm
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
pretrained_model = "vit_base_patch16_224"
img_size = 224
epochs = 15
batch_size = 100
lr = 1e-3

In [None]:
class_dict = {0: "HS", 1: "NHS"}

In [None]:
transform = transforms.Compose([transforms.Resize((img_size, img_size)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [None]:
"""
BATCH_SIZE = 64
IMG_SHAPE  = 224

image_gen = ImageDataGenerator(rescale=1./255,
                              horizontal_flip = True )

train_data = image_gen.flow_from_directory(directory=train_dir,
                                          batch_size=BATCH_SIZE,
                                          shuffle=True,
                                          target_size=(IMG_SHAPE,IMG_SHAPE),
                                          class_mode='binary')

plt.imshow(train_data[0][0][0])
plt.show()
"""

In [None]:
"""
image_gen_val = ImageDataGenerator(rescale=1./255)

test_data = image_gen_val.flow_from_directory(directory=validation_dir,
                                                 batch_size=BATCH_SIZE,
                                                 target_size=(IMG_SHAPE, IMG_SHAPE),
                                                 class_mode='binary')
plt.imshow(test_data[0][0][0])
plt.show()
"""

In [None]:

train_data = datasets.ImageFolder(root="iccad1_modified/train", transform=transform)
test_data = datasets.ImageFolder(root="iccad1_modified/validation", transform=transform)

In [None]:
len(train_data), len(test_data)

In [None]:
train_batches = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_batches = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [None]:
len(train_batches), len(test_batches)

In [None]:
net = timm.create_model(pretrained_model, pretrained=True)

for param in net.parameters():
    param.requires_grad = False

net.head = nn.Linear(net.head.in_features, 3)
#net.to(device)

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
input = torch.randn(100, 3, 224, 224) #.to(device)
output = net(input)
output.shape

In [None]:
num_parameters = sum(p.numel() for p in net.parameters() if p.requires_grad)
num_parameters

In [None]:
opt = torch.optim.Adam(net.parameters(), lr)
loss_fn = nn.CrossEntropyLoss()

In [None]:
def get_accuracy(preds, y):
    preds = preds.argmax(dim=1, keepdim=True)
    correct = preds.squeeze(1).eq(y)
    acc = correct.sum() / torch.FloatTensor([y.shape[0]]) #.to(device)

    return acc.item()

In [None]:
def loop(net, batches, train):
    batch_losses = []
    batch_accs = []

    if train:
        print("Train Loop:")
        print("")
        net.train()

        for X, y in tqdm.tqdm(batches, total=len(batches)):
            X = X #.to(device)
            y = y #.to(device)

            preds = net(X)
            loss = loss_fn(preds, y)
            acc = get_accuracy(preds, y)

            opt.zero_grad()
            loss.backward()
            opt.step()

            batch_losses.append(loss.item())
            batch_accs.append(acc)

    else:
        print("Validation Loop:")
        print("")
        net.eval()

        with torch.no_grad():
            for X, y in tqdm.tqdm(batches, total=len(batches)):
                X = X #.to(device)
                y = y #.to(device)

                preds = net(X)
                loss = loss_fn(preds, y)
                acc = get_accuracy(preds, y)

                batch_losses.append(loss.item())
                batch_accs.append(acc)

    print("")
    print("")

    return sum(batch_losses) / len(batch_losses), sum(batch_accs) / len(batch_accs)

In [None]:
def predict(net, img, transform, class_dict):
    img = Image.open(img).convert("RGB")
    img = transform(img) #.to(device)
    net.eval()

    with torch.no_grad():
        pred = net(img.unsqueeze(0))

    pred = pred.argmax(dim=1)
    print(class_dict[pred.item()])

In [None]:
for epoch in range(epochs):
    train_loss, train_acc = loop(net, train_batches, True)
    val_loss, val_acc = loop(net, test_batches, False)

    print(f"epoch: {epoch} | train_loss: {train_loss:.4f} | train_acc: {train_acc:.4f} | val_loss: {val_loss:.4f} | val_acc: {val_acc:.4f}")
    print("")