## VS CODE

In [45]:
import os
import pandas as pd
import torch
import numpy as np
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import timm
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

In [46]:
# === Configurations ===
ROOT_DIR = Path("C:/Users/rsriram3/Documents/ind_study")
OUTPUT_DIR = ROOT_DIR / "data"
CSV_PATH = OUTPUT_DIR / "full_augmented_images_metrics.csv"
CHECKPOINT_DIR = OUTPUT_DIR / "best_checkpoints"
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
# IMAGE_ROOT = ROOT_DIR / "images"
NUM_CLASSES = 2
IMAGE_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 15
PATIENCE = 5
NUM_FOLDS = 5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [47]:
# === Transforms ===
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [48]:
# === Dataset with optional aux features ===
class SharedHeadDataset(Dataset):
    def __init__(self, df, use_aux=True):
        self.df = df.reset_index(drop=True)
        self.use_aux = use_aux

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = (ROOT_DIR / row['image']).resolve()
        image = Image.open(image_path).convert("RGB")
        image = transform(image)
        label = torch.tensor(row['label'], dtype=torch.long)

        if self.use_aux:
            aux = torch.tensor([row['brightness'], row['edge_density'], row['entropy']], dtype=torch.float32)
            return image, aux, label
        else:
            return image, label

In [49]:
# === Collate function to unify batch ===
def collate_fn(batch):
    images, auxs, labels = zip(*batch)
    return torch.stack(images), torch.stack(auxs), torch.tensor(labels)

In [50]:
# === Mixup ===
def mixup_data(x, aux, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    mixed_aux = lam * aux + (1 - lam) * aux[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, mixed_aux, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [51]:
# === Extended Model ===
class ExtendedModel(nn.Module):
    def __init__(self, backbone, num_classes):
        super().__init__()
        self.backbone = backbone
        self.head = nn.Sequential(
            nn.Linear(backbone.num_features + 3, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        self.backbone.reset_classifier(0)

    def forward(self, x, aux):
        x = self.backbone(x)
        combined = torch.cat([x, aux], dim=1)
        return self.head(combined)

In [52]:
# === Evaluation ===
def evaluate_model(model, dataloader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, aux, labels in dataloader:
            images, aux, labels = images.to(DEVICE), aux.to(DEVICE), labels.to(DEVICE)
            outputs = model(images, aux)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    print(classification_report(all_labels, all_preds))
    return acc, f1, cm

In [53]:
def plot_metrics(train_accs, val_accs, name, path):
    plt.plot(train_accs, label='Train Acc')
    plt.plot(val_accs, label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title(name)
    plt.legend()
    plt.grid(True)
    plt.savefig(path / f"{name}.png")
    plt.clf()

In [54]:
# === Plot Confusion Matrix ===
def plot_confusion_matrix(cm, model_name):
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix: {model_name}")
    plt.show()

In [55]:
# === Model Training Function ===
# def train_model(model_name, train_csv, val_csv, root_dir, model_save_path):
#     train_df = pd.read_csv(train_csv)
#     val_df = pd.read_csv(val_csv)

#     train_dataset = SharedHeadDataset(train_df, use_aux=True)
#     val_dataset = SharedHeadDataset(val_df, use_aux=True)

#     train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
#     val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

#     base_model = timm.create_model(model_name, pretrained=True, num_classes=0)
#     model = ExtendedModel(base_model, num_classes=NUM_CLASSES).to(DEVICE)
#     optimizer = optim.AdamW(model.parameters(), lr=3e-4)
#     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
#     criterion = nn.CrossEntropyLoss()

#     train_accs, val_accs = [], []
#     best_val_acc, early_stop_counter = 0.0, 0

#     for epoch in range(EPOCHS):
#         model.train()
#         correct, total = 0, 0
#         for images, aux, labels in tqdm(train_loader):
#             images, aux, labels = images.to(DEVICE), aux.to(DEVICE), labels.to(DEVICE)
#             optimizer.zero_grad()
#             mixed_x, mixed_aux, y_a, y_b, lam = mixup_data(images, aux, labels)
#             outputs = model(mixed_x, mixed_aux)
#             loss = mixup_criterion(criterion, outputs, y_a, y_b, lam)
#             loss.backward()
#             optimizer.step()
#             _, preds = torch.max(outputs, 1)
#             total += labels.size(0)
#             correct += (lam * preds.eq(y_a).sum().item() + (1 - lam) * preds.eq(y_b).sum().item())
#         train_acc = correct / total
#         val_acc, val_f1, val_cm = evaluate_model(model, val_loader)
#         train_accs.append(train_acc)
#         val_accs.append(val_acc)
#         scheduler.step()

#         if val_acc > best_val_acc:
#             best_val_acc = val_acc
#             early_stop_counter = 0
#             torch.save(model.state_dict(), model_save_path)
#         else:
#             early_stop_counter += 1
#             if early_stop_counter >= PATIENCE:
#                 break
#     return model, train_accs, val_accs

In [56]:
class EarlyStopping:
    def __init__(self, patience=5, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_acc):
        if self.best_score is None or val_acc > self.best_score:
            self.best_score = val_acc
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} / {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True

In [57]:
# === Model Training Function - Early stopping===
def train_model(model_name, train_csv, val_csv, root_dir, model_save_path):
    train_df = pd.read_csv(train_csv)
    val_df = pd.read_csv(val_csv)

    train_dataset = SharedHeadDataset(train_df, use_aux=True)
    val_dataset = SharedHeadDataset(val_df, use_aux=True)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

    base_model = timm.create_model(model_name, pretrained=True, num_classes=0)
    model = ExtendedModel(base_model, num_classes=NUM_CLASSES).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=3e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
    criterion = nn.CrossEntropyLoss()

    early_stopper = EarlyStopping(patience=PATIENCE, verbose=True)

    train_accs, val_accs = [], []
    best_val_acc = 0.0

    print("Epoch\tTrain Acc\tVal Acc\tVal F1")
    for epoch in range(EPOCHS):
        model.train()
        correct, total = 0, 0
        for images, aux, labels in tqdm(train_loader):
            images, aux, labels = images.to(DEVICE), aux.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            mixed_x, mixed_aux, y_a, y_b, lam = mixup_data(images, aux, labels)
            outputs = model(mixed_x, mixed_aux)
            loss = mixup_criterion(criterion, outputs, y_a, y_b, lam)
            loss.backward()
            optimizer.step()
            _, preds = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (lam * preds.eq(y_a).sum().item() + (1 - lam) * preds.eq(y_b).sum().item())

        train_acc = correct / total
        val_acc, val_f1, val_cm = evaluate_model(model, val_loader)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        scheduler.step()

        print(f"{epoch+1}\t{train_acc:.4f}\t{val_acc:.4f}\t{val_f1:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), model_save_path)

        early_stopper(val_acc)
        if early_stopper.early_stop:
            print("Early stopping triggered.")
            break

    print("Best Val Accuracy:", best_val_acc)
    return model, train_accs, val_accs

In [58]:
model_configs = [
    ("swin_small_patch4_window7_224.ms_in1k", "swin_model.pth"),
    ("coatnet_1_rw_224.sw_in1k", "coatnet_model.pth"),
    ("convnext_small.fb_in1k", "convnext_model.pth"),
    ("tiny_vit_5m_224.dist_in22k_ft_in1k", "tiny_vit_model.pth"),
    ("edgenext_xx_small.in1k", "edgenext_xx_model.pth")
]

In [59]:
# === Load & Clean CSV ===
df = pd.read_csv(CSV_PATH)
df = df.dropna(subset=['image', 'label', 'brightness', 'edge_density', 'entropy']).reset_index(drop=True)

In [60]:
df.shape

(14810, 5)

In [61]:
# === Cross-validation ===
skf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)

In [62]:
model_id = 3
model_name = model_configs[model_id][0]
history = model_configs[model_id][1]
save_path = ROOT_DIR / "models" / history.split('.')[0]
save_path.mkdir(parents=True, exist_ok=True)
plot_save_path = Path("C:/Users/rsriram3/Documents/ind_study/test-IIM") / "figures" / "shared_head_figures" / history.split('.')[0]
plot_save_path.mkdir(parents=True, exist_ok=True)

In [63]:
# Create folds directory
FOLDS_DIR = OUTPUT_DIR / "folds"
FOLDS_DIR.mkdir(parents=True, exist_ok=True)

In [64]:
for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['label'])):
    print(f"\n Fold {fold + 1} ")
    train_df = df.iloc[train_idx]
    val_df = df.iloc[val_idx]

    train_fold_path = FOLDS_DIR / f"{history}_train_fold_{fold+1}.csv"
    val_fold_path = FOLDS_DIR / f"{history}_val_fold_{fold+1}.csv"
    train_df.to_csv(train_fold_path, index=False)
    val_df.to_csv(val_fold_path, index=False)

    model_save_path = save_path / f"{model_name}_fold{fold+1}.pth"
    model, train_accs, val_accs = train_model(model_name, train_fold_path, val_fold_path, ROOT_DIR, model_save_path)
    plot_metrics(train_accs, val_accs, f"{model_name}_fold{fold+1}", plot_save_path)



 Fold 1 


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Epoch	Train Acc	Val Acc	Val F1


100%|██████████| 371/371 [03:45<00:00,  1.64it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2003
           1       1.00      1.00      1.00       959

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

1	0.8643	0.9997	0.9995


100%|██████████| 371/371 [03:45<00:00,  1.65it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2003
           1       1.00      1.00      1.00       959

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

2	0.8865	0.9997	0.9995
EarlyStopping counter: 1 / 5


100%|██████████| 371/371 [03:44<00:00,  1.65it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2003
           1       1.00      1.00      1.00       959

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

3	0.8856	1.0000	1.0000


100%|██████████| 371/371 [03:46<00:00,  1.64it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2003
           1       1.00      1.00      1.00       959

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

4	0.8859	1.0000	1.0000
EarlyStopping counter: 1 / 5


100%|██████████| 371/371 [03:44<00:00,  1.65it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2003
           1       1.00      1.00      1.00       959

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

5	0.8947	1.0000	1.0000
EarlyStopping counter: 2 / 5


100%|██████████| 371/371 [03:44<00:00,  1.65it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2003
           1       1.00      1.00      1.00       959

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

6	0.8950	1.0000	1.0000
EarlyStopping counter: 3 / 5


100%|██████████| 371/371 [03:45<00:00,  1.65it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2003
           1       1.00      1.00      1.00       959

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

7	0.8972	1.0000	1.0000
EarlyStopping counter: 4 / 5


100%|██████████| 371/371 [03:44<00:00,  1.65it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2003
           1       1.00      1.00      1.00       959

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

8	0.8931	1.0000	1.0000
EarlyStopping counter: 5 / 5
Early stopping triggered.
Best Val Accuracy: 1.0

 Fold 2 
Epoch	Train Acc	Val Acc	Val F1


100%|██████████| 371/371 [03:43<00:00,  1.66it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2003
           1       1.00      1.00      1.00       959

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

1	0.8576	0.9997	0.9995


100%|██████████| 371/371 [03:42<00:00,  1.67it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2003
           1       1.00      1.00      1.00       959

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

2	0.8859	1.0000	1.0000


100%|██████████| 371/371 [03:42<00:00,  1.67it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2003
           1       1.00      1.00      1.00       959

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

3	0.8895	1.0000	1.0000
EarlyStopping counter: 1 / 5


100%|██████████| 371/371 [03:43<00:00,  1.66it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2003
           1       1.00      1.00      1.00       959

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

4	0.8881	1.0000	1.0000
EarlyStopping counter: 2 / 5


100%|██████████| 371/371 [03:41<00:00,  1.67it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2003
           1       1.00      1.00      1.00       959

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

5	0.8903	1.0000	1.0000
EarlyStopping counter: 3 / 5


100%|██████████| 371/371 [03:43<00:00,  1.66it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2003
           1       1.00      1.00      1.00       959

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

6	0.8933	1.0000	1.0000
EarlyStopping counter: 4 / 5


100%|██████████| 371/371 [03:44<00:00,  1.65it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2003
           1       1.00      1.00      1.00       959

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

7	0.8970	1.0000	1.0000
EarlyStopping counter: 5 / 5
Early stopping triggered.
Best Val Accuracy: 1.0

 Fold 3 
Epoch	Train Acc	Val Acc	Val F1


100%|██████████| 371/371 [03:44<00:00,  1.65it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

1	0.8616	1.0000	1.0000


100%|██████████| 371/371 [03:42<00:00,  1.67it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

2	0.8818	1.0000	1.0000
EarlyStopping counter: 1 / 5


100%|██████████| 371/371 [03:44<00:00,  1.65it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

3	0.8872	1.0000	1.0000
EarlyStopping counter: 2 / 5


100%|██████████| 371/371 [04:10<00:00,  1.48it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

4	0.8935	1.0000	1.0000
EarlyStopping counter: 3 / 5


100%|██████████| 371/371 [04:21<00:00,  1.42it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

5	0.8937	1.0000	1.0000
EarlyStopping counter: 4 / 5


100%|██████████| 371/371 [03:42<00:00,  1.67it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

6	0.8913	1.0000	1.0000
EarlyStopping counter: 5 / 5
Early stopping triggered.
Best Val Accuracy: 1.0

 Fold 4 
Epoch	Train Acc	Val Acc	Val F1


100%|██████████| 371/371 [03:40<00:00,  1.68it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

1	0.8567	1.0000	1.0000


100%|██████████| 371/371 [03:39<00:00,  1.69it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

2	0.8851	1.0000	1.0000
EarlyStopping counter: 1 / 5


100%|██████████| 371/371 [03:42<00:00,  1.67it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

3	0.8845	1.0000	1.0000
EarlyStopping counter: 2 / 5


100%|██████████| 371/371 [03:45<00:00,  1.64it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

4	0.8904	1.0000	1.0000
EarlyStopping counter: 3 / 5


100%|██████████| 371/371 [03:40<00:00,  1.68it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

5	0.8963	1.0000	1.0000
EarlyStopping counter: 4 / 5


100%|██████████| 371/371 [03:41<00:00,  1.67it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

6	0.8912	0.9986	0.9979
EarlyStopping counter: 5 / 5
Early stopping triggered.
Best Val Accuracy: 1.0

 Fold 5 
Epoch	Train Acc	Val Acc	Val F1


100%|██████████| 371/371 [03:42<00:00,  1.67it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

1	0.8613	1.0000	1.0000


100%|██████████| 371/371 [03:43<00:00,  1.66it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

2	0.8815	1.0000	1.0000
EarlyStopping counter: 1 / 5


100%|██████████| 371/371 [03:44<00:00,  1.65it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

3	0.8843	0.9997	0.9995
EarlyStopping counter: 2 / 5


100%|██████████| 371/371 [03:42<00:00,  1.67it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

4	0.8911	1.0000	1.0000
EarlyStopping counter: 3 / 5


100%|██████████| 371/371 [03:43<00:00,  1.66it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

5	0.8948	1.0000	1.0000
EarlyStopping counter: 4 / 5


100%|██████████| 371/371 [03:42<00:00,  1.66it/s]


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2004
           1       1.00      1.00      1.00       958

    accuracy                           1.00      2962
   macro avg       1.00      1.00      1.00      2962
weighted avg       1.00      1.00      1.00      2962

6	0.8889	0.9997	0.9995
EarlyStopping counter: 5 / 5
Early stopping triggered.
Best Val Accuracy: 1.0


<Figure size 640x480 with 0 Axes>