In [None]:
import re
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime

def read_log_file(file_path):
    with open(file_path, 'r') as file:
        return file.read()

def split_into_sections(content):
    # Split at each occurrence of "Batch[*]dataset[*]", but keep the delimiter
    pattern = r'(?=Batch\[\d+\]dataset\[\d+\])'
    sections = re.split(pattern, content)
    
    # Filter out empty sections and strip whitespace
    sections = [section.strip() for section in sections if section.strip()]
    
    # If the first section doesn't start with "Batch", it's a header section
    if sections and not sections[0].startswith('Batch'):
        header = sections[0]
        sections = sections[1:]  # Remove the header from sections
    else:
        header = None
    
    return sections


def extract_epoch_data(section):
    epoch_pattern = r'Epoch (\d+).*, loss: ([\d.]+), lr: ([\d.e-]+), Post active train Correct (\d+), incorrect (\d+), loss: ([\d.]+)'
    matches = re.findall(epoch_pattern, section)
    
    data = []
    for match in matches:
        data.append({
            'epoch': int(match[0]),
            'training_loss': float(match[1]),
            'learning_rate': float(match[2]),
            'correct': int(match[3]),
            'incorrect': int(match[4]),
            'post_train_loss': float(match[5])
        })
    return pd.DataFrame(data)

def extract_rank_data(section):
    """
    Extract rank, vote weight, and diff information from a log section.
    
    Args:
        section (str): Log section text containing rank information
        
    Returns:
        pd.DataFrame: DataFrame with columns ['rank', 'vote_weight', 'diff_positions', 'diff_values']
    """
    # Regular expression pattern to match rank information
    rank_pattern = r'Rank: (\d+), Vote weight: ([\d.]+),.* Diff: \[(.*?)\]'
    
    # Find all matches in the section
    matches = re.findall(rank_pattern, section)

    data = []
    for match in matches:
        rank = int(match[0])
        vote_weight = float(match[1])
        
        # Parse the diff information
        diff_str = match[2]
        diff_positions = []
        diff_values = []
        
        if diff_str:
            # Split the diff string into individual tuples
            diff_items = re.findall(r'\((\d+), (\d+).*\)', diff_str)
            for pos, val in diff_items:
                diff_positions.append(int(pos))
                diff_values.append(int(val))
        
        data.append({
            'rank': rank,
            'vote_weight': vote_weight,
            'diff_positions': diff_positions,
            'diff_values': diff_values
        })
    
    # Create DataFrame
    rank_df = pd.DataFrame(data)
    # display(rank_df)
    return rank_df

def analyze_section(section, section_num):
    if 'skip adaptive training' in section:
        return {
            'correct_and_skip': True
        }
    if 'OutOfMemoryError!' in section:
        return {
            'oom': True
        }        
    # Extract initial batch information
    batch_info = re.search(r'BATCH NO\.(\d+), shape: \[([^\]]+)\]', section)
    if batch_info:
        batch_num = batch_info.group(1)
        batch_shape = batch_info.group(2)
    else:
        batch_num = "Unknown"
        batch_shape = "Unknown"
    
    # Extract epoch data
    epoch_df = extract_epoch_data(section)
    rank_df = extract_rank_data(section)
    
    if epoch_df.empty:
        return None
    
    # Calculate accuracy
    epoch_df['accuracy'] = epoch_df['correct'] / (epoch_df['correct'] + epoch_df['incorrect'])
    
    return {
        'section_num': section_num,
        'batch_num': batch_num,
        'batch_shape': batch_shape,
        'epoch_data': epoch_df,
        'rank_data': rank_df
    }

def plot_section_metrics(section_data):
    if section_data is None or section_data['epoch_data'].empty:
        return
    
    df = section_data['epoch_data']
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 12))
    
    # Plot training loss and post-train loss
    ax1.plot(df['epoch'], df['training_loss'], label='Training Loss')
    ax1.plot(df['epoch'], df['post_train_loss'], label='Post-train Loss')
    ax1.set_title(f"Section {section_data['section_num']} (Batch {section_data['batch_num']}) - Loss")
    ax1.legend()
    ax1.grid(True)
    
    # Plot accuracy
    ax2.plot(df['epoch'], df['accuracy'])
    ax2.set_title('Accuracy')
    ax2.grid(True)
    
    # Plot learning rate
    ax3.plot(df['epoch'], df['learning_rate'])
    ax3.set_title('Learning Rate')
    ax3.grid(True)
    
    plt.tight_layout()
    plt.show()

def analyze_content(content):
    # Main execution
    sections = split_into_sections(content)
    
    print(f"Found {len(sections)} training sections")
    
    # Analyze each section
    
    has_right_answer_count = 0
    has_right_answer_exist_in_top_2_ranks = 0
    correct_and_skip_count = 0
    oom_count = 0
    
    for i, section in enumerate(sections):
        section_data = analyze_section(section, i+1)
        if section_data:
            if 'correct_and_skip' in section_data:
                correct_and_skip_count += 1
            if 'oom' in section_data:
                oom_count += 1
            elif 'rank_data' in section_data:
                rank_df = section_data['rank_data']
                if 'diff_values' in rank_df:
                    right_answer_exist = (rank_df['diff_values'].apply(len) == 0).any()
                    if right_answer_exist:
                        has_right_answer_count += 1
    
                    right_answer_exist_in_top_2_ranks = (rank_df[rank_df['rank']<=1]['diff_values'].apply(len) == 0).any()
                    if right_answer_exist_in_top_2_ranks:
                        has_right_answer_exist_in_top_2_ranks += 1
    
    print('has_right_answer_count', has_right_answer_count, 'has_right_answer_exist_in_top_2_ranks', has_right_answer_exist_in_top_2_ranks, has_right_answer_exist_in_top_2_ranks/len(sections)*100, 'correct_and_skip_count', correct_and_skip_count, correct_and_skip_count/len(sections)*100, 'oom_count', oom_count, 'over', len(sections))

In [None]:
file_path = "../meta_training_2500_c15.log" # "../meta_training_c15.log" # 
analyze_content(read_log_file(file_path))

file_path = "../meta_training_2500_c29.log" # 2500
analyze_content(read_log_file(file_path))

file_path = "../meta_training_2500_c30.log" # 2500
analyze_content(read_log_file(file_path))

In [None]:
file_path = "../meta_training_c15_b15.log"
analyze_content(read_log_file(file_path))

file_path = "../meta_training_c15_b8_complete.log"
analyze_content(read_log_file(file_path))

file_path = "../meta_training_c15_b8_complete[:234].log"
analyze_content(read_log_file(file_path))

In [None]:
file_path = "../meta_training_fv_c29.log"
analyze_content(read_log_file(file_path))

In [None]:
file_path = "../meta_training_fv_c29.log"
analyze_content(read_log_file(file_path))

In [None]:
file_path = "../meta_training_2500_c30[:156].log"
analyze_content(read_log_file(file_path))

file_path = "../meta_training_bark[:156].log"
analyze_content(read_log_file(file_path))

In [None]:
# file_path = "../meta_training_2500_c30[:86].log"
# analyze_content(read_log_file(file_path))

file_path = "../meta_training_reverse_aug.log"
analyze_content(read_log_file(file_path))