In [2]:
import torch
import pickle
import numpy as np
import pandas as pd
import os

from os.path import dirname



root_path = dirname(os.getcwd()) + "/HGNN_NA"

pd.set_option("display.max_columns", None)
data_dir = root_path + "/data/datasets/original/"
data_dir_processed = root_path + "/data/datasets/processed/"
data_dir_graphs = root_path + "/data/datasets/graphs/"

print(root_path, data_dir, data_dir_processed, data_dir_graphs, sep="\n")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

/home/sebdis/HGNN/HGNN_NA
/home/sebdis/HGNN/HGNN_NA/data/datasets/original/
/home/sebdis/HGNN/HGNN_NA/data/datasets/processed/
/home/sebdis/HGNN/HGNN_NA/data/datasets/graphs/


device(type='cuda', index=0)

In [3]:
dataset = "BPI_Challenge_2012_A"

In [4]:
tab_all = pd.read_csv(data_dir_processed+dataset+"_processed_all.csv")
print(tab_all.head())
list_activities = list(tab_all["Activity"].unique())

   org:resource lifecycle:transition           Activity       time:timestamp  \
0           112             COMPLETE        A_SUBMITTED  2011/09/30 22:38:44   
1           112             COMPLETE  A_PARTLYSUBMITTED  2011/09/30 22:38:44   
2           112             COMPLETE      A_PREACCEPTED  2011/09/30 22:39:37   
3         10862             COMPLETE         A_ACCEPTED  2011/10/01 09:42:43   
4         10862             COMPLETE        A_FINALIZED  2011/10/01 09:45:09   

         case:REG_DATE  CaseID  case:AMOUNT_REQ  
0  2011/10/01 00:38:44  173688            20000  
1  2011/10/01 00:38:44  173688            20000  
2  2011/10/01 00:38:44  173688            20000  
3  2011/10/01 00:38:44  173688            20000  
4  2011/10/01 00:38:44  173688            20000  


In [5]:
import random

torch.manual_seed(0)
torch.cuda.manual_seed(0)
random.seed(0)
np.random.seed(0)

In [6]:
with open(data_dir_graphs + dataset + "_TRAIN_v3.pkl", "rb") as f:
    X_train, Y_train = pickle.load(f)
with open(data_dir_graphs + dataset + "_VALID_v3.pkl", "rb") as f:
    X_valid, Y_valid = pickle.load(f)
with open(data_dir_graphs + dataset + "_TEST_v3.pkl", "rb") as f:
    X_test, Y_test = pickle.load(f)

In [None]:
with open(data_dir_graphs + dataset + "_TRAIN.pkl", "rb") as f:
    X_train, Y_train = pickle.load(f)
with open(data_dir_graphs + dataset + "_VALID.pkl", "rb") as f:
    X_valid, Y_valid = pickle.load(f)
with open(data_dir_graphs + dataset + "_TEST.pkl", "rb") as f:
    X_test, Y_test = pickle.load(f)

In [36]:
from typing_extensions import Self
from torch_geometric.data import Dataset
from torch.utils.data import DataLoader
from torch_geometric.transforms import ToUndirected, NormalizeFeatures

transform = ToUndirected()
t2 = NormalizeFeatures(["Activity"])


class Het_graph_data(Dataset):
    def __init__(self, prefix_graphs, labels) -> Self:
        self.X = prefix_graphs
        self.Y = labels

    # get the number of rows in the dataset
    def __len__(self):
        return len(self.Y)

    # get a row at a particular index in the dataset
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

    @staticmethod
    def collate(batch):
        # print(batch)
        data = [t2(item[0]) for item in batch]
        # data = [transform(item[0]) for item in batch]
        
        Y = [item[1] for item in batch]
        return [data, Y]

In [37]:
train_loader = DataLoader(
    Het_graph_data(X_train + X_valid, Y_train + Y_valid),
    batch_size=512,
    shuffle=True,
    collate_fn=Het_graph_data.collate,
)


valid_loader = DataLoader(
    Het_graph_data(X_valid, Y_valid),
    batch_size=2,
    shuffle=False,
    collate_fn=Het_graph_data.collate,
)

test_loader = DataLoader(
    Het_graph_data(X_test, Y_test),
    batch_size=32,
    shuffle=False,
    collate_fn=Het_graph_data.collate,
)

In [22]:
# A Class to keep track of the metrics of the classification process
class ClassificationMetrics:

  # Constructor takes the number of classes, in our case 20
  def __init__(self, num_classes=20):
    self.num_classes = num_classes
    # Initialize a confusion matrix
    self.C = torch.zeros(num_classes, num_classes)

  # Update the confusion matrix with the new scores
  def add(self, yp, yt):
    # yp: 1D tensor with predictions
    # yt: 1D tensor with ground-truth targets
    yp = yp.to("cpu")
    yt = yt.to("cpu")
    with torch.no_grad(): # We require no computation graph
      self.C+=(yt*self.C.shape[1]+yp).bincount(minlength=self.C.numel()).view(self.C.shape).float()

  def clear(self):
    # We set the confusion matrix to zero
    self.C.zero_()

  # Computes the global accuracy
  def acc(self):
    return self.C.diag().sum().item()/self.C.sum()

  # Computes the class-averaged accuracy
  def mAcc(self):
    return (self.C.diag()/self.C.sum(-1)).mean().item()

  # Computers the class-averaged Intersection over Union
  def mIoU(self):
    return (self.C.diag()/(self.C.sum(0)+self.C.sum(1)-self.C.diag())).mean().item()

  # Returns the confusion matrix
  def confusion_matrix(self):
    return self.C

In [38]:
loaders = {"train": train_loader, "validation" : valid_loader, "test" : test_loader}

In [24]:
node_types, edge_types = X_train[0].metadata()

In [None]:
X_train[0]


In [25]:
node_types

['org:resource',
 'lifecycle:transition',
 'Activity',
 'time:timestamp',
 'case:REG_DATE',
 'case:AMOUNT_REQ']

In [26]:
edge_types

[('org:resource', 'related_to', 'org:resource'),
 ('Activity', 'followed_by', 'Activity'),
 ('time:timestamp', 'related_to', 'time:timestamp'),
 ('Activity', 'related_to', 'org:resource'),
 ('Activity', 'related_to', 'lifecycle:transition'),
 ('Activity', 'related_to', 'time:timestamp'),
 ('Activity', 'related_to', 'case:REG_DATE'),
 ('Activity', 'related_to', 'case:AMOUNT_REQ')]

In [27]:
# from models.models import HGNN
import datetime

from torch import nn
from tqdm.notebook import tqdm


In [42]:
from copy import deepcopy
from typing_extensions import Self
from torch_geometric.nn import SAGEConv, HeteroConv, GATConv, GCN  # Linear, GCNConv
from torch.nn import ModuleList, Module, Sequential, Softmax, Dropout, Linear, ReLU
from torch import mean, stack, sum, concat


class HGNN(Module):

    def __init__(self, hid, out, layers, node_types, nodes_relations) -> Self:  # type: ignore
        super().__init__()

        # List of convolutional layers
        self.convs = ModuleList()
        for _ in range(layers):
            conv = HeteroConv(
                {relation: SAGEConv((-1, -1), hid, normalize=False) for relation in nodes_relations}
                # {('Activity', 'followed_by', 'Activity') : SAGEConv((-1,-1), hid),
                #  ("org:resource", "related_to", 'org:resource') : SAGEConv((-1,-1), hid),
                #  ('time:timestamp', "related_to", "time:timestamp") : SAGEConv((-1,-1), hid)}
                ,
                aggr="sum",
            )
            self.convs.append(conv)

        # print(nodes_relations)
        # Take each node hid representation and apply a linear layer
        # self.linear_nodes = Sequential(Linear(hid, hid),ReLU(), Dropout(p=0.5), Linear(hid, hid), ReLU(), Dropout(p=0.5), Linear(hid, int(hid / 2)), ReLU())

        # Return the softmax with the class probabilities
        # self.fc = Sequential(Linear(int(hid/2)*(len(node_types)), out), ReLU(), Softmax(dim=0))
        self.fc = Sequential(
            Linear(hid * (len(node_types)), 64),
            ReLU(),
            nn.BatchNorm1d(64),
            Dropout(p=0.5),
            Linear(64, 32),
            ReLU(),
            nn.BatchNorm1d(32),
            Linear(32, out),
            # Softmax(dim=1)
            # ReLU()
        )

    def forward(self, x, edge_index):
        outs = []
        for i in range(len(x)):
            for conv in self.convs:
                x[i] = conv(x[i], edge_index[i])
                x[i] = {node_type: nodes_features.relu() for node_type, nodes_features in x[i].items()}

            # Node features of each node in the graph
            nodes_features = [x[i][key] for key in x[i].keys()]
            # print(nodes_features)
            

            # Global sum of each node type
            for j in range(len(nodes_features)):
                nodes_features[j] = sum(nodes_features[j], dim=0)

            # Global mean pooling
            # nodes_features = mean(stack(nodes_features), dim=0)
            nodes_features = concat(nodes_features)
            outs.append(nodes_features)


        out = self.fc(stack(outs))

        return out  # {key : self.linear(x_dict[key]) for key in x_dict.keys()}, nodes_features

org:resource, related_to, org:resource)={
    edge_attr=[1, 1],
    edge_index=[2, 1],
  },
  (Activity, followed_by, Activity)={
    edge_attr=[1, 1],
    edge_index=[2, 1],
  },
  (time:timestamp, related_to, time:timestamp)

In [None]:
model = HGNN(hid=128, out=len(list_activities), layers=4, node_types=node_types, nodes_relations=edge_types)
model

In [None]:
X_train[0].x_dict.items()

In [None]:
X_train[0].edge_items()

In [None]:
from torch_geometric.transforms import ToUndirected

X_train[0] = ToUndirected()(X_train[0])

In [None]:
X_train[0].x_dict

In [None]:
X_train[0].edge_items()

In [None]:
X_train[0].is_directed()

In [27]:
cl_train = {key:0 for key in range(10)}
for i,(x,y) in enumerate(train_loader):
    _,classes = torch.max(stack(y), dim=1)
    for c in list(classes):
        try:
            cl_train[c.item()] +=1
        except KeyError:
            cl_train[c.item()] = 1
            
cl_train

{0: 3370,
 1: 1553,
 2: 1553,
 3: 1991,
 4: 5173,
 5: 3311,
 6: 0,
 7: 4846,
 8: 1553,
 9: 0}

In [None]:
weights = []

In [None]:
cl_val = {key:0 for key in range(10)}
for i,(x,y) in enumerate(valid_loader):
    _,classes = torch.max(stack(y), dim=1)
    for c in list(classes):
        try:
            cl_val[c.item()] +=1
        except KeyError:
            cl_val[c.item()] = 1
            
cl_val

In [None]:
cl_test = {key:0 for key in range(10)}
for i,(x,y) in enumerate(test_loader):
    _,classes = torch.max(stack(y), dim=1)
    for c in list(classes):
        try:
            cl_test[c.item()] +=1
        except KeyError:
            cl_test[c.item()] = 1
            
cl_test

In [53]:
from torch import tensor


s = 0
for i in cl_train:
    s += cl_train[i]
weights = [s/cl_train[k] if cl_train[k] != 0 else 0 for k in cl_train ]

# weights = [0.7,0.7,1,0.7,0.7,0.7,0.7,0.7,0.7,0.7]
weights = tensor(weights, device=device)
weights[0] *= 0.7
weights[1] *= 1
weights[2] *= 0.7
weights[3] *= 0.7
weights[4] *= 1

weights[0] *= 1
weights

tensor([ 6.9288, 10.5248, 10.5248,  8.2094,  4.5138,  7.0522,  0.0000,  4.8184,
        15.0354,  0.0000], device='cuda:0')

In [43]:
model = HGNN(hid=16, out=len(list_activities), layers=2, node_types=node_types, nodes_relations=edge_types)
model.to(device)

HGNN(
  (convs): ModuleList(
    (0-1): 2 x HeteroConv(num_relations=8)
  )
  (fc): Sequential(
    (0): Linear(in_features=96, out_features=64, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=64, out_features=32, bias=True)
    (5): ReLU()
    (6): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Linear(in_features=32, out_features=10, bias=True)
  )
)

In [44]:


with torch.no_grad():
    for i,(x,y) in enumerate(train_loader):
        x = [xx.to(device) for xx in x]
        
        x_edge_index_dicts = [xx.edge_index_dict for xx in x]
        x_dicts = [xx.x_dict for xx in x]
        # print("//"*50)
        #print(x)
        #print(x.edge_index_dict)
        #print(x.x_dict)
        # print("//"*50)
        
        y = [yy.to(device) for yy in y]
        y = stack(y)
        model(x_dicts, x_edge_index_dicts)
        break

[tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 3.2842e+00, 9.1865e+01, 0.0000e+00,
         3.2269e+01, 0.0000e+00, 0.0000e+00, 1.8723e+01, 0.0000e+00, 1.8195e+01,
         3.5143e+01, 0.0000e+00, 0.0000e+00, 3.8571e+01],
        [0.0000e+00, 0.0000e+00, 6.1416e+00, 6.2157e+01, 6.3387e+01, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 1.6856e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         2.2825e+01, 0.0000e+00, 0.0000e+00, 1.0471e+02],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 3.7172e+02, 8.8430e+03, 0.0000e+00,
         3.0435e+03, 0.0000e+00, 0.0000e+00, 1.7689e+03, 0.0000e+00, 1.7521e+03,
         3.4165e+03, 0.0000e+00, 0.0000e+00, 3.8378e+03],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 2.0044e+03, 7.8361e+03, 0.0000e+00,
         1.4574e+03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.1400e+03,
         2.8791e+03, 0.0000e+00, 0.0000e+00, 6.1738e+03],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 2.9375e+03, 7.3955e+03, 0.0000e+00,
         4.6267e+02, 0.0000e+00, 0.000

In [35]:
from torch import int32, tensor

# model = HGNN(hid=16, out=len(list_activities), layers=2, node_types=node_types, nodes_relations=edge_types)


num_epochs = 20
best_accuracy = 0
early_stop_patience = 10
lr_value = 0.01

best_model = None

num_runs = 1
running_time = []

metric_tracker = ClassificationMetrics(num_classes=len(list_activities))

for run in range(num_runs):

    start = datetime.datetime.now()
    print("Run: {}".format(run + 1))

    # model = model.to(device)

    # print("compiling...")
    # model = torch.compile(model, dynamic=True)
    # print("Compiled model")

    # criterion = nn.CrossEntropyLoss(weight=weights, label_smoothing=0.1)
    # criterion = nn.CrossEntropyLoss(weight=weights)
    criterion = nn.CrossEntropyLoss()

    # optimizer = torch.optim.Adagrad(model.parameters(), lr=lr_value) # momentum=0.9, weight_decay=1e-1)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr_value)

    not_improved_count = 0

    for epoch in range(num_epochs):
        print(
            "\n-- EPOCH {}/{} -------------------------\n".format(epoch + 1, num_epochs)
        )
        torch.cuda.empty_cache()
        count_train = [0 for _ in range(len(list_activities))]
        count_true_train = [0 for _ in range(len(list_activities))]
        count_val = [0 for _ in range(len(list_activities))]
        count_true_val = [0 for _ in range(len(list_activities))]
        for state in ["train", "test"]:
            if state == "train":
                model.train()
                metric_tracker.clear()
                # true_label = []
                # predictions = []
                # scores = []

            else:

                # print(metric_tracker.confusion_matrix())
                # print(count_train)
                # print(count_true_train)

                print(
                    "\tTRAIN | acc: {:.4f} | mAcc: {:.4f} | mIoU: {:.4f}".format(
                        metric_tracker.acc(),
                        metric_tracker.mAcc(),
                        metric_tracker.mIoU(),
                        # 0,0,
                    )
                )

                # y_true = np.concatenate(true_label)
                # y_pred = np.concatenate(predictions)
                # scores = np.concatenate(scores)

                # print("TRAIN")
                # print(conta(y_true, 0), conta(y_true, 1))
                # print(conta(y_pred, 0), conta(y_pred, 1))

                # print_stats(y_pred, y_true, scores)
                metric_tracker.clear()
                model.eval()
                # true_label = []
                # predictions = []
                # scores = []

            running_loss = 0.0
            running_corrects = 0

            for i, (x, y) in tqdm(enumerate(loaders[state])):

                if i % 1 == 0:
                    print(metric_tracker.acc())
                    # print(running_loss/i)
                    print(metric_tracker.confusion_matrix())
                    # print(count_train)
                    # print(count_true_train)
                    print("//" * 14)
                # print("X")
                # print("//"*50)
                # print(x)
                # print("//"*50)
                # print(y)
                x = [xx.to(device) for xx in x]
                x_edge_index_dicts = [xx.edge_index_dict for xx in x]
                x_dicts = [xx.x_dict for xx in x]
                # print("//"*50)
                # print(x)
                # print(x.edge_index_dict)
                # print(x.x_dict)
                # print("//"*50)

                y = [yy.to(device) for yy in y]
                y = stack(y)

                # x = [[sub_item.to(device=device) for sub_item in item] for item in x]

                # y = torch.tensor([torch.max(yi,0)[1] for yi in y])

                # y = y.to(device)

                # print(outputs)
                # print("//"*30)

                # outputs = outputs.to(device)

                # print(outputs)
                # print(y)

                if state == "train":
                    optimizer.zero_grad()
                    _, true_preds = torch.max(y, 1)
                    outputs = model(x_dicts, x_edge_index_dicts)
                    loss = criterion(outputs, true_preds)
                    loss.backward()
                    optimizer.step()
                else:
                    _, true_preds = torch.max(y, 1)
                    outputs = model(x_dicts, x_edge_index_dicts)

                _, preds = torch.max(outputs, 1)

                # print(y)
                # print(true_preds)
                preds = preds.to(device)
                true_preds = true_preds.to(device)
                # if state == "train": # For now batch is set to one
                #     #for i in preds:
                #     count_train[preds] += 1
                #     count_true_train[true_preds] += 1
                #
                # else:
                #     #for i in preds:
                #     count_val[preds] += 1
                #     count_true_val[true_preds] += 1

                # running_loss += loss.item()

                # print(outputs, y, sep="\n")
                # print(preds, true_preds, sep="\n")
                # print(preds)
                # print(true_preds)
                metric_tracker.add(
                    torch.tensor(preds.to(int32)), torch.tensor(true_preds.to(int32))
                )
                # print(metric_tracker.confusion_matrix())

        print(count_val)
        print(
            "\tEVAL  | acc: {:.4f} | mAcc: {:.4f} | mIoU: {:.4f}\n".format(
                metric_tracker.acc(), metric_tracker.mAcc(), metric_tracker.mIoU()
            )
            # 0,0)
        )

        if epoch == 0:
            best_accuracy = metric_tracker.acc()

        else:
            if metric_tracker.acc() > best_accuracy:
                print("SAVING MODEL..............\n")
                best_accuracy = metric_tracker.acc()
                not_improved_count = 0
            else:
                not_improved_count += 1

        if not_improved_count == early_stop_patience:
            print(
                "Validation performance didn't improve for {} epochs. "
                "Training stops.".format(early_stop_patience)
            )
            break

    running_time.append((datetime.datetime.now() - start).total_seconds())

Run: 1

-- EPOCH 1/20 -------------------------



0it [00:00, ?it/s]

tensor(nan)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
////////////////////////////


  torch.tensor(preds.to(int32)), torch.tensor(true_preds.to(int32))


tensor(0.0527)
tensor([[ 3.,  0.,  4.,  0., 16.,  4., 22.,  1.,  4.,  2.],
        [ 4.,  9.,  0.,  7.,  0.,  0.,  0.,  8.,  0.,  6.],
        [ 1., 11.,  0.,  5.,  0.,  0.,  0.,  8.,  3.,  9.],
        [ 7.,  7.,  1.,  0.,  6.,  2.,  3.,  5.,  1., 12.],
        [13.,  2.,  8.,  5.,  3., 23., 35.,  5.,  8.,  6.],
        [14.,  2.,  0.,  0.,  1.,  3.,  0.,  0.,  2., 69.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [13.,  0.,  8.,  1.,  9.,  4., 50.,  7.,  5.,  9.],
        [ 2.,  7.,  0.,  4.,  0.,  1.,  0.,  7.,  2., 13.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]])
////////////////////////////
tensor(0.1973)
tensor([[35.,  0.,  6.,  0., 16., 18., 22.,  2.,  4., 30.],
        [ 4., 24.,  0., 14.,  0.,  1.,  2.,  8.,  3.,  9.],
        [ 1., 26.,  1., 11.,  0.,  3.,  1.,  8., 13.,  9.],
        [11., 23.,  3.,  4.,  9.,  4.,  5.,  6.,  4., 22.],
        [38.,  9.,  8.,  7., 29., 35., 39., 23., 10., 12.],
        [17.,  2.,  7.,  0.,  1., 27.,  

0it [00:00, ?it/s]

tensor(nan)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
////////////////////////////
tensor(0.6562)
tensor([[5., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 1., 0., 0., 0., 0.],
        [0., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 3., 0., 0., 2., 0., 0.],
        [0., 0., 0., 0., 0., 5., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 7., 0., 0.],
        [0., 1., 0., 0., 0., 1., 0., 0., 0.

0it [00:00, ?it/s]

tensor(nan)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
////////////////////////////
tensor(0.6406)
tensor([[77.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.],
        [ 0., 19.,  4.,  4.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0., 19.,  6., 12.,  0., 10.,  0.,  0.,  0.,  0.],
        [18.,  0.,  9.,  6.,  1.,  5.,  0.,  0.,  0.,  0.],
        [13.,  0.,  2.,  5., 48.,  3.,  0., 36.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0., 82.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,

0it [00:00, ?it/s]

tensor(nan)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
////////////////////////////
tensor(0.6875)
tensor([[5., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 2., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 3., 0., 0., 2., 0., 0.],
        [0., 0., 0., 0., 0., 5., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 7., 0., 0.],
        [0., 1., 0., 1., 0., 0., 0., 0., 0.

0it [00:00, ?it/s]

tensor(nan)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
////////////////////////////
tensor(0.6895)
tensor([[64.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0., 22.,  4.,  7.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  7.,  6., 11.,  0.,  0.,  0.,  0.,  0.,  0.],
        [11.,  0.,  6., 27.,  1.,  2.,  0.,  0.,  0.,  0.],
        [11.,  0.,  5., 12., 59.,  2.,  0., 30.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0., 79.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 1.,  0.,  0.,

0it [00:00, ?it/s]

tensor(nan)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
////////////////////////////
tensor(0.6875)
tensor([[5., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 1., 0., 0.],
        [1., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 3., 0., 0., 2., 0., 0.],
        [0., 0., 0., 0., 0., 5., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 7., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 1., 0.

0it [00:00, ?it/s]

tensor(nan)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
////////////////////////////
tensor(0.6777)
tensor([[82.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0., 20.,  8.,  5.,  0.,  1.,  0.,  0.,  2.,  0.],
        [ 0., 13.,  7., 11.,  0.,  0.,  0.,  0.,  0.,  0.],
        [10.,  0., 14., 14.,  0.,  0.,  0.,  0.,  0.,  0.],
        [16.,  0.,  5.,  4., 55.,  0.,  0., 35.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0., 73.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(),"76_Acc")