In [1]:
import albumentations as A
import gc
import matplotlib.pyplot as plt
import math
import multiprocessing
import numpy as np
import os
import pandas as pd
import random
import time
import timm
import torch
import torch.nn as nn
from pathlib import Path


from albumentations.pytorch import ToTensorV2
from glob import glob
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from typing import Dict, List

from sklearn.model_selection import KFold, GroupKFold
from skimage.transform import resize
from torch.optim.lr_scheduler import OneCycleLR
import torch.nn.functional as F
import logging
import functools
import pywt

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Using', torch.cuda.device_count(), 'GPU(s)')



class config:
    AMP = True
    BATCH_SIZE_TRAIN = 8
    BATCH_SIZE_VALID = 8
    EPOCHS = 4
    FOLDS = 5
    FREEZE = False
    GRADIENT_ACCUMULATION_STEPS = 1
    MAX_GRAD_NORM = 1e7
    MODEL = "efficientnet_b4"
    NUM_FROZEN_LAYERS = 39
    NUM_WORKERS = 0 # multiprocessing.cpu_count()
    PRINT_FREQ = 20
    SEED = 20
    TRAIN_FULL_DATA = False
    VISUALIZE = False
    WEIGHT_DECAY = 0.01
    
    
class paths:
    OUTPUT_DIR = Path("/kaggle/working/")
#     PRE_LOADED_EEGS = '/kaggle/input/brain-eeg-spectrograms/eeg_specs.npy'
#     PRE_LOADED_SPECTROGRAMS = '/kaggle/input/brain-spectrograms/specs.npy'
    TRAIN_CSV = "/kaggle/input/hms-harmful-brain-activity-classification/train.csv"
    TRAIN_EEGS = "/kaggle/input/hms-harmful-brain-activity-classification/train_eegs/"
    TRAIN_SPECTROGRAMS = "/kaggle/input/hms-harmful-brain-activity-classification/train_spectrograms/"
    ROOT = Path.cwd()
    INPUT = ROOT / "input"
    # OUTPUT_DIR = ROOT / "output"
    DATA = Path("./original_data")
    
    PRE_LOADED_EEGS = '/kaggle/working/brain-eeg-spectrograms/eeg_specs.npy'
    PRE_LOADED_SPECTROGRAMS = '/kaggle/brain-spectrograms/working/specs.npy'
    PRE_LOADED_Wavelets = '/kaggle/working/brain-wavelets/specs.npy'
    
    # PRE_LOADED_EEGS = './kaggle/input/brain-eeg-spectrograms/eeg_specs.npy'
    # PRE_LOADED_SPECTROGRAMS = './kaggle/input/brain-spectrograms/specs.npy'
#     PRE_LOADED_Wavelets = './kaggle/input/brain-wavelets/specs.npy'
    
    # TRAIN_SPECTROGRAMS = DATA / "train_spectrograms"
    # TRAIN_EEGS = DATA / "train_eegs"
    # TRAIN_CSV = DATA / "train.csv"

log_filename = paths.ROOT/'new_version_training_record.log'

logging.basicConfig(filename=log_filename, level=logging.INFO,
                    format='%(asctime)s %(levelname)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
def log_time(func):
    """warpper for logging running time"""

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        logging.info(f"{func.__name__} took {end_time - start_time:.4f} seconds.")
        print(f"{func.__name__} took {end_time - start_time:.4f} seconds.")
        return result

    return wrapper


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s: float):
    "Convert to minutes."
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since: float, percent: float):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))



def plot_spectrogram(spectrogram_path: str):
    """
    Source: https://www.kaggle.com/code/mvvppp/hms-eda-and-domain-journey
    Visualize spectrogram recordings from a parquet file.
    :param spectrogram_path: path to the spectrogram parquet.
    """
    sample_spect = pd.read_parquet(spectrogram_path)
    
    split_spect = {
        "LL": sample_spect.filter(regex='^LL', axis=1),
        "RL": sample_spect.filter(regex='^RL', axis=1),
        "RP": sample_spect.filter(regex='^RP', axis=1),
        "LP": sample_spect.filter(regex='^LP', axis=1),
    }
    
    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 12))
    axes = axes.flatten()
    label_interval = 5
    for i, split_name in enumerate(split_spect.keys()):
        ax = axes[i]
        img = ax.imshow(np.log(split_spect[split_name]).T, cmap='viridis', aspect='auto', origin='lower')
        cbar = fig.colorbar(img, ax=ax)
        cbar.set_label('Log(Value)')
        ax.set_title(split_name)
        ax.set_ylabel("Frequency (Hz)")
        ax.set_xlabel("Time")

        ax.set_yticks(np.arange(len(split_spect[split_name].columns)))
        ax.set_yticklabels([column_name[3:] for column_name in split_spect[split_name].columns])
        frequencies = [column_name[3:] for column_name in split_spect[split_name].columns]
        ax.set_yticks(np.arange(0, len(split_spect[split_name].columns), label_interval))
        ax.set_yticklabels(frequencies[::label_interval])
    plt.tight_layout()
    plt.show()
    
@log_time   
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed) 

   
def sep():
    print("-"*100)
    

class CustomDataset(Dataset):
    def __init__(
        self, df: pd.DataFrame, config,
        augment: bool = False, mode: str = 'train',
        specs: Dict[int, np.ndarray] = None,
        eeg_specs: Dict[int, np.ndarray] = None,
        wavelets_spectrograms: Dict[int, np.ndarray] = None,
    ): 
        self.df = df
        self.config = config
        self.batch_size = self.config.BATCH_SIZE_TRAIN
        self.augment = augment
        self.mode = mode
        self.spectrograms = specs if specs is not None else {}
        self.eeg_spectrograms = eeg_specs if eeg_specs is not None else {}
        self.wavelets_spectrograms = wavelets_spectrograms if wavelets_spectrograms is not None else {}
        
    def __len__(self):
        """
        Denotes the number of batches per epoch.
        """
        return len(self.df)
        
    def __getitem__(self, index):
        """
        Generate one batch of data.
        """
        X, y = self.__data_generation(index)
        if self.augment:
            X = self.__transform(X) 
        return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

    def log_and_Standarize(self,img):
        # Log transform spectrogram
            img = np.clip(img, np.exp(-4), np.exp(8))
            img = np.log(img)

            # Standarize per image
            ep = 1e-6
            mu = np.nanmean(img.flatten())
            std = np.nanstd(img.flatten())
            img = (img - mu) / (std + ep)
            img = np.nan_to_num(img, nan=0.0)
            return img

    def __data_generation(self, index):
        """
        Generates data containing batch_size samples.
        """
        X = np.zeros((128, 256, 12), dtype='float32')
        y = np.zeros(6, dtype='float32')
        img = np.ones((128,256), dtype='float32')
        row = self.df.iloc[index]
        if self.mode=='test': 
            r = 0
        else: 
            r = int((row['min'] + row['max']) // 4)
            
        for region in range(4):
            img = self.spectrograms[row.spectrogram_id][r:r+300, region*100:(region+1)*100].T
            img = self.log_and_Standarize(img)
            X[14:-14, :, region] = img[:, 22:-22] / 2.0
            
        img = self.eeg_spectrograms[row.eeg_id]
        img = img.to_numpy()
        img = self.log_and_Standarize(img)
        img = resize(img, (128, 256, 4))
        X[:, :, 4:8] = img

        # Combine wavelet features
        img = self.wavelets_spectrograms[row.spectrogram_id]
        img = self.log_and_Standarize(img)
        img = resize(img, (128, 256,4))
        X[:, :, 8:12] = img


        if self.mode != 'test':
            y = row[label_cols].values.astype(np.float32)
    
        return X, y
    
    def __transform(self, img):
        transforms = A.Compose([
            A.HorizontalFlip(p=0.5),
        ])
        return transforms(image=img)['image']


class CustomModel(nn.Module):
    def __init__(self, config, num_classes: int = 6, pretrained: bool = True):
        super(CustomModel, self).__init__()
        self.USE_KAGGLE_SPECTROGRAMS = True
        self.USE_EEG_SPECTROGRAMS = True
        self.USE_WAVELET_SPECTROGRAMS = True
        self.model = timm.create_model(
            config.MODEL,
            pretrained=pretrained,
            drop_rate = 0.1,
            drop_path_rate = 0.2,
        )
        if config.FREEZE:
            for i,(name, param) in enumerate(list(self.model.named_parameters())\
                                             [0:config.NUM_FROZEN_LAYERS]):
                param.requires_grad = False

        self.features = nn.Sequential(*list(self.model.children())[:-2])
        self.custom_layers = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(self.model.num_features, num_classes)
        )

    def __reshape_input(self, x):
        """
        Reshapes input torch.Size([8, 128, 256, 12]) -> [8, 3, 512, 768] monotone image.
        """ 
        components = []
        if self.USE_KAGGLE_SPECTROGRAMS:
            spectrograms = [x[:, :, :, i:i+1] for i in range(4)]
            components.append(torch.cat(spectrograms, dim=1))
        if self.USE_EEG_SPECTROGRAMS:
            eegs = [x[:, :, :, i:i+1] for i in range(4,8)]
            eegs = torch.cat(eegs, dim=1)
            components.append(eegs)

        if self.USE_WAVELET_SPECTROGRAMS:
            wavelets = [x[:, :, :, i:i+1] for i in range(8,12)]
            wavelets = torch.cat(wavelets, dim=1)
            components.append(wavelets)

        if components:
            x = torch.cat(components, dim=2)

        x = torch.cat([x, x, x], dim=3)  
        x = x.permute(0, 3, 1, 2)
        return x

    def forward(self, x):
        x = self.__reshape_input(x)
        x = self.features(x)
        x = self.custom_layers(x)
        return x


@log_time
def train_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    """One epoch training pass."""
    model.train() 
    criterion = nn.KLDivLoss(reduction="batchmean")
    scaler = torch.cuda.amp.GradScaler(enabled=config.AMP)
    losses = AverageMeter()
    start = end = time.time()
    global_step = 0
    
    # ========== ITERATE OVER TRAIN BATCHES ============
    with tqdm(train_loader, unit="train_batch", desc='Train') as tqdm_train_loader:
        for step, (X, y) in enumerate(tqdm_train_loader):
            X = X.to(device)
            y = y.to(device)
            batch_size = y.size(0)
            with torch.cuda.amp.autocast(enabled=config.AMP):
                y_preds = model(X) 
                loss = criterion(F.log_softmax(y_preds, dim=1), y)
            if config.GRADIENT_ACCUMULATION_STEPS > 1:
                loss = loss / config.GRADIENT_ACCUMULATION_STEPS
            losses.update(loss.item(), batch_size)
            scaler.scale(loss).backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.MAX_GRAD_NORM)

            if (step + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                global_step += 1
                scheduler.step()
            end = time.time()

            # ========== LOG INFO ==========
            if step % config.PRINT_FREQ == 0 or step == (len(train_loader)-1):
                print('Epoch: [{0}][{1}/{2}] '
                      'Elapsed {remain:s} '
                      'Loss: {loss.avg:.4f} '
                      'Grad: {grad_norm:.4f}  '
                      'LR: {lr:.8f}  '
                      .format(epoch+1, step, len(train_loader), 
                              remain=timeSince(start, float(step+1)/len(train_loader)),
                              loss=losses,
                              grad_norm=grad_norm,
                              lr=scheduler.get_last_lr()[0]))

    return losses.avg

@log_time
def valid_epoch(valid_loader, model, criterion, device):
    model.eval()
    softmax = nn.Softmax(dim=1)
    losses = AverageMeter()
    prediction_dict = {}
    preds = []
    start = end = time.time()
    with tqdm(valid_loader, unit="valid_batch", desc='Validation') as tqdm_valid_loader:
        for step, (X, y) in enumerate(tqdm_valid_loader):
            X = X.to(device)
            y = y.to(device)
            batch_size = y.size(0)
            with torch.no_grad():
                y_preds = model(X)
                loss = criterion(F.log_softmax(y_preds, dim=1), y)
            if config.GRADIENT_ACCUMULATION_STEPS > 1:
                loss = loss / config.GRADIENT_ACCUMULATION_STEPS
            losses.update(loss.item(), batch_size)
            y_preds = softmax(y_preds)
            preds.append(y_preds.to('cpu').numpy())
            end = time.time()

            # ========== LOG INFO ==========
            if step % config.PRINT_FREQ == 0 or step == (len(valid_loader)-1):
                print('EVAL: [{0}/{1}] '
                      'Elapsed {remain:s} '
                      'Loss: {loss.avg:.4f} '
                      .format(step, len(valid_loader),
                              remain=timeSince(start, float(step+1)/len(valid_loader)),
                              loss=losses))
                
#     prediction_dict["predictions"] = np.concatenate(preds)
        # 修改的部分来处理空的preds列表
    if preds:  # 如果preds列表不为空
        prediction_dict["predictions"] = np.concatenate(preds)
    else:  # 如果preds列表为空
        prediction_dict["predictions"] = np.array([])  # 返回空数组或其他默认值

    return losses.avg, prediction_dict
    return losses.avg, prediction_dict

@log_time
def train_loop(df, fold,stage =1):
    
    logging.info(f"========== Fold: {fold} training ==========")
    
    paths.OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    
    # ======== SPLIT ==========
    train_folds = df[df['fold'] != fold].reset_index(drop=True)
    valid_folds = df[df['fold'] == fold].reset_index(drop=True)
    
    label_cols = ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']
    
    
    # ======== CALCULATE KL DIV ==========
    if stage == 2:
        # Normalize label distributions
        y_train = train_folds[label_cols].values
        y_train_normalized = y_train / y_train.sum(axis=1, keepdims=True)
        
        # Add small value to avoid log(0)
        labels = torch.tensor(y_train_normalized, dtype=torch.float) + 1e-5
        
        # Compute KL Loss with uniform distribution
        kl_loss = F.kl_div(torch.log(labels), torch.tensor([1/6]*6, dtype=torch.float), reduction='none').sum(dim=1)
        
        # Filter based on KL Loss
        train_folds = train_folds[kl_loss.numpy() < 5.5].reset_index(drop=True)
        logging.info(f"Filtered training data to {len(train_folds)} samples based on KL Loss < 5.5.")

        
        
    # ======== DATASETS ==========
    train_dataset = CustomDataset(train_folds,config, mode="train", augment=True, specs=all_spectrograms, eeg_specs=all_eegs,wavelets_spectrograms = all_wavelet_spectrograms)
    valid_dataset = CustomDataset(valid_folds, config, mode="train", augment=False, specs=all_spectrograms, eeg_specs=all_eegs,wavelets_spectrograms = all_wavelet_spectrograms)
    
    # ======== DATALOADERS ==========
    train_loader = DataLoader(train_dataset,
                              batch_size=config.BATCH_SIZE_TRAIN,
                              shuffle=False,
                              num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=config.BATCH_SIZE_VALID,
                              shuffle=False,
                              num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=False)
    
    # ======== MODEL ==========
    model = CustomModel(config)
    model.to(device)

    
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.1, weight_decay=config.WEIGHT_DECAY)
    scheduler = OneCycleLR(
        optimizer,
        max_lr=1e-3,
        epochs=config.EPOCHS,
        steps_per_epoch=len(train_loader),
        pct_start=0.1,
        anneal_strategy="cos",
        final_div_factor=100,
    )

    # ======= LOSS ==========
    criterion = nn.KLDivLoss(reduction="batchmean")
    
    best_loss = np.inf
    # ====== ITERATE EPOCHS ========
    for epoch in range(config.EPOCHS):
        start_time = time.time()

        # ======= TRAIN ==========
        avg_train_loss = train_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device)

        # ======= EVALUATION ==========
        avg_val_loss, prediction_dict = valid_epoch(valid_loader, model, criterion, device)
        predictions = prediction_dict["predictions"]
        
        # ======= SCORING ==========
        elapsed = time.time() - start_time

        logging.info(f'Epoch {epoch+1} - avg_train_loss: {avg_train_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
        
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            logging.info(f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model')
            torch.save({'model': model.state_dict(),
                        'predictions': predictions},
                        paths.OUTPUT_DIR / f"{config.MODEL.replace('/', '_')}_fold_{fold}_best.pth")

    ## TypeError: unsupported operand type(s) for +: 'WindowsPath' and 'str'
    # predictions = torch.load(paths.OUTPUT_DIR + f"/{config.MODEL.replace('/', '_')}_fold_{fold}_best.pth", 
    #                          map_location=torch.device('cpu'))['predictions']
    paths.OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    predictions = torch.load(paths.OUTPUT_DIR / f"{config.MODEL.replace('/', '_')}_fold_{fold}_best.pth",
                         map_location=torch.device('cpu'))['predictions']

    valid_folds[target_preds] = predictions
    #+++ 将预测结果添加到valid_folds中，为了区分预测和真实标签，我们给预测的列名添加后缀'_pred'
    for i, label in enumerate(label_cols):
        valid_folds[label + '_pred'] = predictions[:, i]
    
    torch.cuda.empty_cache()
    gc.collect()
    
    return valid_folds

@log_time
def train_loop_full_data(df):
    train_dataset = CustomDataset(df, config, mode="train", augment=True,specs=all_spectrograms, eeg_specs=all_eegs,wavelets_spectrograms = all_wavelet_spectrograms)
    train_loader = DataLoader(train_dataset,
                              batch_size=config.BATCH_SIZE_TRAIN,
                              shuffle=False,
                              num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)
    model = CustomModel(config)
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.1, weight_decay=config.WEIGHT_DECAY)
    scheduler = OneCycleLR(
        optimizer,
        max_lr=1e-3,
        epochs=config.EPOCHS,
        steps_per_epoch=len(train_loader),
        pct_start=0.1,
        anneal_strategy="cos",
        final_div_factor=100,
    )
    criterion = nn.KLDivLoss(reduction="batchmean")
    best_loss = np.inf
    for epoch in range(config.EPOCHS):
        start_time = time.time()
        avg_train_loss = train_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device)
        elapsed = time.time() - start_time
        logging.info(f'Epoch {epoch+1} - avg_train_loss: {avg_train_loss:.4f}  time: {elapsed:.0f}s')
        save_path = output_dir / f"{config.MODEL.replace('/', '_')}_fold_{fold}_best.pth"
        torch.save({'model': model.state_dict(), 'predictions': predictions}, save_path)
#         torch.save(
#             {'model': model.state_dict()},
#             paths.OUTPUT_DIR + f"/{config.MODEL.replace('/', '_')}_epoch_{epoch}.pth")
    torch.cuda.empty_cache()
    gc.collect()
    return 

@log_time
def get_result(oof_df):
    kl_loss = nn.KLDivLoss(reduction="batchmean")
    labels = torch.tensor(oof_df[label_cols].values)
    preds = torch.tensor(oof_df[target_preds].values)
    preds = F.log_softmax(preds, dim=1)
    result = kl_loss(preds, labels)
    return result

@log_time
def preparing_data(df):
    train_df = df.groupby('eeg_id')[['spectrogram_id','spectrogram_label_offset_seconds']].agg({
        'spectrogram_id':'first',
        'spectrogram_label_offset_seconds':'min'
    })
    train_df.columns = ['spectrogram_id','min']

    aux = df.groupby('eeg_id')[['spectrogram_id','spectrogram_label_offset_seconds']].agg({
        'spectrogram_label_offset_seconds':'max'
    })
    train_df['max'] = aux

    aux = df.groupby('eeg_id')[['patient_id']].agg('first')
    train_df['patient_id'] = aux

    aux = df.groupby('eeg_id')[label_cols].agg('sum')
    for label in label_cols:
        train_df[label] = aux[label].values
        
    y_data = train_df[label_cols].values
    y_data = y_data / y_data.sum(axis=1,keepdims=True)
    train_df[label_cols] = y_data

    aux = df.groupby('eeg_id')[['expert_consensus']].agg('first')
    train_df['target'] = aux

    train_df = train_df.reset_index()
    return train_df

## +++ 计算每个样本的KL散度，并将其作为新列添加到train_df
def compute_kl_divergence(data, label_cols):
    labels = data[label_cols].values + 1e-5
    labels /= labels.sum(axis=1, keepdims=True)
    kl_div = torch.nn.functional.kl_div(
        torch.log(torch.tensor(labels, dtype=torch.float)),
        torch.tensor([[1/6] * 6], dtype=torch.float),
        reduction='none'
    ).sum(dim=1).numpy()
    return kl_div

def compute_wavelet_features(signal, wavelet='db4', level=5):
    coeffs = pywt.wavedec(signal, wavelet, level=level)
    # 从小波系数中提取特征而不是直接用小波系数，因为有不规则大小。
    features = []
    for coeff in coeffs:
        features.extend([np.mean(coeff), np.std(coeff)])
    return np.array(features)


@log_time
def loading_parquet(train_df, config = config, READ_SPEC_FILES = False,READ_EEG_SPEC_FILES = False,wavelet = 'None'):
    paths_spectrograms = glob(paths.TRAIN_SPECTROGRAMS + "*.parquet")
#     paths_spectrograms = glob(str(paths.TRAIN_SPECTROGRAMS / "*.parquet"))
    print(f'There are {len(paths_spectrograms)} spectrogram parquets in total path')

    if READ_SPEC_FILES:    
        all_spectrograms = {}
        all_wavelet_spectrograms = {}
        spectrogram_ids = train_df['spectrogram_id'].unique()
        print(f'There are {len(spectrogram_ids)} spectrogram parquets in this training process')
        for spec_id in tqdm(spectrogram_ids):
        # for file_path in tqdm(paths_spectrograms):
            file_path = f"{paths.TRAIN_SPECTROGRAMS}/{spec_id}.parquet"
            aux = pd.read_parquet(file_path)
            spec_arr = aux.fillna(0).values[:, 1:].T.astype("float32")  # (Hz, Time) = (400, 300)
            wavelet_features = np.array([compute_wavelet_features(row, wavelet=wavelet) for row in spec_arr])
            name = int(file_path.split("/")[-1].split('.')[0])
            # all_spectrograms[name] = aux.iloc[:,1:].values  
            all_spectrograms[name] = aux.fillna(0).iloc[:,1:].values.astype("float32")
            all_wavelet_spectrograms[name] = wavelet_features
            del aux
            del wavelet_features
        os.makedirs(os.path.dirname(paths.PRE_LOADED_SPECTROGRAMS), exist_ok=True)
        os.makedirs(os.path.dirname(paths.PRE_LOADED_Wavelets), exist_ok=True)
        np.save(paths.PRE_LOADED_SPECTROGRAMS, all_spectrograms, allow_pickle=True)
        np.save(paths.PRE_LOADED_Wavelets, all_wavelet_spectrograms, allow_pickle=True)
    else:
        all_spectrograms = np.load(paths.PRE_LOADED_SPECTROGRAMS, allow_pickle=True).item()
        all_wavelet_spectrograms = np.load(paths.PRE_LOADED_Wavelets, allow_pickle=True).item()
        
    if config.VISUALIZE:
        idx = np.random.randint(0,len(paths_spectrograms))
        spectrogram_path = paths_spectrograms[idx]
        plot_spectrogram(spectrogram_path)

    # Read EEG Spectrograms
    paths_eegs = glob(paths.TRAIN_EEGS + "*.parquet")
#     paths_eegs = glob(str(paths.TRAIN_EEGS / "*.parquet"))
    print(f'There are {len(paths_eegs)} EEG spectrograms in total path')
    if READ_EEG_SPEC_FILES:
        all_eegs = {}
        eeg_ids = train_df['eeg_id'].unique()
        print(f'There are {len(eeg_ids)} EEG spectrograms in this training path')
        for eeg_id in tqdm(eeg_ids):
            file_path = f"{paths.TRAIN_EEGS}/{eeg_id}.parquet"
            eeg_spectrogram =  pd.read_parquet(file_path)
            all_eegs[eeg_id] = eeg_spectrogram
            del eeg_spectrogram
        os.makedirs(os.path.dirname(paths.PRE_LOADED_EEGS), exist_ok=True)
        np.save(paths.PRE_LOADED_EEGS, all_eegs, allow_pickle=True)
    else:
        all_eegs = np.load(paths.PRE_LOADED_EEGS, allow_pickle=True).item()



    
    return all_spectrograms,all_eegs,all_wavelet_spectrograms


if __name__ == "__main__":
    overall_start_time = time.time()
    print(f"Log file path: {log_filename.absolute()}")
    logging.info('--------------------------------------------------')
    logging.info(f'Into loading stage')

    target_preds = [x + "_pred" for x in ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']]
    label_to_num = {'Seizure': 0, 'LPD': 1, 'GPD': 2, 'LRDA': 3, 'GRDA': 4, 'Other':5}
    num_to_label = {v: k for k, v in label_to_num.items()}
    seed_everything(config.SEED)

#     df = pd.read_csv(paths.TRAIN_CSV)
    df = pd.read_csv(paths.TRAIN_CSV, nrows=1000)
    label_cols = df.columns[-6:]
    print(f"Train cataframe shape is: {df.shape}")
    print(f"Labels: {list(label_cols)}")
    print(df.head())

    #处理train_df，eeg_id,只保留第一个spectrogram_id，min及max spec offset，第一个patient_id等
    train_df = preparing_data(df)
    print('Train non-overlapp eeg_id shape:', train_df.shape )
    print(train_df.head())
    train_df.to_csv('./local_train_df.csv', index=False)
    
    # 添加KL散度列到train_df
    train_df['kl_divergence'] = compute_kl_divergence(train_df, label_cols)
    print('kl value is ---',train_df[['kl_divergence']])
    
    logging.info(f'Into loading stage: combine wavelet feature into X')
    logging.info(f'Into loading stage: loading single npy from local file')
    all_spectrograms,all_eegs,all_wavelet_spectrograms = loading_parquet(train_df, config = config, READ_SPEC_FILES = True,READ_EEG_SPEC_FILES = True,wavelet='db1')
    

    # Validation 
    gkf = GroupKFold(n_splits=config.FOLDS)
    for fold, (train_index, valid_index) in enumerate(gkf.split(train_df, train_df.target, train_df.patient_id)):
        train_df.loc[valid_index, "fold"] = int(fold)
        
    print(train_df.groupby('fold').size()), sep()
    print(train_df.head())

    logging.info(f'training based on model: efficientnet_b4')
    logging.info(f'Feature: without eegs, only specs and wavelets')
    train_dataset = CustomDataset(train_df, config, mode="train", 
                                  specs=all_spectrograms, eeg_specs=all_eegs,wavelets_spectrograms = all_wavelet_spectrograms)
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE_TRAIN,
        shuffle=False,
        num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True
    )
    X, y = train_dataset[0]
    print(f"X shape: {X.shape}")
    print(f"y shape: {y.shape}")

    if config.VISUALIZE:
        ROWS = 2
        COLS = 3
        for (X, y) in train_loader:
            plt.figure(figsize=(20,8))
            for row in range(ROWS):
                for col in range(COLS):
                    plt.subplot(ROWS, COLS, row*COLS + col+1)
                    t = y[row*COLS + col]
                    img = X[row*COLS + col, :, :, 0]
                    mn = img.flatten().min()
                    mx = img.flatten().max()
                    img = (img-mn)/(mx-mn)
                    plt.imshow(img)
                    tars = f'[{t[0]:0.2f}'
                    for s in t[1:]:
                        tars += f', {s:0.2f}'
                    eeg = train_df.eeg_id.values[row*config.BATCH_SIZE_TRAIN + row*COLS + col]
                    plt.title(f'EEG = {eeg}\nTarget = {tars}',size=12)
                    plt.yticks([])
                    plt.ylabel('Frequencies (Hz)',size=14)
                    plt.xlabel('Time (sec)',size=16)
            plt.show()
            break

    #dynamic learning rate
    EPOCHS = config.EPOCHS
    BATCHES = len(train_loader)
    steps = []
    lrs = []
    optim_lrs = []
    model = CustomModel(config)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    scheduler = OneCycleLR(
        optimizer,
        max_lr=1e-3,
        epochs=config.EPOCHS,
        steps_per_epoch=len(train_loader),
        pct_start=0.05,
        anneal_strategy="cos",
        final_div_factor=100,
    )
    for epoch in range(EPOCHS):
        for batch in range(BATCHES):
            scheduler.step()
            lrs.append(scheduler.get_last_lr()[0])
            steps.append(epoch * BATCHES + batch)

    # max_lr = max(lrs)
    # min_lr = min(lrs)
    # print(f"Maximum LR: {max_lr} | Minimum LR: {min_lr}")
    # plt.figure()
    # plt.plot(steps, lrs, label='OneCycle')
    # plt.ticklabel_format(axis='y', style='sci', scilimits=(0,0))
    # plt.xlabel("Step")
    # plt.ylabel("Learning Rate")
    # plt.show()

    if not config.TRAIN_FULL_DATA:
        oof_df = pd.DataFrame()
        results = []  # 存储每个阶段的结果
        for fold_id in range(config.FOLDS):
            # 第一阶段训练，使用所有数据
            print(f"Starting Stage 1 Training for Fold {fold_id}")
            valid_folds_stage_1 = train_loop(train_df, fold_id, stage=1)
            # train_loop返回了字典包含最佳模型的评分和路径
            best_model_score_1 = valid_folds_stage_1.get('best_score', None)
            best_model_path_1 = valid_folds_stage_1.get('best_model_path', None)

            # 第二阶段训练，基于KL散度过滤数据
            print(f"Starting Stage 2 Training for Fold {fold_id}")
            valid_folds_stage_2 = train_loop(train_df, fold_id, stage=2)
            best_model_score_2 = valid_folds_stage_2.get('best_score', None)
            best_model_path_2 = valid_folds_stage_2.get('best_model_path', None)

            results.append((fold_id, best_model_score_1, best_model_path_1, best_model_score_2, best_model_path_2))
            logging.info(f"Fold {fold_id}: Stage 1 - Score: {best_model_score_1}, Path: {best_model_path_1}; Stage 2 - Score: {best_model_score_2}, Path: {best_model_path_2}")

            oof_df = pd.concat([oof_df, valid_folds_stage_2])
        # 循环结束后处理oof_df
        oof_df.reset_index(drop=True, inplace=True)
        
         # 对整体CV结果的打印和日志记录
        print(f"========== CV: {get_result(oof_df)} ==========")
        logging.info(f"========== CV: {get_result(oof_df)} ==========")
        logging.info(f"----------------------------------------------------------------------------------")
        # 保存oof_df或进一步分析
        oof_df.to_csv(os.path.join(paths.OUTPUT_DIR, 'oof_df.csv'), index=False)
    else:
        train_loop_full_data(train_df)

Using 1 GPU(s)
Log file path: /kaggle/working/new_version_training_record.log
seed_everything took 0.0020 seconds.
Train cataframe shape is: (1000, 15)
Labels: ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']
       eeg_id  eeg_sub_id  eeg_label_offset_seconds  spectrogram_id  \
0  1628180742           0                       0.0          353733   
1  1628180742           1                       6.0          353733   
2  1628180742           2                       8.0          353733   
3  1628180742           3                      18.0          353733   
4  1628180742           4                      24.0          353733   

   spectrogram_sub_id  spectrogram_label_offset_seconds    label_id  \
0                   0                               0.0   127492639   
1                   1                               6.0  3887563113   
2                   2                               8.0  1142670488   
3                   3                           

100%|██████████| 72/72 [00:13<00:00,  5.16it/s]


There are 17300 EEG spectrograms in total path
There are 105 EEG spectrograms in this training path


100%|██████████| 105/105 [00:02<00:00, 40.85it/s]


loading_parquet took 17.6791 seconds.
fold
0.0    21
1.0    21
2.0    21
3.0    21
4.0    21
dtype: int64
----------------------------------------------------------------------------------------------------
      eeg_id  spectrogram_id     min     max  patient_id  seizure_vote  \
0    8071080         2593634     0.0    18.0        2944           0.0   
1   72355774        11526349   112.0   112.0       40966           0.0   
2  122762465         8440102   320.0   376.0       54724           0.0   
3  138236967         3252414     0.0     0.0       44623           0.0   
4  142901500        12916371  1284.0  1298.0       21996           0.0   

   lpd_vote  gpd_vote  lrda_vote  grda_vote  other_vote target  kl_divergence  \
0  0.000000  0.000000        0.0       1.00        0.00   GRDA       7.802402   
1  0.000000  0.000000        0.0       0.00        1.00  Other       7.802402   
2  0.000000  0.000000        0.0       0.25        0.75  Other       6.162571   
3  0.000000  0.000000   

model.safetensors:   0%|          | 0.00/77.9M [00:00<?, ?B/s]



Starting Stage 1 Training for Fold 0


Train:  10%|█         | 1/10 [00:05<00:46,  5.18s/train_batch]

Epoch: [1][0/10] Elapsed 0m 5s (remain 0m 46s) Loss: 1.4667 Grad: 206618.1250  LR: 0.00028000  


Train: 100%|██████████| 10/10 [00:39<00:00,  3.98s/train_batch]


Epoch: [1][9/10] Elapsed 0m 39s (remain 0m 0s) Loss: 1.4317 Grad: 153886.2188  LR: 0.00090961  
train_epoch took 39.8329 seconds.


Validation:  33%|███▎      | 1/3 [00:07<00:15,  7.58s/valid_batch]

EVAL: [0/3] Elapsed 0m 7s (remain 0m 15s) Loss: 1.1072 


Validation: 100%|██████████| 3/3 [00:26<00:00,  8.93s/valid_batch]

EVAL: [2/3] Elapsed 0m 26s (remain 0m 0s) Loss: 1.0488 
valid_epoch took 26.8023 seconds.



Train:  10%|█         | 1/10 [00:03<00:32,  3.63s/train_batch]

Epoch: [2][0/10] Elapsed 0m 3s (remain 0m 32s) Loss: 1.1001 Grad: 110346.8984  LR: 0.00088307  


Train: 100%|██████████| 10/10 [00:38<00:00,  3.80s/train_batch]


Epoch: [2][9/10] Elapsed 0m 38s (remain 0m 0s) Loss: 1.1805 Grad: 45817.0039  LR: 0.00054376  
train_epoch took 38.0544 seconds.


Validation:  33%|███▎      | 1/3 [00:07<00:15,  7.51s/valid_batch]

EVAL: [0/3] Elapsed 0m 7s (remain 0m 15s) Loss: 1.0011 


Validation: 100%|██████████| 3/3 [00:26<00:00,  8.88s/valid_batch]

EVAL: [2/3] Elapsed 0m 26s (remain 0m 0s) Loss: 1.0399 
valid_epoch took 26.6373 seconds.



Train:  10%|█         | 1/10 [00:03<00:32,  3.59s/train_batch]

Epoch: [3][0/10] Elapsed 0m 3s (remain 0m 32s) Loss: 0.9174 Grad: 186931.7656  LR: 0.00050020  


Train: 100%|██████████| 10/10 [00:38<00:00,  3.80s/train_batch]


Epoch: [3][9/10] Elapsed 0m 38s (remain 0m 0s) Loss: 0.9514 Grad: 105155.8672  LR: 0.00014679  
train_epoch took 38.0227 seconds.


Validation:  33%|███▎      | 1/3 [00:07<00:15,  7.60s/valid_batch]

EVAL: [0/3] Elapsed 0m 7s (remain 0m 15s) Loss: 1.0417 


Validation: 100%|██████████| 3/3 [00:26<00:00,  8.92s/valid_batch]


EVAL: [2/3] Elapsed 0m 26s (remain 0m 0s) Loss: 1.1213 
valid_epoch took 26.7746 seconds.


Train:  10%|█         | 1/10 [00:03<00:32,  3.61s/train_batch]

Epoch: [4][0/10] Elapsed 0m 3s (remain 0m 32s) Loss: 0.8237 Grad: 91463.0781  LR: 0.00011733  


Train: 100%|██████████| 10/10 [00:38<00:00,  3.81s/train_batch]


Epoch: [4][9/10] Elapsed 0m 38s (remain 0m 0s) Loss: 0.8358 Grad: 57273.0039  LR: 0.00000230  
train_epoch took 38.0934 seconds.


Validation:  33%|███▎      | 1/3 [00:07<00:15,  7.53s/valid_batch]

EVAL: [0/3] Elapsed 0m 7s (remain 0m 15s) Loss: 1.0844 


Validation: 100%|██████████| 3/3 [00:26<00:00,  8.87s/valid_batch]

EVAL: [2/3] Elapsed 0m 26s (remain 0m 0s) Loss: 1.1510 
valid_epoch took 26.6175 seconds.





train_loop took 262.6432 seconds.
Starting Stage 2 Training for Fold 0


Train:  50%|█████     | 1/2 [00:04<00:04,  4.74s/train_batch]

Epoch: [1][0/2] Elapsed 0m 4s (remain 0m 4s) Loss: 0.9491 Grad: 116464.1406  LR: 0.00093304  


Train: 100%|██████████| 2/2 [00:08<00:00,  4.24s/train_batch]


Epoch: [1][1/2] Elapsed 0m 8s (remain 0m 0s) Loss: 1.0828 Grad: 423023.2188  LR: 0.00078687  
train_epoch took 8.4908 seconds.


Validation:  33%|███▎      | 1/3 [00:07<00:15,  7.63s/valid_batch]

EVAL: [0/3] Elapsed 0m 7s (remain 0m 15s) Loss: 1.2465 


Validation: 100%|██████████| 3/3 [00:26<00:00,  8.98s/valid_batch]

EVAL: [2/3] Elapsed 0m 26s (remain 0m 0s) Loss: 1.1605 
valid_epoch took 26.9373 seconds.



Train:  50%|█████     | 1/2 [00:04<00:04,  4.76s/train_batch]

Epoch: [2][0/2] Elapsed 0m 4s (remain 0m 4s) Loss: 0.6344 Grad: 153091.3594  LR: 0.00058699  


Train: 100%|██████████| 2/2 [00:08<00:00,  4.27s/train_batch]


Epoch: [2][1/2] Elapsed 0m 8s (remain 0m 0s) Loss: 0.7345 Grad: 94386.0391  LR: 0.00037084  
train_epoch took 8.5487 seconds.


Validation:  33%|███▎      | 1/3 [00:07<00:15,  7.55s/valid_batch]

EVAL: [0/3] Elapsed 0m 7s (remain 0m 15s) Loss: 1.3456 


Validation: 100%|██████████| 3/3 [00:26<00:00,  8.88s/valid_batch]


EVAL: [2/3] Elapsed 0m 26s (remain 0m 0s) Loss: 1.2279 
valid_epoch took 26.6629 seconds.


Train:  50%|█████     | 1/2 [00:04<00:04,  4.80s/train_batch]

Epoch: [3][0/2] Elapsed 0m 4s (remain 0m 4s) Loss: 0.6809 Grad: 122046.0703  LR: 0.00017893  


Train: 100%|██████████| 2/2 [00:08<00:00,  4.29s/train_batch]


Epoch: [3][1/2] Elapsed 0m 8s (remain 0m 0s) Loss: 0.7705 Grad: 93925.9297  LR: 0.00004723  
train_epoch took 8.5869 seconds.


Validation:  33%|███▎      | 1/3 [00:07<00:15,  7.57s/valid_batch]

EVAL: [0/3] Elapsed 0m 7s (remain 0m 15s) Loss: 1.3355 


Validation: 100%|██████████| 3/3 [00:26<00:00,  8.94s/valid_batch]


EVAL: [2/3] Elapsed 0m 26s (remain 0m 0s) Loss: 1.2188 
valid_epoch took 26.8372 seconds.


Train:  50%|█████     | 1/2 [00:04<00:04,  4.76s/train_batch]

Epoch: [4][0/2] Elapsed 0m 4s (remain 0m 4s) Loss: 0.6304 Grad: 98058.7812  LR: 0.00000040  


Train: 100%|██████████| 2/2 [00:08<00:00,  4.27s/train_batch]


Epoch: [4][1/2] Elapsed 0m 8s (remain 0m 0s) Loss: 0.6995 Grad: 134640.2812  LR: 0.00004723  
train_epoch took 8.5429 seconds.


Validation:  33%|███▎      | 1/3 [00:07<00:15,  7.55s/valid_batch]

EVAL: [0/3] Elapsed 0m 7s (remain 0m 15s) Loss: 1.2572 


Validation: 100%|██████████| 3/3 [00:26<00:00,  8.96s/valid_batch]

EVAL: [2/3] Elapsed 0m 26s (remain 0m 0s) Loss: 1.1585 
valid_epoch took 26.8771 seconds.





train_loop took 143.2504 seconds.
Starting Stage 1 Training for Fold 1


Train:  10%|█         | 1/10 [00:03<00:33,  3.69s/train_batch]

Epoch: [1][0/10] Elapsed 0m 3s (remain 0m 33s) Loss: 1.4224 Grad: 142572.5469  LR: 0.00028000  


Train: 100%|██████████| 10/10 [00:48<00:00,  4.82s/train_batch]


Epoch: [1][9/10] Elapsed 0m 48s (remain 0m 0s) Loss: 1.2989 Grad: 185577.5469  LR: 0.00090961  
train_epoch took 48.1980 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.02s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 1.6232 


Validation: 100%|██████████| 3/3 [00:02<00:00,  1.04valid_batch/s]

EVAL: [2/3] Elapsed 0m 2s (remain 0m 0s) Loss: 1.6218 
valid_epoch took 2.8966 seconds.



Train:  10%|█         | 1/10 [00:03<00:33,  3.70s/train_batch]

Epoch: [2][0/10] Elapsed 0m 3s (remain 0m 33s) Loss: 1.0969 Grad: 115692.5078  LR: 0.00088307  


Train: 100%|██████████| 10/10 [00:47<00:00,  4.79s/train_batch]


Epoch: [2][9/10] Elapsed 0m 47s (remain 0m 0s) Loss: 1.0214 Grad: 98284.5703  LR: 0.00054376  
train_epoch took 47.9356 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.01s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 2.2842 


Validation: 100%|██████████| 3/3 [00:02<00:00,  1.06valid_batch/s]


EVAL: [2/3] Elapsed 0m 2s (remain 0m 0s) Loss: 1.9216 
valid_epoch took 2.8308 seconds.


Train:  10%|█         | 1/10 [00:03<00:32,  3.62s/train_batch]

Epoch: [3][0/10] Elapsed 0m 3s (remain 0m 32s) Loss: 1.1530 Grad: 200647.2812  LR: 0.00050020  


Train: 100%|██████████| 10/10 [00:47<00:00,  4.79s/train_batch]


Epoch: [3][9/10] Elapsed 0m 47s (remain 0m 0s) Loss: 0.9097 Grad: 122259.9297  LR: 0.00014679  
train_epoch took 47.8837 seconds.


Validation:  33%|███▎      | 1/3 [00:00<00:01,  1.01valid_batch/s]

EVAL: [0/3] Elapsed 0m 0s (remain 0m 1s) Loss: 2.7491 


Validation: 100%|██████████| 3/3 [00:02<00:00,  1.07valid_batch/s]


EVAL: [2/3] Elapsed 0m 2s (remain 0m 0s) Loss: 2.1397 
valid_epoch took 2.8105 seconds.


Train:  10%|█         | 1/10 [00:03<00:32,  3.65s/train_batch]

Epoch: [4][0/10] Elapsed 0m 3s (remain 0m 32s) Loss: 0.9102 Grad: 90855.3203  LR: 0.00011733  


Train: 100%|██████████| 10/10 [00:47<00:00,  4.78s/train_batch]


Epoch: [4][9/10] Elapsed 0m 47s (remain 0m 0s) Loss: 0.7851 Grad: 149457.5000  LR: 0.00000230  
train_epoch took 47.8491 seconds.


Validation:  33%|███▎      | 1/3 [00:00<00:01,  1.01valid_batch/s]

EVAL: [0/3] Elapsed 0m 0s (remain 0m 1s) Loss: 2.6892 


Validation: 100%|██████████| 3/3 [00:02<00:00,  1.08valid_batch/s]

EVAL: [2/3] Elapsed 0m 2s (remain 0m 0s) Loss: 2.0982 
valid_epoch took 2.7997 seconds.





train_loop took 204.5692 seconds.
Starting Stage 2 Training for Fold 1


Train:  33%|███▎      | 1/3 [00:05<00:10,  5.34s/train_batch]

Epoch: [1][0/3] Elapsed 0m 5s (remain 0m 10s) Loss: 0.8371 Grad: 92562.7812  LR: 0.00098653  


Train: 100%|██████████| 3/3 [00:16<00:00,  5.45s/train_batch]


Epoch: [1][2/3] Elapsed 0m 16s (remain 0m 0s) Loss: 0.8385 Grad: 94097.5156  LR: 0.00084318  
train_epoch took 16.3463 seconds.


Validation:  33%|███▎      | 1/3 [00:00<00:01,  1.01valid_batch/s]

EVAL: [0/3] Elapsed 0m 0s (remain 0m 1s) Loss: 1.5496 


Validation: 100%|██████████| 3/3 [00:02<00:00,  1.08valid_batch/s]

EVAL: [2/3] Elapsed 0m 2s (remain 0m 0s) Loss: 1.5809 
valid_epoch took 2.7940 seconds.



Train:  33%|███▎      | 1/3 [00:05<00:10,  5.25s/train_batch]

Epoch: [2][0/3] Elapsed 0m 5s (remain 0m 10s) Loss: 0.7245 Grad: 96759.8203  LR: 0.00072451  


Train: 100%|██████████| 3/3 [00:16<00:00,  5.45s/train_batch]


Epoch: [2][2/3] Elapsed 0m 16s (remain 0m 0s) Loss: 0.6527 Grad: 130439.3828  LR: 0.00044218  
train_epoch took 16.3474 seconds.


Validation:  33%|███▎      | 1/3 [00:00<00:01,  1.01valid_batch/s]

EVAL: [0/3] Elapsed 0m 0s (remain 0m 1s) Loss: 1.5294 


Validation: 100%|██████████| 3/3 [00:02<00:00,  1.08valid_batch/s]


EVAL: [2/3] Elapsed 0m 2s (remain 0m 0s) Loss: 1.5857 
valid_epoch took 2.7992 seconds.


Train:  33%|███▎      | 1/3 [00:05<00:10,  5.26s/train_batch]

Epoch: [3][0/3] Elapsed 0m 5s (remain 0m 10s) Loss: 0.5869 Grad: 96894.1016  LR: 0.00030224  


Train: 100%|██████████| 3/3 [00:16<00:00,  5.42s/train_batch]


Epoch: [3][2/3] Elapsed 0m 16s (remain 0m 0s) Loss: 0.5577 Grad: 202016.4688  LR: 0.00008262  
train_epoch took 16.2798 seconds.


Validation:  33%|███▎      | 1/3 [00:00<00:01,  1.00valid_batch/s]

EVAL: [0/3] Elapsed 0m 0s (remain 0m 1s) Loss: 1.5227 


Validation: 100%|██████████| 3/3 [00:02<00:00,  1.07valid_batch/s]

EVAL: [2/3] Elapsed 0m 2s (remain 0m 0s) Loss: 1.5797 
valid_epoch took 2.8087 seconds.



Train:  33%|███▎      | 1/3 [00:05<00:10,  5.26s/train_batch]

Epoch: [4][0/3] Elapsed 0m 5s (remain 0m 10s) Loss: 0.5200 Grad: 65304.5508  LR: 0.00002140  


Train: 100%|██████████| 3/3 [00:16<00:00,  5.43s/train_batch]


Epoch: [4][2/3] Elapsed 0m 16s (remain 0m 0s) Loss: 0.5445 Grad: 129407.7031  LR: 0.00002140  
train_epoch took 16.2956 seconds.


Validation:  33%|███▎      | 1/3 [00:00<00:01,  1.01valid_batch/s]

EVAL: [0/3] Elapsed 0m 0s (remain 0m 1s) Loss: 1.5504 


Validation: 100%|██████████| 3/3 [00:02<00:00,  1.07valid_batch/s]

EVAL: [2/3] Elapsed 0m 2s (remain 0m 0s) Loss: 1.5869 
valid_epoch took 2.8143 seconds.





train_loop took 78.1571 seconds.
Starting Stage 1 Training for Fold 2


Train:  10%|█         | 1/10 [00:04<00:36,  4.11s/train_batch]

Epoch: [1][0/10] Elapsed 0m 4s (remain 0m 36s) Loss: 1.4505 Grad: 175500.2500  LR: 0.00028000  


Train: 100%|██████████| 10/10 [00:47<00:00,  4.75s/train_batch]


Epoch: [1][9/10] Elapsed 0m 47s (remain 0m 0s) Loss: 1.3217 Grad: inf  LR: 0.00090961  
train_epoch took 47.4818 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:03,  1.91s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 3s) Loss: 1.5187 


Validation: 100%|██████████| 3/3 [00:03<00:00,  1.26s/valid_batch]

EVAL: [2/3] Elapsed 0m 3s (remain 0m 0s) Loss: 1.5496 
valid_epoch took 3.7768 seconds.



Train:  10%|█         | 1/10 [00:04<00:37,  4.15s/train_batch]

Epoch: [2][0/10] Elapsed 0m 4s (remain 0m 37s) Loss: 1.3899 Grad: 246738.6875  LR: 0.00088307  


Train: 100%|██████████| 10/10 [00:47<00:00,  4.74s/train_batch]


Epoch: [2][9/10] Elapsed 0m 47s (remain 0m 0s) Loss: 1.1376 Grad: inf  LR: 0.00054376  
train_epoch took 47.3677 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:03,  1.90s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 3s) Loss: 1.3804 


Validation: 100%|██████████| 3/3 [00:03<00:00,  1.25s/valid_batch]

EVAL: [2/3] Elapsed 0m 3s (remain 0m 0s) Loss: 1.4826 
valid_epoch took 3.7539 seconds.



Train:  10%|█         | 1/10 [00:04<00:37,  4.13s/train_batch]

Epoch: [3][0/10] Elapsed 0m 4s (remain 0m 37s) Loss: 1.0636 Grad: 169333.2188  LR: 0.00050020  


Train: 100%|██████████| 10/10 [00:47<00:00,  4.74s/train_batch]


Epoch: [3][9/10] Elapsed 0m 47s (remain 0m 0s) Loss: 0.9904 Grad: 268612.7500  LR: 0.00014679  
train_epoch took 47.4581 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:03,  1.90s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 3s) Loss: 1.4377 


Validation: 100%|██████████| 3/3 [00:03<00:00,  1.24s/valid_batch]


EVAL: [2/3] Elapsed 0m 3s (remain 0m 0s) Loss: 1.5285 
valid_epoch took 3.7183 seconds.


Train:  10%|█         | 1/10 [00:04<00:37,  4.13s/train_batch]

Epoch: [4][0/10] Elapsed 0m 4s (remain 0m 37s) Loss: 1.0540 Grad: 182006.5156  LR: 0.00011733  


Train: 100%|██████████| 10/10 [00:47<00:00,  4.75s/train_batch]


Epoch: [4][9/10] Elapsed 0m 47s (remain 0m 0s) Loss: 0.9083 Grad: 253590.6562  LR: 0.00000230  
train_epoch took 47.4699 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:03,  1.90s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 3s) Loss: 1.3991 


Validation: 100%|██████████| 3/3 [00:03<00:00,  1.24s/valid_batch]

EVAL: [2/3] Elapsed 0m 3s (remain 0m 0s) Loss: 1.5260 
valid_epoch took 3.7266 seconds.





train_loop took 206.3971 seconds.
Starting Stage 2 Training for Fold 2


Train:  33%|███▎      | 1/3 [00:04<00:08,  4.37s/train_batch]

Epoch: [1][0/3] Elapsed 0m 4s (remain 0m 8s) Loss: 0.8854 Grad: 105124.7500  LR: 0.00098653  


Train: 100%|██████████| 3/3 [00:15<00:00,  5.14s/train_batch]


Epoch: [1][2/3] Elapsed 0m 15s (remain 0m 0s) Loss: 0.8628 Grad: 134195.4062  LR: 0.00084318  
train_epoch took 15.4208 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:03,  1.91s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 3s) Loss: 1.6445 


Validation: 100%|██████████| 3/3 [00:03<00:00,  1.24s/valid_batch]

EVAL: [2/3] Elapsed 0m 3s (remain 0m 0s) Loss: 1.6012 
valid_epoch took 3.7407 seconds.



Train:  33%|███▎      | 1/3 [00:04<00:08,  4.33s/train_batch]

Epoch: [2][0/3] Elapsed 0m 4s (remain 0m 8s) Loss: 0.7738 Grad: 92117.0703  LR: 0.00072451  


Train: 100%|██████████| 3/3 [00:15<00:00,  5.12s/train_batch]


Epoch: [2][2/3] Elapsed 0m 15s (remain 0m 0s) Loss: 0.7012 Grad: 308892.4375  LR: 0.00044218  
train_epoch took 15.3679 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:03,  1.89s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 3s) Loss: 1.6352 


Validation: 100%|██████████| 3/3 [00:03<00:00,  1.24s/valid_batch]

EVAL: [2/3] Elapsed 0m 3s (remain 0m 0s) Loss: 1.5933 
valid_epoch took 3.7197 seconds.



Train:  33%|███▎      | 1/3 [00:04<00:08,  4.28s/train_batch]

Epoch: [3][0/3] Elapsed 0m 4s (remain 0m 8s) Loss: 0.7874 Grad: 233918.0312  LR: 0.00030224  


Train: 100%|██████████| 3/3 [00:15<00:00,  5.10s/train_batch]


Epoch: [3][2/3] Elapsed 0m 15s (remain 0m 0s) Loss: 0.6496 Grad: 120171.8125  LR: 0.00008262  
train_epoch took 15.3058 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:03,  1.89s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 3s) Loss: 1.5306 


Validation: 100%|██████████| 3/3 [00:03<00:00,  1.25s/valid_batch]

EVAL: [2/3] Elapsed 0m 3s (remain 0m 0s) Loss: 1.5466 
valid_epoch took 3.7633 seconds.



Train:  33%|███▎      | 1/3 [00:04<00:08,  4.27s/train_batch]

Epoch: [4][0/3] Elapsed 0m 4s (remain 0m 8s) Loss: 0.6670 Grad: 160364.0156  LR: 0.00002140  


Train: 100%|██████████| 3/3 [00:15<00:00,  5.10s/train_batch]


Epoch: [4][2/3] Elapsed 0m 15s (remain 0m 0s) Loss: 0.6020 Grad: 88116.0625  LR: 0.00002140  
train_epoch took 15.2975 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:03,  1.89s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 3s) Loss: 1.4748 


Validation: 100%|██████████| 3/3 [00:03<00:00,  1.23s/valid_batch]

EVAL: [2/3] Elapsed 0m 3s (remain 0m 0s) Loss: 1.5344 
valid_epoch took 3.7109 seconds.





train_loop took 78.5049 seconds.
Starting Stage 1 Training for Fold 3


Train:  10%|█         | 1/10 [00:03<00:32,  3.57s/train_batch]

Epoch: [1][0/10] Elapsed 0m 3s (remain 0m 32s) Loss: 1.6172 Grad: 211082.9531  LR: 0.00028000  


Train: 100%|██████████| 10/10 [00:48<00:00,  4.81s/train_batch]


Epoch: [1][9/10] Elapsed 0m 48s (remain 0m 0s) Loss: 1.4223 Grad: 149608.7812  LR: 0.00090961  
train_epoch took 48.1594 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.26s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 1.3724 


Validation: 100%|██████████| 3/3 [00:02<00:00,  1.13valid_batch/s]

EVAL: [2/3] Elapsed 0m 2s (remain 0m 0s) Loss: 1.4049 
valid_epoch took 2.6594 seconds.



Train:  10%|█         | 1/10 [00:03<00:32,  3.56s/train_batch]

Epoch: [2][0/10] Elapsed 0m 3s (remain 0m 32s) Loss: 1.4219 Grad: 156934.7344  LR: 0.00088307  


Train: 100%|██████████| 10/10 [00:48<00:00,  4.84s/train_batch]


Epoch: [2][9/10] Elapsed 0m 48s (remain 0m 0s) Loss: 1.2340 Grad: 118126.0000  LR: 0.00054376  
train_epoch took 48.4053 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.24s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 1.1324 


Validation: 100%|██████████| 3/3 [00:02<00:00,  1.12valid_batch/s]

EVAL: [2/3] Elapsed 0m 2s (remain 0m 0s) Loss: 1.1629 
valid_epoch took 2.6924 seconds.



Train:  10%|█         | 1/10 [00:03<00:32,  3.56s/train_batch]

Epoch: [3][0/10] Elapsed 0m 3s (remain 0m 32s) Loss: 1.2796 Grad: 138379.3750  LR: 0.00050020  


Train: 100%|██████████| 10/10 [00:48<00:00,  4.83s/train_batch]


Epoch: [3][9/10] Elapsed 0m 48s (remain 0m 0s) Loss: 1.0854 Grad: 181087.3750  LR: 0.00014679  
train_epoch took 48.3438 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.26s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 0.9970 


Validation: 100%|██████████| 3/3 [00:02<00:00,  1.12valid_batch/s]

EVAL: [2/3] Elapsed 0m 2s (remain 0m 0s) Loss: 1.0397 
valid_epoch took 2.6915 seconds.



Train:  10%|█         | 1/10 [00:03<00:32,  3.59s/train_batch]

Epoch: [4][0/10] Elapsed 0m 3s (remain 0m 32s) Loss: 1.0189 Grad: 137235.6250  LR: 0.00011733  


Train: 100%|██████████| 10/10 [00:48<00:00,  4.85s/train_batch]


Epoch: [4][9/10] Elapsed 0m 48s (remain 0m 0s) Loss: 0.9537 Grad: 113226.3203  LR: 0.00000230  
train_epoch took 48.5006 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.26s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 1.0151 


Validation: 100%|██████████| 3/3 [00:02<00:00,  1.11valid_batch/s]

EVAL: [2/3] Elapsed 0m 2s (remain 0m 0s) Loss: 1.0330 
valid_epoch took 2.7179 seconds.





train_loop took 206.2917 seconds.
Starting Stage 2 Training for Fold 3


Train:  50%|█████     | 1/2 [00:05<00:05,  5.20s/train_batch]

Epoch: [1][0/2] Elapsed 0m 5s (remain 0m 5s) Loss: 0.9896 Grad: 106951.9844  LR: 0.00093304  


Train: 100%|██████████| 2/2 [00:08<00:00,  4.49s/train_batch]


Epoch: [1][1/2] Elapsed 0m 8s (remain 0m 0s) Loss: 0.8994 Grad: 80132.1719  LR: 0.00078687  
train_epoch took 8.9836 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.25s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 1.4141 


Validation: 100%|██████████| 3/3 [00:02<00:00,  1.13valid_batch/s]

EVAL: [2/3] Elapsed 0m 2s (remain 0m 0s) Loss: 1.4495 
valid_epoch took 2.6728 seconds.



Train:  50%|█████     | 1/2 [00:05<00:05,  5.28s/train_batch]

Epoch: [2][0/2] Elapsed 0m 5s (remain 0m 5s) Loss: 0.7780 Grad: 82517.6250  LR: 0.00058699  


Train: 100%|██████████| 2/2 [00:09<00:00,  4.53s/train_batch]


Epoch: [2][1/2] Elapsed 0m 9s (remain 0m 0s) Loss: 0.7726 Grad: 139852.3750  LR: 0.00037084  
train_epoch took 9.0749 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.25s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 1.3904 


Validation: 100%|██████████| 3/3 [00:02<00:00,  1.13valid_batch/s]

EVAL: [2/3] Elapsed 0m 2s (remain 0m 0s) Loss: 1.4212 
valid_epoch took 2.6761 seconds.



Train:  50%|█████     | 1/2 [00:05<00:05,  5.19s/train_batch]

Epoch: [3][0/2] Elapsed 0m 5s (remain 0m 5s) Loss: 0.6291 Grad: 119349.8750  LR: 0.00017893  


Train: 100%|██████████| 2/2 [00:08<00:00,  4.48s/train_batch]


Epoch: [3][1/2] Elapsed 0m 8s (remain 0m 0s) Loss: 0.6348 Grad: 54733.8164  LR: 0.00004723  
train_epoch took 8.9778 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.25s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 1.3775 


Validation: 100%|██████████| 3/3 [00:02<00:00,  1.12valid_batch/s]

EVAL: [2/3] Elapsed 0m 2s (remain 0m 0s) Loss: 1.4099 
valid_epoch took 2.6962 seconds.



Train:  50%|█████     | 1/2 [00:05<00:05,  5.21s/train_batch]

Epoch: [4][0/2] Elapsed 0m 5s (remain 0m 5s) Loss: 0.6409 Grad: 109681.5156  LR: 0.00000040  


Train: 100%|██████████| 2/2 [00:08<00:00,  4.49s/train_batch]


Epoch: [4][1/2] Elapsed 0m 8s (remain 0m 0s) Loss: 0.6655 Grad: 116503.3438  LR: 0.00004723  
train_epoch took 8.9957 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.25s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 1.3563 


Validation: 100%|██████████| 3/3 [00:02<00:00,  1.13valid_batch/s]

EVAL: [2/3] Elapsed 0m 2s (remain 0m 0s) Loss: 1.3900 
valid_epoch took 2.6720 seconds.





train_loop took 48.9102 seconds.
Starting Stage 1 Training for Fold 4


Train:  10%|█         | 1/10 [00:03<00:33,  3.71s/train_batch]

Epoch: [1][0/10] Elapsed 0m 3s (remain 0m 33s) Loss: 1.5316 Grad: 143364.2812  LR: 0.00028000  


Train: 100%|██████████| 10/10 [00:47<00:00,  4.79s/train_batch]


Epoch: [1][9/10] Elapsed 0m 47s (remain 0m 0s) Loss: 1.3478 Grad: 201270.2812  LR: 0.00090961  
train_epoch took 47.9233 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.46s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 1.3229 


Validation: 100%|██████████| 3/3 [00:03<00:00,  1.00s/valid_batch]

EVAL: [2/3] Elapsed 0m 3s (remain 0m 0s) Loss: 1.4356 
valid_epoch took 3.0193 seconds.



Train:  10%|█         | 1/10 [00:03<00:33,  3.69s/train_batch]

Epoch: [2][0/10] Elapsed 0m 3s (remain 0m 33s) Loss: 1.3645 Grad: 151697.5000  LR: 0.00088307  


Train: 100%|██████████| 10/10 [00:48<00:00,  4.80s/train_batch]


Epoch: [2][9/10] Elapsed 0m 48s (remain 0m 0s) Loss: 1.2123 Grad: 105433.6953  LR: 0.00054376  
train_epoch took 48.0438 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.46s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 1.1546 


Validation: 100%|██████████| 3/3 [00:03<00:00,  1.01s/valid_batch]

EVAL: [2/3] Elapsed 0m 3s (remain 0m 0s) Loss: 1.2480 
valid_epoch took 3.0341 seconds.



Train:  10%|█         | 1/10 [00:03<00:33,  3.70s/train_batch]

Epoch: [3][0/10] Elapsed 0m 3s (remain 0m 33s) Loss: 1.1716 Grad: 262954.1562  LR: 0.00050020  


Train: 100%|██████████| 10/10 [00:48<00:00,  4.82s/train_batch]


Epoch: [3][9/10] Elapsed 0m 48s (remain 0m 0s) Loss: 0.9663 Grad: 111703.1641  LR: 0.00014679  
train_epoch took 48.1979 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.47s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 1.0177 


Validation: 100%|██████████| 3/3 [00:03<00:00,  1.02s/valid_batch]

EVAL: [2/3] Elapsed 0m 3s (remain 0m 0s) Loss: 1.0893 
valid_epoch took 3.0565 seconds.



Train:  10%|█         | 1/10 [00:03<00:33,  3.72s/train_batch]

Epoch: [4][0/10] Elapsed 0m 3s (remain 0m 33s) Loss: 0.9900 Grad: 126916.1484  LR: 0.00011733  


Train: 100%|██████████| 10/10 [00:48<00:00,  4.81s/train_batch]


Epoch: [4][9/10] Elapsed 0m 48s (remain 0m 0s) Loss: 0.8730 Grad: 131830.5625  LR: 0.00000230  
train_epoch took 48.0652 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.46s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 1.0120 


Validation: 100%|██████████| 3/3 [00:03<00:00,  1.01s/valid_batch]

EVAL: [2/3] Elapsed 0m 3s (remain 0m 0s) Loss: 1.0692 
valid_epoch took 3.0372 seconds.





train_loop took 206.4848 seconds.
Starting Stage 2 Training for Fold 4


Train:  33%|███▎      | 1/3 [00:05<00:10,  5.29s/train_batch]

Epoch: [1][0/3] Elapsed 0m 5s (remain 0m 10s) Loss: 0.8146 Grad: 108176.0234  LR: 0.00098653  


Train: 100%|██████████| 3/3 [00:16<00:00,  5.45s/train_batch]


Epoch: [1][2/3] Elapsed 0m 16s (remain 0m 0s) Loss: 0.8390 Grad: 153540.3281  LR: 0.00084318  
train_epoch took 16.3651 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.50s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 3s) Loss: 1.3761 


Validation: 100%|██████████| 3/3 [00:03<00:00,  1.02s/valid_batch]

EVAL: [2/3] Elapsed 0m 3s (remain 0m 0s) Loss: 1.5077 
valid_epoch took 3.0765 seconds.



Train:  33%|███▎      | 1/3 [00:05<00:10,  5.15s/train_batch]

Epoch: [2][0/3] Elapsed 0m 5s (remain 0m 10s) Loss: 0.7815 Grad: 372374.2812  LR: 0.00072451  


Train: 100%|██████████| 3/3 [00:16<00:00,  5.43s/train_batch]


Epoch: [2][2/3] Elapsed 0m 16s (remain 0m 0s) Loss: 0.7159 Grad: 108057.5391  LR: 0.00044218  
train_epoch took 16.2995 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.47s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 1.3777 


Validation: 100%|██████████| 3/3 [00:03<00:00,  1.01s/valid_batch]

EVAL: [2/3] Elapsed 0m 3s (remain 0m 0s) Loss: 1.4822 
valid_epoch took 3.0529 seconds.



Train:  33%|███▎      | 1/3 [00:05<00:10,  5.19s/train_batch]

Epoch: [3][0/3] Elapsed 0m 5s (remain 0m 10s) Loss: 0.6353 Grad: 88941.9531  LR: 0.00030224  


Train: 100%|██████████| 3/3 [00:16<00:00,  5.41s/train_batch]


Epoch: [3][2/3] Elapsed 0m 16s (remain 0m 0s) Loss: 0.6344 Grad: 209431.8594  LR: 0.00008262  
train_epoch took 16.2483 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.47s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 1.3197 


Validation: 100%|██████████| 3/3 [00:03<00:00,  1.01s/valid_batch]

EVAL: [2/3] Elapsed 0m 3s (remain 0m 0s) Loss: 1.4112 
valid_epoch took 3.0413 seconds.



Train:  33%|███▎      | 1/3 [00:05<00:10,  5.24s/train_batch]

Epoch: [4][0/3] Elapsed 0m 5s (remain 0m 10s) Loss: 0.7093 Grad: 134372.3750  LR: 0.00002140  


Train: 100%|██████████| 3/3 [00:16<00:00,  5.43s/train_batch]


Epoch: [4][2/3] Elapsed 0m 16s (remain 0m 0s) Loss: 0.6916 Grad: 220671.8438  LR: 0.00002140  
train_epoch took 16.3005 seconds.


Validation:  33%|███▎      | 1/3 [00:01<00:02,  1.48s/valid_batch]

EVAL: [0/3] Elapsed 0m 1s (remain 0m 2s) Loss: 1.2667 


Validation: 100%|██████████| 3/3 [00:03<00:00,  1.02s/valid_batch]

EVAL: [2/3] Elapsed 0m 3s (remain 0m 0s) Loss: 1.3548 
valid_epoch took 3.0642 seconds.





train_loop took 79.6918 seconds.
get_result took 0.0043 seconds.
get_result took 0.0012 seconds.
