Skip to content

Commit

Permalink
fix ArminSampler mps mode
Browse files Browse the repository at this point in the history
  • Loading branch information
ShoyaYasuda committed Feb 20, 2024
1 parent 49ca669 commit 13b273e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 71 deletions.
2 changes: 1 addition & 1 deletion tytan/_version.py
@@ -1 +1 @@
__version__ = "0.0.25"
__version__ = "0.0.26"
74 changes: 4 additions & 70 deletions tytan/sampler.py
Expand Up @@ -698,14 +698,14 @@ def run(self, qubo, shots=100, T_num=2000, show=False):

# --- テンソル疑似SA ---
#
qmatrix = torch.tensor(qmatrix, device=self.device).float()
qmatrix = torch.tensor(qmatrix, dtype=torch.float32, device=self.device).float()

# プール初期化
pool_num = shots
pool = torch.randint(0, 2, (pool_num, N), device=self.device).float()
pool = torch.randint(0, 2, (pool_num, N), dtype=torch.float32, device=self.device).float()

# スコア初期化
score = torch.sum((pool @ qmatrix) * pool, dim=1)
score = torch.sum((pool @ qmatrix) * pool, dim=1, dtype=torch.float32)

# フリップ数リスト(2個まで下がる)
flip = np.sort(nr.rand(T_num) ** 2)[::-1]
Expand Down Expand Up @@ -794,70 +794,4 @@ def run(self, qubo, shots=100, T_num=2000, show=False):


if __name__ == "__main__":
import time, os, sys
from tytan import symbols_list, Compile

size = 24
shots = 100

num_people = size #人の数
num_time = size #シフトの時間

#量子ビットを用意
q = symbols_list([num_people], 'q_{}')

#シフト枠
shift = np.ones((num_time)) * 2

ppl = np.zeros((num_people, num_time))
for i in range(num_people//2):
ppl[i, i] = 2
ppl[i, num_time//2 + i] = 2
# choice = [0, 1]
# weight = [(num_time-2.0)/num_time, 2.0/num_time]
for i in range(num_people//2, num_people):
ppl[i] = np.random.randint(0, 2, num_time) * 8.0 / num_time
# ppl[i] = np.random.choice(choice, num_time, p=weight)
np.set_printoptions(threshold=np.inf)
# print(ppl)

# QUBOを作る。それぞれの時間に対して、枠の人数に収まるように変数を格納。
H = 0
for j in range(num_time):
tmp = 0
for i in range(num_people):
tmp += ppl[i][j] * q[i]
H += (tmp - shift[i])**2
#print(H)

# H = 2*q[0]*q[1] + 1*q[2]*q[3] + 2*q[4] + q[5]

#コンパイル
s = time.time()
qubo, offset = Compile(H).get_qubo()
print(f'{round(time.time() - s, 1)} s')


#サンプラー選択
solver = ArminSampler(seed=None, mode='GPU')
#サンプリング
s = time.time()
result = solver.run(qubo, shots=shots, show=True)

#確認
tmp = np.zeros((num_people, num_time), int)
r = result[0][0]
print(result[0][1], result[0][2])
for k in r.keys():
b = int(r[k])
i = int(k.split('_')[1])
if b == 1:
tmp[i] = ppl[i]
#print(tmp)
print(np.sum(tmp, axis=0))
print(result[0][1] + size*2**2)

# suc = (np.sum(tmp, axis=0) == 2).all()
suc = abs(result[0][1] + size*2**2) < 0.0001
print(f'schedule | {size:03} | {shots:04} | {suc} | {round(time.time() - s, 1)} s | {round(sys.getsizeof(qubo)/1024, 1)} KB |')

pass

0 comments on commit 13b273e

Please sign in to comment.