In [5]:
from PIL import Image
import torch
from torchvision.transforms import transforms
from torch.utils.data import Dataset,DataLoader,TensorDataset
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import os
from IPython.display import display
import random

In [6]:
device = "cuda"if torch.cuda.is_available() else "cpu"
print("Device using:", device)

Device using: cuda


In [7]:
class MultiHeadAttendtion (nn.Module):
    def __init__(self,embed_dim,num_heads,qkv_bias=True):
        super(MultiHeadAttendtion, self).__init__()
        self.num_head = num_heads
        self.embed_dim = embed_dim
        self.scale = embed_dim**-0.5

        self.query = nn.Conv1d(in_channels=embed_dim,out_channels=embed_dim,kernel_size=1,bias=qkv_bias)
        self.key = nn.Conv1d(in_channels=embed_dim,out_channels=embed_dim,kernel_size=1,bias=qkv_bias)
        self.value = nn.Conv1d(in_channels=embed_dim,out_channels=embed_dim,kernel_size=1,bias=qkv_bias)
        self.proj = nn.Conv1d(in_channels=embed_dim,out_channels=embed_dim,kernel_size=1)
    def forward(self,x):
        B,T,E = x.shape
        q = self.query(x.transpose(1,2)).view(B,self.num_head,E//self.num_head,T).transpose(2,3)
        k = self.key(x.transpose(1,2)).view(B,self.num_head,E//self.num_head,T).transpose(2,3)
        v = self.value(x.transpose(1,2)).view(B,self.num_head,E//self.num_head,T).transpose(2,3)
        atten = (q @ k.transpose(-2,-1)) * self.scale
        atten = atten.softmax(dim=-1)
        x = (atten @ v).transpose(2,3).reshape(B,E,T)
        x = self.proj(x).transpose(1,2)
        return x

class TransformerBlock (nn.Module):
    def __init__(self,embed_dim,num_heads,mlp_ratio=4.0,in_channel=1):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        # self.multi = nn.MultiheadAttention(embed_dim=embed_dim,num_heads=num_heads,add_bias_kv=True) #kdim=embed_dim,vdim=embed_dim
        self.multi = MultiHeadAttendtion(embed_dim=embed_dim,num_heads=num_heads,qkv_bias=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        hidden_dim = (embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim,hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim,embed_dim)
        )
    def forward(self,x):
        # print(x.shape)
        # x = self.norm1(x)
        # x = x + self.multi(x,x,x)[0]
        x = x + self.multi(self.norm1(x))
        # print(x.shape)
        x = x + self.mlp(self.norm2(x))
        return x
    
class EmbeddingLayer (nn.Module):
    def __init__(self,in_channel=1,embed_dim=64,patch_size=128,patch_num=16):
        super(EmbeddingLayer, self).__init__()
        self.patch_num = patch_num
        self.proj = nn.Linear( in_channel * patch_size**2,embed_dim)
        self.dp = nn.Dropout(0.25)
        self.token = nn.Parameter(torch.zeros(1,1,embed_dim))
    def forward(self,x):
        batch_size = x.size(0)
        # print(x.shape)
        x = x.view(batch_size,self.patch_num,-1)
        # print(x.shape)
        embeding = self.proj(x)
        token = self.token.expand(batch_size,-1,-1)
        x = self.dp(torch.cat((token,embeding),dim=1))
        return x
    
class ViT(nn.Module):
    def __init__(self,in_channel=1,class_num=3,embed_dim=64,depth=4,num_head=4,mlp_ratio=4.0):
        super(ViT, self).__init__()
        self.embed_dim = embed_dim
        self.embedding = EmbeddingLayer(in_channel,embed_dim)
        self.transformblock = nn.ModuleList([
            TransformerBlock(embed_dim,num_head,mlp_ratio,in_channel) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim,class_num)
        self.dp = nn.Dropout(0.25)
    def forward(self,x):
        x = self.embedding(x)
        for block in self.transformblock:
            x = self.dp(block(x))
        token = x[:,0]
        token = self.head(self.dp(token))
        return token

In [None]:
model = ViT(1,3,400,4,16,4).to(device=device)
model.load_state_dict(torch.load('DCWreg_ViT.pth',weights_only=False))
model.eval()
BASE_DIR = os.path.abspath(os.path.dirname(os.getcwd()))
BASE_DIR

FileNotFoundError: [Errno 2] No such file or directory: 'Cleaned_ViT_data.pth'

In [None]:
rnd_images = []
for phase in os.listdir(os.path.join(BASE_DIR,"afhq")):
    phase_path = os.path.join(os.path.join(BASE_DIR,"afhq"),phase)
    for label in os.listdir(phase_path):
            label_path = os.path.join(phase_path,label)
            rnd_image = os.listdir(label_path)[random.randint(0,len(os.listdir(label_path))-1)]
            rnd_image_path = os.path.join(label_path,rnd_image)
            image = Image.open(rnd_image_path)
            rnd_images.append(image)
            # display(image)

In [None]:
def preprocess_Image(image,transform=None):
    if transform:
        patch = transform(image)
    patch_tensor = torch.tensor(np.array(patch),dtype = torch.float32, device=device)
    return patch_tensor
def classified_image(input,model):
    input = input.unsqueeze(0)
    with torch.no_grad():
        output = model(input)
        _, predicted = torch.max(output, 1)
        print(output,predicted.item())
        classified_image = 'cat' if predicted.item()==0 else 'dog' if predicted.item()==1 else 'wild'
        print(f"The above image is classified as: {classified_image}")
transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
    # transforms.CenterCrop(299),
    # transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [None]:
for image in rnd_images:
    display(image)
    img_tensor = preprocess_Image(image,transform)
    classified_image(img_tensor,model)