In [1]:
from utils.metrics import plot_cost, l2_loss, l1_loss


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [9]:
def extract_loss_values(filepath):
    """
    Read a file containing training and validation loss values and extract them into separate lists.
    
    Parameters:
    filepath (str): Path to the file containing loss values
    
    Returns:
    tuple: (train_loss_values, valid_loss_values) as lists of floats
    """
    train_loss_values = []
    valid_loss_values = []
    
    try:
        with open(filepath, 'r') as file:
            for line in file:
                # Skip empty lines
                if not line.strip():
                    continue
                    
                # Look for lines containing loss values
                # This assumes a format like "Epoch X: train_loss=0.123, valid_loss=0.456"
                # Adjust the parsing logic based on your actual file format
                if "Train Loss" in line and "Valid Loss" in line:
                    # Extract the train loss
                    train_part = line.split("Train Loss:")[1].split(",")[0]
                    train_loss = float(train_part.strip())
                    train_loss_values.append(train_loss)
                    
                    # Extract the valid loss
                    valid_part = line.split("Valid Loss:")[1].split(",")[0]
                    valid_loss = float(valid_part.strip())
                    valid_loss_values.append(valid_loss)
        
        print(f"Extracted {len(train_loss_values)} training loss values and {len(valid_loss_values)} validation loss values")
        return train_loss_values, valid_loss_values
    
    except FileNotFoundError:
        print(f"Error: File {filepath} not found")
        return [], []
    except Exception as e:
        print(f"Error parsing file: {e}")
        return [], []

In [11]:
# load in the data required
data = 'model_' + 'support' + '.log'
train_loss, valid_loss = extract_loss_values(data)


Extracted 400 training loss values and 400 validation loss values


In [18]:
# call the plot functions
plot_cost(training=train_loss, validation=valid_loss, model="SUPPORT", name="loss",
                  epochs=400, best_epoch=171)

In [19]:
# load in the data required
data = 'model_' + 'flchain' + '.log'
train_loss, valid_loss = extract_loss_values(data)

Extracted 600 training loss values and 600 validation loss values


In [20]:
# call the plot functions
plot_cost(training=train_loss, validation=valid_loss, model="Flchain", name="loss",
                  epochs=600, best_epoch=432)