In [None]:
import os
import torch
import wandb
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn.functional as F
import seaborn as sns
import torch.nn as nn

from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import ViTForImageClassification
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from kaggle_secrets import UserSecretsClient
from tqdm.auto import tqdm as tqdm_progress

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

In [None]:
def init_wandb(project_name="vit-retinal-disease-cls", config=None):
    try:
        user_secrets = UserSecretsClient()

        wandb_api_key = user_secrets.get_secret("wandb")
        os.environ['WANDB_API_KEY'] = wandb_api_key

        wandb.login(key=wandb_api_key)
        
        run = wandb.init(
            project=project_name,
            config=config,
            tags=["ViT", "Retinal Disease"],
            notes="Vision Transformer on Retinal Disease Classification Dataset"
        )
        print("W&B successfully initialized")
        return run
    except Exception as e:
        print(f"Error initializing W&B: {str(e)}")
        return None

In [None]:
class RetinalDiseaseDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.labels_frame = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.diseases = [col for col in self.labels_frame.columns if col not in ['ID', 'Disease_Risk']]

        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.labels_frame)

    def __getitem__(self, idx):
        img_filename = str(self.labels_frame.iloc[idx, 0])
        
        if not img_filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            img_filename += '.png'
        
        img_name = os.path.join(self.img_dir, img_filename)
        
        if not os.path.exists(img_name):
            img_name = os.path.join(self.img_dir, str(self.labels_frame.iloc[0, 0]) + '.png')
        
        image = Image.open(img_name).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        labels = torch.tensor(
            self.labels_frame.iloc[idx, 1:].values, 
            dtype=torch.float32
        )
        
        return image, labels

In [None]:
def visualize_dataset_samples(dataloader, num_samples=12, title="Dataset Samples"):
    rows = int(np.ceil(np.sqrt(num_samples)))
    cols = rows
    
    plt.figure(figsize=(15, 15))
    plt.suptitle(title)
    
    images, labels = next(iter(dataloader))
    
    for i in range(min(num_samples, len(images))):
        plt.subplot(rows, cols, i+1)
        
        if isinstance(images[i], torch.Tensor):
            img = images[i].cpu().permute(1, 2, 0).numpy()
        else:
            img = images[i]
        
        img = img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
        img = np.clip(img, 0, 1)
        
        plt.imshow(img)
        plt.title(f"Labels: {labels[i].numpy()}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
def analyze_retinal_dataset(csv_path):
    df = pd.read_csv(csv_path)
    diseases = [col for col in df.columns if col not in ['ID', 'Disease_Risk']]
    
    total_samples = len(df)
    class_distribution = df[diseases].sum().sort_values(ascending=False)
    percentage_per_label = (class_distribution / total_samples * 100).round(2)
    
    disease_correlation = df[diseases].corr()
    
    multi_label_count = (df[diseases].sum(axis=1) > 1).sum()
    multi_label_percentage = (multi_label_count / total_samples * 100).round(2)
    
    plt.figure(figsize=(20, 15))
    
    plt.subplot(2, 2, 1)
    percentage_per_label.plot(kind='bar')
    plt.title('Disease Distribution')
    plt.xlabel('Disease')
    plt.ylabel('% of Dataset')
    plt.xticks(rotation=90)
    
    plt.subplot(2, 2, 2)
    sns.heatmap(disease_correlation, cmap='coolwarm', center=0, 
                annot=False, cbar_kws={'label': 'Correlation'})
    plt.title('Disease Correlation')
    
    plt.subplot(2, 2, 3)
    percentage_per_label.head(10).plot(kind='pie', autopct='%1.1f%%')
    plt.title('Top-10 Diseases')
    
    plt.subplot(2, 2, 4)
    labels_count = df[diseases].sum(axis=1)
    labels_count.plot(kind='hist', bins=range(1, labels_count.max()+2))
    plt.title('Distribution of Labels per Image')
    plt.xlabel('Number of Diseases')
    plt.ylabel('Frequency')
    
    plt.tight_layout()
    plt.show()
    
    print("\nDataset Analysis Report")
    print(f"Total number of samples: {total_samples}")
    print(f"\nSamples with multiple labels: {multi_label_count} ({multi_label_percentage}%)")
    
    print("\nTop-5 Most Common Diseases:")
    for disease, percentage in percentage_per_label.head().items():
        print(f"{disease}: {percentage}%")
    
    print("\nTop-5 Least Common Diseases:")
    for disease, percentage in percentage_per_label.tail().items():
        print(f"{disease}: {percentage}%")
    
    return {
        'total_samples': total_samples,
        'multi_label_samples': multi_label_count,
        'multi_label_percentage': multi_label_percentage,
        'most_common_diseases': percentage_per_label.head(),
        'least_common_diseases': percentage_per_label.tail()
    }

In [None]:
def create_dataloader(csv_path, img_dir, batch_size=32, shuffle=True, train=True):
    transform_list = [
        transforms.Resize((224, 224))
    ]
    
    if train:
        transform_list.extend([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10)
        ])
    
    transform_list.extend([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                              std=[0.229, 0.224, 0.225])
    ])

    transform = transforms.Compose(transform_list)

    dataset = RetinalDiseaseDataset(
        csv_file=csv_path, 
        img_dir=img_dir,
        transform=transform
    )
    
    return DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=shuffle,
        num_workers=0,
        pin_memory=True
    )

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        return torch.mean(F_loss)

In [None]:
def create_model(num_classes=46, pretrained=True):
    if pretrained:
        model = ViTForImageClassification.from_pretrained(
            'google/vit-base-patch16-224', 
            num_labels=num_classes,
            ignore_mismatched_sizes=True
        )
    else:
        model = ViTForImageClassification.from_pretrained(
            'google/vit-base-patch16-224', 
            num_labels=num_classes,
            num_hidden_layers=6,
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1
        )
    
    return model

In [None]:
def log_image_predictions(model, val_loader, device, wandb_run=None):
    model.eval()
    images, true_labels = next(iter(val_loader))
    images = images.to(device)
    true_labels = true_labels.cpu().numpy()
    
    with torch.no_grad():
        outputs = model(images).logits
        pred_probs = torch.sigmoid(outputs).cpu().numpy()
        pred_labels = (pred_probs > 0.5).astype(int)
    
    if wandb_run:
        wandb_images = []
        for i in range(min(16, len(images))):
            img = images[i].cpu().permute(1, 2, 0).numpy()
            img = img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
            img = np.clip(img, 0, 1)
            
            true_diseases = np.where(true_labels[i])[0]
            pred_diseases = np.where(pred_labels[i])[0]
            
            wandb_image = wandb.Image(
                img, 
                caption=f"True: {true_diseases}, Pred: {pred_diseases}"
            )
            wandb_images.append(wandb_image)
        
        wandb.log({"sample_predictions": wandb_images})

In [None]:
def train_model(model, train_loader, val_loader, device, epochs=20):
    try:
        wandb_run = init_wandb(config={
            "learning_rate": 1e-4,
            "epochs": epochs,
            "batch_size": train_loader.batch_size
        })
    except Exception as e:
        print(f"WandB initialization error: {e}")
        wandb_run = None

    criterion = FocalLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        train_preds, train_labels = [], []
        
        pbar = tqdm_progress(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images).logits
            loss = criterion(outputs, labels.float())
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            preds = torch.sigmoid(outputs).detach().cpu().numpy()
            train_preds.append(preds)
            train_labels.append(labels.cpu().numpy())
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        train_preds = np.concatenate(train_preds)
        train_labels = np.concatenate(train_labels)
        train_pred_binary = (train_preds > 0.5).astype(int)
        
        train_precision, train_recall, train_f1, _ = precision_recall_fscore_support(
            train_labels, train_pred_binary, average='micro')
        
        # Валидация
        model.eval()
        val_loss = 0
        val_preds, val_labels = [], []
        
        with torch.no_grad():
            val_pbar = tqdm_progress(val_loader, desc=f"Validation {epoch+1}/{epochs}")
            for images, labels in val_pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images).logits
                val_loss += criterion(outputs, labels.float()).item()
                
                preds = torch.sigmoid(outputs).detach().cpu().numpy()
                val_preds.append(preds)
                val_labels.append(labels.cpu().numpy())
                
                val_pbar.set_postfix({'loss': f'{val_loss/len(val_preds):.4f}'})
        
        val_preds = np.concatenate(val_preds)
        val_labels = np.concatenate(val_labels)
        val_pred_binary = (val_preds > 0.5).astype(int)
        
        val_precision, val_recall, val_f1, _ = precision_recall_fscore_support(
            val_labels, val_pred_binary, average='micro')
        
        if wandb_run:
            wandb.log({
                "epoch": epoch,
                "train_loss": train_loss/len(train_loader),
                "val_loss": val_loss/len(val_loader),
                "train_precision": train_precision,
                "train_recall": train_recall,
                "train_f1": train_f1,
                "val_precision": val_precision,
                "val_recall": val_recall,
                "val_f1": val_f1
            })
        
        print(f"Epoch {epoch+1}:")
        print(f"Train Loss: {train_loss/len(train_loader):.4f}")
        print(f"Val Loss: {val_loss/len(val_loader):.4f}")
        print(f"Train F1: {train_f1:.4f}, Val F1: {val_f1:.4f}")
        
        #if train_loss/len(train_loader) < 0.01:
        #    print("Early stopping due to low loss")
        #    break
    
    if wandb_run:
        wandb.finish()
    
    return model

In [None]:
BASE_PATH = "/kaggle/input/retinal-disease-classification"

TRAIN_CSV = os.path.join(BASE_PATH, "Training_Set/Training_Set/RFMiD_Training_Labels.csv")
TRAIN_IMG_DIR = os.path.join(BASE_PATH, "Training_Set/Training_Set/Training")

VAL_CSV = os.path.join(BASE_PATH, "Evaluation_Set/Evaluation_Set/RFMiD_Validation_Labels.csv")
VAL_IMG_DIR = os.path.join(BASE_PATH, "Evaluation_Set/Evaluation_Set/Validation")

In [None]:
print("\nVisualizing Training Dataset Samples")
train_loader = create_dataloader(TRAIN_CSV, TRAIN_IMG_DIR)
visualize_dataset_samples(train_loader, title="Training Dataset Samples")

In [None]:
print("\nVisualizing Validation Dataset Samples")
val_loader = create_dataloader(VAL_CSV, VAL_IMG_DIR, train=False)
visualize_dataset_samples(val_loader, title="Validation Dataset Samples")

In [None]:
print("Analyzing Training Dataset")
train_analysis = analyze_retinal_dataset(TRAIN_CSV)

In [None]:
model = create_model().to(device)

In [None]:
trained_model = train_model(model, train_loader, val_loader, device)