In [1]:
import pandas as pd                                         
import numpy as np                                              
from scipy.special import comb                                      
import math                                            
from neal import SimulatedAnnealingSampler                            
from pyqubo import Array, Constraint, Placeholder, solve_qubo           
import itertools                                                        
import random                                                           
import matplotlib.pyplot as plt                                         
import timeit

In [2]:
#=============                         
# 関数定義                          
#=============                                                        
# make_t_list_columns_num_samples         
# は正確検定の表の周辺和を計算しているので
# 次の関数に置き換え          
def calc_marginals(df):                   
    return np.array([                      
        sum(df['Y']),                     
        np.dot(df['Y'], df['LI']),      
        np.dot(df['Y'], df['SEX']),      
        np.dot(df['Y'], df['AOP']),      
    ])                                 

In [8]:
def find_valid_y(df, num_reads):                                                        
    t_list = calc_marginals(df)
    var_y = Array.create('y', shape=df.shape[0], vartype='BINARY')
    sa_sampler = SimulatedAnnealingSampler()
    
    valid_y_list = {}                                                                   
    valid_y_num = {}                                                                    
    for t1 in range(0, sum(df['LI'])+1):                                                                                                                                                     
        """#QUBO式で定式化                                                                 
        H = (sum(var_y) - t_list[0])**2                                                     
        H += (np.dot(df['LI'], var_y) - t1)**2                                        
        for j in range(2, len(t_list)):
            H += (np.dot(df.iloc[:, j], var_y) - t_list[j])**2"""
            
        #QUBO式で定式化                                                                 
        H = (sum(var_y) - t_list[0])**2
        v1 = df.to_numpy()[:, 1]
        H += (var_y.dot(v1) - t1) ** 2                                      
        for j in range(2, len(t_list)):                                                
            v2 = df.to_numpy()[:, j] 
            H += (var_y.dot(v2) - t_list[j]) ** 2
            
                                                                                            
        bqm = H.compile().to_bqm()                                                      
        res = sa_sampler.sample(bqm, num_reads=num_reads)                              
        print('res(t1=', t1, ')',res.record)                                                      
                                                                                        
        valid_y_list[t1] = []                                                           
        valid_y_num[t1] = 0                                                             
        for y_info in list(res.record):                                                 
            if y_info[1] == 0.:
                valid_y_num[t1] += 1                                        
                valid_y_list[t1].append(list(y_info[0]))                    
                print('energy0')
                
            '''if sum(y_info[0]) == t_list[0]:#int同士の比較  
                print('yes')
                y = pd.Series(y_info[0])                                                
                if np.dot([df['LI']], y) == t1:#int同士の比較                       
                    for j in range(2, len(t_list)):
                        v = df.to_numpy()[:, j]
                        if np.dot(v, y) - t_list[j] != 0:
                            break                                                       
                    else:                                                               
                        if all(list(y_info[0]) != p for p in valid_y_list[t1]):         
                            valid_y_num[t1] += 1                                        
                            valid_y_list[t1].append(list(y_info[0]))                    
                            print('perfect') '''                                           
    return valid_y_list, valid_y_num                                                                                                                                          

In [13]:
#==========
#テストコード
#==========
def test_find_valid_y():
    df = pd.read_csv('../../input/ost16.csv', sep=',', index_col=0)
    true_t1 = sum(df['Y'] * df['LI'])
    valid_y_list, valid_y_num = find_valid_y(df,  num_reads = 10)
    print(valid_y_list, valid_y_num)
    assert valid_y_num[true_t1] > 0  
    
test_find_valid_y()

Index(['Y', 'LI', 'SEX', 'AOP'], dtype='object')
res(t1= 0 ) [([0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0], 19., 1)
 ([0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0], 19., 1)
 ([0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0], 19., 1)
 ([0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0], 19., 1)
 ([0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 19., 1)
 ([0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0], 19., 1)
 ([0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0], 19., 1)
 ([0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0], 19., 1)
 ([0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0], 19., 1)
 ([0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0], 19., 1)]
res(t1= 1 ) [([0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0], 13., 1)
 ([0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0], 13., 1)
 ([1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0], 13., 1)
 ([1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0], 13., 1)
 ([1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0], 13., 1)
 ([0, 1, 1, 0, 1, 0, 0, 0,

In [11]:
def test_validity():
    df1 = pd.read_csv('../../input/ost16.csv', sep=',',index_col=0)
    df2 = pd.read_csv('../../input/ost16.csv', sep=',',index_col=0)
    new_y = np.array([1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1])
    df2['Y'] = new_y
    t_list1 = calc_marginals(df1)
    t_list2 = calc_marginals(df2)
    print(t_list1)
    print(t_list2)
    assert np.all(t_list1[[0,2,3]] == t_list2[[0,2,3]]) 

test_validity()

[8 6 4 4]
[8 7 5 4]


AssertionError: 