In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
from typing import Optional, Tuple, List, Dict
import time
from tqdm import tqdm

class TransformerEncoderLayer(nn.Module):
    """
    A single transformer encoder layer.
    """
    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: str = "relu",
        layer_norm_eps: float = 1e-5,
    ):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        
        # Feed forward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
        # Activation function
        self.activation = _get_activation_fn(activation)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # Multi-head attention block
        src2 = self.norm1(src)
        src2, _ = self.self_attn(src2, src2, src2, attn_mask=src_mask, 
                                key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        
        # Feed forward block
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        
        return src


class TransformerEncoder(nn.Module):
    """
    Full transformer encoder with configurable number of layers.
    """
    def __init__(
        self,
        d_model: int,
        nhead: int,
        num_layers: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: str = "relu",
        layer_norm_eps: float = 1e-5,
    ):
        super().__init__()
        
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                activation=activation,
                layer_norm_eps=layer_norm_eps
            )
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        output = src
        
        for layer in self.layers:
            output = layer(output, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask)
            
        output = self.norm(output)
        return output


class LineSequenceClassifier(nn.Module):
    """
    Transformer model for binary classification of 2D line sequences.
    Each line is represented by 4 numbers (x1, y1, x2, y2).
    """
    def __init__(
        self,
        line_dim: int = 4,  # Dimension of each line (x1, y1, x2, y2)
        d_model: int = 128, # Embedding dimension
        nhead: int = 8,
        num_encoder_layers: int = 4,
        dim_feedforward: int = 512,
        dropout: float = 0.1,
        activation: str = "relu",
        layer_norm_eps: float = 1e-5,
    ):
        super().__init__()
        
        self.d_model = d_model
        
        # Project 4D line features to d_model dimensions
        self.line_embedding = nn.Linear(line_dim, d_model)
        
        # Transformer encoder
        self.transformer_encoder = TransformerEncoder(
            d_model=d_model,
            nhead=nhead,
            num_layers=num_encoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation,
            layer_norm_eps=layer_norm_eps
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(d_model, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(16, 1)
        )
        
    def forward(self, src, src_lengths):
        # src shape: [batch_size, seq_len, line_dim]
        batch_size, max_len, _ = src.shape
        
        # Create padding mask based on sequence lengths
        src_key_padding_mask = torch.arange(max_len, device=src.device).expand(batch_size, max_len) >= src_lengths.unsqueeze(1)
        
        # Line feature embedding
        x = self.line_embedding(src)
        
        # Transformer encoding
        encoded = self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)
        
        # Get the last valid token for each sequence in the batch
        batch_indices = torch.arange(encoded.size(0), device=encoded.device)
        last_indices = src_lengths - 1  # Convert to 0-indexed
        last_hidden = encoded[batch_indices, last_indices]
        
        # Binary classification
        logits = self.classifier(last_hidden).squeeze(-1)
        return logits

    def predict(self, src, src_lengths):
        """
        Convenience method that returns binary predictions.
        """
        logits = self.forward(src, src_lengths)
        return torch.sigmoid(logits) >= 0.5


class LineSequenceDataset(Dataset):
    def __init__(self, file_path):
        self.file_path = file_path
        self.cache = []  # To store parsed scenes
        self.samples = []  # To store indices of individual samples

        # Read and parse the file during initialization
        self._parse_file()

    def _parse_file(self):
        """Reads the file and caches scenes and their queries."""
        with open(self.file_path, 'r') as file:
            for line_idx, line in enumerate(file):
                line = line.strip().split()

                # Parse obstacles
                obstacles = []
                idx = 0
                while idx < len(line) and line[idx] != 'q':
                    x, y, x1, y1 = float(line[idx]), float(line[idx + 1]), float(line[idx + 2]), float(line[idx + 3])
                    idx += 4
                    obstacles.append((x, y, x1, y1))

                # Parse queries and labels
                queries = []
                while idx < len(line):
                    if line[idx] == 'q':
                        idx += 1  # Skip 'q'
                        x, y, x1, y1, label = float(line[idx]), float(line[idx + 1]), float(line[idx + 2]), float(line[idx + 3]), int(line[idx+4])
                        idx += 5
                        queries.append(((x, y, x1, y1), label))

                # Cache the parsed scene
                scene_idx = len(self.cache)
                self.cache.append((obstacles, queries))

                # Index individual samples
                for query_idx in range(len(queries)):
                    self.samples.append((scene_idx, query_idx))


    def __len__(self):
        """Total number of samples (queries)."""
        return len(self.samples)

    def __getitem__(self, idx):
        """Returns a single sample: obstacles, query, label, and augmentation orders."""
        scene_idx, query_idx = self.samples[idx]
        obstacles, queries = self.cache[scene_idx]
        query, label = queries[query_idx]

        return obstacles + [query], len(obstacles + [query]), label
    

def _get_activation_fn(activation):
    """Helper function to get activation function by name."""
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu
    else:
        raise ValueError(f"Activation function {activation} not supported")


def collate_fn(batch):
    """
    Custom collate function for variable length sequences.
    """
    sequences, sequence_lengths, labels = zip(*batch)
    
    # Find max sequence length in this batch
    max_len = max(sequence_lengths)
    
    # Pad sequences to max_len
    padded_sequences = []
    for seq, seq_len in zip(sequences, sequence_lengths):
        padded_seq = np.zeros((max_len, 4))
        padded_seq[:seq_len] = seq
        padded_sequences.append(padded_seq)
    
    # Convert to tensors
    sequences_tensor = torch.tensor(np.array(padded_sequences), dtype=torch.float32)
    sequence_lengths_tensor = torch.tensor(sequence_lengths, dtype=torch.long)
    labels_tensor = torch.tensor(labels, dtype=torch.float32)
    
    return sequences_tensor, sequence_lengths_tensor, labels_tensor


def calculate_metrics(y_true, y_pred):
    """
    Calculate classification metrics: accuracy, recall, precision, specificity, and F1 score.
    Handles cases where precision may be undefined due to no positive predictions.
    """
    # Convert to numpy arrays if tensors
    if isinstance(y_true, torch.Tensor):
        y_true = y_true.cpu().numpy()
    if isinstance(y_pred, torch.Tensor):
        y_pred = y_pred.cpu().numpy()
    
    # Calculate metrics with zero_division=0 to prevent warnings
    accuracy = accuracy_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred, zero_division=0)
    precision = precision_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    
    # Calculate specificity (true negative rate)
    if len(y_true) > 0:
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 1.0
    else:
        specificity = 0.0
    
    return {
        'accuracy': accuracy,
        'recall': recall,
        'precision': precision,
        'specificity': specificity,
        'f1_score': f1
    }


def train(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10, patience=3, resume_from=None):
    """
    Train the model and evaluate on validation set with early stopping.
    Enhanced progress bar shows all metrics in real-time.
    """
    best_val_metrics = {'f1_score': 0}
    train_metrics_history = []
    val_metrics_history = []
    
    # Early stopping setup
    best_val_f1 = 0
    epochs_without_improvement = 0
    start_epoch = 0

    if resume_from is not None and os.path.exists(resume_from):
        print(f"Loading checkpoint from {resume_from}")
        checkpoint = torch.load(resume_from, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1  # Start from the next epoch
        if 'best_val_f1' in checkpoint:
            best_val_f1 = checkpoint['best_val_f1']
        if 'train_metrics_history' in checkpoint:
            train_metrics_history = checkpoint['train_metrics_history']
        if 'val_metrics_history' in checkpoint:
            val_metrics_history = checkpoint['val_metrics_history']
        print(f"Resuming from epoch {start_epoch} with validation F1: {best_val_f1:.4f}")
    
    # For tracking GPU memory
    if device.type == 'cuda':
        start_gpu_memory = torch.cuda.memory_allocated(device) / (1024 ** 2)  # MB
        print(f"Initial GPU memory usage: {start_gpu_memory:.2f} MB")
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=2, verbose=True
    )
    
    # Use scaler for mixed precision training if using GPU
    scaler = torch.amp.GradScaler('cuda')
    
    for epoch in range(start_epoch, num_epochs):
        epoch_start_time = time.time()
        
        # Training phase
        model.train()
        train_loss = 0.0
        train_preds = []
        train_labels = []
        
        # Progress bar for training
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        
        # Running metrics for progress bar
        running_train_metrics = {
            'loss': 0.0,
            'acc': 0.0,
            'prec': 0.0,
            'rec': 0.0,
            'spec': 0.0,
            'f1': 0.0
        }
        samples_seen = 0
        
        for sequences, seq_lengths, labels in progress_bar:
            batch_size = labels.size(0)
            samples_seen += batch_size
            
            # Move data to device
            sequences = sequences.to(device, non_blocking=True)
            seq_lengths = seq_lengths.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            # Forward pass with mixed precision if on GPU
            optimizer.zero_grad()
            
            if scaler is not None:
                with torch.amp.autocast('cuda'):
                    logits = model(sequences, seq_lengths)
                    loss = criterion(logits, labels)
                
                # Backward pass with gradient scaling
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                # Standard training path
                logits = model(sequences, seq_lengths)
                loss = criterion(logits, labels)
                loss.backward()
                optimizer.step()
            
            # Store predictions and true labels
            train_loss += loss.item()
            with torch.no_grad():
                preds = (torch.sigmoid(logits) >= 0.5).float()
            
            # Convert to CPU numpy for metric calculation
            batch_preds = preds.cpu().numpy()
            batch_labels = labels.cpu().numpy()
            train_preds.extend(batch_preds)
            train_labels.extend(batch_labels)
            
            # Calculate batch metrics for progress bar update
            batch_metrics = calculate_metrics(batch_labels, batch_preds)
            
            # Update running metrics for progress bar (weighted average based on batch size)
            running_train_metrics['loss'] = (running_train_metrics['loss'] * (samples_seen - batch_size) + loss.item() * batch_size) / samples_seen
            running_train_metrics['acc'] = (running_train_metrics['acc'] * (samples_seen - batch_size) + batch_metrics['accuracy'] * batch_size) / samples_seen
            running_train_metrics['prec'] = (running_train_metrics['prec'] * (samples_seen - batch_size) + batch_metrics['precision'] * batch_size) / samples_seen
            running_train_metrics['rec'] = (running_train_metrics['rec'] * (samples_seen - batch_size) + batch_metrics['recall'] * batch_size) / samples_seen
            running_train_metrics['spec'] = (running_train_metrics['spec'] * (samples_seen - batch_size) + batch_metrics['specificity'] * batch_size) / samples_seen
            running_train_metrics['f1'] = (running_train_metrics['f1'] * (samples_seen - batch_size) + batch_metrics['f1_score'] * batch_size) / samples_seen
            
            # Update progress bar with all metrics
            progress_bar.set_postfix({
                'loss': f"{running_train_metrics['loss']:.4f}",
                'acc': f"{running_train_metrics['acc']:.4f}",
                'prec': f"{running_train_metrics['prec']:.4f}", 
                'rec': f"{running_train_metrics['rec']:.4f}",
                'spec': f"{running_train_metrics['spec']:.4f}",
                'f1': f"{running_train_metrics['f1']:.4f}"
            })
        
        # Calculate final training metrics on all data
        train_metrics = calculate_metrics(train_labels, train_preds)
        train_metrics['loss'] = train_loss / len(train_loader)
        train_metrics_history.append(train_metrics)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_preds = []
        val_labels = []
        
        # Progress bar for validation
        progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
        
        # Running metrics for validation progress bar
        running_val_metrics = {
            'loss': 0.0,
            'acc': 0.0,
            'prec': 0.0,
            'rec': 0.0,
            'spec': 0.0,
            'f1': 0.0
        }
        samples_seen = 0
        
        with torch.no_grad():
            for sequences, seq_lengths, labels in progress_bar:
                batch_size = labels.size(0)
                samples_seen += batch_size
                
                # Move data to device
                sequences = sequences.to(device, non_blocking=True)
                seq_lengths = seq_lengths.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                
                # Forward pass
                logits = model(sequences, seq_lengths)
                loss = criterion(logits, labels)
                
                # Store predictions and true labels
                val_loss += loss.item()
                preds = (torch.sigmoid(logits) >= 0.5).float()
                
                # Convert to CPU numpy for metric calculation
                batch_preds = preds.cpu().numpy()
                batch_labels = labels.cpu().numpy()
                val_preds.extend(batch_preds)
                val_labels.extend(batch_labels)
                
                # Calculate batch metrics for progress bar update
                batch_metrics = calculate_metrics(batch_labels, batch_preds)
                
                # Update running metrics for progress bar (weighted average based on batch size)
                running_val_metrics['loss'] = (running_val_metrics['loss'] * (samples_seen - batch_size) + loss.item() * batch_size) / samples_seen
                running_val_metrics['acc'] = (running_val_metrics['acc'] * (samples_seen - batch_size) + batch_metrics['accuracy'] * batch_size) / samples_seen
                running_val_metrics['prec'] = (running_val_metrics['prec'] * (samples_seen - batch_size) + batch_metrics['precision'] * batch_size) / samples_seen
                running_val_metrics['rec'] = (running_val_metrics['rec'] * (samples_seen - batch_size) + batch_metrics['recall'] * batch_size) / samples_seen
                running_val_metrics['spec'] = (running_val_metrics['spec'] * (samples_seen - batch_size) + batch_metrics['specificity'] * batch_size) / samples_seen
                running_val_metrics['f1'] = (running_val_metrics['f1'] * (samples_seen - batch_size) + batch_metrics['f1_score'] * batch_size) / samples_seen
                
                # Update progress bar with all metrics
                progress_bar.set_postfix({
                    'loss': f"{running_val_metrics['loss']:.4f}",
                    'acc': f"{running_val_metrics['acc']:.4f}",
                    'prec': f"{running_val_metrics['prec']:.4f}", 
                    'rec': f"{running_val_metrics['rec']:.4f}",
                    'spec': f"{running_val_metrics['spec']:.4f}",
                    'f1': f"{running_val_metrics['f1']:.4f}"
                })
        
        # Calculate validation metrics
        val_metrics = calculate_metrics(val_labels, val_preds)
        val_metrics['loss'] = val_loss / len(val_loader)
        val_metrics_history.append(val_metrics)

        # Save best model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_metrics': val_metrics,
        }, 'model.pt')
        print(f"Saved model at epoch: {epoch}")
    
        # Update learning rate based on validation F1 score
        scheduler.step(val_metrics['f1_score'])
        
        # Check for early stopping
        if val_metrics['f1_score'] > best_val_f1:
            best_val_f1 = val_metrics['f1_score']
            best_val_metrics = val_metrics
            epochs_without_improvement = 0
            # Save best model
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_metrics': val_metrics,
            }, 'best_line_classifier.pt')
            print(f"Saved best model with validation F1: {best_val_f1:.4f}")
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break
        
        # Calculate epoch time
        epoch_time = time.time() - epoch_start_time
        
        # Print progress
        print(f"Epoch {epoch+1}/{num_epochs} - Time: {epoch_time:.2f}s")
        print(f"Train Loss: {train_metrics['loss']:.4f}, "
              f"Accuracy: {train_metrics['accuracy']:.4f}, "
              f"Precision: {train_metrics['precision']:.4f}, "
              f"Recall: {train_metrics['recall']:.4f}, "
              f"Specificity: {train_metrics['specificity']:.4f}, "
              f"F1: {train_metrics['f1_score']:.4f}")
        print(f"Val Loss: {val_metrics['loss']:.4f}, "
              f"Accuracy: {val_metrics['accuracy']:.4f}, "
              f"Precision: {val_metrics['precision']:.4f}, "
              f"Recall: {val_metrics['recall']:.4f}, "
              f"Specificity: {val_metrics['specificity']:.4f}, "
              f"F1: {val_metrics['f1_score']:.4f}")
        
        # Track GPU memory usage
        if device.type == 'cuda':
            current_gpu_memory = torch.cuda.memory_allocated(device) / (1024 ** 2)  # MB
            max_gpu_memory = torch.cuda.max_memory_allocated(device) / (1024 ** 2)  # MB
            print(f"GPU Memory: Current {current_gpu_memory:.2f} MB, Max {max_gpu_memory:.2f} MB")
            # Reset peak memory stats for next epoch
            torch.cuda.reset_peak_memory_stats(device)
    
    print("Training complete!")
    print("Best validation metrics:")
    for metric, value in best_val_metrics.items():
        print(f"{metric}: {value:.4f}")
    
    return train_metrics_history, val_metrics_history


def main():
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Check CUDA availability and set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    if device.type == 'cuda':
        # Print GPU info
        print(f"GPU Device: {torch.cuda.get_device_name(0)}")
        print(f"CUDA Version: {torch.version.cuda}")
        # Set CUDA optimizations
        torch.backends.cudnn.benchmark = True
        print(f"CuDNN Enabled: {torch.backends.cudnn.enabled}")
        print(f"CuDNN Version: {torch.backends.cudnn.version()}")
        # Set initial GPU memory
        torch.cuda.empty_cache()
        print(f"Initial GPU Memory: {torch.cuda.memory_allocated(0)/(1024**2):.2f} MB")
    
    # Create datasets
    train_dataset = LineSequenceDataset("/kaggle/input/data1002/train.txt")
    
    val_dataset = LineSequenceDataset("/kaggle/input/data1002/val.txt")
    
    # test_dataset = LineSequenceDataset()
    
    # Calculate number of workers based on CPU cores
    num_workers = min(4, os.cpu_count() or 0)
    print(f"Using {num_workers} worker threads for data loading")
    
    # Create data loaders with pinned memory for faster GPU transfer
    train_loader = DataLoader(
        train_dataset, 
        batch_size=32, 
        shuffle=True, 
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=device.type == 'cuda'
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=32, 
        shuffle=False, 
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=device.type == 'cuda'
    )
    
    # test_loader = DataLoader(
    #     test_dataset, 
    #     batch_size=32, 
    #     shuffle=False, 
    #     collate_fn=collate_fn,
    #     num_workers=num_workers,
    #     pin_memory=device.type == 'cuda'
    # )
    
    # Initialize model and move to device
    model = LineSequenceClassifier(
        line_dim=4,
        d_model=128,
        nhead=8,
        num_encoder_layers=4,
        dim_feedforward=512,
        dropout=0.1
    ).to(device)
    
    # Print model summary
    num_params = sum(p.numel() for p in model.parameters())
    print(f"Model has {num_params:,} parameters")
    
    # Define loss function and optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
    
    # Train model
    print("Starting training...")
    train_metrics_history, val_metrics_history = train(
        model, train_loader, val_loader, criterion, optimizer, device, num_epochs=2, patience=10, resume_from="/kaggle/input/model107/pytorch/default/1/best_line_classifier.pt"
    )
    
    # # Load best model
    # checkpoint = torch.load('best_line_classifier.pt', map_location=device)
    # model.load_state_dict(checkpoint['model_state_dict'])
    # best_epoch = checkpoint['epoch']
    # print(f"Loaded best model from epoch {best_epoch+1}")
    
    # # Evaluate on test set
    # model.eval()
    # test_preds = []
    # test_labels = []
    # test_loss = 0.0
    
    # progress_bar = tqdm(test_loader, desc="Evaluating on test set")
    
    # with torch.no_grad():
    #     for sequences, seq_lengths, labels in progress_bar:
    #         # Move data to device
    #         sequences = sequences.to(device, non_blocking=True)
    #         seq_lengths = seq_lengths.to(device, non_blocking=True)
    #         labels = labels.to(device, non_blocking=True)
            
    #         # Forward pass
    #         logits = model(sequences, seq_lengths)
    #         loss = criterion(logits, labels)
    #         test_loss += loss.item()
            
    #         # Store predictions and true labels
    #         preds = (torch.sigmoid(logits) >= 0.5).float()
    #         test_preds.extend(preds.cpu().numpy())
    #         test_labels.extend(labels.cpu().numpy())
            
    #         # Update progress bar
    #         progress_bar.set_postfix({'loss': f"{loss.item():.4f}"})
    
    # # Calculate and print test metrics
    # test_metrics = calculate_metrics(test_labels, test_preds)
    # test_metrics['loss'] = test_loss / len(test_loader)
    # print("\nTest metrics:")
    # for metric, value in test_metrics.items():
    #     print(f"{metric}: {value:.4f}")
    
    
    # Clean up
    if device.type == 'cuda':
        torch.cuda.empty_cache()


if __name__ == "__main__":
    import os
    main()

Using device: cuda
GPU Device: Tesla P100-PCIE-16GB
CUDA Version: 12.1
CuDNN Enabled: True
CuDNN Version: 90100
Initial GPU Memory: 41.65 MB


  checkpoint = torch.load(resume_from, map_location=device)


Using 4 worker threads for data loading
Model has 819,809 parameters
Starting training...
Loading checkpoint from /kaggle/input/model107/pytorch/default/1/best_line_classifier.pt
Resuming from epoch 1 with validation F1: 0.0000
Initial GPU memory usage: 41.65 MB


Epoch 2/2 [Train]:   0%|          | 189/136290 [00:31<5:59:10,  6.32it/s, loss=0.4061, acc=0.8191, prec=0.8108, rec=0.7910, spec=0.8415, f1=0.7948]