In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import polars as pl

from sklearn.model_selection import train_test_split

from tqdm.notebook import tqdm

import random
import numpy as np

# Set the seed for Python's built-in random module
random.seed(69)

# Set the seed for NumPy's random number generator
np.random.seed(69)

# Set the seed for PyTorch's random number generators
torch.manual_seed(69)
torch.cuda.manual_seed(69)
torch.cuda.manual_seed_all(69)  # if you are using multi-GPU.
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [2]:
pl.read_parquet('/home/owner/Documents/DEV/BrainLabyrinth/data/combined.parquet')

event_id,marker,time,Fp1,Fpz,Fp2,F7,F3,Fz,F4,F8,FC5,FC1,FC2,FC6,M1,T7,C3,Cz,C4,T8,M2,CP5,CP1,CP2,CP6,P7,P3,Pz,P4,P8,POz,O1,O2,EOG,AF7,AF3,AF4,AF8,F5,F1,F2,F6,FC3,FCz,FC4,C5,C1,C2,C6,CP3,CP4,P5,P1,P2,P6,PO5,PO3,PO4,PO6,FT7,FT8,TP7,TP8,PO7,PO8,Oz,__null_dask_index__
i64,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64
0,"""Stimulus/1""",9.438,9.53724,7.803913,-6.957125,-3.662101,-1.831665,6.678428,0.643271,-28.855688,17.219198,10.443791,8.557194,5.307324,-41.916786,-20.510318,20.105794,7.447749,18.204764,-21.123294,-54.043086,-0.779434,9.156651,3.35639,7.050375,8.437749,10.035689,18.781002,12.663772,-1.516957,5.923493,9.755382,5.507086,56.784839,-16.524688,0.995008,1.210651,-16.334868,5.347515,7.849705,4.708359,-21.316867,2.927581,9.449216,5.787111,3.851742,5.561468,1.623157,0.986564,13.195735,5.790764,9.944564,14.604585,21.713373,4.75506,10.76817,11.458283,-8.927917,3.754084,0.357387,-4.394041,-8.618822,-5.172546,11.513064,4.785936,5.749947,0
0,"""Stimulus/1""",9.44,13.396619,13.868635,-2.922562,3.933463,1.559054,9.798244,3.592579,-24.953035,15.061343,8.709912,9.965191,5.720247,-70.75158,-41.867714,15.554878,6.088562,21.294248,-0.120773,-50.053209,-8.458973,6.922121,3.178088,16.021653,-9.844435,-2.435621,18.337207,18.720653,21.930048,5.166919,-11.485745,27.682245,62.681269,-15.060192,0.757932,5.598742,-17.990556,2.588855,7.099604,5.117679,-16.808443,5.047086,12.311436,11.37285,3.025393,8.060817,5.687988,5.582753,10.034344,11.733741,-3.211322,11.257885,25.065433,21.465142,-6.313013,-5.096768,-1.385707,27.073425,-1.288571,-4.394788,-19.987022,14.889966,-10.593543,27.311613,-6.368204,1
0,"""Stimulus/1""",9.442,14.850883,14.096507,-0.078051,-0.246867,1.38473,11.079935,4.932076,-24.924988,14.634668,9.740824,11.113594,5.932303,-66.520317,-25.757618,14.687831,6.78713,22.50476,-9.39821,-44.082786,-6.577221,8.424037,4.776816,11.782222,-2.206932,-3.447845,17.317284,15.465552,7.507261,4.749701,-4.053521,20.096318,56.423504,-16.052469,0.542124,3.936092,-14.449495,2.010495,9.100914,5.373109,-16.629689,4.217228,12.391046,10.807673,2.673265,6.832248,5.108662,5.212829,9.382513,9.761762,-3.64864,9.981714,22.28212,13.84102,-1.795933,-1.438725,-4.064985,22.191193,-9.739858,-5.565701,-9.834637,8.252451,-1.853507,20.793088,3.175448,2
0,"""Stimulus/1""",9.444,21.293703,18.676068,2.396635,5.01313,5.984928,13.828105,6.36971,-23.503992,20.354059,13.549383,12.007385,2.86629,-32.600562,6.407807,19.207374,8.063451,18.753907,-36.384114,-46.5438,0.755845,9.151167,1.696862,-2.183605,9.926201,2.112486,14.258738,2.697055,-30.135514,-6.710183,0.557059,-28.83586,66.307082,-10.671446,5.90164,6.894815,-11.230897,8.280238,14.231695,8.181461,-18.193031,10.290102,16.717396,9.37912,4.499651,9.509004,5.286035,-0.039267,11.469463,0.325132,2.030729,9.968108,16.758211,-15.392703,2.525013,3.591888,-23.756455,-26.819381,-14.263371,-7.832686,1.443246,-6.503959,3.155976,-28.110124,-16.510437,3
0,"""Stimulus/1""",9.446,20.054007,18.219546,3.120409,0.285677,6.550377,17.138199,15.234075,-13.355245,18.774967,14.254335,17.074288,12.947393,-34.977934,5.612512,15.83888,9.753912,22.196893,-37.742914,-36.697359,-3.596704,5.289391,1.301041,4.294134,1.056083,-4.711626,13.018475,5.3802,-11.755816,-9.147353,-10.122114,-23.32428,64.653867,-5.254813,8.025885,10.051126,-3.915474,10.896383,16.240127,13.117371,-7.493371,10.223294,18.24204,15.672042,-0.109563,9.109484,6.398258,10.450412,8.586929,4.823086,-5.016421,7.561401,18.00294,-12.701589,-6.145933,-6.488521,-21.213362,-21.120105,-17.366758,9.093167,-11.152638,1.806184,-7.890925,-21.705784,-25.354465,4
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1381,"""Stimulus/A""",928.476,9.591813,-41.557977,-60.77412,23.760604,-41.33637,-66.987031,-11.20374,-90.704749,3.790776,-30.919723,-37.216313,-61.566037,-52.382623,-9.286577,-42.072383,-38.344749,-53.684958,-44.898705,-49.312139,-0.965291,-33.180112,-17.891398,-10.890542,-19.449326,-9.946805,-33.991378,-89.135265,-85.339557,-19.949215,-13.598605,-35.072145,379.205543,10.027611,-13.869941,-38.946625,-80.264638,-1.0685,-43.047086,-10.514064,-57.963668,31.12568,-46.874566,-32.352284,18.198025,-51.127463,-36.473882,-43.249089,-26.761444,-42.808292,-27.528111,-29.581585,-52.231929,-52.319126,-18.26109,-19.034752,-54.394882,-40.090876,1.734976,-61.119441,6.031929,-18.942257,-19.640234,-39.629149,-28.380572,275996
1381,"""Stimulus/A""",928.478,10.05211,-39.285851,-57.304729,24.489219,-41.405448,-65.223565,-7.581582,-85.183478,7.284242,-32.467265,-36.44456,-58.760747,-51.05454,-0.406084,-41.466138,-37.820387,-53.660458,-51.282983,-47.950155,-3.073491,-34.703956,-18.293676,-11.464762,-24.24428,-11.731681,-33.641051,-89.751296,-86.665381,-21.44362,-17.000163,-35.684019,385.31099,13.91835,-10.738223,-30.392338,-74.628701,2.910838,-41.721126,-7.360083,-49.079759,30.052943,-45.777716,-31.916799,19.00681,-50.908402,-34.346295,-36.57473,-26.347318,-41.537574,-29.77699,-30.339979,-52.22948,-53.980062,-21.867465,-21.754061,-57.051582,-40.913533,3.709452,-60.064794,4.821594,-19.524317,-20.815015,-40.568466,-29.987083,275997
1381,"""Stimulus/A""",928.48,12.065095,-40.534576,-57.783633,29.010773,-43.466559,-65.123134,-6.594105,-84.364808,11.379481,-32.393301,-35.587802,-55.748424,-54.481948,9.565403,-42.304131,-38.425967,-52.91369,-47.393221,-47.164644,-2.847837,-34.77862,-17.350552,-9.44111,-21.020738,-8.33714,-32.512395,-87.232313,-83.327347,-20.088512,-11.905468,-34.854007,402.957563,16.298075,-10.369621,-30.212791,-73.874753,5.971585,-39.426622,-7.338018,-49.622007,30.432534,-45.860797,-31.460015,18.195945,-51.894871,-34.708639,-30.659673,-27.710193,-41.63766,-27.134769,-30.320361,-51.321614,-52.401906,-19.546804,-20.653188,-55.703298,-39.288035,12.232324,-56.132915,10.525082,-15.37765,-18.393694,-40.166692,-28.50722,275998
1381,"""Stimulus/A""",928.482,14.72505,-38.596546,-53.248571,31.472895,-39.738697,-64.958224,-6.417419,-83.385095,11.147532,-30.443848,-35.688861,-55.982359,-46.464062,18.694285,-45.26478,-38.533711,-50.081167,-42.353148,-42.931054,-3.614683,-32.516321,-16.236222,-6.832051,-17.149998,-11.510808,-35.3496,-88.705435,-85.313381,-21.334178,-14.573576,-37.156088,394.311109,17.552649,-8.430964,-33.06964,-70.96307,3.752907,-39.039754,-5.491022,-53.773884,29.962905,-46.757481,-30.543767,14.668657,-51.72058,-34.076067,-32.300195,-29.405876,-39.132014,-25.195138,-29.335414,-51.595348,-49.798575,-17.942612,-18.328749,-55.256379,-41.800521,12.451331,-55.866528,9.126756,-17.300153,-20.074211,-42.179936,-31.332407,275999


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torch.optim.lr_scheduler import ReduceLROnPlateau
import polars as pl
import numpy as np
from scipy.interpolate import CubicSpline
from imblearn.over_sampling import SMOTE

torch.backends.cudnn.benchmark = True

#============================================================
# Model Architecture (Unchanged)
#============================================================
class EEGDSConv(nn.Module):
    def __init__(self, dropout=0.5):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv1d(64, 64, 15, padding='same', groups=64),
            nn.Conv1d(64, 16, 1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.MaxPool1d(4),
            nn.Dropout(dropout),
            nn.Conv1d(16, 16, 7, padding='same', groups=16),
            nn.Conv1d(16, 8, 1),
            nn.BatchNorm1d(8),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(8, 1)
        )
    
    def forward(self, x):
        x = x.permute(0, 2, 1)
        return self.block(x).squeeze(-1)  # Squeeze last dimension to match target shape


#============================================================
# Enhanced Dataset Class with Proper Encapsulation
#============================================================
class EEGDataset(Dataset):
    def __init__(self, source, max_length=2000):
        self.df = self._load_and_filter(source)
        self.event_ids = self.df['event_id'].unique().to_list()
        self.max_length = max_length
        # Compute class weights after filtering, based on the marker distribution.
        self._class_weights = self.compute_class_weights()
        # Use all columns except event_id and marker as features.
        self.feature_cols = [c for c in self.df.columns if c not in {'event_id', 'marker'}]
        self._precompute_samples()
    
    @property
    def class_weights(self):
        # Expose the computed weights as a property.
        return self._class_weights 

    def __len__(self):
        return len(self.event_ids)
    
    def _precompute_samples(self):
        """Cache all samples in memory during initialization"""
        self.samples = []
        for event_id in self.event_ids:
            event_data = self.df.filter(pl.col("event_id") == event_id)
            features = torch.tensor(event_data.select(self.feature_cols).to_numpy(), 
                                  dtype=torch.float32)
            features = self._pad_sequence(features)
            label = 1.0 if event_data['marker'][0] == "Stimulus/P" else 0.0
            self.samples.append( (torch.tensor(label, dtype=torch.float32), features) )

    def __getitem__(self, idx):
        return self.samples[idx]  # Direct access now!
    
    def compute_class_weights(self):
        """
        Compute inverse frequency weights based on the 'marker' column.
        Assumes markers are "Stimulus/A" and "Stimulus/P".
        """
        # Get unique combinations of event_id and marker.
        unique_events = self.df.select(["event_id", "marker"]).unique()
        
        # Use value_counts on the "marker" column.
        counts_df = unique_events["marker"].value_counts()

        # We'll use 'values' if it exists, otherwise 'marker'.
        d = { (row.get("values") or row.get("marker")): row["count"] 
            for row in counts_df.to_dicts() }
        
        weight_A = 1.0 / d.get("Stimulus/A", 1)
        weight_P = 1.0 / d.get("Stimulus/P", 1)
        return {"A": weight_A, "P": weight_P}
   
    def split_dataset(self, ratios=(0.7, 0.15, 0.15), seed=None):
        """
        Splits the dataset into three EEGDataset instances for train, val, and test.
        This method shuffles the event_ids and then partitions them based on the given ratios.
        """
        if seed is not None:
            np.random.seed(seed)
        
        # Copy and shuffle the event_ids
        event_ids = self.event_ids.copy()
        np.random.shuffle(event_ids)
        total = len(event_ids)
        
        n_train = int(ratios[0] * total)
        n_val   = int(ratios[1] * total)
        
        train_ids = event_ids[:n_train]
        val_ids   = event_ids[n_train:n_train+n_val]
        test_ids  = event_ids[n_train+n_val:]
        
        # Filter self.df for the selected event_ids
        train_df = self.df.filter(pl.col("event_id").is_in(train_ids))
        val_df   = self.df.filter(pl.col("event_id").is_in(val_ids))
        test_df  = self.df.filter(pl.col("event_id").is_in(test_ids))
        
        # Create new EEGDataset instances using the filtered data
        train_set = EEGDataset(train_df, self.max_length)
        val_set   = EEGDataset(val_df, self.max_length)
        test_set  = EEGDataset(test_df, self.max_length)
        
        return train_set, val_set, test_set
    
    def _load_and_filter(self, source):
        if isinstance(source, str):
            df = pl.read_parquet(source)
        elif isinstance(source, pl.DataFrame):
            df = source
        else:
            raise ValueError("Unsupported source type")
        
        # Keep only events with these markers and drop unneeded columns.
        df = df.filter(pl.col('marker').is_in(["Stimulus/A", "Stimulus/P"]))
        for col in ['time', '__null_dask_index__']:
            if col in df.columns:
                df = df.drop(col)
        return df

    def _pad_sequence(self, tensor):
        # Pre-allocate tensor for maximum efficiency
        padded = torch.zeros((self.max_length, tensor.size(1)), dtype=tensor.dtype)
        length = min(tensor.size(0), self.max_length)
        padded[:length] = tensor[:length]
        return padded
    
    def apply_smote(self):
        # Extract from precomputed PADDED samples (2000×64 guaranteed)
        X = np.stack([features.numpy().flatten() for _, features in self.samples])
        y = np.array([label.item() for label, _ in self.samples])
        
        # Apply SMOTE
        smote = SMOTE()
        X_res, y_res = smote.fit_resample(X, y)

        # Create synthetic events
        new_events = []
        new_event_id = self.df['event_id'].max() + 1
        feature_columns = self.feature_cols
        
        # Create dtype conversion map (Polars → NumPy)
        dtype_map = {
            pl.Float64: np.float64,
            pl.Float32: np.float32,
            pl.Int64: np.int64,
            pl.Int32: np.int32
        }

        # Process only new synthetic samples
        for features_flat, label in zip(X_res[len(self):], y_res[len(self):]):
            # Reshape to (2000, 64)
            features_2d = features_flat.reshape(self.max_length, len(feature_columns))
            
            # Create event DataFrame
            event_data = {
                "event_id": [new_event_id] * self.max_length,
                "marker": ["Stimulus/P" if label else "Stimulus/A"] * self.max_length
            }
            
            # Add features with original dtypes
            for col_idx, col in enumerate(feature_columns):
                pl_dtype = self.df.schema[col]
                # Handle list types if necessary
                if isinstance(pl_dtype, pl.List):
                    np_dtype = dtype_map.get(pl_dtype.inner, np.float64)
                else:
                    np_dtype = dtype_map.get(pl_dtype, np.float64)
                
                event_data[col] = features_2d[:, col_idx].astype(np_dtype)

            new_events.append(pl.DataFrame(event_data))
            new_event_id += 1

        # Update dataset
        self.df = pl.concat([self.df, *new_events])
        self.event_ids = self.df['event_id'].unique().to_list()
        self._precompute_samples()
        self._class_weights = self.compute_class_weights()
        
        return self

#============================================================
# Training Pipeline
#============================================================

# Configuration
config = {
    'data_path': '/home/owner/Documents/DEV/BrainLabyrinth/data/combined.parquet',
    'split_ratios': (0.7, 0.15, 0.15),
    'batch_size': 32,
    'dropout': 0.6,
    'lr': 1e-5,
    'weight_decay': 5e-5,
    'epochs': 200,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'log_dir': './runs/CNN'
}

def collate_fn(batch):
    """
    Collate function for variable-length EEG feature sequences.

    Each sample is expected to be a tuple (label, feature), where:
    - label is a scalar tensor (or 1D tensor) representing the class/target.
    - feature is a tensor of shape (seq_len, num_channels), where seq_len may vary.

    This function stacks labels and pads features along the time dimension so that
    all sequences in the batch have the same length.
    """
    # Unzip the batch into labels and features
    labels, features = zip(*batch)
    
    labels = torch.stack(labels)
    # Optionally: labels = labels.unsqueeze(1)  # Uncomment if required by your loss function
    padded_features = pad_sequence(features, batch_first=True)
    
    return labels, padded_features

# Initialize dataset
full_dataset = EEGDataset(config['data_path'])

# Split dataset
train_set, val_set, test_set = full_dataset.split_dataset(
    ratios=config['split_ratios']
)

len_dataset = len(train_set)
sample = train_set[0]
label_shape = sample[0].shape
feature_shape = sample[1].shape

print(f"unbalanced train dataset shape: ({len_dataset}, [labels: {label_shape}, features: {list(feature_shape)}])")

# Balance training set
train_set.apply_smote()

train_loader = DataLoader(
    train_set,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=8,        # Use class-conscious workers
    pin_memory=True,      # For GPU acceleration
    persistent_workers=True,
    collate_fn=collate_fn
)
val_loader = DataLoader(val_set, batch_size=config['batch_size'], collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=config['batch_size'], collate_fn=collate_fn)

len_dataset = len(train_set)
sample = train_set[0]
label_shape = sample[0].shape
feature_shape = sample[1].shape

print(f"train dataset shape: ({len_dataset}, [labels: {label_shape}, features: {list(feature_shape)}])")

def train_model(config):    
    # Model initialization
    model = EEGDSConv(dropout=config['dropout']).to(config['device'])
    
    # Loss function with automatic class weighting
    pos_weight = torch.tensor([
        train_set.class_weights['A'] / train_set.class_weights['P']
    ]).to(config['device'])
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['lr'],
        weight_decay=config['weight_decay']
    )
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=10,
        verbose=True
    )
    
    # Training loop
    writer = SummaryWriter(log_dir=config['log_dir'])
    best_metric = float('inf')
    
    for epoch in tqdm(range(config['epochs'])):
        # Training phase
        model.train()
        train_loss = 0
        for labels, features in train_loader:
            features = features.to(config['device'])
            labels = labels.to(config['device'])
            
            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Validation phase
        model.eval()
        val_loss = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for labels, features in val_loader:
                features = features.to(config['device'])
                labels = labels.to(config['device'])
                
                outputs = model(features)
                val_loss += criterion(outputs, labels).item()
                
                preds = torch.sigmoid(outputs)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        # Calculate metrics
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        predictions = (np.array(all_preds) > 0.5).astype(int)
        
        # Update scheduler
        scheduler.step(val_loss)
        
        # Log metrics
        writer.add_scalars('Loss', {'train': train_loss, 'val': val_loss}, epoch)
        metrics = {
            'accuracy': accuracy_score(all_labels, predictions),
            'precision': precision_score(all_labels, predictions),
            'recall': recall_score(all_labels, predictions),
            'f1': f1_score(all_labels, predictions)
        }
        writer.add_scalars('Metrics', metrics, epoch)
        
        # Save best model
        if val_loss < best_metric:
            best_metric = val_loss
            torch.save(model.state_dict(), f"{config['log_dir']}/best_model.pth")
    
    writer.close()
    return model

# Start training
trained_model = train_model(config)

unbalanced train dataset shape: (1926, [labels: torch.Size([]), features: [2000, 64]])
train dataset shape: (2606, [labels: torch.Size([]), features: [2000, 64]])




  0%|          | 0/200 [00:00<?, ?it/s]

In [4]:
import torch
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

# Assuming model, criterion, test_loader, device, writer, and epoch are already defined
best_model = torch.load('best_model.torch')
best_model.eval()
test_loss = 0
all_test_markers = []
all_test_predictions = []
with torch.no_grad():
    for markers, features in tqdm(test_loader):
        features = features.to(device)
        markers = markers.unsqueeze(-1).to(device)

        outputs = best_model(features)
        loss = criterion(outputs, markers)
        test_loss += loss.item()

        # Collect markers and predictions for metrics calculation
        all_test_markers.extend(markers.cpu().numpy().flatten())
        all_test_predictions.extend(torch.sigmoid(outputs).cpu().numpy().flatten())

test_loss /= len(test_loader)
# Calculate test metrics
test_accuracy = accuracy_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_precision = precision_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_recall = recall_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_f1 = f1_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_roc_auc = roc_auc_score(all_test_markers, all_test_predictions)

# Log test metrics to TensorBoard
writer.add_scalar('Metrics/test_accuracy', test_accuracy, epoch)
writer.add_scalar('Metrics/test_precision', test_precision, epoch)
writer.add_scalar('Metrics/test_recall', test_recall, epoch)
writer.add_scalar('Metrics/test_f1', test_f1, epoch)
writer.add_scalar('Metrics/test_roc_auc', test_roc_auc, epoch)

# Close the TensorBoard writer
writer.close()

  best_model = torch.load('best_model.torch')


AttributeError: 'collections.OrderedDict' object has no attribute 'eval'

In [None]:
print(f"""
{test_accuracy=}
{test_precision=}
{test_recall=}
{test_f1=}
{test_roc_auc=}
"""
)

In [None]:
from sklearn.metrics import f1_score
import numpy as np
best_threshold = 0.0
best_f1 = 0.0
thresholds = np.arange(0.0, 1.0, 0.01)

for threshold in tqdm(thresholds):
    binary_predictions = (all_test_predictions > threshold).astype(int)
    current_recall = f1_score(all_test_markers, binary_predictions)

    if current_recall > best_f1:
        best_f1 = current_recall
        best_threshold = threshold

print(f"{best_threshold=}")
print(f"{best_f1=}")

In [None]:
from sklearn.metrics import recall_score
import numpy as np
best_threshold = 0.1
best_recall = 0.0
thresholds = np.arange(0.1, 1.0, 0.01)

for threshold in tqdm(thresholds):
    binary_predictions = (all_test_predictions > threshold).astype(int)
    current_recall = recall_score(all_test_markers, binary_predictions)

    if current_recall > best_recall:
        best_recall = current_recall
        best_threshold = threshold

print(f"{best_threshold=}")
print(f"{best_recall=}")