In [11]:
#!/usr/bin/env python
# coding: utf-8

# In[2]:


# import json
import new_env
import copy
import numpy as np
import json
import networkx as nx
import matplotlib.pyplot as plt


# In[3]:


class AST:
    def __init__(self, node, children=[], parent=None):
        self.node = node
        self.children = children
        self.parent = [parent]
    def _print(self, depth=1):
        print (depth * "--- " + self.node.value)
        if len(self.children) > 0:
            for child in self.children:
                child._print(depth+1)
        
class Token:
    def __init__(self, value, _type, arity=None):
        self.value = value
        self._type = _type
        self.arity = arity
    
#assume ast has been passed with ast.node as function 
def func_to_ast(ast, tokens, arity):
    if len(tokens) == 0:
        return ast

    node = tokens[0]
    tokens.pop(0)

    new_node = AST(node, children = [], parent=ast)
    
    if node._type == "variable":

        ast.children.append(new_node)

    elif node._type == "func" or node._type == "lambda":

        new_ast = func_to_ast(new_node, tokens, node.arity)
        ast.children.append(new_ast)
      
        
    if arity == 1:
        return ast
    else:
        return func_to_ast(ast, tokens, arity-1)
    
def tokens_to_ast(tokens):
    ast = AST(tokens[0], children=[])
    tokens.pop(0)
    return func_to_ast(ast, tokens, ast.node.arity)


# In[4]:


def polished_to_tokens_2(goal):
    polished_goal = goal.split(" ")
    tokens = []
    
    while len(polished_goal) > 0:
        if polished_goal[0] == '@':
            polished_goal.pop(0)
            arity = 1
            
            while polished_goal[0] == '@':
                arity += 1
                polished_goal.pop(0)
                
            func = polished_goal[0]
            polished_goal.pop(0)

            if func[0] == 'C':
                #should only be one string after the library 
                func = func + "|" + polished_goal[0]
                polished_goal.pop(0)
            
            #otherwise variable func, and nothing following it

            tokens.append(Token(func, "func", arity))
        
        #variable or constant case
        else:
            var = polished_goal[0]
            polished_goal.pop(0)
            #lambda case
            if var[0] == "|":
                tokens.append(Token("".join(var), "lambda", 2))
            else:
            
                if var[0] == "C":
                    #need to append this and the next as constants are space separated
                    var = var + polished_goal[0]
                    polished_goal.pop(0)

                tokens.append(Token("".join(var), "variable"))

    return tokens
            


# In[5]:


def hierarchy_pos(G, root, levels=None, width=1., height=1.):
    '''If there is a cycle that is reachable from root, then this will see infinite recursion.
       G: the graph
       root: the root node
       levels: a dictionary
               key: level number (starting from 0)
               value: number of nodes in this level
       width: horizontal space allocated for drawing
       height: vertical space allocated for drawing'''
    TOTAL = "total"
    CURRENT = "current"
    def make_levels(levels, node=root, currentLevel=0, parent=None):
        """Compute the number of nodes for each level
        """
        if not currentLevel in levels:
            levels[currentLevel] = {TOTAL : 0, CURRENT : 0}
        levels[currentLevel][TOTAL] += 1
        neighbors = G.neighbors(node)
        for neighbor in neighbors:
            if not neighbor == parent:
                levels =  make_levels(levels, neighbor, currentLevel + 1, node)
        return levels

    def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0):
        dx = 1/levels[currentLevel][TOTAL]
        left = dx/2
        pos[node] = ((left + dx*levels[currentLevel][CURRENT])*width, vert_loc)
        levels[currentLevel][CURRENT] += 1
        neighbors = G.neighbors(node)
        for neighbor in neighbors:
            if not neighbor == parent:
                pos = make_pos(pos, neighbor, currentLevel + 1, node, vert_loc-vert_gap)
        return pos
    if levels is None:
        levels = make_levels({})
    else:
        levels = {l:{TOTAL: levels[l], CURRENT:0} for l in levels}
    vert_gap = height / (max([l for l in levels])+1)
    return make_pos({})


# In[6]:


def print_graph(ast):
    G = nx.DiGraph()
    add_node(ast, G)
    
    labels = nx.get_node_attributes(G, 'value')
    pos = hierarchy_pos(G,ast.node)    
    plt.figure(1,figsize=(15,30)) 
    nx.draw(G, pos=pos, labels=labels, with_labels=True,
            arrowsize=20,
            node_color='none',
            node_size=6000)#, font_weight='bold')
    
    
    labels = {e: G.get_edge_data(e[0], e[1])["child"] for e in G.edges()}
    nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)
    ax = plt.gca() # to get the current axis
    ax.collections[0].set_edgecolor("#FF0000") 
    plt.savefig("Graph.png", format="PNG")
    #plt.show()
# 0

def add_node(ast, graph):
    graph.add_node(ast.node, value = str(ast.node.value))
    for i, child in enumerate(ast.children):
        graph.add_edge(ast.node, child.node, child=i)
        add_node(child, graph)
    
#verticalalignment='bottom'


# In[7]:


def add_lambda_children(lambda_ast):
    #child should be '|' with first child of that as variable, and rest as quantified scope
    var = lambda_ast.children[0].node.value
    sub_tree = lambda_ast.children[1]
    
    if sub_tree.node.value == var:
        return lambda_ast
    
    def apply_lambda(ast, var):
        if ast.node.value == var:
            lambda_ast.children.append(ast)
            ast.parent.append(lambda_ast)
            for child in ast.children:
                apply_lambda(child,var) 
        else:
            for child in ast.children:
                apply_lambda(child,var)
        return
    
    apply_lambda(sub_tree, var)
    return lambda_ast

def process_lambdas(ast):
    ret = []
    def get_lambdas(ast):
        if ast.node._type == "lambda":
            ret.append(ast)
        for child in ast.children:
                get_lambdas(child)
        return    
    get_lambdas(ast)
    
    for l in ret:
        add_lambda_children(l)
    
    return ast


# In[8]:


def merge_leaves(ast):
    lambda_tokens = []
    #only merge leaf nodes if they're within the same quantified scope
    
    def run_lambdas(lambda_ast):
        var = lambda_ast.children[0].node.value
        
        #check for edge case lambda x: x
        if len(lambda_ast.children) == 1:
            return lambda_ast
        
        sub_tree = lambda_ast.children[1]
        lambda_token = lambda_ast.children[0]
        lambda_tokens.append(lambda_token)
        
        
        def merge_lambda(ast, var):
            #if lambda variable, and leaf node, point parents to original node
            if ast.node.value == var and ast.children == []:
                #this way ensures no duplicates 
                for parent in ast.parent:
                    new_children = []
                    flag = False
                    for c in parent.children:
                        if c.node.value != var or c.children != []:
                            new_children.append(c)
                        elif flag == False:
                            new_children.append(lambda_token)
                            flag = True
                    parent.children = new_children
                
            for child in ast.children:
                merge_lambda(child,var) 

            return ast 

        merge_lambda(sub_tree, var)
        return lambda_ast

    def merge_all_lambdas(ast):
        ret = []
        def get_lambdas(ast):
            if ast.node._type == "lambda":
                ret.append(ast)
            for child in ast.children:
                    get_lambdas(child)
            return    

        get_lambdas(ast)

        for l in ret:
            run_lambdas(l)

        return ast
    
    merge_all_lambdas(ast)
    
#     #TODO no quantifier case
#     print (set(lambda_tokens) - leaf_tokens)

                
    return ast
        
    


# In[9]:


def rename(ast):
    if ast.node.value[0] == 'V':
        if ast.children != []:
            ast.node.value = "VARFUNC"
        else:
            ast.node.value = "VAR"

    for child in ast.children:
        rename(child)
    
    return ast


# In[10]:


def goal_to_graph(polished_goal):
    return rename(merge_leaves(process_lambdas(tokens_to_ast(polished_to_tokens_2(polished_goal)))))


# In[11]:


# all_graphs = [goal_to_graph(g) for g in polished_goals]


# In[12]:


with open("new_db.json") as fp:
    new_db = json.load(fp)


polished_goals = []
for val in new_db.values():
    polished_goals.append(val[2])

tokens = list(set([token.value for polished_goal in polished_goals for token in polished_to_tokens_2(polished_goal)  if token.value[0] != 'V']))

tokens.append("VAR")
tokens.append("VARFUNC")
tokens.append("UNKNOWN")

#print (len(tokens))


# In[13]:


from sklearn.preprocessing import OneHotEncoder


# In[14]:


enc = OneHotEncoder(handle_unknown='ignore')

enc.fit(np.array(tokens).reshape(-1,1))

e = enc.transform(np.array(tokens).reshape(-1,1))

preds = enc.inverse_transform(e)

#ensure encoding is correct
assert [preds[:,0][i] for i in range(preds.shape[0])] == tokens


# In[15]:


enc.inverse_transform(enc.transform(np.array(tokens[:10]).reshape(-1,1)))[:,0]


# In[16]:


def nodes_list_to_senders_receivers(node_list):
    senders = []
    receivers = []
    
    for i, node in enumerate(node_list):
        for child in node.children:
            senders.append(i)
            receivers.append(node_list.index(child))
    return senders, receivers

def nodes_list(g, result=[]):
    result.append(g)
    
    for child in g.children:
        nodes_list(child, result)
            
    return list(set(result))

#possible add edge number, to consider ordering of the arguments


# In[17]:


import pickle
with open("gnn_dataset.pk", 'rb') as f:
    dataset = pickle.load(f)

X = [(d[0], d[1]) for d in dataset]
y = [d[2] for d in dataset]

# unique_exps = list(set([x for xs in X for x in xs]))

# from tqdm import tqdm
# graph_dict = {g: graph_to_jgraph(goal_to_graph(g)) for g in tqdm(unique_exps)}


# with open("graph_dataset.pk", 'wb') as f:
#     pickle.dump(graph_dict, f)


# In[18]:


import torch
from torch import nn
import torch.nn.functional as F

import numpy as np
import pickle


class F_p_module(nn.Module):
    def __init__(self):
        super(F_p_module, self).__init__()    
        
        self.fc = nn.Linear(256, 256)
        
        #self.bn = nn.BatchNorm1d(128)
        
        self.fc2 = nn.Linear(256, 256)

    def forward(self, x):
        x = self.fc2(F.relu(self.fc(x)))
        return  x

    

class F_i_module(nn.Module):
    def __init__(self):
        super(F_i_module, self).__init__()    
        
        self.fc1 = nn.Linear(512, 512)
        
        #self.bn1 = nn.BatchNorm1d(256)
        
        self.fc2 = nn.Linear(512, 512)
        
        #self.bn2 = nn.BatchNorm1d(256)
              

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return  x
    

class F_o_module(nn.Module):
    def __init__(self):
        super(F_o_module, self).__init__()    
        
        self.fc1 = nn.Linear(512, 512)
        
        #self.bn1 = nn.BatchNorm1d(256)
        
        self.fc2 = nn.Linear(512, 512)
        
        #self.bn2 = nn.BatchNorm1d(256)
              

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return  x
    
    
    
class F_x_module(nn.Module):
    def __init__(self, input_shape):
        super(F_x_module, self).__init__()    
        
        self.fc1 = nn.Linear(input_shape, 256)
        

    def forward(self, x):
        return  self.fc1(x)


    
class F_c_module(nn.Module):
    def __init__(self):
        super(F_c_module, self).__init__()    
        
        self.fc1 = nn.Linear(2048, 1024)
        
        #self.bn1 = nn.BatchNorm1d(128)
        
        self.fc2 = nn.Linear(1024, 512)
        
        
        self.fc3 = nn.Linear(512, 1)
        

    def forward(self, x):
        x = F.relu(self.fc1(x))
        
        return torch.sigmoid(self.fc3(self.fc2(x)))


# In[19]:


def init_graph(nodes, embedding_net):
    nodes = embedding_net(nodes)
    return nodes



def run_embedding_iteration(nodes, senders, receivers, F_p, F_i, F_o):
    
    if senders.shape[0] > 1:

        nodes_with_neighbors = torch.concat([torch.stack([nodes[i] for i in senders]), 
                                             torch.stack([nodes[i] for i in receivers])], axis=-1)


        F_i_outputs = F_i(nodes_with_neighbors)

        F_o_outputs = F_o(nodes_with_neighbors)

        
        F_p_inputs = [x + torch.divide(1., (len(torch.where(senders == i)) + len(torch.where(receivers == i))))
                                                 * (torch.sum(F_i_outputs[torch.where(receivers == i)]) + torch.sum(F_o_outputs[torch.where(senders == i)])) for i,x in enumerate(nodes)]
        
        
        F_p_inputs = torch.stack(F_p_inputs).to(device)
        
        nodes = F_p(F_p_inputs)

        
    else:
        return nodes, senders, receivers
    
        
        
    nodes = F_p(F_p_inputs)
    
    return nodes, senders, receivers


def generate_graph_embedding(graph, F_p, F_i, F_o, F_x, conv1, conv2, num_iterations):
    
    start_t = time.time()
    nodes, _, receivers, senders, _, _, _ = graph

    receivers = torch.LongTensor(receivers).to(device)
    
    senders = torch.LongTensor(senders).to(device)

    
    nodes = torch.tensor(nodes, requires_grad=True).to(device)
    
    nodes = F_x(nodes)
        
    for t in range(num_iterations):
        nodes, senders, receivers = run_embedding_iteration(nodes, senders, receivers, F_p, F_i, F_o)
    
   
    nodes = conv2(conv1(nodes.reshape(nodes.shape[1], nodes.shape[0])))

    
    g_embedding = torch.nn.MaxPool1d(1, nodes.shape[1])(nodes)
   

    
    return g_embedding
    

In [None]:
# In[20]:


with open("graph_dataset.pk", 'rb') as f:
    graph_dict = pickle.load( f)


# In[21]:

In [42]:
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data

In [43]:
#turn jax dataset into torch
torch_graph_dict = {}

for k,v in graph_dict.items():
    nodes, _, receivers, senders, _, _, _ = v

    edges = torch.tensor([senders.tolist(), receivers.tolist()], dtype=torch.long)
    
    nodes = torch.tensor(nodes, dtype=torch.float)
    
    torch_graph_dict[k] = Data(x = nodes, edge_index = edges)
    
    
    

In [None]:
import torch
from torch.nn import Sequential as Seq, Linear, ReLU, Dropout
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch.nn.functional import dropout

#no edge weights with this model
# sum (MLP(node, parents))

in_channels = out_channels = 256

#F_o summed over children 
class Child_Aggregation(MessagePassing):
    
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='sum', flow='target_to_source') 
        
        self.mlp = Seq(Dropout(), Linear(2 * in_channels, out_channels),
                       ReLU(), Dropout(),
                       Linear(out_channels, out_channels))

        
    def message(self, x_i, x_j):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]
        
        tmp = torch.cat([x_i, x_j], dim=1)  # tmp has shape [E, 2 * in_channels]
        return self.mlp(tmp)

    
    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        #edge index 0 for degree wrt children

        deg = degree(edge_index[0], x.size(0), dtype=x.dtype)
        
        deg_inv = 1. / deg
        
        deg_inv[deg_inv == float('inf')] = 0
        
        return deg_inv.view(-1,1) * self.propagate(edge_index, x=x)

    
#F_i summed over parents 
class Parent_Aggregation(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='sum', flow='source_to_target') 
        
        self.mlp = Seq(Dropout(), Linear(2 * in_channels, out_channels),
                       ReLU(), Dropout(),
                       Linear(out_channels, out_channels))

        
    def message(self, x_i, x_j):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]
        
        tmp = torch.cat([x_i, x_j], dim=1)  # tmp has shape [E, 2 * in_channels]
        
        return self.mlp(tmp)

    
    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
    
        #edge index 1 for degree wrt parents
        deg = degree(edge_index[1], x.size(0), dtype=x.dtype)
        
        deg_inv = 1. / deg
        
        deg_inv[deg_inv == float('inf')] = 0
        
        return deg_inv.view(-1,1) * self.propagate(edge_index, x=x)

class Final_Agg(nn.Module):
    def __init__(self, embedding_dim):
        super(Final_Agg, self).__init__()    
        
        self.fc = nn.Linear(embedding_dim * 3, embedding_dim * 2)
        
        self.fc2 = nn.Linear(embedding_dim * 2, embedding_dim)

    def forward(self, x):
        x = self.fc2(dropout(torch.relu(self.fc(dropout(x)))))
        return  x
    

    
        
    
class F_x_module(nn.Module):
    def __init__(self, input_shape, embedding_dim):
        super(F_x_module, self).__init__()    
        
        self.fc1 = nn.Linear(input_shape, embedding_dim)
        

    def forward(self, x):
        return  self.fc1(x)

class F_c_module(nn.Module):
    def __init__(self):
        super(F_c_module, self).__init__()    
        
        self.fc1 = nn.Linear(2048, 1024)
        
        #self.bn1 = nn.BatchNorm1d(128)
        
        self.fc2 = nn.Linear(1024, 512)
        
        
        self.fc3 = nn.Linear(512, 1)
        

    def forward(self, x):
        x = F.relu(self.fc1(dropout(x)))
        
        return torch.sigmoid(self.fc3(dropout(F.relu(self.fc2(dropout(x))))))

    

In [363]:
def generate_graph_embedding(graph, MLP_Agg, F_i_sum, F_o_sum, F_x, conv1, conv2, num_iterations):
    
    nodes = graph.x
    edges = graph.edge_index
  
    nodes = F_x(nodes.to(device))
        
    for t in range(num_iterations):
        fi_sum = F_i_sum(nodes.to(device), edges.to(device))
        fo_sum = F_o_sum(nodes.to(device), edges.to(device))
        node_update = MLP_Agg(torch.cat([nodes, fi_sum, fo_sum], axis=1).to(device))
        nodes = nodes + node_update
    
    nodes = conv2(conv1(nodes.reshape(nodes.shape[1], nodes.shape[0])))
    
    g_embedding = torch.nn.MaxPool1d(1, nodes.shape[1])(nodes)
    
    return g_embedding
    

In [402]:
def binary_loss(preds, targets):
    return -1. * torch.sum(targets * torch.log(preds) + (1 - targets) * torch.log((1. - preds)))

def loss(graph_net, x_batch, y_batch, F_p, F_i, F_o, F_x, F_c, conv1, conv2, num_iterations):
    
    preds = []
    
    for i in range(len(x_batch)):
        x = x_batch[i]
        
        g0_embedding = graph_net(x[0], F_p, F_i, F_o, F_x, conv1, conv2, num_iterations).to(device)

        g1_embedding  = graph_net(x[1], F_p, F_i, F_o, F_x, conv1, conv2, num_iterations).to(device)
        
        pred = F_c(torch.concat([g0_embedding, g1_embedding], axis=0).reshape(1,-1))[0][0]


        preds.append(pred)
        
    eps = 1e-6
    
    
    preds = torch.stack(preds).to(device)
    
    
    
    preds = torch.clip(preds, eps, 1-eps)
            
    return binary_loss(preds, torch.FloatTensor(y_batch).to(device))


# In[22]:


from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=0.5)


# In[23]:


from tqdm import tqdm


# In[24]:

step_size = 3e-4

decay_rate = 0.02

num_epochs = 10

batch_size = 32

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

embedding_dim = 128

fp = Final_Agg(embedding_dim).to(device)
fi = Parent_Aggregation(embedding_dim, embedding_dim).to(device)
fo = Child_Aggregation(embedding_dim, embedding_dim).to(device)
fx = F_x_module(len(tokens), embedding_dim).to(device)
fc = F_c_module().to(device)

conv1 = torch.nn.Conv1d(embedding_dim, 512, 1, stride=1).to(device)
conv2 = torch.nn.Conv1d(512, 1024, 1, stride=1).to(device)


#fp = torch.load("torch_gnn_models/initial/fp")
#fi = torch.load("torch_gnn_models/initial/fi")
#fo = torch.load("torch_gnn_models/initial/fo")
#fx = torch.load("torch_gnn_models/initial/fx")
#fc = torch.load("torch_gnn_models/initial/fc")


optimiser_fp = torch.optim.Adam(list(fp.parameters()), lr=step_size, weight_decay=decay_rate)
optimiser_fi = torch.optim.Adam(list(fi.parameters()), lr=step_size, weight_decay=decay_rate)
optimiser_fo = torch.optim.Adam(list(fo.parameters()), lr=step_size, weight_decay=decay_rate)
optimiser_fx = torch.optim.Adam(list(fx.parameters()), lr=step_size, weight_decay=decay_rate)
optimiser_fc = torch.optim.Adam(list(fc.parameters()), lr=step_size, weight_decay=decay_rate)
optimiser_fconv1 = torch.optim.Adam(list(conv1.parameters()), lr=step_size, weight_decay=decay_rate)
optimiser_fconv2 = torch.optim.Adam(list(conv2.parameters()), lr=step_size, weight_decay=decay_rate)



data_size = len(x_train)

training_losses = []
val_losses = []

def run():
    for j in range(num_epochs):
        for i in tqdm(range(0, int(data_size / batch_size))):

            batch_idx = i + 1
            from_idx = (batch_idx - 1) * batch_size
            to_idx = batch_idx * batch_size

            X_batch = [(torch_graph_dict[x[0]], torch_graph_dict[x[1]]) for x in x_train[from_idx:to_idx]]

            y_batch = y_train[from_idx:to_idx]

            optimiser_fp.zero_grad()
            optimiser_fi.zero_grad()
            optimiser_fo.zero_grad()
            optimiser_fx.zero_grad()
            optimiser_fc.zero_grad()
            optimiser_fconv1.zero_grad()


            loss_val = loss(generate_graph_embedding, X_batch, y_batch, fp, fi, fo, fx, fc,conv1,conv2, 1)

            optimiser_fconv2.zero_grad()

            loss_val.backward()          


            optimiser_fp.step()
            optimiser_fi.step()
            optimiser_fo.step()
            optimiser_fx.step()
            optimiser_fc.step()

            optimiser_fconv1.step()

            optimiser_fconv2.step()


            training_losses.append(loss_val.detach())

            if i % 500 == 0:
                with open("training_losses_torch.pk", "wb+") as f:
                    pickle.dump(training_losses, f)


                #print ("Avg training loss: {}".format(sum(training_losses) / len(training_losses)))
                print ("Curr training loss avg: {}".format(sum(training_losses[-100:]) / len(training_losses[-100:])))
            #if i % 500 == 0:

                X_val = [(torch_graph_dict[x[0]], torch_graph_dict[x[1]]) for x in x_val]


                inds = np.random.randint(0, len(X_val), 32)

                x_val_batch = [X_val[k] for k in inds]
                y_val_batch = [y_val[k] for k in inds]

                loss_validation = loss(generate_graph_embedding, x_val_batch, y_val_batch, fp, fi, fo, fx, fc, conv1, conv2, 1)


                val_losses.append(loss_validation.detach())

                with open("val_losses_torch.pk", "wb+") as f:
                    pickle.dump(val_losses, f)


                print ("Avg val loss: {}".format(sum(val_losses[-100:])/len(val_losses[-100:])))
                print ("Curr val loss: {}".format(loss_validation))


        print ("Epoch {} done".format(j))

#         torch.save(fp, "torch_gnn_models/initial/fp")
#         torch.save(fi, "torch_gnn_models/initial/fi")
#         torch.save(fo, "torch_gnn_models/initial/fo")
#         torch.save(fx, "torch_gnn_models/initial/fx")
#         torch.save(fc, "torch_gnn_models/initial/fc")

In [None]:
run()

  0%|                                          | 1/9200 [00:00<47:18,  3.24it/s]

Curr training loss avg: 22.160371780395508
Avg val loss: 22.140254974365234
Curr val loss: 22.140254974365234


  1%|▍                                       | 102/9200 [00:17<25:46,  5.88it/s]