In [None]:
import pandas as pd
import os
import re

In [None]:
!pip install fastparquet

In [None]:
NANS_IN_THE_MIDDLE_THRESHOLD = 15000 

In [None]:
def sort_num(el):
    num = re.search(r"version_02_(\w+)_(.+)\..+", el)
    part = num.group(2)
    if part == 'final':
        return float('inf')
    else:
        return int(part)
FOLDERS = ['/kaggle/input/nasa-cooked/init_df/', '/kaggle/input/nasa-cooked/init_df_not_in_koi/']
START_INDEX = 0
END_INDEX = -1

In [None]:
parquets = []
csvs = []
for f in FOLDERS:
    l = os.listdir(f)
    print(l)
    parquets.append([el for el in l if el.endswith('.parquet')])
    csvs.append([el for el in l if el.endswith('.csv')])
    
    parquets[-1].sort(key=sort_num)
    csvs[-1].sort(key=sort_num)

In [None]:
dfs_to_concat = []
for f, csv_list, parquet_list in zip(FOLDERS, csvs, parquets):
    print(f'using folder {f}')
    for pair in zip(csv_list[START_INDEX:END_INDEX], parquet_list[START_INDEX:END_INDEX]):
        print(f'concating pair {pair}')
        df_values_loaded = pd.read_parquet(f + pair[1], engine="fastparquet")
        df_values_loaded.columns = df_values_loaded.columns.astype(int)
        df_ids_loaded = pd.read_csv(f + pair[0])
        print(f'{df_values_loaded.shape=}, {df_ids_loaded.shape=}')
        if df_values_loaded.shape[0] > df_values_loaded.shape[1]: 
            print('concating with transpose')
            df_part = pd.concat([df_ids_loaded, df_values_loaded.T.reset_index(drop=True)], axis=1)
        else: 
            print('concating without transpose')
            df_part = pd.concat([df_ids_loaded, df_values_loaded], axis=1)
        print(f'{df_part.shape=}')
        df_part.set_index(['KEPID', 'PLANET_NUM'], inplace=True)
        dfs_to_concat.append(df_part)


In [None]:
del parquets, csvs

In [None]:
full_df = pd.concat(dfs_to_concat, axis=0)

In [None]:
del dfs_to_concat

In [None]:
full_df['LABEL'].value_counts()

In [None]:
mask = full_df['LABEL'] == 1
subset = full_df[mask]
full_df = full_df.drop(subset.index)

In [None]:
del mask, subset

In [None]:
last_valid_pos = full_df.apply(lambda row: row.last_valid_index(), axis=1).map(
    lambda idx: full_df.columns.get_loc(idx) if idx is not None else np.nan
)

In [None]:
nan_count = full_df.isna().sum(axis=1)
print(nan_count)

In [None]:
row_len = full_df.shape[1] 
diff = row_len - last_valid_pos
print(diff)

In [None]:
del last_valid_pos

In [None]:
mask_nans_in_the_middle = (nan_count - diff) > NANS_IN_THE_MIDDLE_THRESHOLD

In [None]:
del diff, nan_count

In [None]:
df_cleaned = full_df[~mask_nans_in_the_middle]
del full_df

In [None]:
del mask_nans_in_the_middle, row_len

In [None]:
import torch
import pandas as pd

def remove_outliers_torch(
    df: pd.DataFrame,
    sigma: float = 3.0,
    method: str = "median",
    fill_value=float("nan"),
    batch_size: int = 2048,
) -> pd.DataFrame:

    n_gpus = torch.cuda.device_count()
    if n_gpus == 0:
        raise RuntimeError("No CUDA device detected")

    device_list = [f"cuda:{i}" for i in range(n_gpus)]

    data_cpu = torch.tensor(df.to_numpy(dtype="float32"), device="cpu")
    n_rows = data_cpu.shape[0]
    result_cpu = torch.empty_like(data_cpu)

    splits = torch.chunk(data_cpu, n_gpus)
    results_cpu = []

    for i, chunk in enumerate(splits):
        device = device_list[i]
        out_gpu = _process_chunk_gpu(chunk, sigma, method, fill_value, batch_size, device)
        results_cpu.append(out_gpu.to("cpu"))  

    result = torch.cat(results_cpu, dim=0)
    return pd.DataFrame(result.numpy(), index=df.index, columns=df.columns)


@torch.no_grad()
def _process_chunk_gpu(
    data_chunk: torch.Tensor,
    sigma: float,
    method: str,
    fill_value: float,
    batch_size: int,
    device: str,
) -> torch.Tensor:
    data_chunk = data_chunk.to(device, non_blocking=True)
    n_rows = data_chunk.shape[0]
    result = torch.empty_like(data_chunk)

    for start in range(0, n_rows, batch_size):
        end = min(start + batch_size, n_rows)
        batch = data_chunk[start:end]

        nan_mask = torch.isnan(batch)

        if method == "median":
            center = torch.nanmedian(batch, dim=1, keepdim=True).values
        else:
            center = torch.nanmean(batch, dim=1, keepdim=True)

        diff = batch - center
        diff[nan_mask] = 0
        count = (~nan_mask).sum(dim=1, keepdim=True).clamp(min=1)
        var = (diff ** 2).sum(dim=1, keepdim=True) / count
        std = torch.sqrt(var)

        mask = (torch.abs(batch - center) > sigma * std) & ~nan_mask

        if torch.isnan(torch.tensor(fill_value)):
            batch = torch.where(mask, torch.tensor(float("nan"), device=device), batch)
        else:
            batch = torch.where(mask, torch.tensor(fill_value, device=device), batch)

        result[start:end] = batch

    return result


In [None]:
df_removed_outliers = remove_outliers_torch(df_cleaned)

In [None]:
df_removed_outliers["LABEL"] = df_cleaned["LABEL"]

In [None]:
df_removed_outliers.head()

In [None]:
del df_cleaned

In [None]:
import numpy as np
n = df_removed_outliers.shape[1] // 2
def fill_nan_fourier(row, n_components=n):
    x = np.arange(len(row))
    y = row.values.copy()
    mask = ~np.isnan(y)
    x_valid = x[mask]
    y_valid = y[mask]
    
    if len(y_valid) == 0:
        return row  
    
    y_interp = np.interp(x, x_valid, y_valid)
    fft_coeffs = np.fft.rfft(y_interp)
    if n_components:
        fft_coeffs[n_components:] = 0
    y_smooth = np.fft.irfft(fft_coeffs, n=len(y_interp))
    
    y_filled = row.copy()
    y_filled = row.astype(np.float32).copy()
    y_filled[np.isnan(row)] = y_smooth[np.isnan(row)].astype(np.float32)
    return y_filled


In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
df_filled = df_removed_outliers.apply(fill_nan_fourier, axis=1)

In [None]:
df_filled["LABEL"] = df_removed_outliers["LABEL"]

In [None]:
df_filled.head()

In [None]:
del df_removed_outliers

In [None]:
df_filled.loc[df_filled['LABEL'] == 3, 'LABEL'] = 2

In [None]:
df_filled.head()

In [None]:
df_filled["LABEL"].value_counts()

In [None]:
df_filled.isna().sum(axis=1)

In [None]:
from dataclasses import dataclass
from typing import Optional, Tuple, List
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

In [None]:

def nan_to_zero_and_mask(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    mask = (~np.isnan(X)).astype(np.float32)
    X_filled = np.nan_to_num(X, nan=0.0).astype(np.float32)
    return X_filled, mask

def robust_scale_rows(X: np.ndarray, mask: np.ndarray, eps=1e-6) -> np.ndarray:
    Xn = np.empty_like(X, dtype=np.float32)
    for i in range(X.shape[0]):
        v = mask[i] > 0
        if v.sum() == 0:
            Xn[i] = X[i]
            continue
        med = np.median(X[i, v])
        mad = np.median(np.abs(X[i, v] - med)) + eps
        Xn[i] = (X[i] - med) / mad
    return Xn

def uniform_downsample(x: np.ndarray, factor: int) -> np.ndarray:
    T = x.shape[-1]
    newT = T // factor
    return x[..., :newT*factor].reshape(*x.shape[:-1], newT, factor).mean(axis=-1)

def extract_local_window(x: np.ndarray, mask: np.ndarray, win_len: int) -> Tuple[np.ndarray, np.ndarray]:
    flux, m = x[0], mask[0]
    T = flux.shape[-1]
    if win_len >= T:
        return x[..., :win_len], mask[..., :win_len]
    step = max(1, win_len // 4)
    best_s, best_i = -1, 0
    for i in range(0, T - win_len + 1, step):
        seg = flux[i:i+win_len]
        mv = m[i:i+win_len] > 0
        if mv.sum() < max(8, win_len//16):  
            continue
        s = np.var(seg[mv])
        if s > best_s:
            best_s, best_i = s, i
    i = best_i
    return x[..., i:i+win_len], mask[..., i:i+win_len]

In [None]:
class BinaryFocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=0.25):
        super().__init__(); self.gamma=gamma; self.alpha=alpha; self.reduction="mean"
        
    def forward(self, logits, targets):
        p = torch.sigmoid(logits).clamp(1e-6, 1-1e-6)
        ce = nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction="none")
        pt = p*targets + (1-p)*(1-targets)
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        loss = alpha_t * (1 - pt).pow(self.gamma) * ce

        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        return loss

In [None]:
class DualViewDataset(Dataset):
    def __init__(
        self,
        df_X: pd.DataFrame,
        y: np.ndarray,
        downsample_factor: int = 4,   
        local_len: int = 1024
    ):
        X = df_X.to_numpy(dtype=np.float32, na_value=np.nan)
        X, mask = nan_to_zero_and_mask(X)
        X = robust_scale_rows(X, mask)  
        self.labels = y.astype(np.float32)
        self.global_views: List[np.ndarray] = []
        self.local_views:  List[np.ndarray] = []
        for i in range(X.shape[0]):
            flux = X[i][None, :]          
            msk  = mask[i][None, :]      
            g = np.concatenate([flux, msk], axis=0)  
            g_ds = uniform_downsample(g, factor=downsample_factor) 
            l_win, l_m = extract_local_window(g, np.concatenate([msk, msk], axis=0), local_len)
            self.global_views.append(g_ds.astype(np.float32))
            self.local_views.append(l_win.astype(np.float32))

    def __len__(self): return len(self.labels)

    def __getitem__(self, i):
        gv = torch.from_numpy(self.global_views[i])  
        lv = torch.from_numpy(self.local_views[i])   
        y  = torch.tensor(self.labels[i])
        return gv, lv, y

In [None]:

class MaskedGAP(nn.Module):
    def forward(self, feats: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        masked = feats * mask
        denom = mask.sum(dim=-1).clamp_min(1.0)  
        return masked.sum(dim=-1) / denom       

In [None]:
class TCNBlock(nn.Module):
    def __init__(self, ch, ks=9, dil=1, p=0.2):
        super().__init__()
        pad = (ks-1)//2 * dil
        self.net = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=ks, dilation=dil, padding=pad),
            nn.BatchNorm1d(ch),
            nn.ReLU(),
            nn.Dropout(p),
            nn.Conv1d(ch, ch, kernel_size=ks, dilation=dil, padding=pad),
            nn.BatchNorm1d(ch),
        )
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.net(x) + x)


class SE1D(nn.Module):
    def __init__(self, ch, r=8):
        super().__init__()
        self.fc = nn.Sequential(nn.Linear(ch, ch//r), nn.ReLU(), nn.Linear(ch//r, ch), nn.Sigmoid())
    def forward(self, x):         
        w = x.mean(dim=-1)           
        s = self.fc(w).unsqueeze(-1) 
        return x * s


In [None]:
class GlobalTCN(nn.Module):
    def __init__(self, in_ch=2, width=128, ks=9, dilations=(1,2,4,8,16,32), p=0.2):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv1d(in_ch, width, kernel_size=5, padding=2, stride=1),
            nn.ReLU()
        )
        self.blocks = nn.Sequential(*[TCNBlock(width, ks=ks, dil=d, p=p) for d in dilations])
        self.gap = MaskedGAP()

    def forward(self, x):
        flux, mask = x[:, :1, :], x[:, 1:2, :]
        h = torch.cat([flux, mask], dim=1)
        h = self.stem(h)
        h = self.blocks(h)
        pooled = self.gap(h, mask)  
        return pooled

In [None]:
class GlobalTCN(nn.Module):
    def __init__(self, in_ch=2, width=128, ks=9, dilations=(1,2,4,8,16,32), p=0.2):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv1d(in_ch, width, kernel_size=5, padding=2, stride=1),
            nn.ReLU()
        )
        self.blocks = nn.Sequential(*[TCNBlock(width, ks=ks, dil=d, p=p) for d in dilations])
        self.se = SE1D(width)
        self.gap = MaskedGAP()

    def forward(self, x):
        flux, mask = x[:, :1, :], x[:, 1:2, :]
        h = torch.cat([flux, mask], dim=1)
        h = self.stem(h)
        h = self.blocks(h)
        h = self.se(h)
        pooled = self.gap(h, mask)  
        return pooled

In [None]:
class InceptionModule1D(nn.Module):
    def __init__(self, in_ch, out_ch, ks_list=(3,5,9,15), bottleneck=32):
        super().__init__()
        self.reduce = nn.Conv1d(in_ch, bottleneck, kernel_size=1)
        self.branches = nn.ModuleList([
            nn.Conv1d(bottleneck, out_ch//len(ks_list), kernel_size=ks, padding=ks//2)
            for ks in ks_list
        ])
        self.bn = nn.BatchNorm1d(out_ch)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.reduce(x)
        outs = [b(x) for b in self.branches]
        y = torch.cat(outs, dim=1)
        return self.act(self.bn(y))

class LocalInception(nn.Module):
    def __init__(self, in_ch=2, width=128):
        super().__init__()
        self.stem = nn.Sequential(nn.Conv1d(in_ch, 64, kernel_size=5, padding=2), nn.ReLU())
        self.inc1 = InceptionModule1D(64, width)
        self.inc2 = InceptionModule1D(width, width)
        self.gap  = MaskedGAP()

    def forward(self, x):
        flux, mask = x[:, :1, :], x[:, 1:2, :]
        h = torch.cat([flux, mask], dim=1)
        h = self.stem(h)
        h = self.inc1(h)
        h = self.inc2(h)
        pooled = self.gap(h, mask)
        return pooled

In [None]:
class DualViewNet(nn.Module):
    def __init__(self, gv_width=128, lv_width=128, use_meta=False, meta_dim=0):
        super().__init__()
        self.global_branch = GlobalTCN(in_ch=2, width=gv_width)
        self.local_branch  = LocalInception(in_ch=2, width=lv_width)
        in_dim = gv_width + lv_width
        if use_meta and meta_dim > 0:
            self.meta = nn.Sequential(
                nn.Linear(meta_dim, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.2)
            )
            in_dim += 64
        else:
            self.meta = None
        self.head = nn.Sequential(
            nn.BatchNorm1d(in_dim),
            nn.Dropout(0.3),
            nn.Linear(in_dim, 128), nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )

    def forward(self, gv, lv, meta: Optional[torch.Tensor]=None):
        g = self.global_branch(gv)  
        l = self.local_branch(lv)   
        feats = torch.cat([g, l], dim=1)
        if self.meta is not None and meta is not None:
            feats = torch.cat([feats, self.meta(meta)], dim=1)
        logit = self.head(feats).squeeze(1)
        return logit

In [None]:
from dataclasses import dataclass
from typing import Optional, Tuple, List
import numpy as np
import torch, torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import (
    average_precision_score, roc_auc_score,
    precision_score, recall_score, f1_score, confusion_matrix
)
import wandb
from typing import Literal

In [None]:
@dataclass
class TrainCfg:
    downsample_factor: int = 4        
    local_len: int = 1024
    batch_size: int = 16
    epochs: int = 80
    lr: float = 1e-3
    weight_decay: float = 1e-4
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    amp: bool = False                 
    monitor: str = "val/pr_auc"     
    patience: int = 20                
    ckpt_path: str = "/kaggle/working/best.ckpt"
    project: str = "exoplanets"
    entity: str = "nasa-public_static_void_frogs"
    run_name: Optional[str] = "without-50-candidates-realy-threshold-dynamic-planet=1"
    scheduler: str = "plateau"       
    onecycle_max_lr: float = 2e-3
    threshold: float = 0.65
    loss_fn: Literal["bce","bce_pos_weight","focal"] = "focal"
    auto_pos_weight: bool = True
    pos_weight: float | None = None
    beta=1.5 
    focal_gamma=1.5 
    focal_alpha=0.30

In [None]:

def compute_metrics(y_true: np.ndarray, y_prob: np.ndarray, threshold: float = 0.65):
    y_pred = (y_prob >= threshold).astype(np.int32)
    pr_auc = average_precision_score(y_true, y_prob)
    roc = roc_auc_score(y_true, y_prob)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec  = recall_score(y_true, y_pred, zero_division=0)
    f1   = f1_score(y_true, y_pred, zero_division=0)
    cm   = confusion_matrix(y_true, y_pred, labels=[0,1])
    return {"pr_auc": pr_auc, "roc_auc": roc, "precision": prec, "recall": rec, "f1": f1, "cm": cm}


In [None]:

class Trainer:
    def __init__(self, model: nn.Module, cfg: TrainCfg):
        self.model = model
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        self.model.to(self.device)

        self.opt = torch.optim.AdamW(self.model.parameters(), lr=cfg.lr
                                
                                    )
        self.scaler = torch.amp.GradScaler(enabled=cfg.amp) 

        if cfg.scheduler == "plateau":
            self.sched = torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt, mode="max",
                                                                    factor=0.5, patience=2, verbose=True)
        elif cfg.scheduler == "onecycle":
            self.sched = None
        else:
            self.sched = None
        self.run = wandb.init(project=cfg.project, entity=cfg.entity, name=cfg.run_name, config=vars(cfg))
        wandb.watch(self.model, log="all", log_freq=200)

        self.best_score = -float("inf")
        self.epochs_no_improve = 0
        self.best_threshold = None  


    def _forward_loss(self, batch, criterion):
        gv, lv, y = batch
        gv = gv.to(self.device, non_blocking=True)
        lv = lv.to(self.device, non_blocking=True)
        y  = y.to(self.device, non_blocking=True)
        
        with torch.autocast(device_type=("cuda" if self.device.type=="cuda" else "cpu"),
                            dtype=(torch.float16 if self.device.type=="cuda" else torch.bfloat16),
                            enabled=self.cfg.amp):
            logits = self.model(gv, lv)
            loss = criterion(logits, y)

        return logits, y, loss


    def train_one_epoch(self, loader, criterion, epoch: int):
        self.model.train()
        total = 0; loss_sum = 0.0

        if self.cfg.scheduler == "onecycle" and self.sched is None:
            steps_per_epoch = len(loader)
            self.sched = torch.optim.lr_scheduler.OneCycleLR(
                self.opt, max_lr=self.cfg.onecycle_max_lr,
                epochs=self.cfg.epochs, steps_per_epoch=steps_per_epoch
            )

        for step, batch in enumerate(loader, 1):
            self.opt.zero_grad(set_to_none=True)
            logits, y, loss = self._forward_loss(batch, criterion)

            self.scaler.scale(loss).backward()
            self.scaler.step(self.opt)
            self.scaler.update()
            if isinstance(self.sched, torch.optim.lr_scheduler.OneCycleLR):
                self.sched.step()

            bs = y.size(0)
            total += bs
            loss_sum += loss.item() * bs

            if step % 50 == 0:
                wandb.log({"train/loss_step": loss.item(),
                           "train/lr": self.opt.param_groups[0]["lr"],
                           "epoch": epoch})

        avg = loss_sum / max(1, total)
        wandb.log({"train/loss_epoch": avg, "epoch": epoch})
        return avg

    @torch.no_grad()
    def validate(self, loader, criterion, split="val", epoch: int = 0):
        self.model.eval()
        loss_sum = 0.0; total = 0
        probs_all, y_all = [], []

        for batch in loader:
            logits, y, loss = self._forward_loss(batch, criterion)
            prob = torch.sigmoid(logits).float().detach().cpu().numpy().reshape(-1)
            yy   = y.detach().cpu().numpy().astype(np.int32).reshape(-1)

            probs_all.append(prob); y_all.append(yy)
            loss_sum += loss.item() * yy.shape[0]
            total += yy.shape[0]

        y_true = np.concatenate(y_all)
        y_prob = np.concatenate(probs_all)

        from sklearn.metrics import precision_recall_curve
        prec, rec, thr = precision_recall_curve(y_true, y_prob)
        f_beta = (1 + self.cfg.beta**2) * prec[:-1] * rec[:-1] / (self.cfg.beta**2 * prec[:-1] + rec[:-1] + 1e-9)
        best_idx = int(np.nanargmax(f_beta))
        best_thr = float(thr[best_idx]) if thr.size else 0.5
        if split == "val":
            self.best_threshold = best_thr 
        wandb.log({f"{split}/threshold_bestF1": best_thr, "epoch": epoch})

        y_probas_2col = np.stack([1.0 - y_prob, y_prob], axis=1)
        cm_plot = wandb.plot.confusion_matrix(
            probs=y_probas_2col,
            y_true=y_true.tolist(),
            class_names=["neg","pos"],
            title=f"{split} Confusion Matrix"
        )
        wandb.log({f"{split}/confusion_matrix": cm_plot})

        pr_plot = wandb.plot.pr_curve(y_true=y_true.tolist(),
                                      y_probas=y_probas_2col.tolist(),
                                      labels=["neg","pos"])
        roc_plot = wandb.plot.roc_curve(y_true=y_true.tolist(),
                                        y_probas=y_probas_2col.tolist(),
                                        labels=["neg","pos"])
        wandb.log({f"{split}/pr_curve": pr_plot, f"{split}/roc_curve": roc_plot})

        
        metrics = compute_metrics(y_true, y_prob, threshold=best_thr)
        avg_loss = loss_sum / max(1, total)

        wandb.log({
            f"{split}/loss": avg_loss,
            f"{split}/pr_auc": metrics["pr_auc"],
            f"{split}/roc_auc": metrics["roc_auc"],
            f"{split}/precision": metrics["precision"],
            f"{split}/recall": metrics["recall"],
            f"{split}/f1": metrics["f1"],
            "epoch": epoch
        })

        if isinstance(self.sched, torch.optim.lr_scheduler.ReduceLROnPlateau):
            self.sched.step(metrics["pr_auc"]) 

        return avg_loss, metrics

    def _is_improved(self, metrics: dict) -> bool:
        key = self.cfg.monitor.split("/", 1)[-1] 
        return metrics.get(key, -1.0) > self.best_score


    def fit(self, dl_tr: DataLoader, dl_va: DataLoader):
        if self.cfg.loss_fn == "bce":
            criterion = nn.BCEWithLogitsLoss()
        elif self.cfg.loss_fn == "bce_pos_weight":
            if getattr(self.cfg, "auto_pos_weight", True):
                pos = 0; neg = 0
                for _, _, y in dl_tr:
                    y = y.numpy()
                    pos += (y == 1).sum()
                    neg += (y == 0).sum()
                pw = neg / max(1, pos)
            else:
                pw = self.cfg.pos_weight 
            criterion = nn.BCEWithLogitsLoss(
                pos_weight=torch.tensor([pw], device=self.device)
            )
        elif self.cfg.loss_fn == "focal":
            criterion = BinaryFocalLoss(gamma=self.cfg.focal_gamma, alpha=self.cfg.focal_alpha)
        else:
            raise ValueError("cfg.loss_fn must be one of: bce | bce_pos_weight | focal")
        

        for ep in range(1, self.cfg.epochs + 1):
            train_loss = self.train_one_epoch(dl_tr, criterion, ep)
            val_loss, val_metrics = self.validate(dl_va, criterion, split="val", epoch=ep)

            current_score = val_metrics[self.cfg.monitor.split("/",1)[-1]]
            improved = self._is_improved(val_metrics)
            if improved:
                self.best_score = current_score
                self.epochs_no_improve = 0
                torch.save(self.model.state_dict(), self.cfg.ckpt_path)
                wandb.run.summary["best_score"] = float(self.best_score)
                wandb.run.summary["best_ckpt"] = self.cfg.ckpt_path
            else:
                self.epochs_no_improve += 1

            print(f"Epoch {ep:02d} | train_loss={train_loss:.4f} | "
                  f"val_loss={val_loss:.4f} | PR-AUC={val_metrics['pr_auc']:.4f} | "
                  f"ROC-AUC={val_metrics['roc_auc']:.4f}")
                  
            if self.epochs_no_improve >= self.cfg.patience:
                print(f"Early stopping at epoch {ep} (no improve {self.cfg.patience} epochs).")
                break


    @torch.no_grad()
    def test(self, dl_te: DataLoader):
        criterion = nn.BCEWithLogitsLoss()
        _, te_metrics = self.validate(dl_te, criterion, split="test", epoch=0)
        if self.best_threshold is not None:
            wandb.run.summary["best_val_threshold"] = float(self.best_threshold)
        print("TEST:", te_metrics)
        wandb.finish()
        return te_metrics


In [None]:
from sklearn.model_selection import train_test_split

df_X = df_filled.drop(columns=["LABEL"])
y = df_filled["LABEL"].to_numpy()

y = (y == 0).astype(np.float32)

X_trainval, X_test, y_trainval, y_test = train_test_split(
    df_X, y, test_size=0.15, random_state=42, stratify=y
)

X_train, X_val, y_train, y_val = train_test_split(
    X_trainval, y_trainval, test_size=0.15, random_state=42, stratify=y_trainval
)

print(f"Shapes: train={X_train.shape}, val={X_val.shape}, test={X_test.shape}")

cfg = TrainCfg(epochs=80, scheduler="plateau", patience=20, monitor="val/pr_auc")

ds_tr = DualViewDataset(X_train, y_train,
                        downsample_factor=cfg.downsample_factor,
                        local_len=cfg.local_len)
ds_va = DualViewDataset(X_val, y_val,
                        downsample_factor=cfg.downsample_factor,
                        local_len=cfg.local_len)
ds_te = DualViewDataset(X_test, y_test,
                        downsample_factor=cfg.downsample_factor,
                        local_len=cfg.local_len)


In [None]:
dl_tr = DataLoader(ds_tr, batch_size=cfg.batch_size, shuffle=True,
                   num_workers=2, pin_memory=True)
dl_va = DataLoader(ds_va, batch_size=cfg.batch_size, shuffle=False,
                   num_workers=2, pin_memory=True)
dl_te = DataLoader(ds_te, batch_size=cfg.batch_size, shuffle=False,
                   num_workers=2, pin_memory=True)

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("WANDB_API_KEY")
wandb.login(key=secret_value_0)


model = DualViewNet(gv_width=128, lv_width=128)
trainer = Trainer(model, cfg)
trainer.fit(dl_tr, dl_va)    
trainer.test(dl_te)          

In [None]:
 torch.save(model.state_dict(), 'DualViewNet.pth')