In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.utils.data.dataloader
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from PIL import Image
import numpy as np
from tqdm import tqdm
from torchvision.models.swin_transformer import swin_t
from torchvision.models import resnet50, ResNet50_Weights
import cv2 as cv2

In [None]:
transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

df = pd.read_csv('./dataset/raw/aptos-eye/train.csv')

train_df, val_df = train_test_split(df,test_size=0.1,random_state=8)

print(train_df['diagnosis'].value_counts())
print(val_df['diagnosis'].value_counts())

    

class CustomDataset(Dataset):
    def __init__(self, df, img_dir_path, extension =".png", transform=None):
        self.df = df
        self.img_dir_path = img_dir_path
        self.extension = extension
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        frame = self.df.iloc[index]
        img_name = frame['id_code']
        label = frame['diagnosis']
        img_path = f"{self.img_dir_path}/{img_name}{self.extension}"

        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (256, 256))      
        image = Image.fromarray(img) 

        if self.transform is not None:
            image = self.transform(image)

        return image, label 
    

batch_size = 40

train_dataset = CustomDataset(df = train_df, img_dir_path="./dataset/raw/aptos-eye/train_images", transform=transform)
val_dataset = CustomDataset(df = val_df, img_dir_path="./dataset/raw/aptos-eye/train_images", transform=transform)

train_dataloader = DataLoader (dataset= train_dataset, batch_size=batch_size, shuffle = True)
val_dataloader = DataLoader (dataset= val_dataset, batch_size=batch_size, shuffle = True)



def visualizeImage( count=5):
    plt.figure(figsize=(16,9))
    for i in range(count):
        plt.subplot(1,count, i+1 )
        img,label = train_dataset[i]
        plt.title(label)
        plt.imshow(img.permute(1,2,0))


visualizeImage()    


In [3]:
class SwinResModel(nn.Module):
    def __init__(self, out_features=5):
        super().__init__()
        self.swinModel = swin_t(weights=None)
    
        swin_in_features = self.swinModel.head.in_features
        self.swinModel.head = nn.Linear(swin_in_features, 150)
                                   
        self.resModel = resnet50(weights=None)
     
        res_in_features = self.resModel.fc.in_features
        self.resModel.fc = nn.Linear(res_in_features, 150)     

        self.fc = nn.Sequential(
            nn.BatchNorm1d(300),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(300, 5)
        )     
        

    def forward(self, x):
        out_swin = self.swinModel(x)
        out_res = self.resModel(x)
        cat = torch.cat((out_swin, out_res), 1)
        out = self.fc(cat)
        return out



In [4]:

def visualizeResult(train_loss, train_acc, val_loss, val_acc, epochs):
    plt.title("Model's Loss Visualization")
    plt.plot(range(epochs), train_loss, label = "training loss")
    plt.plot(range(epochs), val_loss, label = "validation loss")
    plt.legend()
    plt.xlabel ("Epochs")
    plt.ylabel("Loss")
    plt.show()

    plt.title("Model's Accuracy Visualization")
    plt.plot(range(epochs), train_acc, label="training accuracy")
    plt.plot(range(epochs), val_acc, label ="validation accuracy")
    plt.legend()
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.show()


In [None]:
epochs = 30

def training(
        model, criterion, optimizer, epochs
):
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []

    for epoch in range(epochs):
        model.train()
        t_loss = 0
        t_acc = 0
        for i, data in enumerate(tqdm(train_dataloader)):
            imgs, labels = data
            imgs = imgs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            y_hat = model(imgs)
            loss = criterion(y_hat, labels)
            loss.backward()
            optimizer.step()
            
            t_loss+=loss.item()
            prediction_indices = torch.argmax(y_hat,1)
            correct = 0
            correct += (prediction_indices == labels).sum().item()
            t_acc += correct/labels.size(0)

        train_loss.append(t_loss/len(train_dataloader))
        train_acc.append(t_acc/len(train_dataloader))


        model.eval()
        v_loss = 0
        v_acc = 0
        with torch.no_grad():
            for i, data in enumerate(val_dataloader):
                imgs, labels = data
                imgs = imgs.to(device)
                labels = labels.to(device)
                y_hat = model(imgs)
                loss = criterion(y_hat, labels)
                v_loss+=loss.item()

                prediction_indices = torch.argmax(y_hat, 1)
                correct = 0
                correct += (prediction_indices==labels).sum().item()
                v_acc += correct/labels.size(0)
               
    
        val_loss.append(v_loss/len(val_dataloader))
        val_acc.append(v_acc/len(val_dataloader)) 

        print(f'Epoch {epoch+1}  Train Loss: {train_loss[epoch]:.2f},  Train accuracy: {train_acc[epoch]:.2f}, Validation Loss: {val_loss[epoch]:.2f},  Validation accuracy: {val_acc[epoch]:.2f}')
    
    torch.save(model.state_dict(), '/kaggle/working/resswin.pt')
    visualizeResult(train_loss, train_acc, val_loss, val_acc, epochs)
    


model = SwinResModel().to(device)
criterion = nn.CrossEntropyLoss()
learning_rate = 0.001

optimizer = torch.optim.Adam(model.parameters(), lr= learning_rate)


training(model, criterion, optimizer, epochs)



In [None]:
#inference

eyeClass = {0:'Normal', 1: 'Mild', 2:'Moderate',  
             3:'Severe', 4:'Proliferative'}


def inference(
        model: torch.nn.Module,
        device: torch.device,
        image:Image,
):
    model = model.to(device)
    model.eval() 
    image = transform(image) 
    with torch.no_grad():
        img = image.to(device).unsqueeze(0)
        prediction = model(img).squeeze(0)
        print(prediction)
        predict_index = torch.argmax(prediction,0).item()
        eye_diagnosis = eyeClass[predict_index]
        print(f'prediction: {eye_diagnosis}')


test_df = pd.read_csv("./dataset/raw/aptos-eye/sample_submission.csv")
n = np.random.randint(test_df.shape[0])
row = test_df.iloc[n]
row_img = row["id_code"]
print(f"img: {row_img}  label: {eyeClass[row['diagnosis']]}")
patient_eye_image = Image.open(f"./dataset/raw/aptos-eye/test_images/{row_img}.png")


inference(model, device, patient_eye_image)

        
