In [31]:
import gym
import torch
from torch import nn, optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import random
from statistics import mean

In [2]:
print(torch.cuda.is_available())

True


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [90]:
class Network(nn.Module):
    def __init__(self,seed):
        super(Network, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.fc1 = nn.Linear(8, 256)
        self.fc2 = nn.Linear(256,128)
        self.fc3 = nn.Linear(128,64)
        self.fc4 = nn.Linear(64, 4)
    
    def forward(self, state):
        h = F.relu(self.fc1(state))
        h = F.relu(self.fc2(h))
        h = F.relu(self.fc3(h))
        y = self.fc4(h)
        return y

In [114]:
class QLearningAgent(object):
    def __init__(self,alpha,gamma,epsilon,n_eps,N,C,M,seed):
        self.memory = []
        self.memory_max = N
        self.target_update = C
        self.Q_t = Network(seed).to(device)
        self.Q = Network(seed).to(device)
        self.alpha = alpha
        self.optimizer = optim.SGD(self.Q.parameters(), lr=self.alpha)
        self.gamma = gamma
        self.epsilon = epsilon
        self.seed = seed
        self.C = C
        self.n_eps = n_eps
        self.mini_batch_size = M
        self.env = gym.make('LunarLander-v2')
        self.env.seed(seed)
    
    def store_memory(self,state,action,reward,next_state,done):
        reward = np.array([reward],dtype = float)
        action = np.array([action],dtype = int)
        done = np.array([done],dtype = int)
        self.memory.append((state,action,reward,next_state,done))
    
    def sample_memory(self,M):
        batch = np.array(random.sample(self.memory, k=M),dtype = object)
        batch = batch.T
        batch = batch.tolist()
#         print(batch[1])
        return (torch.tensor(batch[0]).to(device),torch.tensor(batch[1],dtype=torch.int64).to(device),torch.tensor(batch[2],dtype = torch.float).to(device),torch.tensor(batch[3]).to(device),torch.tensor(batch[4]).to(device))
    
    def solve(self):
        states = self.env.observation_space
        actions = self.env.action_space
        np.random.seed(self.seed)
        count = 0
        scores = []
        for eps in range(self.n_eps):
            state = self.env.reset()
            score = 0
            for i in range(1000000):
                greed = np.random.random()
                #Feed Forward once to predict the best action for current state
                self.Q.eval()
                with torch.no_grad():
                    weights = self.Q(torch.tensor(state).to(device))
#                     print(weights)
                self.Q.train()
                if greed < self.epsilon:
                    action = np.random.randint(0, 4)
                else:
                    action = np.argmax(weights.detach().cpu().numpy())
                next_state, reward, done, data = self.env.step(action)
                score+=reward
                self.store_memory(state,action,reward,next_state,done)
                
                if len(self.memory)<self.mini_batch_size:
                    break
                else:
                    transitions = self.sample_memory(self.mini_batch_size)
                
                states,actions,rewards,next_states,dones = transitions
                Q_t = self.Q_t(next_states).detach()
                Q_tmax = Q_t.max(1)[0].unsqueeze(1)
                case2 = rewards + self.gamma * Q_tmax
                case1 = rewards
#                 y_j = reward + (gamma * Q_tmax * (1-done))  
                y = torch.where(dones<1,case2,case1)    
                Q = self.Q(states).gather(1, actions)
                print(y)
                print(Q)
                loss = F.mse_loss(Q, y)
#                 print(loss)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                count+=1
                if count == self.C:
                    count = 0
#                     self.Q_t = deepcopy(self.Q)
                    self.Q_t.load_state_dict(self.Q.state_dict())
                state = deepcopy(next_state)
                if done:
                    break
            scores.append(score)
        print(mean(scores))


In [115]:
agent = QLearningAgent(0.001,0.99,0.7,300,300,50,64,5)

In [116]:
agent.solve()

tensor([[-2.8034],
        [ 0.1124],
        [-2.0223],
        [ 0.8182],
        [-1.5688],
        [-0.7718],
        [ 0.6705],
        [ 0.0831],
        [-2.0394],
        [-1.8530],
        [ 1.4764],
        [-0.6657],
        [-0.4819],
        [ 2.2737],
        [-2.4323],
        [-1.9792],
        [ 3.3549],
        [-1.6846],
        [-1.9973],
        [ 1.9711],
        [-1.5296],
        [-1.5450],
        [-1.9230],
        [-2.0575],
        [ 0.6077],
        [-1.7831],
        [-0.6417],
        [ 1.5677],
        [-1.1218],
        [ 0.2709],
        [-1.9832],
        [-0.7079],
        [ 0.5161],
        [ 1.7517],
        [ 0.2448],
        [ 0.6231],
        [-0.7676],
        [ 1.7187],
        [ 0.4347],
        [-1.0096],
        [-0.8431],
        [-1.7115],
        [ 1.5441],
        [ 0.4616],
        [-1.8827],
        [ 0.3167],
        [-1.5262],
        [-0.6446],
        [-1.5263],
        [-1.1352],
        [-1.0047],
        [-0.8245],
        [ 1.

tensor([[ 0.0451],
        [ 0.0192],
        [ 0.0434],
        [ 0.0195],
        [ 0.0429],
        [-0.1406],
        [ 0.0498],
        [ 0.0487],
        [ 0.0407],
        [ 0.0300],
        [ 0.0406],
        [ 0.0139],
        [-0.0054],
        [ 0.0146],
        [ 0.0515],
        [ 0.0491],
        [ 0.0145],
        [-0.0024],
        [-0.1541],
        [ 0.0176],
        [ 0.0332],
        [-0.0112],
        [-0.0055],
        [-0.0057],
        [ 0.0360],
        [-0.1610],
        [-0.1412],
        [ 0.0459],
        [ 0.0435],
        [-0.0040],
        [ 0.0315],
        [-0.0062],
        [ 0.0229],
        [ 0.0258],
        [ 0.0220],
        [ 0.0444],
        [-0.0041],
        [ 0.0475],
        [ 0.0484],
        [ 0.0304],
        [ 0.0522],
        [ 0.0220],
        [ 0.0393],
        [ 0.0096],
        [-0.0032],
        [ 0.0435],
        [-0.1297],
        [ 0.0440],
        [ 0.0505],
        [-0.0067],
        [ 0.0500],
        [ 0.0495],
        [ 0.

        [-1.0047]], device='cuda:0')
tensor([[ 0.0427],
        [ 0.0491],
        [ 0.0463],
        [ 0.0494],
        [-0.1317],
        [ 0.0492],
        [ 0.0249],
        [-0.0147],
        [ 0.0394],
        [-0.0138],
        [-0.1427],
        [ 0.0255],
        [ 0.0339],
        [ 0.0328],
        [ 0.0436],
        [-0.1362],
        [-0.0100],
        [ 0.0296],
        [ 0.0521],
        [ 0.0431],
        [-0.0159],
        [ 0.0417],
        [-0.0106],
        [ 0.0192],
        [-0.1172],
        [-0.0093],
        [-0.0140],
        [-0.0139],
        [-0.0183],
        [ 0.0342],
        [ 0.0182],
        [ 0.0124],
        [ 0.0523],
        [ 0.0483],
        [-0.0085],
        [ 0.0108],
        [-0.0103],
        [-0.0110],
        [-0.0133],
        [ 0.0382],
        [ 0.0188],
        [-0.0094],
        [-0.0115],
        [-0.1304],
        [ 0.0500],
        [ 0.0247],
        [-0.0140],
        [ 0.0148],
        [-0.0136],
        [ 0.0450],
        [ 0.0

tensor([[-1.1352],
        [-1.5688],
        [-1.7831],
        [ 0.5537],
        [-0.8431],
        [-0.8245],
        [ 0.6870],
        [ 1.9711],
        [-0.5570],
        [-2.4323],
        [ 0.0831],
        [-0.5670],
        [ 0.4616],
        [-1.5262],
        [ 0.6705],
        [-1.1821],
        [-1.8827],
        [ 0.4022],
        [-0.7718],
        [ 1.3948],
        [ 1.1600],
        [-1.7584],
        [ 3.3549],
        [-1.1218],
        [ 0.2069],
        [-0.7676],
        [ 1.3502],
        [ 1.4764],
        [-1.9295],
        [-2.6403],
        [-3.5008],
        [-0.5040],
        [ 1.0774],
        [ 0.0957],
        [ 1.7517],
        [-1.8409],
        [-1.0096],
        [-2.3148],
        [ 0.6231],
        [ 1.7187],
        [-2.0965],
        [-1.8530],
        [-1.9230],
        [-1.9832],
        [ 0.2448],
        [-2.0223],
        [ 2.2737],
        [-0.6657],
        [ 0.0563],
        [ 1.1011],
        [-0.6446],
        [-1.6846],
        [ 0.

tensor([[-0.8431],
        [-2.6403],
        [-0.3783],
        [-1.5688],
        [-1.5263],
        [-0.5040],
        [ 0.5537],
        [-3.5008],
        [ 0.6231],
        [-1.7584],
        [-2.3148],
        [ 0.2138],
        [-0.7718],
        [-2.8999],
        [ 0.3167],
        [-0.6446],
        [ 3.3549],
        [ 0.0563],
        [ 0.0831],
        [-1.8409],
        [ 1.6637],
        [-0.2519],
        [ 1.5441],
        [-1.7920],
        [ 0.8182],
        [ 1.1600],
        [ 0.4347],
        [-2.0575],
        [ 2.2973],
        [ 0.0957],
        [-2.8034],
        [-0.4819],
        [-1.7831],
        [ 2.2737],
        [ 0.3849],
        [-1.3660],
        [-2.0965],
        [ 1.0774],
        [-1.0096],
        [ 0.2709],
        [-1.1476],
        [ 1.2753],
        [ 0.1384],
        [-0.6417],
        [-1.6846],
        [ 0.5773],
        [-0.9090],
        [-1.7115],
        [-0.5670],
        [-0.7676],
        [ 1.4764],
        [-1.6344],
        [ 0.

tensor([[ 1.3948],
        [-2.0965],
        [-1.8640],
        [-0.6657],
        [-1.5688],
        [-0.8431],
        [-1.1218],
        [ 0.6231],
        [-1.0975],
        [-2.0394],
        [ 3.3549],
        [ 1.8051],
        [-0.3783],
        [-1.5296],
        [ 0.6870],
        [-0.2519],
        [-1.1352],
        [-2.0223],
        [-0.3847],
        [-2.6510],
        [-0.8245],
        [-0.5570],
        [-0.8082],
        [-0.5040],
        [-1.5262],
        [ 0.1124],
        [ 1.1563],
        [-2.6403],
        [-1.7584],
        [ 2.2737],
        [-0.7718],
        [-1.7115],
        [-1.5450],
        [-2.0575],
        [-2.8034],
        [ 1.5441],
        [ 1.1011],
        [ 0.0831],
        [-0.5467],
        [ 1.5677],
        [ 1.9711],
        [ 3.3671],
        [-2.5800],
        [-0.3988],
        [-2.4323],
        [-1.0096],
        [-1.9792],
        [-2.8999],
        [ 0.2438],
        [-1.7920],
        [-1.7194],
        [ 0.6077],
        [ 1.

        [ 0.0360]], device='cuda:0', grad_fn=<GatherBackward>)
tensor([[ 0.6626],
        [-1.3784],
        [ 1.1481],
        [-0.5714],
        [ 0.6757],
        [-0.8287],
        [ 1.5427],
        [-1.1073],
        [-0.9192],
        [ 0.3166],
        [-1.1662],
        [-0.4574],
        [ 0.4288],
        [ 0.1056],
        [-2.0476],
        [ 0.2338],
        [-1.7875],
        [-0.4906],
        [-1.1566],
        [-0.6502],
        [-2.0291],
        [-1.2284],
        [-1.6906],
        [-1.7284],
        [ 2.2926],
        [ 0.2897],
        [-1.7146],
        [-1.1942],
        [ 0.3734],
        [-1.6044],
        [-0.5681],
        [ 0.6228],
        [ 1.5645],
        [ 3.3642],
        [-1.9896],
        [-1.5307],
        [ 0.3901],
        [ 1.7156],
        [-1.8762],
        [-0.5549],
        [-0.3903],
        [ 1.2637],
        [-2.6603],
        [ 1.9666],
        [ 0.0840],
        [-0.2644],
        [-1.1059],
        [-0.2797],
        [ 0.1269],
      

tensor([[ 0.0115],
        [ 0.0036],
        [ 0.0338],
        [ 0.0399],
        [-0.1388],
        [ 0.0382],
        [-0.0300],
        [-0.0412],
        [ 0.0324],
        [ 0.0331],
        [-0.1566],
        [ 0.0384],
        [ 0.0156],
        [ 0.0294],
        [ 0.0280],
        [ 0.0262],
        [ 0.0443],
        [-0.0403],
        [ 0.0321],
        [ 0.0363],
        [ 0.0314],
        [ 0.0326],
        [-0.0416],
        [ 0.0376],
        [ 0.0240],
        [-0.0039],
        [-0.1398],
        [ 0.0325],
        [ 0.0131],
        [ 0.0165],
        [-0.1614],
        [-0.1521],
        [-0.0325],
        [-0.1434],
        [ 0.0211],
        [ 0.0324],
        [ 0.0295],
        [ 0.0269],
        [-0.0393],
        [-0.1547],
        [ 0.0333],
        [ 0.0228],
        [ 0.0389],
        [ 0.0386],
        [ 0.0433],
        [-0.0331],
        [ 0.0380],
        [ 0.0368],
        [ 0.0356],
        [-0.0031],
        [ 0.0365],
        [ 0.0337],
        [ 0.

tensor([[ 2.0349e-02],
        [-1.5176e-01],
        [ 3.2194e-02],
        [ 3.2691e-02],
        [ 2.1382e-02],
        [ 1.6111e-02],
        [ 1.3202e-02],
        [ 1.8192e-02],
        [-4.4531e-02],
        [ 3.3577e-02],
        [-1.6859e-01],
        [-4.4128e-02],
        [ 3.2515e-02],
        [-8.6310e-03],
        [-1.8372e-01],
        [ 2.6897e-02],
        [ 3.4499e-02],
        [-1.6672e-01],
        [-1.5740e-01],
        [ 1.5274e-02],
        [ 2.8208e-02],
        [ 4.0414e-02],
        [ 7.1284e-03],
        [ 3.4810e-02],
        [ 3.4128e-02],
        [-3.1362e-02],
        [-4.4840e-02],
        [ 1.7305e-02],
        [-3.5189e-02],
        [ 2.3953e-02],
        [-1.4464e-01],
        [ 2.3128e-02],
        [ 7.0542e-05],
        [ 2.5115e-02],
        [-3.6847e-02],
        [ 1.8338e-02],
        [ 1.6053e-02],
        [-3.2293e-02],
        [-1.4189e-01],
        [ 2.1842e-02],
        [ 3.7116e-03],
        [ 2.2001e-02],
        [ 3.2166e-02],
        [-3

tensor([[-0.0384],
        [-0.1715],
        [-0.0458],
        [ 0.0273],
        [ 0.0180],
        [ 0.0146],
        [ 0.0258],
        [-0.1733],
        [-0.1636],
        [ 0.0186],
        [ 0.0293],
        [ 0.0148],
        [-0.0438],
        [ 0.0094],
        [-0.1499],
        [ 0.0270],
        [ 0.0120],
        [ 0.0278],
        [ 0.0295],
        [-0.0474],
        [ 0.0141],
        [ 0.0310],
        [ 0.0162],
        [ 0.0182],
        [-0.0476],
        [ 0.0193],
        [-0.0172],
        [ 0.0299],
        [ 0.0191],
        [ 0.0252],
        [-0.0386],
        [-0.0400],
        [-0.0181],
        [-0.0377],
        [ 0.0263],
        [ 0.0100],
        [ 0.0274],
        [ 0.0103],
        [ 0.0237],
        [-0.0352],
        [ 0.0155],
        [-0.0146],
        [-0.0482],
        [ 0.0359],
        [ 0.0191],
        [ 0.0039],
        [ 0.0097],
        [-0.0467],
        [-0.0417],
        [-0.0391],
        [-0.1384],
        [-0.1816],
        [ 0.

tensor([[-0.1777],
        [ 0.0028],
        [-0.1567],
        [ 0.0088],
        [-0.0449],
        [ 0.0221],
        [ 0.0202],
        [ 0.0197],
        [ 0.0087],
        [ 0.0047],
        [-0.0141],
        [-0.0182],
        [ 0.0230],
        [-0.0422],
        [ 0.0180],
        [-0.0242],
        [-0.0104],
        [ 0.0042],
        [-0.0511],
        [ 0.0098],
        [-0.1535],
        [-0.0162],
        [ 0.0058],
        [ 0.0232],
        [ 0.0113],
        [ 0.0049],
        [ 0.0031],
        [ 0.0160],
        [ 0.0042],
        [ 0.0036],
        [-0.0437],
        [ 0.0102],
        [-0.0039],
        [ 0.0043],
        [ 0.0057],
        [ 0.0182],
        [ 0.0103],
        [-0.1795],
        [-0.0424],
        [ 0.0160],
        [-0.1577],
        [ 0.0121],
        [ 0.0065],
        [ 0.0203],
        [-0.0386],
        [-0.0440],
        [-0.0414],
        [ 0.0212],
        [-0.1444],
        [ 0.0076],
        [ 0.0058],
        [-0.1736],
        [ 0.

tensor([[ 0.0052],
        [-0.0049],
        [-0.0033],
        [-0.0068],
        [-0.0066],
        [-0.0030],
        [ 0.0013],
        [ 0.0165],
        [-0.0469],
        [-0.0543],
        [-0.0006],
        [-0.1862],
        [-0.0155],
        [ 0.0044],
        [-0.1417],
        [-0.0004],
        [ 0.0028],
        [-0.0440],
        [ 0.0099],
        [ 0.0141],
        [-0.0080],
        [ 0.0079],
        [-0.0537],
        [-0.0014],
        [-0.1809],
        [-0.0040],
        [-0.0012],
        [ 0.0114],
        [ 0.0093],
        [-0.0022],
        [-0.0445],
        [-0.0289],
        [-0.0501],
        [-0.0544],
        [-0.1640],
        [-0.1650],
        [-0.1760],
        [-0.1763],
        [-0.0029],
        [-0.0017],
        [-0.0519],
        [-0.1587],
        [ 0.0049],
        [ 0.0009],
        [ 0.0156],
        [-0.1787],
        [-0.1607],
        [-0.0326],
        [-0.0494],
        [-0.1640],
        [ 0.0013],
        [ 0.0142],
        [-0.

tensor([[-4.5273e-04],
        [-5.6679e-02],
        [-9.3359e-03],
        [-1.8384e-01],
        [-1.5677e-01],
        [-8.1034e-03],
        [ 7.5787e-03],
        [ 3.8996e-03],
        [-2.6963e-03],
        [-1.5669e-01],
        [-7.4123e-05],
        [-1.8468e-01],
        [-1.8807e-01],
        [ 7.4785e-03],
        [ 2.3394e-03],
        [-1.3496e-02],
        [-8.0566e-03],
        [ 8.5754e-03],
        [-5.5620e-03],
        [ 1.0465e-02],
        [-1.7895e-01],
        [-4.2964e-03],
        [ 1.4987e-02],
        [-4.7756e-02],
        [-5.2657e-02],
        [-5.2960e-02],
        [-3.1543e-03],
        [-1.6845e-01],
        [-5.6413e-02],
        [ 7.1333e-03],
        [-5.0127e-02],
        [-4.4744e-03],
        [-4.2046e-02],
        [-1.7196e-01],
        [-8.7439e-03],
        [ 6.7079e-03],
        [-1.6417e-02],
        [-1.7178e-02],
        [-1.7296e-01],
        [ 1.5713e-02],
        [-1.9188e-01],
        [ 2.3241e-03],
        [-4.8662e-02],
        [-4

tensor([[-6.4613e-02],
        [-1.8060e-01],
        [-2.7242e-02],
        [-5.2735e-02],
        [ 1.1390e-02],
        [-8.9634e-03],
        [-1.3990e-02],
        [-2.4573e-02],
        [-2.4293e-02],
        [-2.5953e-03],
        [-2.0877e-01],
        [ 2.4389e-05],
        [-2.1471e-01],
        [-1.0486e-02],
        [ 9.2707e-03],
        [-1.6624e-02],
        [ 1.1980e-02],
        [-7.2774e-03],
        [-5.8654e-02],
        [-5.2098e-02],
        [-1.8061e-01],
        [-8.4962e-03],
        [-1.9104e-02],
        [-1.0357e-02],
        [-1.6340e-02],
        [-1.4041e-02],
        [-8.2974e-03],
        [-5.9252e-02],
        [-4.7091e-02],
        [-7.6213e-03],
        [-7.9108e-03],
        [-1.5042e-02],
        [-1.6019e-02],
        [-1.2583e-02],
        [-1.6449e-02],
        [-4.5481e-02],
        [-1.9233e-01],
        [-1.0015e-03],
        [-1.5363e-01],
        [-5.1363e-02],
        [-1.8724e-01],
        [-1.5122e-02],
        [-1.9228e-01],
        [-1

        [-0.0724]], device='cuda:0', grad_fn=<GatherBackward>)
tensor([[-0.8098],
        [-2.6603],
        [-0.8051],
        [-0.1670],
        [ 1.1030],
        [ 0.6311],
        [-0.7383],
        [-1.9239],
        [-0.5970],
        [ 0.2002],
        [-1.4557],
        [-0.5647],
        [-1.3777],
        [-3.5681],
        [ 0.5619],
        [-1.9605],
        [ 0.1623],
        [ 1.2263],
        [-1.4889],
        [-0.7003],
        [-2.6829],
        [-1.5237],
        [ 1.4356],
        [ 0.0178],
        [ 1.3591],
        [ 0.4718],
        [-1.1914],
        [ 1.3002],
        [-2.7311],
        [-2.4414],
        [-1.1533],
        [ 0.7955],
        [ 2.2566],
        [-1.3981],
        [-1.8787],
        [ 0.0460],
        [-1.7608],
        [-0.4414],
        [ 0.7834],
        [-1.1414],
        [-3.0202],
        [-1.7457],
        [-2.6214],
        [-1.5039],
        [-1.6090],
        [-2.7046],
        [ 0.2057],
        [-2.1404],
        [-1.5622],
      

        [-2.8376]], device='cuda:0')
tensor([[-0.1735],
        [-0.0512],
        [-0.0179],
        [-0.0609],
        [-0.0543],
        [-0.0173],
        [-0.0484],
        [-0.0192],
        [-0.0574],
        [-0.0218],
        [-0.0555],
        [-0.0250],
        [-0.0622],
        [-0.0228],
        [-0.0494],
        [-0.0614],
        [-0.0459],
        [-0.0753],
        [-0.0146],
        [-0.0276],
        [-0.0317],
        [-0.1814],
        [-0.0261],
        [-0.0199],
        [-0.1694],
        [-0.1923],
        [-0.0264],
        [-0.0280],
        [-0.1682],
        [-0.0597],
        [-0.0267],
        [-0.0177],
        [-0.0620],
        [-0.1841],
        [-0.2074],
        [-0.0242],
        [-0.0571],
        [-0.0199],
        [-0.0560],
        [-0.0207],
        [-0.0253],
        [-0.0551],
        [-0.2278],
        [-0.0293],
        [-0.1956],
        [-0.0309],
        [-0.0284],
        [-0.0264],
        [-0.2152],
        [-0.1822],
        [-0.0

tensor([[-2.7046],
        [ 0.0945],
        [-1.3436],
        [-3.4845],
        [-2.4414],
        [ 0.5339],
        [ 0.5063],
        [-3.9299],
        [-1.8363],
        [-2.0259],
        [-2.0218],
        [-2.6928],
        [-3.4482],
        [-0.1670],
        [-0.8953],
        [-1.1914],
        [ 3.3281],
        [ 0.3551],
        [-2.4526],
        [-3.0397],
        [-0.6045],
        [ 1.6858],
        [-0.8051],
        [-1.9613],
        [-3.5681],
        [ 1.5280],
        [ 1.7532],
        [-2.6214],
        [-4.0642],
        [-3.5758],
        [-1.4285],
        [-1.2615],
        [-1.9239],
        [ 0.4718],
        [-1.6090],
        [-3.0828],
        [-2.6371],
        [ 1.3002],
        [-1.2119],
        [-2.9446],
        [-3.5265],
        [-3.2481],
        [-2.8376],
        [-5.9195],
        [ 0.7955],
        [-0.2969],
        [-1.1923],
        [ 1.0499],
        [ 0.2902],
        [-2.9608],
        [-2.6745],
        [-1.8793],
        [-3.

        [-2.0622e+00]], device='cuda:0')
tensor([[-0.0607],
        [-0.0482],
        [-0.0724],
        [-0.0576],
        [-0.0581],
        [-0.0449],
        [-0.0477],
        [-0.0513],
        [-0.0485],
        [-0.0467],
        [-0.0448],
        [-0.2063],
        [-0.0615],
        [-0.0613],
        [-0.0499],
        [-0.0558],
        [-0.0795],
        [-0.0496],
        [-0.0520],
        [-0.0434],
        [-0.0647],
        [-0.0651],
        [-0.0640],
        [-0.0632],
        [-0.0427],
        [-0.2204],
        [-0.0729],
        [-0.0602],
        [-0.0449],
        [-0.0595],
        [-0.0764],
        [-0.0501],
        [-0.2407],
        [-0.1451],
        [-0.0454],
        [-0.2290],
        [-0.0483],
        [-0.0583],
        [-0.0465],
        [-0.0527],
        [-0.0527],
        [-0.0487],
        [-0.0827],
        [-0.0534],
        [-0.0533],
        [-0.0514],
        [-0.0516],
        [-0.0472],
        [-0.0927],
        [-0.0586],
        [

tensor([[-0.0682],
        [-0.0678],
        [-0.0665],
        [-0.0772],
        [-0.0700],
        [-0.0760],
        [-0.0973],
        [-0.0479],
        [-0.0557],
        [-0.2418],
        [-0.0771],
        [-0.0749],
        [-0.0782],
        [-0.0681],
        [-0.0668],
        [-0.0676],
        [-0.0725],
        [-0.0700],
        [-0.0745],
        [-0.0717],
        [-0.2232],
        [-0.0729],
        [-0.0826],
        [-0.2142],
        [-0.0691],
        [-0.0775],
        [-0.0842],
        [-0.0819],
        [-0.2327],
        [-0.2426],
        [-0.0677],
        [-0.1999],
        [-0.2282],
        [-0.0645],
        [-0.0739],
        [-0.0695],
        [-0.0736],
        [-0.0837],
        [-0.2272],
        [-0.0631],
        [-0.0640],
        [-0.2170],
        [-0.0627],
        [-0.0643],
        [-0.0769],
        [-0.0591],
        [-0.2065],
        [-0.0678],
        [-0.0731],
        [-0.0835],
        [-0.0596],
        [-0.0765],
        [-0.

tensor([[ -2.5511],
        [ -0.8292],
        [ -0.5653],
        [  0.7190],
        [ -1.4860],
        [ -1.9416],
        [ -4.0497],
        [ -2.8582],
        [ -2.1164],
        [  0.2091],
        [ -4.3796],
        [  0.5665],
        [ -1.5128],
        [  0.2647],
        [ -0.6637],
        [ -1.6422],
        [ -1.8380],
        [ -1.0277],
        [ -1.2638],
        [ -2.9111],
        [ -1.2162],
        [ -2.3107],
        [ -1.5721],
        [  1.7210],
        [-24.6958],
        [  1.5441],
        [ -2.0245],
        [ -2.7986],
        [  3.2657],
        [ -6.0265],
        [ -1.5807],
        [ -1.7588],
        [  0.3399],
        [ -0.7969],
        [ -1.2656],
        [ -1.8767],
        [ -0.8655],
        [ -1.4797],
        [ -1.5560],
        [ -0.9485],
        [ -7.0513],
        [ -3.1975],
        [ -0.4974],
        [ -1.8259],
        [ -0.0391],
        [ -2.1970],
        [ -1.4067],
        [ -2.5117],
        [ -1.2863],
        [ -3.0311],


        [-2.0245e+00]], device='cuda:0')
tensor([[-0.1059],
        [-0.1006],
        [-0.2615],
        [-0.1019],
        [-0.1230],
        [-0.0911],
        [-0.2324],
        [-0.0903],
        [-0.0733],
        [-0.1212],
        [-0.0845],
        [-0.1197],
        [-0.2633],
        [-0.0931],
        [-0.1236],
        [-0.0748],
        [-0.0945],
        [-0.0908],
        [-0.0961],
        [-0.1134],
        [-0.1305],
        [-0.1088],
        [-0.0949],
        [-0.0734],
        [-0.1029],
        [-0.0661],
        [-0.1045],
        [-0.0979],
        [-0.0964],
        [-0.0973],
        [-0.0979],
        [-0.1015],
        [-0.1154],
        [-0.1008],
        [-0.1819],
        [-0.1080],
        [-0.1008],
        [-0.2414],
        [-0.1172],
        [-0.0924],
        [-0.1008],
        [-0.1047],
        [-0.0882],
        [-0.1089],
        [-0.2715],
        [-0.1105],
        [-0.1058],
        [-0.2643],
        [-0.1003],
        [-0.1088],
        [

tensor([[-0.1154],
        [-0.1134],
        [-0.1113],
        [-0.2622],
        [-0.1127],
        [-0.2504],
        [-0.0835],
        [-0.1007],
        [-0.1027],
        [-0.1115],
        [-0.1094],
        [-0.0842],
        [-0.1204],
        [-0.1448],
        [-0.1294],
        [-0.1051],
        [-0.1040],
        [-0.1237],
        [-0.1132],
        [-0.1181],
        [-0.1058],
        [-0.2746],
        [-0.2722],
        [-0.0850],
        [-0.1190],
        [-0.1290],
        [-0.1062],
        [-0.1153],
        [-0.2689],
        [-0.0825],
        [-0.0828],
        [-0.1214],
        [-0.1058],
        [-0.1266],
        [-0.1098],
        [-0.1230],
        [-0.0828],
        [-0.1078],
        [-0.1034],
        [-0.2689],
        [-0.1120],
        [-0.1090],
        [-0.1073],
        [-0.2939],
        [-0.1074],
        [-0.1173],
        [-0.1173],
        [-0.2733],
        [-0.1150],
        [-0.2704],
        [-0.2653],
        [-0.2351],
        [-0.

        [-1.2134e+00]], device='cuda:0')
tensor([[-0.1381],
        [-0.1430],
        [-0.1536],
        [-0.2775],
        [-0.1331],
        [-0.0836],
        [-0.1421],
        [-0.1026],
        [-0.1402],
        [-0.1927],
        [-0.1372],
        [-0.1513],
        [-0.1124],
        [-0.1374],
        [-0.0867],
        [-0.1654],
        [-0.1561],
        [-0.1319],
        [-0.1408],
        [-0.2847],
        [-0.2791],
        [-0.1058],
        [-0.2749],
        [-0.1479],
        [-0.1178],
        [-0.1500],
        [-0.1621],
        [-0.0893],
        [-0.1344],
        [-0.2843],
        [-0.1073],
        [-0.1216],
        [-0.1496],
        [-0.1266],
        [-0.0883],
        [-0.2628],
        [-0.0890],
        [-0.1522],
        [-0.1268],
        [-0.1501],
        [-0.1435],
        [-0.1262],
        [-0.1121],
        [-0.1487],
        [-0.0878],
        [-0.1536],
        [-0.2772],
        [-0.1483],
        [-0.1535],
        [-0.1134],
        [

tensor([[-0.1966],
        [-0.1910],
        [-0.2087],
        [-0.1118],
        [-0.1580],
        [-0.1869],
        [-0.1897],
        [-0.2667],
        [-0.1943],
        [-0.1883],
        [-0.0963],
        [-0.2220],
        [-0.1225],
        [-0.1895],
        [-0.1460],
        [-0.0992],
        [-0.0972],
        [-0.2346],
        [-0.1941],
        [-0.1945],
        [-0.1333],
        [-0.2768],
        [-0.0998],
        [-0.0949],
        [-0.2510],
        [-0.0931],
        [-0.1448],
        [-0.1331],
        [-0.1948],
        [-0.1816],
        [-0.2061],
        [-0.0980],
        [-0.0960],
        [-0.0996],
        [-0.0916],
        [-0.2967],
        [-0.1410],
        [-0.1353],
        [-0.1308],
        [-0.2945],
        [-0.1413],
        [-0.1439],
        [-0.1819],
        [-0.2788],
        [-0.1482],
        [-0.0937],
        [-0.1950],
        [-0.1892],
        [-0.1223],
        [-0.0939],
        [-0.1355],
        [-0.2951],
        [-0.

tensor([[-0.3095],
        [-0.7672],
        [-0.3082],
        [-0.1685],
        [-0.2689],
        [-0.1011],
        [-0.0982],
        [-0.0982],
        [-0.1002],
        [-0.3017],
        [-0.1003],
        [-0.1057],
        [-0.2487],
        [-0.1609],
        [-0.1206],
        [-0.1006],
        [-0.0992],
        [-0.1673],
        [-0.2289],
        [-0.1017],
        [-0.3085],
        [-0.2322],
        [-0.1046],
        [-0.2914],
        [-0.1659],
        [-0.1489],
        [-0.1402],
        [-0.1505],
        [-0.3196],
        [-0.2343],
        [-0.1627],
        [-0.1048],
        [-0.3097],
        [-0.2856],
        [-0.2340],
        [-0.3060],
        [-0.1543],
        [-0.2947],
        [-0.2236],
        [-0.1600],
        [-0.1504],
        [-0.1010],
        [-0.3112],
        [-0.2510],
        [-0.2758],
        [-0.3045],
        [-0.2325],
        [-0.0995],
        [-0.1009],
        [-0.3113],
        [-0.2188],
        [-0.3071],
        [-0.

tensor([[ 1.1469e+00],
        [-2.3412e-01],
        [-1.7763e+00],
        [-2.7101e+00],
        [ 4.6354e+01],
        [ 6.8951e-01],
        [ 7.0605e-01],
        [ 8.6278e-02],
        [ 5.2303e+00],
        [ 2.2026e+00],
        [-7.1515e-01],
        [-8.3509e-01],
        [-2.0429e+00],
        [ 1.7242e+00],
        [ 7.1539e-03],
        [ 2.8274e-01],
        [-3.6675e+00],
        [-6.2148e-01],
        [-1.2884e+00],
        [-1.2162e+00],
        [-1.4797e+00],
        [-3.6623e+00],
        [-2.6987e+00],
        [-2.4696e+01],
        [ 8.7146e-01],
        [-7.0513e+00],
        [-2.9111e+00],
        [-1.9111e+00],
        [ 7.0533e-01],
        [-2.2842e+00],
        [-2.2892e+00],
        [-4.0282e-01],
        [-3.3211e-01],
        [-1.2101e+00],
        [-3.4509e+00],
        [-1.4860e+00],
        [-1.5807e+00],
        [-2.0720e+00],
        [-4.0370e+00],
        [-1.4208e+00],
        [-5.8192e-01],
        [-2.4558e+00],
        [-7.4069e-01],
        [-3

tensor([[-0.0963],
        [-0.2090],
        [-0.2067],
        [-0.3347],
        [-0.1851],
        [-0.3226],
        [-0.3148],
        [-0.3490],
        [-0.1978],
        [-0.2936],
        [-0.0990],
        [-0.3141],
        [-0.2047],
        [-0.2098],
        [-0.1965],
        [-0.3432],
        [-0.3018],
        [-0.2979],
        [-0.0988],
        [-0.2043],
        [-0.3308],
        [-0.3068],
        [-0.4813],
        [-0.2089],
        [-0.2086],
        [-0.0997],
        [-0.3396],
        [-0.0947],
        [-0.4319],
        [-0.3072],
        [-0.2098],
        [-0.2836],
        [-0.3420],
        [-0.0994],
        [-0.3031],
        [-0.0961],
        [-0.1861],
        [-0.3232],
        [-0.1147],
        [-0.2024],
        [-0.3092],
        [-0.3088],
        [-0.2096],
        [-0.1977],
        [-0.1020],
        [-0.0901],
        [-0.2062],
        [-0.1010],
        [-0.2862],
        [-0.3416],
        [-0.3068],
        [-0.1992],
        [-0.

KeyboardInterrupt: 

[ 0.0491, -0.1246,  0.0294, -0.0133]