In [16]:
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
from pathlib import Path
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import ruamel.yaml as yaml

from soul_gan.utils.general_utils import ROOT_DIR, DotConfig


In [18]:
sns.set_theme()

In [19]:
import matplotlib.pyplot as plt

SMALL_SIZE = 14
MEDIUM_SIZE = 16
BIGGER_SIZE = 18

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('lines', linewidth=3)
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [25]:
def plot_res(log_path, config, arange):
    is_values = np.loadtxt(Path(log_path, 'is_values.txt'))[:, 0]
    fid_values = np.loadtxt(Path(log_path, 'fid_values.txt'))
    callback_results = np.loadtxt(Path(log_path, 'callback_results.txt'))
    energy_results = callback_results[0]
    dgz_results = callback_results[1]

    fig = plt.figure()
    plt.plot(arange, is_values)
    plt.xlabel('Iteration')
    plt.ylabel('IS')
    plt.title('Inception Score')
    fig.tight_layout()
    plt.savefig(Path(log_path, 'figs', f'{log_path.stem}_is.png'))
    plt.savefig(Path(log_path, 'figs', f'{log_path.stem}_is.pdf'))
    plt.close()

    fig = plt.figure()
    plt.plot(arange, fid_values)
    plt.xlabel('Iteration')
    plt.ylabel('FID')
    plt.title('FID Score')
    fig.tight_layout()
    plt.savefig(Path(log_path, 'figs', f'{log_path.stem}_fid.png'))
    plt.savefig(Path(log_path, 'figs', f'{log_path.stem}_fid.pdf'))
    plt.close()

    fig = plt.figure()
    plt.plot(arange, energy_results)
    plt.xlabel('Iteration')
    plt.ylabel(r'$U(z)$')
    plt.title('Energy')
    fig.tight_layout()
    plt.savefig(Path(log_path, 'figs', f'{log_path.stem}_energy.png'))
    plt.savefig(Path(log_path, 'figs', f'{log_path.stem}_energy.pdf'))
    plt.close()

    fig = plt.figure()
    plt.plot(arange, dgz_results)
    plt.xlabel('Iteration')
    plt.ylabel(r'$d(G(z))$')
    plt.axhline(config.thermalize[False]['real_score'], linestyle='--', label='avg real score', color='r')
    plt.title('Discriminator scores')
    plt.legend()
    fig.tight_layout()
    plt.savefig(Path(log_path, 'figs', f'{log_path.stem}_dgz.png'))
    plt.savefig(Path(log_path, 'figs', f'{log_path.stem}_dgz.pdf'))
    plt.close()

    if Path(log_path, 'weight_norm.txt').exists():
        weight_norms = np.loadtxt(Path(log_path, 'weight_norm.txt'))
        mean = weight_norms.mean(0)
        std = weight_norms.std(0)
        fig = plt.figure()
        plt.plot(arange, mean)
        plt.fill_between(arange, mean - 1.96 * std, mean + 1.96 * std, alpha=0.3, label='95% CI')
        for weight_norm in weight_norms[:5]:
            plt.plot(arange, weight_norm, alpha=0.3)
        plt.xlabel('Iteration')
        plt.ylabel(r'$\|\theta\|_2$')
        #plt.axhline(config.thermalize[False]['real_score'], linestyle='--', label='avg real score', color='r')
        plt.title('Weight norm')
        plt.legend()
        fig.tight_layout()
        plt.savefig(Path(log_path, 'figs', f'{log_path.stem}_weight.png'))
        plt.savefig(Path(log_path, 'figs', f'{log_path.stem}_weight.pdf'))
        plt.close()

    if Path(log_path, 'out.txt').exists():
        outs = np.loadtxt(Path(log_path, 'out.txt'))
        mean = outs.mean(0)
        std = outs.std(0)
        fig = plt.figure()
        plt.plot(arange, mean)
        plt.fill_between(arange, mean - 1.96 * std, mean + 1.96 * std, alpha=0.3, label='95% CI')
        for out in outs[:5]:
            plt.plot(arange, out, alpha=0.3)
        
        # if Path(log_path, 'ref_score.txt').exists():





        plt.xlabel('Iteration')
        plt.ylabel(r'$F(x)$')
        #plt.axhline(config.thermalize[False]['real_score'], linestyle='--', label='avg real score', color='r')
        plt.title('F(x)')
        plt.legend()
        fig.tight_layout()
        plt.savefig(Path(log_path, 'figs', f'{log_path.stem}_out.png'))
        plt.savefig(Path(log_path, 'figs', f'{log_path.stem}_out.pdf'))
        plt.close()

In [26]:
# feature = 'dumb'
# feature = 'discriminator'
feature = 'cluster'

logdir = Path(ROOT_DIR, 'log', f'{feature}_feature')

In [27]:
logs = list(logdir.glob('*'))

for gan_logpath in logs:
    if 'test' in gan_logpath.stem:
        continue
    configs = list(gan_logpath.glob('*.yml'))
    config_path = configs[0] if configs[1].stem == 'gan_config' else configs[1]
    gan_config_path = configs[1] if configs[1].stem == 'gan_config' else configs[0]
    print(config_path.stem)


    Path(gan_logpath, 'figs').mkdir(exist_ok=True)
    #config_path = Path(gan_logpath, 'dcgan-dumb.yml')
    config = DotConfig(yaml.round_trip_load(config_path.open('r')))
    n_steps = config.n_steps
    every = config.every

    gan_config = DotConfig(yaml.round_trip_load(gan_config_path.open('r'))['gan_config'])

    arange = np.arange(0, n_steps + 1, every)
    try:
        plot_res(gan_logpath, gan_config, arange)
    except:
        print('fail')


dcgan-cluster
