In [None]:
import torch
import torch.nn as nn
import math
import matplotlib.pyplot as plt

class SimpleNN(nn.Module):
    def __init__(self, init_type="Signed Constant", remain_rate=0.5):
        super(SimpleNN, self).__init__()
        self.fc = nn.Linear(100, 100)
        self.remain_rate = remain_rate
        self._init_weight(self.fc.weight, init_type)
        
    def _init_weight(self, weight, name="Signed Constant"):
        if name == "Signed Kaiming Constant":
            fan = nn.init._calculate_correct_fan(weight, mode="fan_in")
            gain = nn.init.calculate_gain("relu")
            std = gain / math.sqrt(fan)
            weight.data = weight.data.sign() * std

        elif name == "Scaled Signed Kaiming Constant":
            fan = nn.init._calculate_correct_fan(weight, mode="fan_in")
            fan = fan * (1 - self.remain_rate)
            gain = nn.init.calculate_gain("relu")
            std = gain / math.sqrt(fan)
            weight.data = weight.data.sign() * std

        elif name == "Kaiming Normal":
            nn.init.kaiming_normal_(weight, mode="fan_in", nonlinearity="relu")

        elif name == "Scaled Kaiming Normal":
            fan = nn.init._calculate_correct_fan(weight, mode="fan_in")
            fan = fan * self.remain_rate
            gain = nn.init.calculate_gain("relu")
            std = gain / math.sqrt(fan)
            with torch.no_grad():
                weight.data.normal_(0, std)

        elif name == "Kaiming Uniform":
            nn.init.kaiming_uniform_(weight, mode="fan_in", nonlinearity="relu")
        
        elif name == "Xavier Normal":
            nn.init.xavier_normal_(weight)

init_methods = ["Signed Kaiming Constant", "Scaled Signed Kaiming Constant", "Kaiming Normal", "Scaled Kaiming Normal", "Kaiming Uniform", "Xavier Normal"]

fig, axs = plt.subplots(2, 3, figsize=(36, 24))
fig.suptitle("Weight Initialization Methods Comparison", fontsize=42, y=1.05)

for i, method in enumerate(init_methods):
    model = SimpleNN(init_type=method)
    weights = model.fc.weight.data.cpu().numpy()
    
    row = i // 3
    col = i % 3
    
    axs[row, col].hist(weights.flatten(), bins=30, alpha=0.7)
    axs[row, col].set_title(method, fontsize=42)
    axs[row, col].set_xlabel("Weight Values", fontsize=28)
    axs[row, col].set_ylabel("Frequency", fontsize=28)
    axs[row, col].tick_params(axis='both', which='major', labelsize=28)  # 目盛りのフォントサイズを設定


plt.tight_layout(rect=[0, 0, 1.0, 0.9999])
plt.show()


In [None]:
import torch
import torch.nn as nn
import math
import matplotlib.pyplot as plt

class SimpleNN(nn.Module):
    def __init__(self, init_type="signed_constant", remain_rate=0.5):
        super(SimpleNN, self).__init__()
        self.fc = nn.Linear(100, 100)
        self.remain_rate = remain_rate
        self._init_weight(self.fc.weight, init_type)
        
    def _init_weight(self, weight, name="signed_constant"):
        if name == "signed_constant":
            fan = nn.init._calculate_correct_fan(weight, mode="fan_in")
            gain = nn.init.calculate_gain("relu")
            std = gain / math.sqrt(fan)
            weight.data = weight.data.sign() * std

        elif name == "scaled_signed_constant":
            fan = nn.init._calculate_correct_fan(weight, mode="fan_in")
            fan = fan * (1 - self.remain_rate)
            gain = nn.init.calculate_gain("relu")
            std = gain / math.sqrt(fan)
            weight.data = weight.data.sign() * std

        elif name == "kaiming_normal":
            nn.init.kaiming_normal_(weight, mode="fan_in", nonlinearity="relu")

        elif name == "scaled_kaiming_normal":
            fan = nn.init._calculate_correct_fan(weight, mode="fan_in")
            fan = fan * self.remain_rate
            gain = nn.init.calculate_gain("relu")
            std = gain / math.sqrt(fan)
            with torch.no_grad():
                weight.data.normal_(0, std)

        elif name == "kaiming_uniform":
            nn.init.kaiming_uniform_(weight, mode="fan_in", nonlinearity="relu")
        
        elif name == "xavier_normal":
            nn.init.xavier_normal_(weight)

# 初期化方法のリスト
init_methods = ["signed_constant", "scaled_signed_constant", "kaiming_normal", "scaled_kaiming_normal", "kaiming_uniform", "xavier_normal"]

# プロットの準備
plt.figure(figsize=(15, 10))

# サンプル数を統一してプロット
for method in init_methods:
    model = SimpleNN(init_type=method)
    weights = model.fc.weight.data.cpu().numpy().flatten()
    # 一様分布と正規分布のサンプル数を統一
    #sample_size = min(len(weights), 5000)  # 例えば1000サンプルに統一
    plt.hist(weights, bins=30, alpha=0.3, label=method, density=True)

plt.title("Weight Initialization Methods Comparison", fontsize=32)
plt.xlabel("Weight values", fontsize=32)
plt.ylabel("Frequency", fontsize=32)
plt.legend(fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.show()


In [None]:
import torch
import torch.nn as nn
import math
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import entropy

def calculate_kl_divergence(dist1, dist2):
    return entropy(dist1, dist2)

def init_weight(weight, name="signed_constant", remain_rate=0.9):
    if name == "signed_constant":
        fan = nn.init._calculate_correct_fan(weight, mode="fan_in")
        gain = nn.init.calculate_gain("relu")
        std = gain / math.sqrt(fan)
        weight.data = weight.data.sign() * std

    elif name == "scaled_signed_constant":
        fan = nn.init._calculate_correct_fan(weight, mode="fan_in")
        fan = fan * (1 - remain_rate)
        gain = nn.init.calculate_gain("relu")
        std = gain / math.sqrt(fan)
        weight.data = weight.data.sign() * std

    elif name == "kaiming_normal":
        nn.init.kaiming_uniform_(weight, mode="fan_in", nonlinearity="relu")

    elif name == "scaled_kaiming_normal":
        fan = nn.init._calculate_correct_fan(weight, mode="fan_in")
        fan = fan * remain_rate
        gain = nn.init.calculate_gain("relu")
        std = gain / math.sqrt(fan)
        with torch.no_grad():
            weight.data.normal_(0, std)

    elif name == "kaiming_uniform":
        nn.init.kaiming_uniform_(weight, mode="fan_in", nonlinearity="relu")
    
    elif name == "xavier_normal":
        nn.init.xavier_normal_(weight)

methods = [
    "signed_constant", "scaled_signed_constant", "kaiming_normal",
    "scaled_kaiming_normal", "kaiming_uniform", "xavier_normal"
]

num_params = (1000, 2000)
num_bins = 10
kld_matrix = np.zeros((len(methods), len(methods)))

for i, method1 in enumerate(methods):
    for j, method2 in enumerate(methods):
        weight1 = nn.Parameter(torch.empty(num_params))
        init_weight(weight1, name=method1)
        
        weight2 = nn.Parameter(torch.empty(num_params))
        init_weight(weight2, name=method2)
        
        samples1 = weight1.detach().numpy()
        samples2 = weight2.detach().numpy()
        
        hist1, _ = np.histogram(samples1, bins=num_bins, density=True)
        hist2, _ = np.histogram(samples2, bins=num_bins, density=True)
        
        kld = calculate_kl_divergence(hist1, hist2)
        kld_matrix[i, j] = kld




In [None]:
plt.figure(figsize=(10, 8))
sns.heatmap(kld_matrix, annot=True, cmap="YlGnBu", xticklabels=np.arange(6), yticklabels=np.arange(6))
plt.title("KL Divergence between Weight Initialization Methods")
plt.xlabel("Initialization Mode")
plt.ylabel("Initialization Mode")
plt.tight_layout()
plt.show()