In [2]:
import matplotlib.pyplot as plt
import random
import tqdm
import torch
import torchvision
import torchinfo

In [3]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
IMAGE_TRANSFORM = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                  torchvision.transforms.Normalize(mean=(0.1307,),
                                                                                   std=(0.3081,))])

In [4]:
train_mnist_data = torchvision.datasets.MNIST(root="data",
                                              train=True,
                                              transform=IMAGE_TRANSFORM,
                                              download=True)
test_mnist_data = torchvision.datasets.MNIST(root="data",
                                             train=False,
                                             transform=IMAGE_TRANSFORM,
                                             download=True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 17.6MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 486kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.38MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 3.81MB/s]


In [5]:
print(type(train_mnist_data))  # pytorchのDataSetのようなもの
print(len(train_mnist_data))  # (data1, data2, ,,, data60000)
print(type(train_mnist_data[0]))
print(len(train_mnist_data[0]))  # data1は2つの要素を持つタプル
print(type(train_mnist_data[0][0]))
print(train_mnist_data[0][0].shape)  # data1の1つ目の要素は画像のテンソル
print(type(train_mnist_data[0][1]))
print(train_mnist_data[0][1])  # data1の2つ目の要素は正解ラベルの数字

<class 'torchvision.datasets.mnist.MNIST'>
60000
<class 'tuple'>
2
<class 'torch.Tensor'>
torch.Size([1, 28, 28])
<class 'int'>
5


In [6]:
# num = 50
# fig = plt.figure(figsize=(16, ((num//10)+1)*1.25))
# for i in range(50):
#     fig.add_subplot(num//10, 10, i+1)
#     plt.imshow(X=train_mnist_data[i][0][0], cmap="gray")
#     plt.axis("off")
# plt.subplots_adjust(wspace=0.25, hspace=0.25)
# plt.show()

In [7]:
train_mnist_data_dataloader = torch.utils.data.DataLoader(dataset=train_mnist_data,
                                                          batch_size=100,
                                                          shuffle=True)
test_mnist_data_dataloader = torch.utils.data.DataLoader(dataset=test_mnist_data,
                                                         batch_size=100,
                                                         shuffle=True)

In [8]:
class MyNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_1 = torch.nn.Linear(in_features=28*28, out_features=1000)
        self.layer_2 = torch.nn.Linear(in_features=1000, out_features=2000)
        self.layer_3 = torch.nn.Linear(in_features=2000, out_features=10)

    def forward(self, in_data):
        out_data_layer_1 = torch.nn.functional.relu(input=self.layer_1(input=in_data))
        out_data_layer_2 = torch.nn.functional.relu(input=self.layer_2(input=out_data_layer_1))
        out_data = self.layer_3(input=out_data_layer_2)
        return out_data

In [9]:
nn_model = MyNet().to(device=DEVICE)

In [10]:
torchinfo.summary(model=nn_model,
                  input_size=(100, 784))

Layer (type:depth-idx)                   Output Shape              Param #
MyNet                                    [100, 10]                 --
├─Linear: 1-1                            [100, 1000]               785,000
├─Linear: 1-2                            [100, 2000]               2,002,000
├─Linear: 1-3                            [100, 10]                 20,010
Total params: 2,807,010
Trainable params: 2,807,010
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 280.70
Input size (MB): 0.31
Forward/backward pass size (MB): 2.41
Params size (MB): 11.23
Estimated Total Size (MB): 13.95

In [11]:
optimizer = torch.optim.Adam(params=nn_model.parameters(),
                             lr=0.001,
                             betas=(0.9, 0.999))
loss_f = torch.nn.CrossEntropyLoss()
nn_model.train()
# loss_sum = 0
for done_minibatch, (image_minibatch, label_minibatch) in tqdm.tqdm(iterable=enumerate(train_mnist_data_dataloader),
                                                                    total=len(train_mnist_data_dataloader)):
    image_minibatch = image_minibatch.to(device=DEVICE)
    label_minibatch = label_minibatch.to(device=DEVICE)
    image_minibatch_drop_channel = image_minibatch[:, 0]  # 白黒なのでテンソルのチャネル部分を落とす
    image_minibatch_drop_channel_for_linear = image_minibatch_drop_channel.view(-1, 28*28)  # Linear層の入力になるので、縦横を1次元にならす
    output = nn_model(in_data=image_minibatch_drop_channel_for_linear)
    assert output.shape==(100, 10), "推論結果は想定する次元数ではない"
    # label_minibatch_for_target = label_minibatch.unsqueeze(dim=1)
    # assert label_minibatch_for_target.shape==(100, 1), "正解ラベルは想定する次元数ではない"
    # loss = loss_f(input=output,
    #               target=label_minibatch_for_target)
    assert label_minibatch.shape==(100,), "正解ラベルは想定する次元数ではない"
    loss = loss_f(input=output,
                  target=label_minibatch)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # print("{a}回目のミニバッチ学習の損失値：{b}".format(a=done_minibatch+1, b=loss.item()))
#     loss_sum = loss_sum + loss.item()
# print("損失値の合計：{a}".format(a=loss_sum))

100%|██████████| 600/600 [01:00<00:00,  9.90it/s]


In [12]:
nn_model.eval()
num_test_data = len(test_mnist_data)
num_correct_total = 0
num_wrong_total_per_digit = {}
with torch.no_grad():
    for done_minibatch, (image_minibatch, label_minibatch) in tqdm.tqdm(iterable=enumerate(test_mnist_data_dataloader),
                                                                        total=len(test_mnist_data_dataloader)):
        image_minibatch = image_minibatch.to(device=DEVICE)
        label_minibatch = label_minibatch.to(device=DEVICE)
        image_minibatch_drop_channel = image_minibatch[:, 0]
        image_minibatch_drop_channel_for_linear = image_minibatch_drop_channel.view(-1, 28*28)
        output = nn_model(in_data=image_minibatch_drop_channel_for_linear)
        assert output.shape==(100, 10), "推論結果は想定する次元ではない"
        assert label_minibatch.shape==(100,), "正解ラベルは想定する次元ではない"
        predict_index = output.argmax(dim=1,
                                      keepdim=True)
        assert predict_index.shape==(100, 1),  "推論結果のargmaxが想定する次元ではない"
        for i in range(predict_index.shape[0]):
            predict = predict_index[i][0].item()
            answer = label_minibatch[i].item()
            if predict == answer:
                num_correct_total = num_correct_total + 1
            else:
                if answer not in num_wrong_total_per_digit.keys():
                    num_wrong_total_per_digit[answer] = 1
                else:
                    num_wrong_total_per_digit[answer] = num_wrong_total_per_digit[answer] + 1
        # print("{a}ミニバッチ完了".format(a=done_minibatch+1))
accuracy = num_correct_total / num_test_data
print("正解率：{a}".format(a=accuracy))
for k in num_wrong_total_per_digit.keys():
    print("画像{a}で予測が外れた個数：{b}".format(a=k, b=num_wrong_total_per_digit[k]))

100%|██████████| 100/100 [00:03<00:00, 26.39it/s]

正解率：0.9647
画像9で予測が外れた個数：51
画像3で予測が外れた個数：26
画像2で予測が外れた個数：58
画像7で予測が外れた個数：41
画像8で予測が外れた個数：15
画像5で予測が外れた個数：37
画像4で予測が外れた個数：48
画像6で予測が外れた個数：17
画像0で予測が外れた個数：42
画像1で予測が外れた個数：18





In [13]:
# 一番間違いの多かった正解ラベルをピックアップ
max_num_wrong_digit = max(num_wrong_total_per_digit,
                          key=num_wrong_total_per_digit.get)
max_num_wrong = max(num_wrong_total_per_digit.values())
print("{a}：{b}個".format(a=max_num_wrong_digit, b=max_num_wrong))

2：58個


$$
h = W_0x + \Delta Wx = W_0x + BAx
$$

- $h$: 線型結合で更新された重み
- $\Delta Wx$: モデルの重みの更新量
- $BAx$: 低ランク分解されたモデルの重みの更新量
	- $B$: $d \times r$の行列
	- $A$: $r \times k$の行列
	- ランク$r$を小さく設定し近似することで、パラメータを大幅削減できる
    - $A$はランダムなガウシアンノイズで初期化
	- $B$はゼロで初期化
	- $\Delta W$を$\frac{\alpha}{r}$でスケール
	- $\alpha$は定数で学習率と似ていて学習を安定化させる

In [14]:
# LoRAを定義
class MyLoRA(torch.nn.Module):
    def __init__(self, pre_layer_dim, post_layer_dim, rank, alpha):
        super().__init__()
        self.lora_A = torch.nn.Parameter(data=torch.zeros(size=(rank, post_layer_dim)).to(device=DEVICE))
        torch.nn.init.normal_(tensor=self.lora_A,
                              mean=0,
                              std=1)
        assert self.lora_A.shape==(rank, post_layer_dim), "LoRAのAで想定している次元ではない"
        self.lora_B = torch.nn.Parameter(data=torch.zeros(size=(pre_layer_dim, rank)).to(device=DEVICE))
        assert self.lora_B.shape==(pre_layer_dim, rank), "LoRAのBで想定している次元ではない"
        self.scale = alpha / rank
        self.lora_enable_flag = True

    def forward(self, original_weight):
        if self.lora_enable_flag == True:
            delta_W = torch.matmul(input=self.lora_B,
                                   other=self.lora_A)
            assert delta_W.shape==original_weight.shape, "LoRAで想定している次元ではない"
            return original_weight + (delta_W * self.scale)
        else:
            return original_weight

In [15]:
# 定義したLoRAを呼び出す関数
def add_LoRA(layer_for_LoRA, rank, alpha):
    pre_dim = layer_for_LoRA.weight.shape[0]
    post_dim = layer_for_LoRA.weight.shape[1]
    return MyLoRA(pre_layer_dim=pre_dim,
                  post_layer_dim=post_dim,
                  rank=rank,
                  alpha=alpha)

In [16]:
print(nn_model.layer_1)
print(nn_model.layer_1.weight.shape)
print(nn_model.layer_1.bias.shape)

Linear(in_features=784, out_features=1000, bias=True)
torch.Size([1000, 784])
torch.Size([1000])


In [17]:
# register_parametrizationメソッドの引数parametrizationに関数を指定する事で、LoRAを付与する
torch.nn.utils.parametrize.register_parametrization(module=nn_model.layer_1,
                                                    tensor_name="weight",
                                                    parametrization=add_LoRA(layer_for_LoRA=nn_model.layer_1,
                                                                             rank=1,
                                                                             alpha=1))
torch.nn.utils.parametrize.register_parametrization(module=nn_model.layer_2,
                                                    tensor_name="weight",
                                                    parametrization=add_LoRA(layer_for_LoRA=nn_model.layer_2,
                                                                             rank=1,
                                                                             alpha=1))
torch.nn.utils.parametrize.register_parametrization(module=nn_model.layer_3,
                                                    tensor_name="weight",
                                                    parametrization=add_LoRA(layer_for_LoRA=nn_model.layer_3,
                                                                             rank=1,
                                                                             alpha=1))

ParametrizedLinear(
  in_features=2000, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): MyLoRA()
    )
  )
)

In [18]:
print(nn_model.layer_1)
print(nn_model.layer_1.weight.shape)
print(nn_model.layer_1.bias.shape)
print(nn_model.layer_1.parametrizations.weight[0].lora_A.shape)
print(nn_model.layer_1.parametrizations.weight[0].lora_B.shape)

ParametrizedLinear(
  in_features=784, out_features=1000, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): MyLoRA()
    )
  )
)
torch.Size([1000, 784])
torch.Size([1000])
torch.Size([1, 784])
torch.Size([1000, 1])


In [19]:
# 変数lora_enable_flagをTrueにするとLoRAを加算したモデル(LoRA付き)、FalseにするとLoRAを加算しないモデル(LoRA無しと同じ)
# layer_for_LoRA_list = [nn_model.layer_1, nn_model.layer_2, nn_model.layer_3]
# for layer_for_LoRA in layer_for_LoRA_list:
#     layer_for_LoRA.parametrizations.weight[0].lora_enable_flag = False
#     print(layer_for_LoRA.parametrizations.weight[0].lora_enable_flag)
#     layer_for_LoRA.parametrizations.weight[0].lora_enable_flag = True
#     print(layer_for_LoRA.parametrizations.weight[0].lora_enable_flag)

In [20]:
torchinfo.summary(model=nn_model,
                  input_size=(100, 784))

Layer (type:depth-idx)                   Output Shape              Param #
MyNet                                    [100, 10]                 --
├─ParametrizedLinear: 1-1                [100, 1000]               1,000
│    └─ModuleDict: 2-1                   --                        --
│    │    └─ParametrizationList: 3-1     [1000, 784]               785,784
├─ParametrizedLinear: 1-2                [100, 2000]               2,000
│    └─ModuleDict: 2-2                   --                        --
│    │    └─ParametrizationList: 3-2     [2000, 1000]              2,003,000
├─ParametrizedLinear: 1-3                [100, 10]                 10
│    └─ModuleDict: 2-3                   --                        --
│    │    └─ParametrizationList: 3-3     [10, 2000]                22,010
Total params: 2,813,804
Trainable params: 2,813,804
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0
Input size (MB): 0.31
Forward/backward pass size (MB): 22.43
Params size (MB): 0.03
Estima

In [21]:
print("更新対象(requires_gradがTrue)のパラメータの一覧")
trainable_total_param_num = 0
for model_layer_name, model_layer_parameter in nn_model.named_parameters():
    if model_layer_parameter.requires_grad == True:
        print("{a}：{b}個".format(a=model_layer_name, b=model_layer_parameter.numel()))
        trainable_total_param_num = trainable_total_param_num + model_layer_parameter.numel()
print("更新対象のパラメータの合計数：{a}個".format(a=trainable_total_param_num))

更新対象(requires_gradがTrue)のパラメータの一覧
layer_1.bias：1000個
layer_1.parametrizations.weight.original：784000個
layer_1.parametrizations.weight.0.lora_A：784個
layer_1.parametrizations.weight.0.lora_B：1000個
layer_2.bias：2000個
layer_2.parametrizations.weight.original：2000000個
layer_2.parametrizations.weight.0.lora_A：1000個
layer_2.parametrizations.weight.0.lora_B：2000個
layer_3.bias：10個
layer_3.parametrizations.weight.original：20000個
layer_3.parametrizations.weight.0.lora_A：2000個
layer_3.parametrizations.weight.0.lora_B：10個
更新対象のパラメータの合計数：2813804個


In [22]:
# LoRAのみを更新対象(requires_gradをTrue)にする
for model_layer_name, model_layer_parameter in nn_model.named_parameters():
    if "lora" not in model_layer_name:
        model_layer_parameter.requires_grad = False  # LoRA以外のパラメータは更新対象から外す
    else:
        model_layer_parameter.requires_grad = True  # LoRAのパラメータは更新対象にする

In [23]:
print("更新対象(requires_gradがTrue)のパラメータの一覧")
trainable_total_param_num = 0
for model_layer_name, model_layer_parameter in nn_model.named_parameters():
    if model_layer_parameter.requires_grad == True:
        print("{a}：{b}個".format(a=model_layer_name, b=model_layer_parameter.numel()))
        trainable_total_param_num = trainable_total_param_num + model_layer_parameter.numel()
print("更新対象のパラメータの合計数：{a}個".format(a=trainable_total_param_num))

更新対象(requires_gradがTrue)のパラメータの一覧
layer_1.parametrizations.weight.0.lora_A：784個
layer_1.parametrizations.weight.0.lora_B：1000個
layer_2.parametrizations.weight.0.lora_A：1000個
layer_2.parametrizations.weight.0.lora_B：2000個
layer_3.parametrizations.weight.0.lora_A：2000個
layer_3.parametrizations.weight.0.lora_B：10個
更新対象のパラメータの合計数：6794個


In [24]:
# 訓練データから一番間違いの多かった正解ラベルのみをピックアップ
target_digit_row_id_list = [row_id for row_id in range(len(train_mnist_data)) if train_mnist_data[row_id][1] == max_num_wrong_digit]
print(len(target_digit_row_id_list))
target_digit_train_mnist_data = torch.utils.data.Subset(dataset=train_mnist_data,
                                                        indices=target_digit_row_id_list)
print(len(target_digit_train_mnist_data))
target_digit_train_mnist_data_dataloader = torch.utils.data.DataLoader(dataset=target_digit_train_mnist_data,
                                                                       batch_size=100,
                                                                       shuffle=True)

5958
5958


In [25]:
# LoRA付きモデルを学習(訓練データは一番間違いの多かった正解ラベルのみ)
optimizer = torch.optim.Adam(params=nn_model.parameters(),  # 改めてoptimizerを設定しないとLoRA付きのパラメータを更新してくれない
                             lr=0.001,
                             betas=(0.9, 0.999))
loss_f = torch.nn.CrossEntropyLoss()
nn_model.train()
for done_minibatch, (image_minibatch, label_minibatch) in tqdm.tqdm(iterable=enumerate(target_digit_train_mnist_data_dataloader),
                                                                    total=len(target_digit_train_mnist_data_dataloader)):
    image_minibatch = image_minibatch.to(device=DEVICE)
    label_minibatch = label_minibatch.to(device=DEVICE)
    image_minibatch_drop_channel = image_minibatch[:, 0]
    image_minibatch_drop_channel_for_linear = image_minibatch_drop_channel.view(-1, 28*28)
    output = nn_model(in_data=image_minibatch_drop_channel_for_linear)
    loss = loss_f(input=output,
                  target=label_minibatch)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

100%|██████████| 60/60 [00:09<00:00,  6.65it/s]


In [26]:
# LoRA付きモデルで推論
nn_model.eval()
num_test_data = len(test_mnist_data)
num_correct_total = 0
num_wrong_total_per_digit = {}
with torch.no_grad():
    for done_minibatch, (image_minibatch, label_minibatch) in tqdm.tqdm(iterable=enumerate(test_mnist_data_dataloader),
                                                                        total=len(test_mnist_data_dataloader)):
        image_minibatch = image_minibatch.to(device=DEVICE)
        label_minibatch = label_minibatch.to(device=DEVICE)
        image_minibatch_drop_channel = image_minibatch[:, 0]
        image_minibatch_drop_channel_for_linear = image_minibatch_drop_channel.view(-1, 28*28)
        output = nn_model(in_data=image_minibatch_drop_channel_for_linear)
        predict_index = output.argmax(dim=1,
                                      keepdim=True)
        for i in range(predict_index.shape[0]):
            predict = predict_index[i][0].item()
            answer = label_minibatch[i].item()
            if predict == answer:
                num_correct_total = num_correct_total + 1
            else:
                if answer not in num_wrong_total_per_digit.keys():
                    num_wrong_total_per_digit[answer] = 1
                else:
                    num_wrong_total_per_digit[answer] = num_wrong_total_per_digit[answer] + 1
accuracy = num_correct_total / num_test_data
print("正解率：{a}".format(a=accuracy))
for k in num_wrong_total_per_digit.keys():
    print("画像{a}で予測が外れた個数：{b}".format(a=k, b=num_wrong_total_per_digit[k]))

100%|██████████| 100/100 [00:04<00:00, 21.89it/s]

正解率：0.3766
画像7で予測が外れた個数：926
画像1で予測が外れた個数：812
画像0で予測が外れた個数：969
画像6で予測が外れた個数：691
画像8で予測が外れた個数：842
画像9で予測が外れた個数：856
画像5で予測が外れた個数：303
画像3で予測が外れた個数：453
画像4で予測が外れた個数：380
画像2で予測が外れた個数：2





In [27]:
# LoRAを初期化
layer_for_LoRA_list = [nn_model.layer_1, nn_model.layer_2, nn_model.layer_3]
for layer_for_LoRA in layer_for_LoRA_list:
    layer_for_LoRA.parametrizations.weight[0].lora_A.data.normal_(mean=0, std=1)
    layer_for_LoRA.parametrizations.weight[0].lora_B.data.zero_()

In [28]:
# LoRA初期化後のモデルで推論(LoRA無しのモデルと同じ)
nn_model.eval()
num_test_data = len(test_mnist_data)
num_correct_total = 0
num_wrong_total_per_digit = {}
with torch.no_grad():
    for done_minibatch, (image_minibatch, label_minibatch) in tqdm.tqdm(iterable=enumerate(test_mnist_data_dataloader),
                                                                        total=len(test_mnist_data_dataloader)):
        image_minibatch = image_minibatch.to(device=DEVICE)
        label_minibatch = label_minibatch.to(device=DEVICE)
        image_minibatch_drop_channel = image_minibatch[:, 0]
        image_minibatch_drop_channel_for_linear = image_minibatch_drop_channel.view(-1, 28*28)
        output = nn_model(in_data=image_minibatch_drop_channel_for_linear)
        predict_index = output.argmax(dim=1,
                                      keepdim=True)
        for i in range(predict_index.shape[0]):
            predict = predict_index[i][0].item()
            answer = label_minibatch[i].item()
            if predict == answer:
                num_correct_total = num_correct_total + 1
            else:
                if answer not in num_wrong_total_per_digit.keys():
                    num_wrong_total_per_digit[answer] = 1
                else:
                    num_wrong_total_per_digit[answer] = num_wrong_total_per_digit[answer] + 1
accuracy = num_correct_total / num_test_data
print("正解率：{a}".format(a=accuracy))
for k in num_wrong_total_per_digit.keys():
    print("画像{a}で予測が外れた個数：{b}".format(a=k, b=num_wrong_total_per_digit[k]))

100%|██████████| 100/100 [00:04<00:00, 21.42it/s]

正解率：0.9647
画像9で予測が外れた個数：51
画像6で予測が外れた個数：17
画像5で予測が外れた個数：37
画像8で予測が外れた個数：15
画像2で予測が外れた個数：58
画像4で予測が外れた個数：48
画像0で予測が外れた個数：42
画像3で予測が外れた個数：26
画像7で予測が外れた個数：41
画像1で予測が外れた個数：18





In [29]:
# 訓練データから一番間違いの多かった正解ラベルのみの一部をピックアップ
target_digit_row_id_list = [row_id for row_id in range(len(train_mnist_data)) if train_mnist_data[row_id][1] == max_num_wrong_digit]
random.shuffle(target_digit_row_id_list)
part_of_target_digit_row_id_list = target_digit_row_id_list[0:len(target_digit_row_id_list)//8]
print(len(part_of_target_digit_row_id_list))
target_digit_train_mnist_data = torch.utils.data.Subset(dataset=train_mnist_data,
                                                        indices=part_of_target_digit_row_id_list)
print(len(target_digit_train_mnist_data))
target_digit_train_mnist_data_dataloader = torch.utils.data.DataLoader(dataset=target_digit_train_mnist_data,
                                                                       batch_size=100,
                                                                       shuffle=True)

744
744


In [30]:
# 改めてLoRA付きモデルを学習(訓練データは一番間違いの多かった正解ラベルのみの一部)
optimizer = torch.optim.Adam(params=nn_model.parameters(),  # 改めてoptimizerを設定しないとLoRA付きのパラメータを更新してくれない
                             lr=0.001,
                             betas=(0.9, 0.999))
loss_f = torch.nn.CrossEntropyLoss()
nn_model.train()
for done_minibatch, (image_minibatch, label_minibatch) in tqdm.tqdm(iterable=enumerate(target_digit_train_mnist_data_dataloader),
                                                                    total=len(target_digit_train_mnist_data_dataloader)):
    image_minibatch = image_minibatch.to(device=DEVICE)
    label_minibatch = label_minibatch.to(device=DEVICE)
    image_minibatch_drop_channel = image_minibatch[:, 0]
    image_minibatch_drop_channel_for_linear = image_minibatch_drop_channel.view(-1, 28*28)
    output = nn_model(in_data=image_minibatch_drop_channel_for_linear)
    loss = loss_f(input=output,
                  target=label_minibatch)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

100%|██████████| 8/8 [00:00<00:00,  9.43it/s]


In [31]:
# 改めてLoRA付きモデルで推論
nn_model.eval()
num_test_data = len(test_mnist_data)
num_correct_total = 0
num_wrong_total_per_digit = {}
with torch.no_grad():
    for done_minibatch, (image_minibatch, label_minibatch) in tqdm.tqdm(iterable=enumerate(test_mnist_data_dataloader),
                                                                        total=len(test_mnist_data_dataloader)):
        image_minibatch = image_minibatch.to(device=DEVICE)
        label_minibatch = label_minibatch.to(device=DEVICE)
        image_minibatch_drop_channel = image_minibatch[:, 0]
        image_minibatch_drop_channel_for_linear = image_minibatch_drop_channel.view(-1, 28*28)
        output = nn_model(in_data=image_minibatch_drop_channel_for_linear)
        predict_index = output.argmax(dim=1,
                                      keepdim=True)
        for i in range(predict_index.shape[0]):
            predict = predict_index[i][0].item()
            answer = label_minibatch[i].item()
            if predict == answer:
                num_correct_total = num_correct_total + 1
            else:
                if answer not in num_wrong_total_per_digit.keys():
                    num_wrong_total_per_digit[answer] = 1
                else:
                    num_wrong_total_per_digit[answer] = num_wrong_total_per_digit[answer] + 1
accuracy = num_correct_total / num_test_data
print("正解率：{a}".format(a=accuracy))
for k in num_wrong_total_per_digit.keys():
    print("画像{a}で予測が外れた個数：{b}".format(a=k, b=num_wrong_total_per_digit[k]))

100%|██████████| 100/100 [00:05<00:00, 19.01it/s]

正解率：0.9671
画像2で予測が外れた個数：32
画像7で予測が外れた個数：47
画像0で予測が外れた個数：40
画像6で予測が外れた個数：17
画像5で予測が外れた個数：40
画像4で予測が外れた個数：43
画像1で予測が外れた個数：14
画像9で予測が外れた個数：50
画像8で予測が外れた個数：18
画像3で予測が外れた個数：28



