In [1]:
import pennylane as qml
from pennylane import numpy as np
from pennylane.optimize import NesterovMomentumOptimizer

In [90]:
def get_winner(board):
    # Check the board for any winning combinations
    winning_combinations = [
        # Rows
        (0, 1, 2),
        (3, 4, 5),
        (6, 7, 8),
        # Columns
        (0, 3, 6),
        (1, 4, 7),
        (2, 5, 8),
        # Diagonals
        (0, 4, 8),
        (2, 4, 6),
    ]

    x_wins = False
    o_wins = False

    for combo in winning_combinations:
        if board[combo[0]] == board[combo[1]] == board[combo[2]] and board[combo[0]] != '':
            if board[combo[0]] == 'x':
                return [0, 0, 1]
            else:
                return [1, 0, 0]
    return [0, 1, 0]

#This function checks for a couple of things, length of the board,
def is_valid_tic_tac_toe(board):
    # Check that the board has exactly 9 elements
    if len(board) != 9:
        return False

    # Count the number of 'x' and 'o' on the board
    count_x = board.count('x')
    count_o = board.count('o')

    # Check that the difference in count between 'x' and 'o' is 0 or 1
    if abs(count_x - count_o) > 1:
        return False

    # Check the board for any winning combinations
    winning_combinations = [
        # Rows
        (0, 1, 2),
        (3, 4, 5),
        (6, 7, 8),
        # Columns
        (0, 3, 6),
        (1, 4, 7),
        (2, 5, 8),
        # Diagonals
        (0, 4, 8),
        (2, 4, 6),
    ]

    x_wins = False
    o_wins = False

    for combo in winning_combinations:
        if board[combo[0]] == board[combo[1]] == board[combo[2]] and board[combo[0]] != '':
            if board[combo[0]] == 'x':
                x_wins = True
            else:
                o_wins = True

    # Check if both 'x' and 'o' won or if neither won
    if x_wins and o_wins or (not x_wins and not o_wins):
        return False

    # Check that the board is a valid final board configuration
    if (x_wins and count_x != count_o + 1) or (o_wins and count_x != count_o):
        return False
    # All checks have passed, so the board is valid
    return True


def generate_tic_tac_toe_configs():
    valid_configs = []
    winners = []

    # Generate all possible configurations of the board
    for i in range(3 ** 9):
        board = []
        for j in range(9):
            symbol = ''
            if i % 3 == 0:
                symbol = 'x'
            elif i % 3 == 1:
                symbol = 'o'
            board.append(symbol)
            i //= 3

        # Check if the configuration is valid
        if is_valid_tic_tac_toe(board):
            valid_configs.append(board)
            winners.append(get_winner(board))

    return valid_configs, winners

boards, winners = generate_tic_tac_toe_configs()

import pennylane as qml
from pennylane import numpy as np

def encode_data(tic_tac_toe_field):
    # data_g = [1 if entry == 'x' else -1 if entry == 'o' else 0 for entry in tic_tac_toe_field]
    for entry, index in zip(tic_tac_toe_field, range(len(tic_tac_toe_field))):
        qml.RX(entry, wires=[index])
        #print(qml.RX(entry * 2 * np.pi / 3, wires=[index]))

    return


def add_single_qubit_gates(params):
    # define edges centers and lats.
    edges = [0, 2, 6, 8]
    lats = [1, 3, 5, 7]
    center = 4

    for i in edges:
        qml.RX(params[0], wires=[i])
        qml.RY(params[1], wires=[i])
    for i in lats:
        qml.RX(params[2], wires=[i])
        qml.RY(params[3], wires=[i])


    qml.RX(params[4], wires=[center])
    qml.RY(params[5], wires=[center])

    return

def add_two_qubit_gates(params):
    # corners (green)
    corner_qubits = [0, 2, 6, 8]
    edge_qubits = [1, 3, 5, 7]
    center_qubit = 4


    # yellow two-qubit gates
    for i in range(4):
        qml.CPhase(params[2], wires=[4, edge_qubits[i]])


    # red two-qubit gates, hard coded
    qml.CPhase(params[1], wires=[1, 0])
    qml.CPhase(params[1], wires=[1, 2])
    qml.CPhase(params[1], wires=[3, 0])
    qml.CPhase(params[1], wires=[3, 6])
    qml.CPhase(params[1], wires=[5, 2])
    qml.CPhase(params[1], wires=[5, 8])
    qml.CPhase(params[1], wires=[7, 6])
    qml.CPhase(params[1], wires=[7, 8])

    # green two-qubit gates, hard coded
    qml.CPhase(params[0], wires=[0, 4])
    qml.CPhase(params[0], wires=[2, 4])
    qml.CPhase(params[0], wires=[6, 4])
    qml.CPhase(params[0], wires=[8, 4])

    return

In [102]:
### SO FAR SO GOOD ###

print("ok")

obs_ZIZIIIZIZ = 0.25 * qml.PauliZ(0) @ qml.PauliZ(2) @ qml.PauliZ(6) @ qml.PauliZ(8)
obs_IIIIZIIII = qml.PauliZ(4)
obs_IZIZIZIZI = 0.25 * qml.PauliZ(1) @ qml.PauliZ(3) @ qml.PauliZ(5) @ qml.PauliZ(7)

observables = (obs_ZIZIIIZIZ,)# obs_IIIIZIIII, obs_IZIZIZIZI)

dev = qml.device("default.qubit", wires=9)
@qml.qnode(dev)
def circuit(params, tactoe):
    encode_data(tactoe)
    add_single_qubit_gates(params[:6])
    add_two_qubit_gates(params[6:])
    return qml.expval(obs_ZIZIIIZIZ)#, qml.expval(obs_IIIIZIIII), qml.expval(obs_IZIZIZIZI)

tictac = ['x', 'o', '', '', '', '', '', '', '']
total_params=[0.2,0.3,0.2,0.4,0.5, 0.6, 0.2,0.3,0.2]

ok


In [104]:
##SETUP OF THE LABELS AND TRAININGS

x = np.array([[0 if e == '' else 2*np.pi/3 if e == 'x' else -2*np.pi/3 for e in b] for b in boards])
y = np.array(winners)[:,0]

# shuffle the indices
shuffle_indices = np.random.permutation(len(x))
train_size = int(len(x) * 0.3)

# split the indices into training and testing sets
train_indices = np.array(shuffle_indices[:train_size])
test_indices = np.array(shuffle_indices[train_size:])

# create the training and testing sets
X, Y = np.take(x, train_indices, axis=0), np.take(y, train_indices, axis=0)
x_test, y_test = np.take(x, test_indices, axis=0), np.take(y, test_indices, axis=0)

In [103]:
print("Example train data: ", X[17], Y[17])

circ1, circ2, circ3 = circuit(total_params, X[17])
expectation = (circ1, circ2, circ3)
print(expectation)

Example train data:  [ 0.         2.0943951  2.0943951  2.0943951  2.0943951 -2.0943951
 -2.0943951 -2.0943951 -2.0943951] 1


TypeError: iteration over a 0-d array

In [112]:
##DEFINE A COST FUNCTION

def variational_classifier(weights, bias, x):
    return circuit(weights, x) + bias

def square_loss(labels, predictions):
    loss = 0
    for l, p in zip(labels, predictions):
        loss = loss + (l - p) ** 2

    loss = loss / len(labels)
    return loss

def accuracy(labels, predictions):

    loss = 0
    for l, p in zip(labels, predictions):
        if abs(l - p) < 1e-5:
            loss = loss + 1
    loss = loss / len(labels)

    return loss

#check if the cost definition is correct
def cost(weights, bias, X, Y):
    predictions = [variational_classifier(weights, bias, x) for x in X]
    # print(predictions)
    return square_loss(Y, predictions)

In [106]:
##Actual Training of it
#Nesterov goes from 0.01 to 0.99, smaller values indicate smaller step size so lower computation speed but higher accuracy.
##Vice versa for bigger  values. Let's take a simple 0.5 for now. Also try PSPA and Adam.


bias_init = np.array(0.0, requires_grad=True)
opt = NesterovMomentumOptimizer(0.5)
batch_size = 10

In [107]:
from tqdm import tqdm

In [None]:
params = total_params  #at some point we can start it randomly though now i can't
bias = bias_init
weights = np.random.rand(9)*2*np.pi
for it in tqdm(range(100)):
    acc=0
    batch_index = np.random.randint(0, len(X), (batch_size))
    X_batch = X[batch_index]
    Y_batch = Y[batch_index]
    weights, bias, _, _ = opt.step(cost, weights, bias, X_batch, Y_batch)
    print(weights)
    pred_index = np.random.randint(0, len(X), (10))
    predictions = [np.sign(variational_classifier(weights, bias, x)) for x in X[pred_index]]
    acc = accuracy(Y, predictions)

    print(
        "Iter: {:5d} | Cost: {:0.7f} | Accuracy: {:0.7f} ".format(
            it + 1, cost(weights, bias, X, Y), acc
        )
    )
print(params)

  0%|                                                                                          | 0/100 [00:00<?, ?it/s]

[1.97722857 3.86753962 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


  1%|▊                                                                                 | 1/100 [00:24<41:01, 24.86s/it]

Iter:     1 | Cost: 0.2150773 | Accuracy: 0.0070922 
[1.98532091 3.8699603  4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


  2%|█▋                                                                                | 2/100 [00:49<40:52, 25.02s/it]

Iter:     2 | Cost: 0.2048675 | Accuracy: 0.0070922 
[2.0045372  3.90037994 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


  3%|██▍                                                                               | 3/100 [01:15<40:29, 25.04s/it]

Iter:     3 | Cost: 0.2086716 | Accuracy: 0.0070922 
[2.02620953 3.92024107 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


  4%|███▎                                                                              | 4/100 [01:41<40:41, 25.43s/it]

Iter:     4 | Cost: 0.2118212 | Accuracy: 0.0070922 
[2.04701076 3.95790498 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


  5%|████                                                                              | 5/100 [02:06<40:08, 25.35s/it]

Iter:     5 | Cost: 0.2184774 | Accuracy: 0.0070922 
[2.06773362 3.96771982 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


  6%|████▉                                                                             | 6/100 [02:31<39:40, 25.32s/it]

Iter:     6 | Cost: 0.3065067 | Accuracy: 0.0000000 
[2.09645658 3.9993225  4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


  7%|█████▋                                                                            | 7/100 [02:56<39:17, 25.35s/it]

Iter:     7 | Cost: 0.2048574 | Accuracy: 0.0070922 
[2.12781938 4.03183986 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


  8%|██████▌                                                                           | 8/100 [03:22<38:54, 25.38s/it]

Iter:     8 | Cost: 0.2082123 | Accuracy: 0.0070922 
[2.15751502 4.06293137 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


  9%|███████▍                                                                          | 9/100 [03:47<38:24, 25.33s/it]

Iter:     9 | Cost: 0.2154680 | Accuracy: 0.0070922 
[2.19077891 4.10227127 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 10%|████████                                                                         | 10/100 [04:12<37:49, 25.21s/it]

Iter:    10 | Cost: 0.2980347 | Accuracy: 0.0070922 
[2.21730543 4.11849405 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 11%|████████▉                                                                        | 11/100 [04:38<37:30, 25.29s/it]

Iter:    11 | Cost: 0.3098606 | Accuracy: 0.0000000 
[2.24302249 4.1246758  4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 12%|█████████▋                                                                       | 12/100 [05:03<37:16, 25.41s/it]

Iter:    12 | Cost: 0.3824093 | Accuracy: 0.0070922 
[2.25655159 4.11335694 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 13%|██████████▌                                                                      | 13/100 [05:29<36:55, 25.46s/it]

Iter:    13 | Cost: 0.3839186 | Accuracy: 0.0000000 
[2.27444744 4.1121691  4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 14%|███████████▎                                                                     | 14/100 [05:54<36:20, 25.36s/it]

Iter:    14 | Cost: 0.2741035 | Accuracy: 0.0070922 
[2.30056675 4.10856721 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 15%|████████████▏                                                                    | 15/100 [06:19<35:51, 25.31s/it]

Iter:    15 | Cost: 0.2404118 | Accuracy: 0.0070922 
[2.32581889 4.10158556 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 16%|████████████▉                                                                    | 16/100 [06:45<35:31, 25.37s/it]

Iter:    16 | Cost: 0.2113007 | Accuracy: 0.0070922 
[2.3483759  4.09512207 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 17%|█████████████▊                                                                   | 17/100 [07:10<35:16, 25.50s/it]

Iter:    17 | Cost: 0.2646025 | Accuracy: 0.0070922 
[2.36717619 4.08766807 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 18%|██████████████▌                                                                  | 18/100 [07:36<34:40, 25.38s/it]

Iter:    18 | Cost: 0.2746545 | Accuracy: 0.0070922 
[2.38389718 4.08082517 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 19%|███████████████▍                                                                 | 19/100 [08:01<34:19, 25.43s/it]

Iter:    19 | Cost: 0.2879488 | Accuracy: 0.0070922 
[2.38717181 4.06228914 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 20%|████████████████▏                                                                | 20/100 [08:27<33:56, 25.46s/it]

Iter:    20 | Cost: 0.4676304 | Accuracy: 0.0070922 
[2.37834445 4.04007891 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 21%|█████████████████                                                                | 21/100 [08:52<33:33, 25.49s/it]

Iter:    21 | Cost: 0.2599924 | Accuracy: 0.0070922 
[2.34265738 3.97106171 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 22%|█████████████████▊                                                               | 22/100 [09:18<33:13, 25.55s/it]

Iter:    22 | Cost: 0.3930856 | Accuracy: 0.0070922 
[2.34126154 3.94110841 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 23%|██████████████████▋                                                              | 23/100 [09:43<32:36, 25.41s/it]

Iter:    23 | Cost: 0.2370914 | Accuracy: 0.0070922 
[2.3542325  3.92602226 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 24%|███████████████████▍                                                             | 24/100 [10:09<32:16, 25.49s/it]

Iter:    24 | Cost: 0.2064538 | Accuracy: 0.0070922 
[2.3678482  3.91255112 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 25%|████████████████████▎                                                            | 25/100 [10:34<31:40, 25.34s/it]

Iter:    25 | Cost: 0.2073242 | Accuracy: 0.0070922 
[2.36899072 3.88908004 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 26%|█████████████████████                                                            | 26/100 [10:59<31:16, 25.36s/it]

Iter:    26 | Cost: 0.2162354 | Accuracy: 0.0070922 
[2.35706714 3.86053073 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 27%|█████████████████████▊                                                           | 27/100 [11:24<30:46, 25.29s/it]

Iter:    27 | Cost: 0.2107130 | Accuracy: 0.0070922 
[2.33191899 3.82509805 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 28%|██████████████████████▋                                                          | 28/100 [11:50<30:36, 25.51s/it]

Iter:    28 | Cost: 0.2168224 | Accuracy: 0.0070922 
[2.28193567 3.82842038 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 29%|███████████████████████▍                                                         | 29/100 [12:16<30:09, 25.49s/it]

Iter:    29 | Cost: 0.3045684 | Accuracy: 0.0000000 
[2.24025732 3.85111512 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 30%|████████████████████████▎                                                        | 30/100 [12:41<29:43, 25.48s/it]

Iter:    30 | Cost: 0.2431649 | Accuracy: 0.0070922 
[2.19169417 3.85535923 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 31%|█████████████████████████                                                        | 31/100 [13:07<29:29, 25.65s/it]

Iter:    31 | Cost: 0.2104261 | Accuracy: 0.0070922 
[2.14020178 3.86898139 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 32%|█████████████████████████▉                                                       | 32/100 [13:32<28:58, 25.57s/it]

Iter:    32 | Cost: 0.2067323 | Accuracy: 0.0070922 
[2.09248017 3.86989387 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 33%|██████████████████████████▋                                                      | 33/100 [13:57<28:16, 25.32s/it]

Iter:    33 | Cost: 0.2101301 | Accuracy: 0.0070922 
[2.03393828 3.85784377 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 34%|███████████████████████████▌                                                     | 34/100 [14:23<27:53, 25.36s/it]

Iter:    34 | Cost: 0.2095590 | Accuracy: 0.0070922 
[1.97736607 3.81833713 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 35%|████████████████████████████▎                                                    | 35/100 [14:48<27:31, 25.40s/it]

Iter:    35 | Cost: 0.2151007 | Accuracy: 0.0070922 
[1.91881803 3.76889761 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 36%|█████████████████████████████▏                                                   | 36/100 [15:14<27:21, 25.64s/it]

Iter:    36 | Cost: 0.2094779 | Accuracy: 0.0070922 
[1.85272248 3.70845731 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 37%|█████████████████████████████▉                                                   | 37/100 [15:40<26:46, 25.50s/it]

Iter:    37 | Cost: 0.2743035 | Accuracy: 0.0035461 
[1.78513882 3.60533293 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 38%|██████████████████████████████▊                                                  | 38/100 [16:05<26:15, 25.42s/it]

Iter:    38 | Cost: 0.2309678 | Accuracy: 0.0070922 
[1.69004145 3.51171697 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 39%|███████████████████████████████▌                                                 | 39/100 [16:30<25:46, 25.35s/it]

Iter:    39 | Cost: 0.4476223 | Accuracy: 0.0070922 
[1.5480735  3.46629553 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 40%|████████████████████████████████▍                                                | 40/100 [16:55<25:22, 25.37s/it]

Iter:    40 | Cost: 0.2404122 | Accuracy: 0.0070922 
[1.35800422 3.33774097 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 41%|█████████████████████████████████▏                                               | 41/100 [17:22<25:10, 25.60s/it]

Iter:    41 | Cost: 0.2152262 | Accuracy: 0.0070922 
[1.12806638 3.21920518 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 42%|██████████████████████████████████                                               | 42/100 [17:46<24:27, 25.31s/it]

Iter:    42 | Cost: 0.2711997 | Accuracy: 0.0070922 
[0.95977537 3.1095063  4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 43%|██████████████████████████████████▊                                              | 43/100 [18:11<24:00, 25.28s/it]

Iter:    43 | Cost: 0.2390017 | Accuracy: 0.0070922 
[0.7843573  2.99794908 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 44%|███████████████████████████████████▋                                             | 44/100 [18:37<23:45, 25.46s/it]

Iter:    44 | Cost: 0.2383184 | Accuracy: 0.0070922 
[0.64903444 2.89931991 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 45%|████████████████████████████████████▍                                            | 45/100 [19:03<23:19, 25.44s/it]

Iter:    45 | Cost: 0.3038853 | Accuracy: 0.0070922 
[0.50120184 2.81066956 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 46%|█████████████████████████████████████▎                                           | 46/100 [19:28<22:57, 25.51s/it]

Iter:    46 | Cost: 0.2374707 | Accuracy: 0.0070922 
[0.33317866 2.68691651 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 47%|██████████████████████████████████████                                           | 47/100 [19:53<22:22, 25.32s/it]

Iter:    47 | Cost: 0.2068175 | Accuracy: 0.0070922 
[0.19908107 2.5624305  4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 48%|██████████████████████████████████████▉                                          | 48/100 [20:19<22:00, 25.39s/it]

Iter:    48 | Cost: 0.2253548 | Accuracy: 0.0070922 
[0.07239212 2.44979409 4.98382994 1.83581191 5.09442993 1.98786539
 0.79222249 6.21125359 0.30230072]


 49%|███████████████████████████████████████▋                                         | 49/100 [20:44<21:30, 25.30s/it]

Iter:    49 | Cost: 0.2108857 | Accuracy: 0.0070922 
[-0.04146793  2.35186107  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 50%|████████████████████████████████████████▌                                        | 50/100 [21:09<21:09, 25.40s/it]

Iter:    50 | Cost: 0.2962029 | Accuracy: 0.0035461 
[-0.14546995  2.26569316  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 51%|█████████████████████████████████████████▎                                       | 51/100 [21:34<20:39, 25.29s/it]

Iter:    51 | Cost: 0.2698070 | Accuracy: 0.0070922 
[-0.23968676  2.18914834  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 52%|██████████████████████████████████████████                                       | 52/100 [22:00<20:15, 25.32s/it]

Iter:    52 | Cost: 0.2164353 | Accuracy: 0.0070922 
[-0.3207986   2.11397688  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 53%|██████████████████████████████████████████▉                                      | 53/100 [22:24<19:39, 25.10s/it]

Iter:    53 | Cost: 0.2195823 | Accuracy: 0.0070922 
[-0.39313691  2.04671702  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 54%|███████████████████████████████████████████▋                                     | 54/100 [22:49<19:11, 25.03s/it]

Iter:    54 | Cost: 0.2186855 | Accuracy: 0.0070922 
[-0.45887562  1.98779409  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 55%|████████████████████████████████████████████▌                                    | 55/100 [23:15<18:54, 25.21s/it]

Iter:    55 | Cost: 0.2237258 | Accuracy: 0.0070922 
[-0.51805834  1.93958205  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 56%|█████████████████████████████████████████████▎                                   | 56/100 [23:40<18:27, 25.18s/it]

Iter:    56 | Cost: 0.3306610 | Accuracy: 0.0070922 
[-0.57129439  1.89619064  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 57%|██████████████████████████████████████████████▏                                  | 57/100 [24:05<18:03, 25.21s/it]

Iter:    57 | Cost: 0.2379197 | Accuracy: 0.0070922 
[-0.61870273  1.85650502  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 58%|██████████████████████████████████████████████▉                                  | 58/100 [24:30<17:37, 25.17s/it]

Iter:    58 | Cost: 0.2062213 | Accuracy: 0.0070922 
[-0.66134567  1.82074742  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 59%|███████████████████████████████████████████████▊                                 | 59/100 [24:56<17:17, 25.32s/it]

Iter:    59 | Cost: 0.2120286 | Accuracy: 0.0070922 
[-0.69972135  1.78902459  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 60%|████████████████████████████████████████████████▌                                | 60/100 [25:22<17:00, 25.50s/it]

Iter:    60 | Cost: 0.2169399 | Accuracy: 0.0070922 
[-0.73425677  1.76046593  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 61%|█████████████████████████████████████████████████▍                               | 61/100 [25:47<16:27, 25.31s/it]

Iter:    61 | Cost: 0.2729873 | Accuracy: 0.0070922 
[-0.76536939  1.73494327  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 62%|██████████████████████████████████████████████████▏                              | 62/100 [26:12<15:59, 25.24s/it]

Iter:    62 | Cost: 0.2101813 | Accuracy: 0.0070922 
[-0.79335955  1.71213443  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 63%|███████████████████████████████████████████████████                              | 63/100 [26:38<15:39, 25.40s/it]

Iter:    63 | Cost: 0.2813157 | Accuracy: 0.0070922 
[-0.81855344  1.69180452  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 64%|███████████████████████████████████████████████████▊                             | 64/100 [27:03<15:09, 25.27s/it]

Iter:    64 | Cost: 0.2291551 | Accuracy: 0.0070922 
[-0.84122767  1.67354435  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 65%|████████████████████████████████████████████████████▋                            | 65/100 [27:28<14:46, 25.32s/it]

Iter:    65 | Cost: 0.3550675 | Accuracy: 0.0070922 
[-0.86163414  1.65723729  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 66%|█████████████████████████████████████████████████████▍                           | 66/100 [27:52<14:06, 24.90s/it]

Iter:    66 | Cost: 0.2513929 | Accuracy: 0.0070922 
[-0.87999993  1.64256435  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 67%|██████████████████████████████████████████████████████▎                          | 67/100 [28:17<13:41, 24.90s/it]

Iter:    67 | Cost: 0.2097778 | Accuracy: 0.0070922 
[-0.89652868  1.62937084  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 68%|███████████████████████████████████████████████████████                          | 68/100 [28:43<13:26, 25.20s/it]

Iter:    68 | Cost: 0.2079742 | Accuracy: 0.0070922 
[-0.91140452  1.61749759  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 69%|███████████████████████████████████████████████████████▉                         | 69/100 [29:08<13:01, 25.22s/it]

Iter:    69 | Cost: 0.2150597 | Accuracy: 0.0070922 
[-0.92479283  1.60681333  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 70%|████████████████████████████████████████████████████████▋                        | 70/100 [29:34<12:40, 25.36s/it]

Iter:    70 | Cost: 0.3077411 | Accuracy: 0.0070922 
[-0.93684235  1.59719641  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 71%|█████████████████████████████████████████████████████████▌                       | 71/100 [29:59<12:13, 25.28s/it]

Iter:    71 | Cost: 0.3394866 | Accuracy: 0.0000000 
[-0.94768693  1.588543    4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 72%|██████████████████████████████████████████████████████████▎                      | 72/100 [30:24<11:45, 25.19s/it]

Iter:    72 | Cost: 0.6060033 | Accuracy: 0.0070922 
[-0.95744705  1.58075507  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 73%|███████████████████████████████████████████████████████████▏                     | 73/100 [30:49<11:19, 25.18s/it]

Iter:    73 | Cost: 0.2792885 | Accuracy: 0.0070922 
[-0.96623115  1.57374594  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 74%|███████████████████████████████████████████████████████████▉                     | 74/100 [31:16<11:12, 25.85s/it]

Iter:    74 | Cost: 0.4287127 | Accuracy: 0.0070922 
[-0.97413685  1.56743773  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 75%|████████████████████████████████████████████████████████████▊                    | 75/100 [31:42<10:43, 25.75s/it]

Iter:    75 | Cost: 0.2565541 | Accuracy: 0.0070922 
[-0.98125197  1.5617604   4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 76%|█████████████████████████████████████████████████████████████▌                   | 76/100 [32:07<10:13, 25.56s/it]

Iter:    76 | Cost: 0.2336161 | Accuracy: 0.0070922 
[-0.98765559  1.55665051  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 77%|██████████████████████████████████████████████████████████████▎                  | 77/100 [32:32<09:45, 25.44s/it]

Iter:    77 | Cost: 0.4548575 | Accuracy: 0.0070922 
[-0.99341884  1.55205239  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


 78%|███████████████████████████████████████████████████████████████▏                 | 78/100 [32:57<09:16, 25.31s/it]

Iter:    78 | Cost: 0.2063336 | Accuracy: 0.0070922 
[-0.99860577  1.54791429  4.98382994  1.83581191  5.09442993  1.98786539
  0.79222249  6.21125359  0.30230072]


In [120]:
opt_weights = weights

In [121]:
opt_weights

tensor([0.21097864, 4.56547731, 3.90377161, 3.99236746, 0.13792279,
        4.8763316 , 5.20257472, 4.96202827, 4.93813199], requires_grad=True)