In [14]:
import pandas as pd                                         
import numpy as np                                              
from scipy.special import comb                                      
import math
from operator import mul
import neal 
import dimod
#from pyqubo import Array, Constraint, Placeholder, solve_qubo           
import itertools                                                        
import random                                                           
import matplotlib.pyplot as plt                                         
import timeit
from itertools import combinations

In [15]:
def calc_marginals(df):                   
    return np.array([                      
        sum(df['Y']),                     
        np.dot(df['Y'], df['LI']),      
        np.dot(df['Y'], df['SEX']),      
        np.dot(df['Y'], df['AOP']),      
    ])                                 

In [16]:
def make_Hamiltonian(df, t1):
    t_list = calc_marginals(df)
    ## サイズは確かにdf.shape[0]ですが、あらけじめNという変数に格納しておくと可読性が上がります。
    ## 変数は関数の中で使うときに混同がおこらなければもっと短くてかまいません。
    N=len(df)
    dup_list = [(i, i) for i in range(N)]
    comb_list = [(i, j) for i in range(N) for j in range(i+1, N)]
    
    lin_Y = [1-2*t_list[0] for (i, _) in dup_list] #同じy同士
    quad_Y = [2 for (i, j) in comb_list] #異なるy同士
    num_Y = t_list[0]**2 #数字の二乗
    
    LI = df['LI'].iloc
    lin_LI = [(LI[i] - 2 * t1) * LI[i] for (i, _) in dup_list]
    quad_LI = [2*LI[i] * LI[j] for (i, j) in comb_list]
    num_LI = t1**2
    
    SEX = df['SEX'].iloc
    lin_SEX  = [(SEX[i] - 2 * t_list[2]) * SEX[i] for (i, _) in dup_list]
    quad_SEX  = [2*SEX[i] * SEX[j] for (i, j) in comb_list]
    num_SEX  = t_list[2]**2
    
    AOP = df['AOP'].iloc
    lin_AOP = [(AOP[i] - 2 * t_list[3]) * AOP[i] for (i, _) in dup_list]
    quad_AOP = [2*AOP[i] * AOP[j] for (i, j) in comb_list]
    num_AOP = t_list[3]**2
    
    ## zipは二つ以上のリストを引数にとることができるのと、
    ## リスト内包表記は、辞書をつくるときにも使えるので、
    ## コードを短くできます。
    ##統合作業
    #lin
    lin_list = [sum(lin) for lin in zip(lin_Y, lin_LI, lin_SEX, lin_AOP)]
    lin = {i: lin_list[i] for (i, _) in dup_list}
    
    #quad
    quad_values = [sum(quad) for quad in zip(quad_Y, quad_LI, quad_SEX, quad_AOP)]
    quad = {ij: quad_values[n] for (n, ij) in enumerate(comb_list)}
    
    #num
    num = num_Y + num_LI + num_SEX + num_AOP
    #print('lin:', lin)
    #print('quad:', quad)
    #print('num:', num)
    return dimod.BinaryQuadraticModel(lin, quad, num, dimod.Vartype.BINARY)#dic, dic, num

In [17]:
def find_valid_y(df, num_reads):                                                        
    sa_sampler = neal.sampler.SimulatedAnnealingSampler()
    t_list = calc_marginals(df)
    
    valid_y_list = {}                                                                   
    valid_y_num = {}                                                                    
    for t1 in range(0, sum(df['LI'])+1):                                                                                                                                                                                                                                               
        bqm = make_Hamiltonian(df, t1)
        res = sa_sampler.sample(bqm)
        
        """for s, e in res.data(['sample','energy']):
            print(s,'\t E = ', e)
        
        res = neal.SimulatedAnnealingSampler().sample(bqm, num_reads=num_reads)
        print('res(t1=', t1, ')',res.record)"""                                            
                                                                                        
        valid_y_list[t1] = []                                                           
        valid_y_num[t1] = 0                                                             
        for y_info in list(res.record):                                                 
            if y_info[1] == 0.:
                valid_y_num[t1] += 1                                        
                valid_y_list[t1].append(list(y_info[0]))                    
                #print('energy0')
                
            '''if sum(y_info[0]) == t_list[0]:#int同士の比較  
                print('yes')
                y = pd.Series(y_info[0])                                                
                if np.dot([df['LI']], y) == t1:#int同士の比較                       
                    for j in range(2, len(t_list)):
                        v = df.to_numpy()[:, j]
                        if np.dot(v, y) - t_list[j] != 0:
                            break                                                       
                    else:                                                               
                        if all(list(y_info[0]) != p for p in valid_y_list[t1]):         
                            valid_y_num[t1] += 1                                        
                            valid_y_list[t1].append(list(y_info[0]))                    
                            print('perfect') '''                                           
    return valid_y_list, valid_y_num                                                                                                                                          

In [18]:
#==========
#テストコード
#==========
def test_find_valid_y():
    df = pd.read_csv('../../input/ost16.csv', sep=',', index_col=0)
    true_t1 = sum(df['Y'] * df['LI'])
    valid_y_list, valid_y_num = find_valid_y(df,  num_reads = 10)
    print(valid_y_list, valid_y_num)
    assert valid_y_num[true_t1] > 0  
    
#test_find_valid_y()

lin: {0: -21, 1: -21, 2: -28, 3: -28, 4: -28, 5: -28, 6: -21, 7: -21, 8: -28, 9: -21, 10: -22, 11: -28, 12: -28, 13: -15, 14: -21, 15: -14}
quad: {(0, 1): 4, (0, 2): 6, (0, 3): 6, (0, 4): 6, (0, 5): 6, (0, 6): 6, (0, 7): 4, (0, 8): 6, (0, 9): 4, (0, 10): 2, (0, 11): 6, (0, 12): 6, (0, 13): 2, (0, 14): 6, (0, 15): 4, (1, 2): 6, (1, 3): 6, (1, 4): 6, (1, 5): 6, (1, 6): 4, (1, 7): 6, (1, 8): 6, (1, 9): 6, (1, 10): 4, (1, 11): 6, (1, 12): 6, (1, 13): 2, (1, 14): 4, (1, 15): 4, (2, 3): 8, (2, 4): 8, (2, 5): 8, (2, 6): 6, (2, 7): 6, (2, 8): 8, (2, 9): 6, (2, 10): 4, (2, 11): 8, (2, 12): 8, (2, 13): 2, (2, 14): 6, (2, 15): 4, (3, 4): 8, (3, 5): 8, (3, 6): 6, (3, 7): 6, (3, 8): 8, (3, 9): 6, (3, 10): 4, (3, 11): 8, (3, 12): 8, (3, 13): 2, (3, 14): 6, (3, 15): 4, (4, 5): 8, (4, 6): 6, (4, 7): 6, (4, 8): 8, (4, 9): 6, (4, 10): 4, (4, 11): 8, (4, 12): 8, (4, 13): 2, (4, 14): 6, (4, 15): 4, (5, 6): 6, (5, 7): 6, (5, 8): 8, (5, 9): 6, (5, 10): 4, (5, 11): 8, (5, 12): 8, (5, 13): 2, (5, 14): 6, (5, 

In [21]:
def test_validity():
    df1 = pd.read_csv('../../input/ost16.csv', sep=',',index_col=0)
    df2 = pd.read_csv('../../input/ost16.csv', sep=',',index_col=0)
    new_y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1])
    df2['Y'] = new_y
    t_list1 = calc_marginals(df1)
    t_list2 = calc_marginals(df2)
    print(t_list1)
    print(t_list2)
    assert np.all(t_list1[[0,2,3]] == t_list2[[0,2,3]]) 

#test_validity()

[8 6 4 4]
[8 7 4 4]
