In [None]:
# Imports
from fastai.vision.all import *
from fastai.data.external import *
import pandas as pd
import os
import timm
from datetime import datetime
import json
import torch
import torch.nn as nn
from sklearn.metrics import average_precision_score, roc_auc_score
from functools import partial
import matplotlib.pyplot as plt
import gc
import numpy as np
import cv2
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader
from queue import Queue, Empty
import threading
from torchvision import transforms
import time
#import keyboard
import ipywidgets as widgets
from IPython.display import display, clear_output
import glob
import traceback


# FastAI defaults für sicheres Laden
import fastai.torch_core
fastai.torch_core.defaults.torch_load_kwargs = {'pickle_module': pickle, 'weights_only': True}

In [None]:
class CheckpointManager:
    def __init__(self, output_manager):
        self.output_manager = output_manager
        
    def wait_for_input(self):
        """Wartet kurz auf Benutzereingabe"""
        print("\n" + "="*50)
        print(f"Epoche beendet. Drücken Sie schnell 'j' + Enter für einen Checkpoint")
        print(f"oder warten Sie kurz für die nächste Epoche.")
        print("="*50 + "\n")
        
        import sys
        import select
        
        # Kurz auf Eingabe warten
        i, o, e = select.select([sys.stdin], [], [], 3)  # 3 Sekunden Timeout
        
        if i:  # Wenn Eingabe verfügbar
            user_input = sys.stdin.readline().strip().lower()
            return user_input == 'j'
        
        print("Keine Eingabe - Training wird fortgesetzt...")
        return False
            
    def save_checkpoint(self, model, optimizer, epoch, metrics_dict):
        """Speichert den aktuellen Zustand des Trainings"""
        try:
            print("\nStarte Checkpoint-Speicherung...")
            
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            checkpoint_path = self.output_manager.get_path(
                'models', 
                f'checkpoint_{timestamp}.pth'
            )
            
            # Checkpoint-Daten vorbereiten
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'metrics': metrics_dict
            }
            
            # Verzeichnis erstellen
            os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
            
            # Checkpoint speichern
            print(f"Speichere nach: {checkpoint_path}")
            torch.save(checkpoint, checkpoint_path)
            
            if os.path.exists(checkpoint_path):
                file_size = os.path.getsize(checkpoint_path)
                print(f"Checkpoint erfolgreich gespeichert! ({file_size/1024/1024:.1f} MB)")
                from fastai.learner import CancelTrainException
                raise CancelTrainException()
                
            print("Fehler: Checkpoint-Datei wurde nicht erstellt!")
            return False
            
        except Exception as e:
            if isinstance(e, CancelTrainException):
                raise
            print(f"Fehler beim Speichern des Checkpoints: {str(e)}")
            traceback.print_exc()
            return False

In [None]:
def clear_gpu_memory():
        """Bereinigt den GPU-Speicher"""
        gc.collect()
        torch.cuda.empty_cache()
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()

In [None]:
class OutputManager:
    def __init__(self):
        # Erstelle Hauptverzeichnis mit Zeitstempel
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.output_dir = f"run_{timestamp}"
        
        # Erstelle Unterverzeichnisse
        self.create_subdirs()
        
        print(f"Ausgabeverzeichnis erstellt: {self.output_dir}")
    
    def create_subdirs(self):
        """Erstellt die benötigten Unterverzeichnisse"""
        self.dirs = {
            'models': os.path.join(self.output_dir, 'models'),
            'plots': os.path.join(self.output_dir, 'plots'),
            'metrics': os.path.join(self.output_dir, 'metrics'),
            'predictions': os.path.join(self.output_dir, 'predictions')
        }
        
        # Erstelle alle Verzeichnisse
        for dir_path in self.dirs.values():
            os.makedirs(dir_path, exist_ok=True)
            print(f"Verzeichnis erstellt: {dir_path}")
    
    def get_path(self, category, filename):
        """Gibt den vollständigen Pfad für eine Datei zurück"""
        if category not in self.dirs:
            raise ValueError(f"Unbekannte Kategorie: {category}")
            
        # Stelle sicher, dass das Verzeichnis existiert
        os.makedirs(self.dirs[category], exist_ok=True)
        
        return os.path.join(self.dirs[category], filename)
    
    def ensure_dir_exists(self, path):
        """Stellt sicher, dass ein Verzeichnis existiert"""
        os.makedirs(os.path.dirname(path), exist_ok=True)
        return path

In [None]:
# Modellklassen auf Modulebene definieren
class CustomHead(nn.Module):
    def __init__(self, n_features, n_classes):
        super().__init__()
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.drop = nn.Dropout(0.3)
        self.fc = nn.Linear(n_features, n_classes)
    
    def forward(self, x):
        x = self.global_pool(x)
        x = self.flatten(x)
        x = self.drop(x)
        x = self.fc(x)
        return x

class MultiLabelModel(nn.Module):
    def __init__(self, backbone, head):
        super().__init__()
        self.backbone = backbone
        self.head = head
    
    def forward(self, x):
        features = self.backbone(x)
        return self.head(features)

In [None]:
class LRSchedulerCallback(Callback):
    def __init__(self, schedule_type, **kwargs):
        self.schedule_type = schedule_type
        self.kwargs = kwargs
        
    def before_fit(self):
        if self.schedule_type == 'cosine':
            self.scheduler = partial(
                torch.optim.lr_scheduler.CosineAnnealingLR,
                T_max=self.learn.n_epoch,
                eta_min=self.kwargs.get('eta_min', 1e-7)
            )
        elif self.schedule_type == 'cosine_warmup':
            self.scheduler = partial(
                torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
                T_0=self.kwargs.get('T_0', 10),
                T_mult=self.kwargs.get('T_mult', 2)
            )
        elif self.schedule_type == 'step':
            self.scheduler = partial(
                torch.optim.lr_scheduler.StepLR,
                step_size=self.kwargs.get('step_size', 30),
                gamma=self.kwargs.get('gamma', 0.1)
            )
        elif self.schedule_type == 'reduce_on_plateau':
            self.scheduler = partial(
                torch.optim.lr_scheduler.ReduceLROnPlateau,
                mode='min',
                factor=self.kwargs.get('factor', 0.1),
                patience=self.kwargs.get('patience', 10),
                verbose=True
            )
        else:
            raise ValueError(f"Unbekannter Schedule-Typ: {self.schedule_type}")
            
        self.learn.opt.lr_scheduler = self.scheduler(self.learn.opt.opt)

In [None]:
class MultiLabelMetrics(Callback):
    def __init__(self, labels, bbox_df, image_dir, output_manager):
        self.labels = labels
        self.bbox_df = bbox_df
        self.image_dir = image_dir
        self.output_manager = output_manager
        self.epoch_metrics = {
            'mean_ap': [],
            'mean_roc_auc': [],
            'overall_bias': [],
            'per_class_ap': {label: [] for label in labels},
            'per_class_roc': {label: [] for label in labels}
        }
        
    def get_bboxes_for_image(self, image_name):
        """Extrahiert Bounding Boxes für ein bestimmtes Bild"""
        try:
            image_bboxes = self.bbox_df[self.bbox_df['Image Index'] == image_name]
            boxes = []
            for _, row in image_bboxes.iterrows():
                boxes.append({
                    'x': row['Bbox [x'],
                    'y': row['y'],
                    'w': row['w'],
                    'h': row['h]'],
                    'label': row['Finding Label']
                })
            return boxes
        except Exception as e:
            print(f"Fehler beim Laden der Bounding Boxes für {image_name}: {str(e)}")
            return []

    def plot_metrics(self):
        """Erstellt und speichert Plots für Loss und Metriken"""
        try:
            if not hasattr(self, 'mean_ap_history') or len(self.mean_ap_history) == 0:
                print("Keine Metrik-Daten für Plots verfügbar")
                return
                
            # Korrekte Epochen-Liste erstellen, die zur Länge der gespeicherten Daten passt
            num_epochs = len(self.mean_ap_history)
            epochs = list(range(1, num_epochs + 1))
            
            # Prüfe, ob alle Listen gleich lang sind
            if hasattr(self, 'train_losses') and len(self.train_losses) > 0:
                if len(self.train_losses) != num_epochs:
                    print(f"Warnung: train_losses hat Länge {len(self.train_losses)}, mean_ap_history hat Länge {num_epochs}")
                    # Synchronisiere die Längen
                    if len(self.train_losses) > num_epochs:
                        self.train_losses = self.train_losses[:num_epochs]
                        self.valid_losses = self.valid_losses[:num_epochs]
                    
                # 1. Plot für Train und Validation Loss
                plt.figure(figsize=(12, 6))
                plt.plot(epochs, self.train_losses, 'b-', label='Training Loss')
                plt.plot(epochs, self.valid_losses, 'r-', label='Validation Loss')
                plt.title('Training und Validation Loss')
                plt.xlabel('Epoch')
                plt.ylabel('Loss')
                plt.legend()
                plt.grid(True)
                
                loss_plot_path = self.output_manager.get_path('plots', f'loss_epoch_{num_epochs}.png')
                plt.savefig(loss_plot_path, dpi=300, bbox_inches='tight')
                plt.close()
            
            # 2. Plot für Mean AP und ROC AUC
            plt.figure(figsize=(12, 6))
            plt.plot(epochs, self.mean_ap_history, 'g-', label='Mean AP')
            plt.plot(epochs, self.mean_roc_history, 'm-', label='Mean ROC AUC')
            plt.title('Mean AP und ROC AUC')
            plt.xlabel('Epoch')
            plt.ylabel('Score')
            plt.legend()
            plt.grid(True)
            
            metrics_plot_path = self.output_manager.get_path('plots', f'metrics_epoch_{num_epochs}.png')
            plt.savefig(metrics_plot_path, dpi=300, bbox_inches='tight')
            plt.close()
            
            # 3. Plot für Overall Bias
            if hasattr(self, 'overall_bias_history') and len(self.overall_bias_history) > 0:
                if len(self.overall_bias_history) != num_epochs:
                    print(f"Warnung: overall_bias_history hat Länge {len(self.overall_bias_history)}, mean_ap_history hat Länge {num_epochs}")
                    # Verwende nur so viele Datenpunkte wie in epochs
                    bias_data = self.overall_bias_history[:num_epochs]
                else:
                    bias_data = self.overall_bias_history
                    
                plt.figure(figsize=(12, 6))
                plt.plot(epochs, bias_data, 'r-')
                plt.title('Overall Bias über Epochen')
                plt.xlabel('Epoch')
                plt.ylabel('Bias (%)')
                plt.grid(True)
                
                bias_plot_path = self.output_manager.get_path('plots', f'bias_epoch_{num_epochs}.png')
                plt.savefig(bias_plot_path, dpi=300, bbox_inches='tight')
                plt.close()
            
            print(f"Plots gespeichert in {self.output_manager.dirs['plots']}")
            
        except Exception as e:
            print(f"Fehler beim Erstellen der Plots: {str(e)}")
            traceback.print_exc()
    
    def plot_performance_curves(self):
        """Erstellt Performance-Plots"""
        try:
            # Sicherheitscheck hinzufügen
            if len(self.epoch_metrics['mean_ap']) == 0:
                print("Keine Metrikdaten für Performance-Plots verfügbar")
                return
            epochs = range(1, len(self.epoch_metrics['mean_ap']) + 1)
            
            # Erstelle Figure mit 3 Subplots
            plt.figure(figsize=(15, 12))
            
            # Plot 1: Durchschnittliche Metriken
            plt.subplot(3, 1, 1)
            plt.plot(epochs, self.epoch_metrics['mean_ap'], 'b-', label='Mean AP')
            plt.plot(epochs, self.epoch_metrics['mean_roc_auc'], 'r-', label='Mean ROC AUC')
            plt.title('Durchschnittliche Performance-Metriken über Epochen')
            plt.xlabel('Epoche')
            plt.ylabel('Score')
            plt.legend()
            plt.grid(True)
            
            # Plot 2: Per-Class AP Scores
            plt.subplot(3, 1, 2)
            for label in self.labels:
                plt.plot(epochs, self.epoch_metrics['per_class_ap'][label], 
                        label=f'{label[:10]}...' if len(label) > 10 else label)
            plt.title('Average Precision pro Klasse')
            plt.xlabel('Epoche')
            plt.ylabel('AP Score')
            plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
            plt.grid(True)
            
            # Plot 3: Overall Bias
            plt.subplot(3, 1, 3)
            plt.plot(epochs, self.epoch_metrics['overall_bias'], 'g-')
            plt.title('Gesamtabweichung (Bias) über Epochen')
            plt.xlabel('Epoche')
            plt.ylabel('Absolute Abweichung (%)')
            plt.grid(True)
            
            plt.tight_layout()
            plt.savefig(
                self.output_manager.get_path('plots', f'performance_epoch_{len(epochs)}.png'),
                bbox_inches='tight', 
                dpi=300
            )
            plt.close()
        except Exception as e:
            print(f"Fehler beim Erstellen der Performance-Plots: {str(e)}")
            import traceback
            traceback.print_exc()

    def visualize_predictions(self, n_samples=6):
        """Visualisiert Vorhersagen für zufällige Bilder"""
        try:
            if not hasattr(self.learn.dls, 'train') or not hasattr(self.learn.dls, 'valid'):
                print("Keine Datenlader verfügbar")
                return
                
            # Erstelle eine einzelne Figur für alle Plots
            fig, axes = plt.subplots(2, n_samples, figsize=(20, 10))
            fig.suptitle('Trainings- und Validierungsvorhersagen', fontsize=16)
            
            def process_batch(dl, row_idx, axes_row):
                try:
                    # Hole einen Batch und stelle sicher, dass er gültig ist
                    batch = dl.one_batch()
                    if not isinstance(batch, (tuple, list)) or len(batch) < 2:
                        print(f"Ungültiger Batch-Format: {type(batch)}")
                        return
                        
                    images, labels = batch
                    if len(images) == 0:
                        print("Keine Bilder im Batch")
                        return
                    
                    # Vorhersagen für den Batch
                    self.learn.model.eval()
                    with torch.no_grad():
                        predictions = torch.sigmoid(self.learn.model(images))
                    
                    # Zufällige Indizes für die Visualisierung
                    # Beschränke n_samples auf verfügbare Bilder
                    actual_n_samples = min(n_samples, len(images))
                    indices = torch.randperm(len(images))[:actual_n_samples]
                    
                    for col_idx, idx in enumerate(indices):
                        try:
                            ax = axes_row[col_idx]
                        
                            # Bild denormalisieren und anzeigen
                            img = images[idx].cpu()
                            
                            # Denormalisierung für ConvNeXt
                            mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 1, 1)
                            std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 1, 1)
                            img = img * std + mean
                            
                            # Konvertiere zu Bild-Format und klemme Werte
                            img = img.permute(1, 2, 0)  # CHW -> HWC
                            img = torch.clamp(img, 0, 1)  # Beschränke auf [0,1]
                            
                            ax.imshow(img)
                            
                            # Versuche den Bildnamen zu extrahieren
                            try:
                                if hasattr(dl.dataset, 'items') and hasattr(dl, 'indices'):
                                    data_idx = dl.indices[idx] if idx < len(dl.indices) else None
                                    if data_idx is not None and data_idx < len(dl.dataset.items):
                                        image_name = dl.dataset.items[data_idx].get('Image Index', 'unknown')
                                    else:
                                        image_name = 'unknown'
                                else:
                                    image_name = 'unknown'
                            
                                # Bounding Boxes zeichnen
                                if image_name != 'unknown':
                                    bboxes = self.get_bboxes_for_image(image_name)
                                    for bbox in bboxes:
                                        rect = plt.Rectangle(
                                            (bbox['x'], bbox['y']),
                                            bbox['w'],
                                            bbox['h'],
                                            fill=False,
                                            edgecolor='red',
                                            linewidth=2
                                        )
                                        ax.add_patch(rect)
                                        ax.text(
                                            bbox['x'],
                                            bbox['y'] - 5,
                                            bbox['label'],
                                            color='red',
                                            fontsize=8,
                                            bbox=dict(facecolor='white', alpha=0.7)
                                        )
                            except Exception as e:
                                print(f"Fehler beim Zeichnen der Bounding Boxes: {e}")
                            
                            # Labels verarbeiten
                            try:
                                true_labels = [self.labels[j] for j in range(len(self.labels)) 
                                             if idx < len(labels) and j < len(labels[idx]) and labels[idx][j] == 1]
                                pred_labels = [self.labels[j] for j in range(len(self.labels)) 
                                             if j < len(predictions[idx]) and predictions[idx][j] > 0.5]
                                
                                # Labels als Text hinzufügen
                                ax.set_title(
                                    f'Wahr: {", ".join(true_labels[:3])}\n' + 
                                    f'Vorher.: {", ".join(pred_labels[:3])}',
                                    fontsize=8
                                )
                            except Exception as e:
                                print(f"Fehler bei der Label-Verarbeitung: {e}")
                                ax.set_title("Fehler bei Labels", fontsize=8)
                                
                            ax.axis('off')
                            
                        except Exception as e:
                            print(f"Fehler bei der Verarbeitung von Bild {idx}: {e}")
                            
                except Exception as e:
                    print(f"Fehler bei der Batch-Verarbeitung: {e}")
                    import traceback
                    traceback.print_exc()
            
            # Trainingsbilder in der ersten Reihe
            process_batch(self.learn.dls.train, 0, axes[0])
            axes[0, 0].set_ylabel('Training', fontsize=12)
            
            # Validierungsbilder in der zweiten Reihe
            process_batch(self.learn.dls.valid, 1, axes[1])
            axes[1, 0].set_ylabel('Validierung', fontsize=12)
            
            plt.tight_layout()
            plt.savefig(
                self.output_manager.get_path('predictions', f'predictions_epoch_{len(self.epoch_metrics["mean_ap"])}.png'),
                bbox_inches='tight',
                dpi=300
            )
            plt.close()
        except Exception as e:
            print(f"Fehler bei der Visualisierung der Vorhersagen: {str(e)}")
            import traceback
            traceback.print_exc()

    def update_metrics_csv(self):
        """Aktualisiert CSV-Datei mit Metriken für alle Epochen"""

        # Prüfe, ob Metrik-Daten verfügbar sind
        if len(self.epoch_metrics['mean_ap']) == 0:
            print("Keine Metrikdaten für CSV-Update verfügbar")
            return
            
        csv_path = self.output_manager.get_path('metrics', 'training_metrics.csv')

        # Prüfe, ob Metrik-Daten verfügbar sind
        if len(self.epoch_metrics['mean_ap']) == 0:
            print("Keine Metrikdaten für CSV-Update verfügbar")
            return
        
        # Erstelle Header falls Datei nicht existiert
        if not os.path.exists(csv_path):
            header = ['epoch', 'mean_ap', 'mean_roc_auc', 'overall_bias', 'train_loss', 'valid_loss']
            header.extend([f'{label}_actual' for label in self.labels])
            header.extend([f'{label}_predicted' for label in self.labels])
            
            with open(csv_path, 'w') as f:
                f.write(','.join(header) + '\n')
        
        # Aktuelle Werte hinzufügen
        epoch = len(self.epoch_metrics['mean_ap'])
        mean_ap = self.epoch_metrics['mean_ap'][-1]
        mean_roc = self.epoch_metrics['mean_roc_auc'][-1]
        overall_bias = self.epoch_metrics['overall_bias'][-1]
        
        train_loss = float(self.learn.recorder.values[-1][0]) if self.learn.recorder.values else None
        valid_loss = float(self.learn.recorder.values[-1][1]) if self.learn.recorder.values else None
        
        # Klassen-Verteilungen
        actual_values = []
        predicted_values = []
        
        # Prüfe, ob die Verteilungen verfügbar sind
        if hasattr(self, 'last_target_distribution') and hasattr(self, 'last_pred_distribution'):
            for i, label in enumerate(self.labels):
                true_dist = self.last_target_distribution[i].item() * 100 if i < len(self.last_target_distribution) else 0
                pred_dist = self.last_pred_distribution[i].item() * 100 if i < len(self.last_pred_distribution) else 0
                actual_values.append(str(true_dist))
                predicted_values.append(str(pred_dist))
        else:
            # Wenn keine Verteilungen verfügbar sind, fülle mit 0
            for _ in self.labels:
                actual_values.append("0")
                predicted_values.append("0")
        
        # Zeile zusammenbauen
        row = [str(epoch), str(mean_ap), str(mean_roc), str(overall_bias), str(train_loss), str(valid_loss)]
        row.extend(actual_values)
        row.extend(predicted_values)
        
        # An CSV anhängen
        with open(csv_path, 'a') as f:
            f.write(','.join(row) + '\n')

    def plot_and_save_all_metrics(self):
        """Erstellt und speichert umfassende Plots für alle Metriken"""
        
         # Sicherheitscheck 
        if len(self.epoch_metrics['mean_ap']) == 0:
            print("Keine Metrikdaten für umfassende Plots verfügbar")
            return
            
        # Performance-Metriken
        plt.figure(figsize=(15, 10))
        
        plt.subplot(2, 2, 1)
        epochs = range(1, len(self.epoch_metrics['mean_ap'])+1)
        plt.plot(epochs, self.epoch_metrics['mean_ap'], 'b-', label='Mean AP')
        plt.plot(epochs, self.epoch_metrics['mean_roc_auc'], 'g-', label='Mean ROC AUC')
        plt.title('Performance-Metriken')
        plt.xlabel('Epoche')
        plt.ylabel('Score')
        plt.legend()
        plt.grid(True)
        
        # Loss-Kurven
        plt.subplot(2, 2, 2)
        train_losses = [x[0] for x in self.learn.recorder.values]
        valid_losses = [x[1] for x in self.learn.recorder.values]
        plt.plot(range(1, len(train_losses)+1), train_losses, 'b-', label='Train Loss')
        plt.plot(range(1, len(valid_losses)+1), valid_losses, 'r-', label='Valid Loss')
        plt.title('Loss-Kurven')
        plt.xlabel('Epoche')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)
        
        # Klassenverteilung der letzten Epoche
        plt.subplot(2, 2, 3)
        
        # Verwende einen sicheren Ansatz - prüfe, ob die Verteilungsvariablen existieren
        if hasattr(self, 'last_target_distribution') and hasattr(self, 'last_pred_distribution'):
            top_classes = sorted(range(len(self.labels)), 
                             key=lambda i: self.last_target_distribution[i].item() if i < len(self.last_target_distribution) else 0,
                             reverse=True)[:8]  # Top 8 Klassen
            
            x = range(len(top_classes))
            width = 0.35
            true_vals = [self.last_target_distribution[i].item() * 100 if i < len(self.last_target_distribution) else 0 for i in top_classes]
            pred_vals = [self.last_pred_distribution[i].item() * 100 if i < len(self.last_pred_distribution) else 0 for i in top_classes]
            
            plt.bar([p - width/2 for p in x], true_vals, width, label='Tatsächlich')
            plt.bar([p + width/2 for p in x], pred_vals, width, label='Vorhergesagt')
            plt.xticks(x, [self.labels[i] for i in top_classes], rotation=45, ha='right')
            plt.title('Klassenverteilung (Epoche {})'.format(len(self.epoch_metrics['mean_ap'])))
            plt.ylabel('Prozent')
            plt.legend()
        else:
            # Wenn keine Daten verfügbar sind, zeige einen Platzhaltertext
            plt.text(0.5, 0.5, "Keine Verteilungsdaten verfügbar", 
                    horizontalalignment='center', verticalalignment='center')
            plt.title('Klassenverteilung (keine Daten)')
        
        # Overall Bias über Epochen
        plt.subplot(2, 2, 4)
        plt.plot(epochs, self.epoch_metrics['overall_bias'], 'r-')
        plt.title('Overall Bias')
        plt.xlabel('Epoche')
        plt.ylabel('Prozent')
        plt.grid(True)
        
        plt.tight_layout()
        plt.savefig(self.output_manager.get_path('plots', f'all_metrics_epoch_{len(self.epoch_metrics["mean_ap"])}.png'),
                   bbox_inches='tight', dpi=300)
        plt.close()

    def after_batch(self):
        if not hasattr(self, 'predictions'):
            self.predictions = []
            self.targets = []
        
        if self.learn.training:
            return
        
        preds = self.learn.pred
        targets = self.learn.y
        self.predictions.append(preds.detach().cpu())
        self.targets.append(targets.detach().cpu())
        
    def after_epoch(self):
        if len(self.predictions) == 0:
            return
                
        try:
            preds = torch.sigmoid(torch.cat(self.predictions))
            targets = torch.cat(self.targets)
    
            # Berechne Metriken...
            metrics_dict = {
                'average_precision': {},
                'roc_auc': {},
                'class_distribution': {},
                'prediction_distribution': {}
            }
            
            # Berechne Vorhersageverteilung
            pred_binary = (preds > 0.5).float()
            pred_distribution = pred_binary.mean(dim=0)
            target_distribution = targets.float().mean(dim=0)

            # Speichere die Verteilungen für andere Methoden
            self.last_target_distribution = target_distribution
            self.last_pred_distribution = pred_distribution
            
            print("\nVerteilungsanalyse der Krankheiten:")
            print("--------------------------------")
            
            # Sicherstellen, dass die Indizes korrekt sind
            for i, label in enumerate(self.labels):
                if i >= len(pred_distribution) or i >= len(target_distribution):
                    print(f"Warnung: Index {i} außerhalb des gültigen Bereichs")
                    continue
                    
                true_dist = target_distribution[i].item() * 100
                pred_dist = pred_distribution[i].item() * 100
                bias = pred_dist - true_dist
                
                print(f"\n{label}:")
                print(f"  Anteil in Daten: {true_dist:.1f}%")
                print(f"  Anteil in Vorhersagen: {pred_dist:.1f}%")
                print(f"  Abweichung: {bias:+.1f}%")
                
                metrics_dict['prediction_distribution'][label] = {
                    'true_distribution': true_dist,
                    'predicted_distribution': pred_dist,
                    'bias': bias
                }
                
                # Performance Metriken
                try:
                    if i < targets.shape[1] and i < preds.shape[1]:
                        # Prüfe, ob positive Beispiele vorhanden sind
                        if torch.sum(targets[:, i]) > 0:
                            ap = average_precision_score(targets[:, i], preds[:, i])
                            metrics_dict['average_precision'][label] = float(ap)
                            self.epoch_metrics['per_class_ap'][label].append(float(ap))
                        else:
                            print(f"Hinweis: Keine positiven Beispiele für {label} in diesem Batch")
                            metrics_dict['average_precision'][label] = None
                            self.epoch_metrics['per_class_ap'][label].append(0.0)
                except Exception as e:
                    print(f"Fehler bei AP Score für {label}: {str(e)}")
                    metrics_dict['average_precision'][label] = None
                    self.epoch_metrics['per_class_ap'][label].append(0.0)
                
                try:
                    if i < targets.shape[1] and i < preds.shape[1]:
                        # Prüfe, ob positive und negative Beispiele vorhanden sind
                        unique_values = torch.unique(targets[:, i])
                        if len(unique_values) > 1:
                            roc = roc_auc_score(targets[:, i], preds[:, i])
                            metrics_dict['roc_auc'][label] = float(roc)
                            self.epoch_metrics['per_class_roc'][label].append(float(roc))
                        else:
                            value_type = "keine" if len(unique_values) == 0 else ("nur positive" if unique_values[0] == 1 else "nur negative")
                            print(f"Hinweis: {value_type} Beispiele für {label} in diesem Batch")
                            metrics_dict['roc_auc'][label] = None
                            self.epoch_metrics['per_class_roc'][label].append(0.0)
                except Exception as e:
                    print(f"Fehler bei ROC AUC für {label}: {str(e)}")
                    metrics_dict['roc_auc'][label] = None
                    self.epoch_metrics['per_class_roc'][label].append(0.0)
            
            # Durchschnittsmetriken
            valid_ap_scores = [score for score in metrics_dict['average_precision'].values() if score is not None]
            valid_roc_scores = [score for score in metrics_dict['roc_auc'].values() if score is not None]
            
            mean_ap = np.mean(valid_ap_scores) if valid_ap_scores else 0.0
            mean_roc = np.mean(valid_roc_scores) if valid_roc_scores else 0.0
            overall_bias = float(torch.abs(pred_distribution - target_distribution).mean() * 100)

            # Loss-Werte für Plots speichern
            if hasattr(self.learn.recorder, 'values') and len(self.learn.recorder.values) > 0:
                if not hasattr(self, 'train_losses'):
                    self.train_losses = []
                    self.valid_losses = []
                    
                # Extrahiere die neuesten Loss-Werte
                latest_record = self.learn.recorder.values[-1]
                if len(latest_record) >= 2:
                    self.train_losses.append(float(latest_record[0]))
                    self.valid_losses.append(float(latest_record[1]))
            
            # Metriken für Plots speichern
            if not hasattr(self, 'mean_ap_history'):
                self.mean_ap_history = []
                self.mean_roc_history = []
                self.overall_bias_history = []
                
            self.mean_ap_history.append(mean_ap)
            self.mean_roc_history.append(mean_roc)
            self.overall_bias_history.append(overall_bias)

            # Nach Berechnung der Metriken
            epoch_metrics = {
                'epoch': len(self.epoch_metrics['mean_ap']),
                'mean_ap': float(mean_ap),
                'mean_roc_auc': float(mean_roc),
                'overall_bias': float(overall_bias),
                'class_distribution': {
                    label: {
                        'actual': float(true_dist),
                        'predicted': float(pred_dist),
                        'bias': float(bias)
                    } for i, label in enumerate(self.labels)
                    if i < len(pred_distribution) and i < len(target_distribution)
                },
                'loss': {
                    'train': float(self.learn.recorder.values[-1][0]) if self.learn.recorder.values else None,
                    'valid': float(self.learn.recorder.values[-1][1]) if self.learn.recorder.values else None
                }
            }

            # Metriken in csv speichern
            self.update_metrics_csv()
            
            # Speichere Metriken als JSON
            metrics_path = self.output_manager.get_path('metrics', f'epoch_{len(self.epoch_metrics["mean_ap"])}_metrics.json')
            with open(metrics_path, 'w') as f:
                json.dump(epoch_metrics, f, indent=2)
            
            # Recorder-Werte verarbeiten
            values = []
            for x in self.learn.recorder.values:
                if x is not None:
                    try:
                        # X ist eine Liste von [train_loss, valid_loss, accuracy]
                        values.append({
                            'train_loss': float(x[0]),
                            'valid_loss': float(x[1]),
                            'accuracy': float(x[2])
                        })
                    except Exception as e:
                        print(f"Fehler beim Konvertieren der Values: {str(e)}")
            
            
            metrics_dict['mean_ap'] = mean_ap
            metrics_dict['mean_roc_auc'] = mean_roc
            metrics_dict['overall_bias'] = overall_bias
            metrics_dict['validation_values'] = values
            
            self.epoch_metrics['mean_ap'].append(mean_ap)
            self.epoch_metrics['mean_roc_auc'].append(mean_roc)
            self.epoch_metrics['overall_bias'].append(overall_bias)
            
            print(f"\nEpoche {len(self.epoch_metrics['mean_ap'])}:")
            print(f"Mean AP: {mean_ap:.3f}")
            print(f"Mean ROC AUC: {mean_roc:.3f}")
            print(f"Overall Bias: {overall_bias:.1f}%")
            
            # Metrics im Learner speichern
            self.learn.metrics_dict = metrics_dict
            
            # Plots aktualisieren
            #self.plot_metrics()
            self.plot_performance_curves()
            self.visualize_predictions()
            self.plot_and_save_all_metrics()
            
            # Checkpoint-Info ausgeben
            print("\n" + "="*50)
            print(f"Epoche {len(self.epoch_metrics['mean_ap'])} abgeschlossen")
            print(f"Mean AP: {mean_ap:.3f}")
            print(f"Mean ROC AUC: {mean_roc:.3f}")
            print(f"Overall Bias: {overall_bias:.1f}%")
            print("="*50)
            
            # Automatisches Speichern des Checkpoints
            checkpoint_mgr = CheckpointManager(self.output_manager)
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            checkpoint_path = self.output_manager.get_path(
                'models', 
                f'checkpoint_epoch_{len(self.epoch_metrics["mean_ap"])}_{timestamp}.pth'
            )
            
            # Checkpoint-Daten vorbereiten
            checkpoint = {
                'epoch': len(self.epoch_metrics['mean_ap']),
                'model_state_dict': self.learn.model.state_dict(),
                'optimizer_state_dict': self.learn.opt.state_dict(),
                'metrics': self.learn.metrics_dict
            }
            
            # Verzeichnis erstellen
            os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
            
            # Checkpoint speichern
            print(f"Speichere Modell nach: {checkpoint_path}")
            torch.save(checkpoint, checkpoint_path)
            
            if os.path.exists(checkpoint_path):
                file_size = os.path.getsize(checkpoint_path)
                print(f"Checkpoint erfolgreich gespeichert! ({file_size/1024/1024:.1f} MB)")
                print("Training wird fortgesetzt...")
            
        except Exception as e:
            print(f"Fehler in after_epoch: {str(e)}")
            traceback.print_exc()
        finally:
            # Listen zurücksetzen
            self.predictions = []
            self.targets = []

In [None]:
class LungXRayDataset:
    def __init__(self, image_dir, labels_csv, bbox_csv, aug_dir=None, aug_csv=None, max_samples=None):
        self.image_dir = image_dir
        self.aug_dir = aug_dir  # Verzeichnis für augmentierte Bilder
        
        # Lade Labels
        self.labels_df = pd.read_csv(labels_csv)
        
        # Lade augmentierte Labels, falls vorhanden
        if aug_csv is not None and os.path.exists(aug_csv):
            aug_df = pd.read_csv(aug_csv)
            self.labels_df = pd.concat([self.labels_df, aug_df], ignore_index=True)
            print(f"Kombinierten Datensatz erstellt mit {len(self.labels_df)} Bildern")
        
        if max_samples is not None:
            self.labels_df = self.labels_df.sample(
                n=min(max_samples, len(self.labels_df)), 
                random_state=42
            )

        # HIER DAS BILDPFAD-CACHING EINFÜGEN:
        # Erstelle ein Pfad-Cache für schnellen Zugriff
        print("Erstelle Cache für Bildpfade...")
        self.image_path_cache = {}
        for img_name in self.labels_df['Image Index'].unique():
            orig_path = os.path.join(self.image_dir, img_name)
            aug_path = os.path.join(self.aug_dir, img_name) if self.aug_dir else None
            
            if os.path.exists(orig_path):
                self.image_path_cache[img_name] = orig_path
            elif aug_path and os.path.exists(aug_path):
                self.image_path_cache[img_name] = aug_path
        print(f"Pfad-Cache erstellt für {len(self.image_path_cache)} Bilder")
        
        self.bbox_df = pd.read_csv(bbox_csv)
        
        # Extrahiere unique Labels
        self.disease_labels = sorted(list(set(
            label.strip() 
            for labels in self.labels_df['Finding Labels'].str.split('|')
            for label in labels
        )))
        
        # Optional: Zähle Klassen für informative Zwecke, ohne komplexe Gewichtungsberechnungen
        self.class_counts = {}
        for labels in self.labels_df['Finding Labels'].str.split('|'):
            for label in labels:
                label = label.strip()
                if label not in self.class_counts:
                    self.class_counts[label] = 0
                self.class_counts[label] += 1
                
        # Ausgabe der Statistiken
        print("Gefundene Krankheiten:")
        for label, count in self.class_counts.items():
            print(f"{label}: {count} Bilder")
    
    # Neue Methode zum Suchen von Bildern in verschiedenen Verzeichnissen
    def _find_image_path(self, image_name):
        """Sucht ein Bild im Pfad-Cache"""
        # Zuerst im Cache nachschlagen
        if image_name in self.image_path_cache:
            return self.image_path_cache[image_name]
        
        # Falls nicht im Cache (sollte selten vorkommen)
        orig_path = os.path.join(self.image_dir, image_name)
        if os.path.exists(orig_path):
            # Zum Cache hinzufügen und zurückgeben
            self.image_path_cache[image_name] = orig_path
            return orig_path
        
        # Falls nicht gefunden und Augmentation-Verzeichnis vorhanden, dort suchen
        if self.aug_dir:
            aug_path = os.path.join(self.aug_dir, image_name)
            if os.path.exists(aug_path):
                # Zum Cache hinzufügen und zurückgeben
                self.image_path_cache[image_name] = aug_path
                return aug_path
        
        # Wenn nicht gefunden, Original-Pfad zurückgeben (wird einen Fehler auslösen)
        return os.path.join(self.image_dir, image_name)

    def _calculate_sample_weights(self):
        """Berechnet Gewichte für jeden Datenpunkt basierend auf seinen Labels"""
        total_samples = len(self.labels_df)
        weights = []
        
        for labels in self.labels_df['Finding Labels'].str.split('|'):
            # Gewichte für jeden Datenpunkt basierend auf seinen Labels
            label_weights = [total_samples / (self.class_counts[label.strip()] + 1e-5) 
                           for label in labels]
            weights.append(np.mean(label_weights))
        
        return torch.FloatTensor(weights)
    
    def get_data(self, image_size=224):
        """Erstellt einen FastAI DataBlock"""
        return DataBlock(
            blocks=(ImageBlock, MultiCategoryBlock(vocab=self.disease_labels)),
            get_x=lambda x: self._find_image_path(x['Image Index']),  # Angepasste get_x Funktion
            get_y=lambda x: [label.strip() for label in x['Finding Labels'].split('|')],
            splitter=RandomSplitter(),
            item_tfms=Resize(460),
            batch_tfms=[
                # Keine Augmentation hier, nur Normalisierung
                Normalize.from_stats([0.5]*3, [0.5]*3)  # ConvNeXt Normalisierung
            ]
        )
    
    def create_dataloader(self, batch_size=32, image_size=224, num_workers=6):
        """Erstellt DataLoader ohne Augmentation-Pipeline"""
        dblock = self.get_data(image_size)
        
        dls = dblock.dataloaders(
            self.labels_df,
            bs=batch_size
        )
        
        # Direkt num_workers setzen
        dls.train.num_workers = num_workers
        dls.valid.num_workers = num_workers

        # Debug: Nach DataLoader Erstellung
        print(f"DataLoader settings:")
        print(f"Training DataLoader workers: {dls.train.num_workers}")
        print(f"Training DataLoader batch size: {dls.train.bs}")
        
        # Berechne Gewichte für Weighted Sampling (behalten wir bei)
        weights = self._calculate_sample_weights()
        
        # Erstelle Weighted Sampler
        sampler = torch.utils.data.WeightedRandomSampler(
            weights=weights,
            num_samples=len(weights),
            replacement=True
        )
        
        # Ersetze den Training-DataLoader mit normalem DataLoader und gewichtetem Sampling
        # KEIN AugmentedDataLoader mehr!
        dls.train.dl = DataLoader(
            dls.train.dataset,
            batch_size=batch_size,
            sampler=sampler,
            num_workers=num_workers,
            pin_memory=True
        )
        
        return dls

In [None]:
class GradientAccumulation(Callback):
    def __init__(self, n_acc=2):
        self.n_acc = n_acc
    
    def before_batch(self):
        if not self.training: return
        # Teile Loss durch Akkumulationsschritte
        self.learn.loss_scale = 1./self.n_acc
    
    def after_backward(self):
        if not self.training: return
        # Nur nach n_acc Schritten optimieren
        if (self.iter+1) % self.n_acc != 0:
            # GPU Cache leeren
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            raise CancelBatchException()

In [None]:
class ModelTrainer:
    def __init__(self, model_name, dataloader, disease_labels, bbox_csv, image_dir, output_manager):
        self.model_name = model_name
        self.dataloader = dataloader
        self.disease_labels = disease_labels
        self.bbox_csv = bbox_csv
        self.image_dir = image_dir
        self.output_manager = output_manager
        self.history = {
            'train_loss': [],
            'valid_loss': [],
            'learning_rates': [],
            'scheduler_info': {}
        }

    def _clear_gpu_memory(self):
        """Bereinigt den GPU-Speicher"""
        gc.collect()
        torch.cuda.empty_cache()
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()

    def get_model(self):
        base_arch = timm.create_model(
            'convnext_tiny.fb_in22k', 
            pretrained=True,
            num_classes=0,
            global_pool=''
        )
        
        n_features = base_arch.num_features
        head = CustomHead(n_features, len(self.disease_labels))
        model = MultiLabelModel(base_arch, head)
        
        # Erstelle den FastAI Learner und speichere ihn als Instanzvariable
        self.learn = Learner(
            self.dataloader,
            model,
            loss_func=nn.BCEWithLogitsLoss(),
            metrics=[accuracy_multi]
        )

        # MixedPrecision (AMP) einschalten um GPU SPeichernutzung zu reduzieren
        self.learn = self.learn.to_fp16()
        
        return self.learn

    def find_learning_rate(self, model, suggestion_mode='steep'):
        print(f"Suche optimale Learning Rate für {self.model_name}...")
        try:
            suggestions = model.lr_find(
                start_lr=1e-7,
                end_lr=1e-1,
                num_it=50,
                show_plot=True
            )
            # GPU-Speicher nach dem Learning Rate Finder leeren
            self._clear_gpu_memory()
            
            return suggestions.valley
        except Exception as e:
            print(f"Fehler beim LR Finding: {str(e)}")

            # Auch bei Fehlerfall Speicher leeren
            self._clear_gpu_memory()
        
            return 1e-4
    def train(self, scheduler_config, epochs=None, freeze_epochs=2, resume_from_checkpoint=None, additional_epochs=0):
        print(f"Training {self.model_name}...")
        
        # PyTorch Speichermanagement optimieren
        import os
        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

            
        # Importiere traceback am Anfang der Funktion
        import traceback

        # GPU-Speicher nach dem Learning Rate Finder leeren
        self._clear_gpu_memory()
        
        if epochs is None:
            epochs = scheduler_config.get('epochs', 10)  # Nutze epochs aus config oder default 10
            
        # Wenn additional_epochs angegeben ist, erhöhe die Gesamtanzahl
            if additional_epochs > 0:
                print(f"Erhöhe Anzahl der Epochen um {additional_epochs}")
                epochs += additional_epochs
            
            print(f"Geplante Gesamtanzahl der Epochen: {epochs}")
        
        if not hasattr(self, 'learn'):
            self.learn = self.get_model()
        
        self.learn.add_cb(MultiLabelMetrics(
            self.disease_labels,
            pd.read_csv(self.bbox_csv),
            self.image_dir,
            self.output_manager
        ))

        # Learning rate für später festlegen
        if scheduler_config.get('use_lr_finder', True):
            lr = scheduler_config.get('base_lr', 1e-4)
        else:
            lr = scheduler_config.get('base_lr', 1e-4)
        
        # Sehr wichtig: Initialisiere den Optimizer BEVOR der Checkpoint geladen wird
        if not hasattr(self.learn, 'opt') or self.learn.opt is None:
            # Optimizer erstellen (ohne lr Parameter)
            self.learn.create_opt()
            print("Optimizer initialisiert")
        
        # Wenn Checkpoint angegeben, lade diesen
        start_epoch = 0
        if resume_from_checkpoint:
            try:
                print(f"Versuche Checkpoint zu laden: {resume_from_checkpoint}")
                
                # Prüfe, ob die Datei ein state_dict oder ein vollständiger Checkpoint ist
                checkpoint = torch.load(resume_from_checkpoint)
                
                # Prüfe das Format
                if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
                    # Vollständiger Checkpoint
                    print("Vollständiger Checkpoint erkannt")
                    self.learn.model.load_state_dict(checkpoint['model_state_dict'])
                    
                    # Lade Optimizer-Zustand
                    if 'optimizer_state_dict' in checkpoint:
                        self.learn.opt.load_state_dict(checkpoint['optimizer_state_dict'])
                    
                    # Starte bei der nächsten Epoche
                    if 'epoch' in checkpoint:
                        start_epoch = checkpoint['epoch']
                        print(f"Training wird ab Epoche {start_epoch + 1} fortgesetzt")
                    
                    # Lade Metriken, falls vorhanden
                    if 'metrics' in checkpoint:
                        self.learn.metrics_dict = checkpoint['metrics']
                        print("Metriken wurden geladen")
                else:
                    # Nur state_dict
                    print("Nur Modell-Gewichte (state_dict) erkannt")
                    self.learn.model.load_state_dict(checkpoint)
                
                print("Checkpoint erfolgreich geladen!")
                    
            except Exception as e:
                print(f"Fehler beim Laden des Checkpoints: {str(e)}")
                traceback.print_exc()
    
        try:
            start_time = datetime.now()  # Start-Zeit messen
            
            # Learning rate finding, nur wenn kein Checkpoint geladen wurde
            if start_epoch == 0 and scheduler_config.get('use_lr_finder', True):
                suggested_lr = self.find_learning_rate(self.learn)
                print(f"Empfohlene Learning Rate: {suggested_lr:.2e}")
                lr = suggested_lr
            else:
                lr = scheduler_config.get('lr', 1e-4)
            
            print("\nStarting training...")
            
            # Training mit Gradient Accumulation
            n_acc = scheduler_config.get('gradient_accumulation', 2)  # Default auf 2, auh mal 4 probieren
            print(f"Training mit Gradient Accumulation: {n_acc}")
            remaining_epochs = epochs - start_epoch
            
            print(f"Start Epoche: {start_epoch}")
            print(f"Verbleibende Epochen: {remaining_epochs}")
            
            if remaining_epochs > 0:
                self.learn.fit_one_cycle(
                    remaining_epochs, 
                    slice(lr/25, lr),
                    wd=scheduler_config.get('weight_decay', 0.01),
                    cbs=[GradientAccumulation(n_acc=n_acc)]
                )

            # Modell speichern
            model_filename = f"{self.model_name}_final"
            
            state_dict_path = self.output_manager.get_path('models', f"{model_filename}_state_dict.pth")
            self.output_manager.ensure_dir_exists(state_dict_path)
            torch.save(self.learn.model.state_dict(), state_dict_path)

            training_time = (datetime.now() - start_time).total_seconds()  # Trainingszeit berechnen
            
            # History updaten
            self.history.update({
                'training_time': training_time,
                'final_metrics': getattr(self.learn, 'metrics_dict', {})
            })
            
            return self.learn
            
        except Exception as e:
            print(f"Fehler während des Trainings: {str(e)}")
            import traceback
            traceback.print_exc()
            
        finally:
            self._clear_gpu_memory()

In [None]:
def setup_training_config():
    """Erstellt die grundlegende Trainingskonfiguration"""
    config = {
        'paths': {
            'IMAGE_DIR': "../Dataset/images",
            'LABELS_CSV': "../Dataset/Data_Entry_2017_v2020.csv",
            'BBOX_CSV': "../Dataset/BBox_List_2017.csv",
            'AUGMENTED_DIR': "../Dataset/augmented_images"  # Neu: Verzeichnis für augmentierte Bilder
        },
        'model_config': {
            'convnext_base': {
                'batch_size': 12,
                'image_size': 192, #statt 224
                'base_lr': 7.59e-5,#1e-4
                'num_workers': 6,
                'prefetch_factor': 4
            }
        },
        'scheduler_config': [{
            'type': 'one_cycle',
            'use_lr_finder': False,
            'epochs': 10,
            'pct_start': 0.3,
            'div': 25,
            'num_it': 50,              # Reduzierte Anzahl von Iterationen für LR-Finder
            'gradient_accumulation': 2 # Neue Option, auh mal 4 probieren
        }]
    }
    return config

In [None]:
def setup_training_environment(config, max_samples=None):
    """Initialisiert die Trainingsumgebung mit vorberechneten augmentierten Daten"""
    # Überprüfen der Existenz der Dateien und Verzeichnisse
    for path in config['paths'].values():
        if not os.path.exists(path):
            raise FileNotFoundError(f"Pfad nicht gefunden: {path}")
    print("Alle Pfade sind verfügbar.")

    # Output Manager erstellen
    output_manager = OutputManager()
    
    # Konfiguration speichern
    with open(output_manager.get_path('metrics', 'configuration.json'), 'w') as f:
        json.dump(config, f, indent=4)

    # Prüfe, ob augmentierte Daten existieren
    aug_dir = config['paths'].get('AUGMENTED_DIR', '../Dataset/augmented_images')
    aug_csv = os.path.join(aug_dir, 'augmented_labels.csv')
    
    # Dataset mit Originaldaten und augmentierten Daten erstellen
    dataset = LungXRayDataset(
        image_dir=config['paths']['IMAGE_DIR'],
        labels_csv=config['paths']['LABELS_CSV'],
        bbox_csv=config['paths']['BBOX_CSV'],
        aug_dir=aug_dir if os.path.exists(aug_dir) else None,
        aug_csv=aug_csv if os.path.exists(aug_csv) else None,
        max_samples=max_samples
    )
    
    return output_manager, dataset

In [None]:
def create_trainer(model_name, config, dataset, output_manager):
    """Erstellt und konfiguriert den ModelTrainer"""
    model_config = config['model_config'][model_name]
    
    dataloader = dataset.create_dataloader(
        batch_size=model_config['batch_size'],
        image_size=model_config['image_size'],
        num_workers=model_config['num_workers']
        #prefetch_factor=model_config['prefetch_factor']
    )
    
    trainer = ModelTrainer(
        model_name, 
        dataloader, 
        dataset.disease_labels,
        config['paths']['BBOX_CSV'],
        config['paths']['IMAGE_DIR'],
        output_manager
    )
    
    return trainer

In [None]:
#Hauptausfuehrungscode
def main(resume_from=None, additional_epochs=0):
    """Hauptfunktion für Training und Fortsetzung"""
    # GPU-Speicher bereinigen
    clear_gpu_memory()
    print("GPU-Speicher bereinigt")
    
    # Konfiguration laden
    config = setup_training_config()
    
    # Trainingsumgebung einrichten
    output_manager, dataset = setup_training_environment(config)
    
    # Trainer für jedes Modell erstellen und trainieren
    results = {}
    for model_name in config['model_config'].keys():
        print(f"\nTraining {model_name}...")
        model_results = {}
        
        trainer = create_trainer(model_name, config, dataset, output_manager)
        
        for scheduler_config in config['scheduler_config']:
            print(f"\nVerwende {scheduler_config['type']} Scheduler...")
            
            # Checkpoint-Pfad vorbereiten, wenn angegeben
            checkpoint_path = None
            if resume_from:
                # Prüfe, ob resume_from ein Verzeichnis oder eine Datei ist
                if os.path.isdir(resume_from):
                    # Wenn ein Verzeichnis, suche den neuesten Checkpoint
                    checkpoint_files = glob.glob(os.path.join(resume_from, "*.pth"))
                    if checkpoint_files:
                        checkpoint_path = max(checkpoint_files, key=os.path.getctime)
                        print(f"Neuester Checkpoint gefunden: {checkpoint_path}")
                    else:
                        print(f"Keine Checkpoint-Dateien in {resume_from} gefunden.")
                else:
                    # Wenn keine Verzeichnis, nutze den Pfad direkt
                    checkpoint_path = resume_from
                    if not os.path.exists(checkpoint_path):
                        print(f"Warnung: Checkpoint {checkpoint_path} existiert nicht!")
            
            # Training starten oder fortsetzen
            model = trainer.train(
                scheduler_config,
                resume_from_checkpoint=checkpoint_path,
                additional_epochs=additional_epochs
            )
            model_results[scheduler_config['type']] = trainer.history
        
        results[model_name] = model_results
    
    # Finale Ergebnisse speichern
    with open(output_manager.get_path('metrics', 'final_results.json'), 'w') as f:
        json.dump(results, f, indent=4)

    # Plot der finalen Trainingsmetriken
    if results and 'convnext_base' in results and 'one_cycle' in results['convnext_base']:
        plt.figure(figsize=(15, 5))
        
        # Loss Plot
        plt.subplot(1, 2, 1)
        plt.plot(results['convnext_base']['one_cycle']['train_loss'], label='Training Loss')
        plt.plot(results['convnext_base']['one_cycle']['valid_loss'], label='Validation Loss')
        plt.title('Training und Validation Loss')
        plt.xlabel('Batch')
        plt.ylabel('Loss')
        plt.legend()
        
        # Learning Rate Plot
        plt.subplot(1, 2, 2)
        plt.plot(results['convnext_base']['one_cycle']['learning_rates'])
        plt.title('Learning Rate Schedule')
        plt.xlabel('Batch')
        plt.ylabel('Learning Rate')
        plt.tight_layout()
        
        plt.savefig(output_manager.get_path('plots', 'final_training_metrics.png'))
        plt.close()
    else:
        print("Keine Trainingsergebnisse verfügbar für das Plotting")

In [None]:
clear_gpu_memory()
print("Go! Go! Go!")

In [None]:
main(
    resume_from="run_20250227_180651/models",#/checkpoint_epoch_5_20250220_180123.pth",
)

In [None]:
main()

In [None]:
if __name__ == "__main__":
    # Neues Training starten
    main()
    
    # Oder: Training von Checkpoint fortsetzen
    
    # main(
        #resume_from="run_20250213_135241/models",
        #additional_epochs=5
    #)

### Zum Laden eines Modells:

In [None]:
# Nur Gewichte laden
state_dict = torch.load(weights_path)
model.load_state_dict(state_dict)

# Vollständiges Modell laden
checkpoint = torch.load(full_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])