In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
import utils

import torch.nn as nn
from torch.autograd.variable import Variable

import model.vgg as vgg
from data.data_loader import data_loader, get_deri_loader, get_var_loader
from model.linear import Linear
from utils.essen_plot import plot_contour_trajectory, plot_neuron
from utils.get_loss import eval_loss, get_loss
from utils.para_feature import *
from utils.seed import seed_torch

In [7]:
def save_fig(pltm, fntmp, pdf=True, x_log=True, y_log=True):
    """
    The function takes in a matplotlib object, a filename, and two boolean values. If the boolean values
    are true, the x and y axes are set to log scale. The function then saves the figure as a png and pdf
    file

    :param pltm: the plot object
    :param fntmp: The name of the file to save the figure to
    :param pdf: if True, save the figure as a pdf file, defaults to True (optional)
    :param x_log: If True, the x-axis will be logarithmic, defaults to True (optional)
    :param y_log: If True, the y-axis will be logarithmic, defaults to True (optional)
    """

    if x_log:
        plt.xscale('log')

    if y_log:
        plt.yscale('log')

    pltm.tight_layout()
    pltm.show()
    pltm.savefig('%s.png' % (fntmp))
    if pdf:
        pltm.savefig("%s.pdf" % (fntmp))

In [11]:
out_lst=[]

for load_dir in [r'/root/zhangzhongwang/data/book/100/0.5/20240307090130730317', r'/root/zhangzhongwang/data/book/100_100_100/0.5/20240307090159903066', r'/root/zhangzhongwang/data/book/100_100_100_100_100/0.5/20240307090228418033', r'/root/zhangzhongwang/data/book/100_100_100_100_100_100_100/0.5/20240307090335319552']:

    result_dict = torch.load(r'%s/result.pth.tar' % (load_dir))
    args = torch.load(r'%s/model/tmp0.pth.tar' % (load_dir))['args']

    para=torch.load(r'%s/model/tmp100000.pth.tar' % (load_dir))['state_dict'][0]
    output = get_network_output(args,para)
    out_lst.append(output)

# plot_neuron(save_dir, output, args)

path=load_dir
plt.figure()
ax = plt.gca()

plt.scatter(args.train_inputs.detach().cpu().numpy(),
            args.train_targets.detach().cpu().numpy(), c='black', label='True')
for i in range(len(out_lst)):
    plt.plot(args.test_inputs.detach().cpu().numpy(),
            out_lst[i].detach().cpu().numpy(), label='Test %s' % (i))
# plt.plot(args.test_inputs.detach().cpu().numpy(),
#             output.detach().cpu().numpy(), 'r-', label='Test')
plt.title('output epoch=%s' % (epoch), fontsize=15)
plt.legend(fontsize=18)
fntmp = os.path.join(path, 'output', str(epoch))
# fntmp = '%soutput/%s' % (path, epoch)
save_fig(plt, fntmp, pdf=False, x_log=False, y_log=False)

Linear(
  (features): Sequential(
    (0): Linear(in_features=1, out_features=100, bias=True)
    (1): Tanh()
    (2): Linear(in_features=100, out_features=1, bias=False)
  )
)
Linear(
  (features): Sequential(
    (0): Linear(in_features=1, out_features=100, bias=True)
    (1): Tanh()
    (2): Linear(in_features=100, out_features=100, bias=True)
    (3): Tanh()
    (4): Linear(in_features=100, out_features=100, bias=True)
    (5): Tanh()
    (6): Linear(in_features=100, out_features=1, bias=False)
  )
)
Linear(
  (features): Sequential(
    (0): Linear(in_features=1, out_features=100, bias=True)
    (1): Tanh()
    (2): Linear(in_features=100, out_features=100, bias=True)
    (3): Tanh()
    (4): Linear(in_features=100, out_features=100, bias=True)
    (5): Tanh()
    (6): Linear(in_features=100, out_features=100, bias=True)
    (7): Tanh()
    (8): Linear(in_features=100, out_features=100, bias=True)
    (9): Tanh()
    (10): Linear(in_features=100, out_features=1, bias=False)
  )
)


In [17]:
import matplotlib.pyplot as plt
import numpy as np


def format_settings(
        wspace=0.25, 
        hspace=0.4, 
        left=0.12, 
        right=0.9, 
        bottom=0.15, 
        top=0.95,
        fs=12,
        show_dpi=80,
        save_dpi=300,
        lw=1.5,
        ms=5,
        axlw=1.5,
        major_tick_len=5,
        major_tick_width=1.5,
        major_tick_pad=5,
        minor_tick_len=0,
        minor_tick_width=0,
        minor_tick_pad=5,
        ):
    '''
        使用方法：
            fig = plt.figure(figsize=(12, 4), dpi=300)
            format_settings()
            grid = plt.GridSpec(2, 2)
            ax1 = fig.add_subplot(grid[0, 0]) # 左上角图
            ax2 = fig.add_subplot(grid[0, 1]) # 右上角图
            ax3 = fig.add_subplot(grid[:, 0]) # 底部空间合并一张图
        注意：
            以上文字和坐标轴粗细适用于figsize长度为12的情形，宽度可调。
            若要调整figsize长度，需要相应调整以上文字和坐标轴粗细。
    '''
    # 设置子图线宽
    plt.rcParams['lines.linewidth'] = lw
    
    # 子图点大小
    plt.rcParams['lines.markersize'] = ms
    
    # 子图间距与位置  w:左右 h:上下
    plt.subplots_adjust(wspace=wspace, hspace=hspace, left=left, right=right, bottom=bottom, top=top)

    # 字体大小
    plt.rcParams['font.size'] = fs
    plt.rcParams['axes.labelsize'] = fs
    plt.rcParams['axes.titlesize'] = fs
    plt.rcParams['xtick.labelsize'] =fs
    plt.rcParams['ytick.labelsize'] = fs
    plt.rcParams['legend.fontsize'] = fs
    # 子图坐标轴宽度
    plt.rcParams['axes.linewidth'] = axlw
    # 子图坐标轴可见性
    plt.rcParams['axes.spines.top'] = True
    plt.rcParams['axes.spines.right'] = True
    plt.rcParams['axes.spines.left'] = True
    plt.rcParams['axes.spines.bottom'] = True

    # 子图坐标轴刻度宽度
    plt.rcParams['xtick.major.width'] = major_tick_width
    plt.rcParams['ytick.major.width'] = major_tick_width
    plt.rcParams['xtick.minor.width'] = minor_tick_width
    plt.rcParams['ytick.minor.width'] = minor_tick_width
    # 子图坐标轴刻度长度
    plt.rcParams['xtick.major.size'] = major_tick_len
    plt.rcParams['ytick.major.size'] = major_tick_len
    plt.rcParams['xtick.minor.size'] = minor_tick_len
    plt.rcParams['ytick.minor.size'] = minor_tick_len
    # 子图坐标轴刻度标签位置
    plt.rcParams['xtick.major.pad'] = major_tick_pad
    plt.rcParams['ytick.major.pad'] = major_tick_pad
    plt.rcParams['xtick.minor.pad'] = minor_tick_pad
    plt.rcParams['ytick.minor.pad'] = minor_tick_pad
    
    # 子图坐标轴刻度标签位置
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    # 子图坐标轴刻度标签位置
    plt.rcParams['xtick.top'] = False 
    plt.rcParams['ytick.right'] = False
    # 子图坐标轴刻度标签位置
    plt.rcParams['xtick.minor.visible'] = False
    plt.rcParams['ytick.minor.visible'] = False
    # 子图坐标轴刻度标签位置
    plt.rcParams['legend.frameon'] = False
    # 子图坐标轴刻度标签位置
    plt.rcParams['figure.dpi'] = show_dpi
    # 子图坐标轴刻度标签位置
    plt.rcParams['savefig.dpi'] = save_dpi


In [26]:
path=load_dir
fig=plt.figure(figsize=(12,8))
format_settings(left=0.15, right=0.85, bottom=0.18,top=0.9,fs=32
        )
ax = plt.gca()
linestyle_lst=['-', '--', '-.', ':']

layer_lst=['2', '4', '6', '8']

plt.scatter(args.train_inputs.detach().cpu().numpy(),
            args.train_targets.detach().cpu().numpy(), c='black', label='Target', s=100)
for i in range(len(out_lst)):
    plt.plot(args.test_inputs.detach().cpu().numpy(),
            out_lst[i].detach().cpu().numpy(), label='%s layer' % (layer_lst[i]), linestyle=linestyle_lst[i], lw=3)
# plt.plot(args.test_inputs.detach().cpu().numpy(),
#             output.detach().cpu().numpy(), 'r-', label='Test')
# plt.title('output epoch=%s' % (epoch), fontsize=15)
plt.legend(fontsize=30, bbox_to_anchor=(0.11, 1), loc='upper left')
plt.savefig(r'/root/zhangzhongwang/book/1d_fitting/freq_prefer.png', dpi=300)
# fntmp = os.path.join(path, 'output', str(epoch))
# fntmp = '%soutput/%s' % (path, epoch)
# save_fig(plt, fntmp, pdf=False, x_log=False, y_log=False)