In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.utils.data.dataloader
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from PIL import Image
import numpy as np
from tqdm import tqdm
import cv2 as cv2

In [None]:
transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_df = pd.read_csv('./dataset/raw/aptos-eye/train.csv')

train_df, val_df = train_test_split(train_df,test_size=0.15,random_state=42)

print(train_df['diagnosis'].value_counts())
print(val_df['diagnosis'].value_counts())

def circle_crop(img):
    img = crop_image_from_gray(img)
    height, width, depth = img.shape
    largest_side = np.max((height, width))
    img = cv2.resize(img, (largest_side, largest_side))
    height, width, depth = img.shape
    x = int(width / 2)
    y = int(height / 2)
    r = np.amin((x, y))
    circle_img = np.zeros((height, width), np.uint8)
    cv2.circle(circle_img, (x, y), int(r), 1, thickness=-1)
    img = cv2.bitwise_and(img, img, mask=circle_img)
    img = crop_image_from_gray(img)

    return img

def crop_image_from_gray(img,tol=7):
    if img.ndim ==2:
        mask = img>tol
        return img[np.ix_(mask.any(1),mask.any(0))]
    elif img.ndim==3:
        gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        mask = gray_img>tol
        
        check_shape = img[:,:,0][np.ix_(mask.any(1),mask.any(0))].shape[0]
        if (check_shape == 0):
            return img 
        else:
            img1=img[:,:,0][np.ix_(mask.any(1),mask.any(0))]
            img2=img[:,:,1][np.ix_(mask.any(1),mask.any(0))]
            img3=img[:,:,2][np.ix_(mask.any(1),mask.any(0))]
            img = np.stack([img1,img2,img3],axis=-1)
        return img
    

class CustomDataset(Dataset):
    def __init__(self, df, img_dir_path, extension =".png", transform=None):
        self.df = df
        self.img_dir_path = img_dir_path
        self.extension = extension
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        frame = self.df.iloc[index]
        img_name = frame['id_code']
        label = frame['diagnosis']
        img_path = f"{self.img_dir_path}/{img_name}{self.extension}"

        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (256, 256))
        img = circle_crop(img)
        img = cv2.addWeighted ( img,4, cv2.GaussianBlur( img , (0,0) , 10) ,-4 ,128)       
        image = Image.fromarray(img) 

        if self.transform is not None:
            image = self.transform(image)

        return image, label 
    

batch_size = 40

train_dataset = CustomDataset(df = train_df, img_dir_path="./dataset/raw/aptos-eye/train_images", transform=transform)
val_dataset = CustomDataset(df = val_df, img_dir_path="./dataset/raw/aptos-eye/train_images", transform=transform)

train_dataloader = DataLoader (dataset= train_dataset, batch_size=batch_size, shuffle = True)
val_dataloader = DataLoader (dataset= val_dataset, batch_size=batch_size, shuffle = True)


def visualizeImage( count=5):
    plt.figure(figsize=(16,9))
    for i in range(count):
        plt.subplot(1,count, i+1 )
        img,label = train_dataset[i]
        plt.title(label)
        plt.imshow(img.permute(1,2,0))


visualizeImage()    



In [3]:
class CnnModel(nn.Module):
    def __init__(self):
        super(CnnModel, self).__init__()

        self.cnn_model = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=6, kernel_size=5),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2,stride=5),
            nn.Conv2d(in_channels=6,out_channels=16, kernel_size=5),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2,stride=5),
        )

        self.fc_model = nn.Sequential(
            nn.Linear(in_features= 1600, out_features=120),
            nn.ReLU(),
            nn.Linear(in_features= 120, out_features= 84 ),
            nn.ReLU(),
            nn.Linear(in_features= 84, out_features= 5 ),
        )


    def forward(self, x):
        x = self.cnn_model(x)
        x = x.view(x.size(0), -1)
        x = self.fc_model(x)
        x = F.sigmoid(x)   
        return x

In [4]:

def visualizeResult(train_loss, train_acc, val_loss, val_acc, epochs):
    plt.title("Model's Loss Visualization")
    plt.plot(range(epochs), train_loss, label = "training loss")
    plt.plot(range(epochs), val_loss, label = "validation loss")
    plt.legend()
    plt.xlabel ("Epochs")
    plt.ylabel("Loss")
    plt.show()

    plt.title("Model's Accuracy Visualization")
    plt.plot(range(epochs), train_acc, label="training accuracy")
    plt.plot(range(epochs), val_acc, label ="validation accuracy")
    plt.legend()
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.show()


In [None]:
epochs = 35

def training(
        model, criterion, optimizer, epochs
):
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []

    for epoch in range(epochs):
        model.train()
        t_loss = 0
        t_acc = 0
        for i, data in enumerate(tqdm(train_dataloader)):
            imgs, labels = data
            imgs = imgs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            y_hat = model(imgs)
            loss = criterion(y_hat, labels)
            loss.backward()
            optimizer.step()
            
            t_loss+=loss.item()
            prediction_indices = torch.argmax(y_hat,1)
            correct = 0
            correct += (prediction_indices == labels).sum().item()
            t_acc += correct/labels.size(0)

        train_loss.append(t_loss/len(train_dataloader))
        train_acc.append(t_acc/len(train_dataloader))


        model.eval()
        v_loss = 0
        v_acc = 0
        with torch.no_grad():
            for i, data in enumerate(val_dataloader):
                imgs, labels = data
                imgs = imgs.to(device)
                labels = labels.to(device)
                y_hat = model(imgs)
                loss = criterion(y_hat, labels)
                v_loss+=loss.item()

                prediction_indices = torch.argmax(y_hat, 1)
                correct = 0
                correct += (prediction_indices==labels).sum().item()
                v_acc += correct/labels.size(0)
               
    
        val_loss.append(v_loss/len(val_dataloader))
        val_acc.append(v_acc/len(val_dataloader)) 

        print(f'Epoch {epoch+1}  Train Loss: {train_loss[epoch]:.2f},  Train accuracy: {train_acc[epoch]:.2f}, Validation Loss: {val_loss[epoch]:.2f},  Validation accuracy: {val_acc[epoch]:.2f}')
    
    torch.save(model.state_dict(), 'swint.pt')
    visualizeResult(train_loss, train_acc, val_loss, val_acc, epochs)
    


model = CnnModel().to(device)
criterion = nn.CrossEntropyLoss()
learning_rate = 0.001

optimizer = torch.optim.Adam(model.parameters(), lr= learning_rate)


training(model, criterion, optimizer, epochs)



In [None]:
eyeClass = {0:'Normal', 1: 'Mild', 2:'Moderate',  
             3:'Severe', 4:'Proliferative'}

def inference(
        model: torch.nn.Module,
        device: torch.device,
        image:Image,
):
    model = model.to(device)
    model.eval() 
    image = transform(image) 
    with torch.no_grad():
        img = image.to(device).unsqueeze(0)
        prediction = model(img).squeeze(0)
        print(prediction)
        predict_index = torch.argmax(prediction,0).item()
        eye_diagnosis = eyeClass[predict_index]
        print(f'prediction: {eye_diagnosis}')


test_df = pd.read_csv("./dataset/raw/aptos-eye/sample_submission.csv")
n = np.random.randint(test_df.shape[0])
row = test_df.iloc[n]
row_img = row["id_code"]
print(f"img: {row_img}  label: {eyeClass[row['diagnosis']]}")
patient_eye_image = Image.open(f"./dataset/raw/aptos-eye/test_images/{row_img}.png")


inference(model, device, patient_eye_image)

        
