In [1]:
pip install pandas

Note: you may need to restart the kernel to use updated packages.


In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

In [3]:
CONFIG = {
    'input_length': 5,
    'input_dim': 1,
    'num_classes': 10,
    'num_hidden': 128,
    'batch_size': 128,
    'learning_rate': 0.001,
    'max_epoch': 200,
    'max_norm': 10,
    'data_size': 100000,
    'portion_train': 0.8,
    'use_scheduler': False,
}

In [4]:
import matplotlib.pyplot as plt
import pandas as pd
import os

def plot_training_curve(csv_path, file_name):
    data = pd.read_csv(csv_path)
    
    epochs = data['epoch']
    train_loss = data['train_loss']
    train_acc = data['train_acc']
    val_loss = data['val_loss']
    val_acc = data['val_acc']
    
    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss, label='Train Loss')
    plt.plot(epochs, val_loss, label='Val Loss')
    plt.title(f'{file_name} Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_acc, label='Train Accuracy')
    plt.plot(epochs, val_acc, label='Val Accuracy')
    plt.title(f'{file_name} Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()

    output_dir = '../img/Part1/'
    os.makedirs(output_dir, exist_ok=True)  
    path = os.path.join(output_dir, file_name + '.png')
    plt.savefig(path)
    
    if os.path.exists(path):
        print(f"Plot successfully saved to {path}")
    else:
        print(f"Failed to save plot to {path}")
    
    plt.show()
    plt.close()


## Train data

In [5]:
import train
csv_file = 'result/t15.csv'
train.main(config=CONFIG, csv_file= csv_file)

ModuleNotFoundError: No module named 'torch'

In [None]:
csv_path = "result/t15.csv"
file_name = 't15'
plot_training_curve(csv_path, file_name)

In [None]:
from IPython.display import display, Image

print("T=5")
display(Image("../img/Part1/t5.png"))

print("T=10")
display(Image("../img/Part1/t10.png"))

print("T=20")
display(Image("../img/Part1/t20.png"))

print("T=30")
display(Image("../img/Part1/t30.png"))


## Accuracy vs. length

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import os

# Define the file paths for the CSV files located in the 'result' folder
csv_files = ['./result/t4.csv', './result/t5.csv', './result/t10.csv', './result/t20.csv']

# Initialize an empty list to store the maximum accuracies
max_accuracies = []
text_lengths = []  # List to store the corresponding text lengths

# Loop over each CSV file, read it, and extract the maximum accuracy
for csv_file in csv_files:
    # Extract text length from the filename using the correct method
    filename = os.path.basename(csv_file)  # Get the file name (e.g., 't4.csv')
    length = int(filename.split('t')[1].split('.csv')[0])  # Extract number after 't' and before '.csv'
    text_lengths.append(length)  # Append the length to the list

    # Check if the file exists before reading
    if os.path.exists(csv_file):
        # Read the CSV file
        data = pd.read_csv(csv_file)

        # Assuming the column name for accuracy is 'train_acc', change if necessary
        if 'train_acc' in data.columns:
            max_acc = data['train_acc'].mean()  # Find maximum accuracy in the CSV
            max_accuracies.append(max_acc)  # Append the result to the list
        else:
            print(f"Warning: 'train_acc' column not found in {csv_file}")
    else:
        print(f"Warning: {csv_file} not found")

# Debugging: Print the lengths of both lists
print(f"text_lengths: {text_lengths}")
print(f"max_accuracies: {max_accuracies}")

# Check if the lengths of both lists match
if len(text_lengths) != len(max_accuracies):
    print(f"Error: The number of text lengths ({len(text_lengths)}) does not match the number of max accuracies ({len(max_accuracies)})")

# Plotting the max accuracy vs. text length
if len(text_lengths) == len(max_accuracies):  # Proceed only if the lengths match
    plt.figure(figsize=(8, 6))
    plt.plot(text_lengths, max_accuracies, marker='o', linestyle='-', color='b')

    # Labeling the axes and the plot
    plt.xlabel('Text Length (T)', fontsize=12)
    plt.ylabel('Max Training Accuracy', fontsize=12)
    plt.title('Max Accuracy vs Text Length', fontsize=14)
    plt.grid(True)
    plt.ylim(bottom=0.4)
    plt.xticks(text_lengths)  # Set x-ticks to correspond to the text lengths
    plt.show()
else:
    print("Error: Data mismatch, unable to plot.")
