#  Pytorch implementation of "A simple neural network module for relational reasoning"
Sort-of-CLEVRの実装

In [2]:
import argparse
import os
import pickle
import random
import numpy as np
import torch
from torch.autograd import Variable

from model import RN, CNN_MLP

ModuleNotFoundError: No module named 'model'

## オプションを受け取る部分
デフォルトは論文に記載されているパラメータ値

In [None]:
parser = argparse.ArgumentParser(description='PyTorch Relational-Network sort-of-CLVR Example')
parser.add_argument('--model', type=str, choices=['RN', 'CNN_MLP'], default='RN',
                    help='resume from model stored')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=20, metavar='N',
                    help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',
                    help='learning rate (default: 0.0001)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--resume', type=str,
                    help='resume from model stored')

##  変数の設定
argsから受け取った変数に名前をつけて格納

In [None]:
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

if args.model == 'CNN_MLP':
    model = CNN_MLP(args)
else:
    model = RN(args)

model_dirs = './model'
bs = args.batch_size
input_img = torch.FloatTensor(bs, 3, 75, 75) #画像の大きさをテンソル化
input_qst = torch.FloatTensor(bs, 11) #input_qstのベクトルをテンソル化
label = torch.LongTensor(bs) #長さ64のテンソル，ダミーのテストラベル

if args.cuda:
    model.cuda()
    input_img = input_img.cuda()
    input_qst = input_qst.cuda()
    label = label.cuda()

# 以下3変数は直後のtensor_dataによって頻繁に変更される
input_img = Variable(input_img)
input_qst = Variable(input_qst)
label = Variable(label)

## 配列のテンソル化，変形
1. 指定したバッチごとに対してimg, qst, ansからなるデータをテンソル化
2. L.57-59の変数を各データが使える形に変形して値を代入（コピー？）

In [None]:
def tensor_data(data, i):
    img = torch.from_numpy(np.asarray(data[0][bs*i:bs*(i+1)]))
    qst = torch.from_numpy(np.asarray(data[1][bs*i:bs*(i+1)]))
    ans = torch.from_numpy(np.asarray(data[2][bs*i:bs*(i+1)]))

    input_img.data.resize_(img.size()).copy_(img)
    input_qst.data.resize_(qst.size()).copy_(qst)
    label.data.resize_(ans.size()).copy_(ans)

## datasetから各種データを取り出す
訓練/テストデータのリストのレコード: (img,qst,ans)  
このレコードを要素ごとに分割して，取り出す．

In [None]:
def cvt_data_axis(data):
    img = [e[0] for e in data]
    qst = [e[1] for e in data]
    ans = [e[2] for e in data]
    return (img, qst, ans)

## 訓練

In [None]:
def train(epoch, rel, norel):
    model.train()

    if len(rel[0]) != len(norel[0]):
        print('Not equal length for relation dataset and non-relation dataset.')
        return
    
    #データセットのシャッフル
    random.shuffle(rel)
    random.shuffle(norel)
    #データセットから要素のタプルを取り出す
    rel = cvt_data_axis(rel)
    norel = cvt_data_axis(norel)
    #バッチ学習
    for i in range(len(rel[0]) // bs):
        tensor_data(rel, i)
        accuracy_rel = model.train_(input_img, input_qst, label)

        tensor_data(norel, i)
        accuracy_norel = model.train_(input_img, input_qst, label)

        if i % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)] Relations accuracy: {:.0f}% | \
                Non-relations accuracy: {:.0f}%'.format(epoch, i * bs * 2, \
                len(rel[0]) * 2, 100. * i * bs/ len(rel[0]), accuracy_rel, accuracy_norel))

## テスト

In [None]:
def test(epoch, rel, norel):
    model.eval()
    if len(rel[0]) != len(norel[0]):
        print('Not equal length for relation dataset and non-relation dataset.')
        return

    rel = cvt_data_axis(rel)
    norel = cvt_data_axis(norel)

    accuracy_rels = []
    accuracy_norels = []
    for i in range(len(rel[0]) // bs):
        tensor_data(rel, i)
        accuracy_rels.append(model.test_(input_img, input_qst, label))

        tensor_data(norel, i)
        accuracy_norels.append(model.test_(input_img, input_qst, label))

    accuracy_rel = sum(accuracy_rels) / len(accuracy_rels) #全てのバッチを調べて，その平均を出している
    accuracy_norel = sum(accuracy_norels) / len(accuracy_norels)
    print('\n Test set: Relation accuracy: {:.0f}% | Non-relation accuracy: {:.0f}%\n'.format(
        accuracy_rel, accuracy_norel))

## データロード
gen_dataset.pyで保存したpickleファイルから読み込み

In [None]:
def load_data():
    print('loading data...')
    dirs = './data'
    filename = os.path.join(dirs, 'sort-of-clevr.pickle')
    with open(filename, 'rb') as f:
        train_datasets, test_datasets = pickle.load(f)
    rel_train = []
    rel_test = []
    norel_train = []
    norel_test = []
    print('processing data...')

    for img, relations, norelations in train_datasets:
        #channel(RGB)方向の次元を3にするためgen_dataset.pyのimgの生成法と関連
        img = np.swapaxes(img, 0, 2)
        for qst, ans in zip(relations[0], relations[1]):
            rel_train.append((img, qst, ans))
        for qst, ans in zip(norelations[0], norelations[1]):
            norel_train.append((img, qst, ans))

    for img, relations, norelations in test_datasets:
        img = np.swapaxes(img, 0, 2)
        for qst, ans in zip(relations[0], relations[1]):
            rel_test.append((img, qst, ans))
        for qst, ans in zip(norelations[0], norelations[1]):
            norel_test.append((img, qst, ans))

    return (rel_train, rel_test, norel_train, norel_test)

## データセットの分割
訓練用，テスト用

In [None]:
rel_train, rel_test, norel_train, norel_test = load_data()
try:
    os.makedirs(model_dirs)
except:
    print('directory {} already exists'.format(model_dirs))

In [None]:
if args.resume:
    filename = os.path.join(model_dirs, args.resume)
    if os.path.isfile(filename):
        print('==> loading checkpoint {}'.format(filename))
        checkpoint = torch.load(filename)
        model.load_state_dict(checkpoint)
        print('==> loaded checkpoint {}'.format(filename))

## 学習，テスト
一定間隔でコマンドラインに出力

In [None]:
for epoch in range(1, args.epochs + 1):
    train(epoch, rel_train, norel_train)
    print('')
    test(epoch, rel_test, norel_test)
    model.save_model(epoch)