In [1]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

import numpy as np
import pandas as pd

In [2]:
import torch
from torch import nn, optim
import torch.nn.functional as F

In [3]:
from torch.utils.data import Dataset, DataLoader


class TabularDataset(Dataset):
    def __init__(self, data, cat_cols=None, output_col=None):
        """
        Characterizes a Dataset for PyTorch

        Parameters
        ----------

        data: pandas data frame
          The data frame object for the input data. It must
          contain all the continuous, categorical and the
          output columns to be used.

        cat_cols: List of strings
          The names of the categorical columns in the data.
          These columns will be passed through the embedding
          layers in the model. These columns must be
          label encoded beforehand. 

        output_col: string
          The name of the output variable column in the data
          provided.
        """

        self.n = data.shape[0]

        if output_col:
            self.y = data[output_col].astype(np.float32).values.reshape(-1, 1)
        else:
            self.y =  np.zeros((self.n, 1))

        self.cat_cols = cat_cols if cat_cols else []
        self.cont_cols = [col for col in data.columns
                          if col not in self.cat_cols + [output_col]]

        if self.cont_cols:
            self.cont_X = data[self.cont_cols].astype(np.float32).values
        else:
            self.cont_X = np.zeros((self.n, 1))

        if self.cat_cols:
            self.cat_X = data[cat_cols].astype(np.int64).values
        else:
            self.cat_X =  np.zeros((self.n, 1))

    def __len__(self):
        """
        Denotes the total number of samples.
        """
        return self.n

    def __getitem__(self, idx):
        """
        Generates one sample of data.
        """
        return [self.y[idx], self.cont_X[idx], self.cat_X[idx]]

In [4]:
class FeedForwardNN(nn.Module):

    def __init__(self, emb_dims, no_of_cont, lin_layer_sizes,
               output_size, emb_dropout, lin_layer_dropouts):

        """
        Parameters
        ----------

        emb_dims: List of two element tuples
          This list will contain a two element tuple for each
          categorical feature. The first element of a tuple will
          denote the number of unique values of the categorical
          feature. The second element will denote the embedding
          dimension to be used for that feature.

        no_of_cont: Integer
          The number of continuous features in the data.

        lin_layer_sizes: List of integers.
          The size of each linear layer. The length will be equal
          to the total number
          of linear layers in the network.

        output_size: Integer
          The size of the final output.

        emb_dropout: Float
          The dropout to be used after the embedding layers.

        lin_layer_dropouts: List of floats
          The dropouts to be used after each linear layer.
        """

        super().__init__()
        
        self.no_of_cont = no_of_cont

        # Embedding layers
        self.emb_layers = nn.ModuleList([nn.Embedding(x, y)
                                         for x, y in emb_dims])
        self.emb_dropout_layer = nn.Dropout(emb_dropout)
        self.first_bn_layer = nn.BatchNorm1d(self.no_of_cont)

        no_of_embs = sum([y for x, y in emb_dims])
        self.no_of_embs = no_of_embs

        # Linear Layers
        first_lin_layer = nn.Linear(self.no_of_embs + self.no_of_cont,
                                    lin_layer_sizes[0])
        
        sizes = self.get_sizes(lin_layer_sizes, output_size)
        actns = [nn.ReLU(inplace=True)] * (len(sizes)-2) + [None]
        layers = []
        for i,(n_in,n_out,dp,act) in enumerate(zip(sizes[:-1],sizes[1:],[0.]+lin_layer_dropouts,actns)):
            layers.append(nn.BatchNorm1d(n_in))
            layers.append(nn.Dropout(dp))
            layers.append(nn.Linear(n_in, n_out))
            
            if act is not None:
                layers.append(act)
        
        self.layers = nn.Sequential(*layers)
        self.layers[-1].register_forward_hook(self.hook_output)
    
    def hook_output(self, module, inp, out):
        self.input_hook = inp
        self.output_hook = out
    
    def get_sizes(self, layers, out_sz):
        return [self.no_of_embs + self.no_of_cont] + layers + [out_sz]

    def forward(self, cont_data, cat_data):
        if self.no_of_embs != 0:
            x = [emb_layer(cat_data[:, i])
               for i, emb_layer in enumerate(self.emb_layers)]
            x = torch.cat(x, 1)
            x = self.emb_dropout_layer(x)

        if self.no_of_cont != 0:
            normalized_cont_data = self.first_bn_layer(cont_data)

            if self.no_of_embs != 0:
                x = torch.cat([x, normalized_cont_data], 1) 
            else:
                x = normalized_cont_data

        x = self.layers(x)
        return x

In [5]:
def rebalance(frame, col='hab_lbl', factor=1):
    max_size = frame[col].value_counts().max()
    lst = [frame]
    for class_index, group in frame.groupby(col):
        lst.append(group.sample(int((max_size-len(group)) / factor), replace=True))
    frame_new = pd.concat(lst)
    
    return frame_new

In [6]:
df = pd.read_csv('dataset-rocky-no-STemp.csv')
df = df.drop('P. Habitable', axis=1)
y = df['hab_lbl']
x = df.drop('hab_lbl', axis=1)

In [7]:
train_df, valid_df = train_test_split(df, train_size=0.7)



In [8]:
train_df = rebalance(train_df)

In [9]:
cat_vars = train_df.columns[np.where(train_df.dtypes == 'int64')]
cat_vars = cat_vars.tolist()
cat_vars.remove('hab_lbl')

In [10]:
from sklearn.preprocessing import LabelEncoder

In [11]:
label_encoders = {}

In [12]:
for cat_col in cat_vars:
    label_encoders[cat_col] = LabelEncoder()
    df[cat_col] = label_encoders[cat_col].fit_transform(df[cat_col])

In [13]:
dataset = TabularDataset(data=df, cat_cols=cat_vars,
                        output_col='hab_lbl')

In [14]:
batchsize = 128
dataloader = DataLoader(dataset, batchsize, shuffle=True, num_workers=1)

In [15]:
cat_dims = [train_df[col].nunique() for col in cat_vars]

In [16]:
cat_dims

[4, 6, 2, 4]

In [17]:
emb_dims = [(x, int(round(1.6 * x ** 0.56))) for x in cat_dims]

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [19]:
model = FeedForwardNN(emb_dims, no_of_cont=37, lin_layer_sizes=[50, 100],
                          output_size=3, emb_dropout=0.04,
                          lin_layer_dropouts=[0.001,0.01]).to(device)

In [20]:
model

FeedForwardNN(
  (emb_layers): ModuleList(
    (0): Embedding(4, 3)
    (1): Embedding(6, 4)
    (2): Embedding(2, 2)
    (3): Embedding(4, 3)
  )
  (emb_dropout_layer): Dropout(p=0.04)
  (first_bn_layer): BatchNorm1d(37, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layers): Sequential(
    (0): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Dropout(p=0.0)
    (2): Linear(in_features=49, out_features=50, bias=True)
    (3): ReLU(inplace)
    (4): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): Dropout(p=0.001)
    (6): Linear(in_features=50, out_features=100, bias=True)
    (7): ReLU(inplace)
    (8): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): Dropout(p=0.01)
    (10): Linear(in_features=100, out_features=3, bias=True)
  )
)

In [21]:
no_of_epochs = 10
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

In [22]:
for y, cont_x, cat_x in dataloader:
    print(y)

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
      

In [24]:
for epoch in range(no_of_epochs):
    running_loss = 0
    for y, cont_x, cat_x in dataloader:
        cat_x = cat_x.to(device)
        cont_x = cont_x.to(device)
        y  = y.to(device)
        # Forward Pass
        preds = model(cont_x, cat_x)
        loss = criterion(preds, y)
        
        running_loss += loss.item()
        
        # Backward Pass and Optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print('Epoch', epoch + 1, 'loss =', running_loss / len(dataloader))

Epoch 1 loss = 0.02317662101372012
Epoch 2 loss = 0.013983973584670042
Epoch 3 loss = 0.012686326650769584
Epoch 4 loss = 0.015390440821647644
Epoch 5 loss = 0.013417569853897606
Epoch 6 loss = 0.012732091492840223
Epoch 7 loss = 0.012395727315119334
Epoch 8 loss = 0.012726676683606846
Epoch 9 loss = 0.012433614964330835
Epoch 10 loss = 0.014887962490320206


In [25]:
from Scheduler import LipschitzSGDScheduler

In [26]:
optimizer = optim.SGD(model.parameters(), lr=0.1)

In [27]:
scheduler = LipschitzSGDScheduler(optimizer, model, 3, batchsize)

In [28]:
for epoch in range(no_of_epochs):
    scheduler.step()
    running_loss = 0
    
    for y, cont_x, cat_x in dataloader:          
        cat_x = cat_x.to(device)
        cont_x = cont_x.to(device)
        y  = y.to(device)

        # Forward Pass
        preds = model(cont_x, cat_x)
        print(torch.argmax(preds, dim=1))
        print(y)
        loss = criterion(preds, y)
        
        running_loss += loss.item()

        # Backward Pass and Optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print('Epoch', epoch+1, 'LR =', scheduler.get_lr(), 'Loss =', running_loss / len(dataloader))

tensor([1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1,
        1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1,
        0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,
        1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1,
        1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1,
        1, 1, 1, 0, 1, 1, 1, 0])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
       

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
      

Epoch 1 LR = [0.01741372173031171] Loss = 0.010483896093709128
tensor([1, 0, 2, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0,
        1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0,
        1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 2, 1, 0, 0, 2, 0,
        0, 2, 1, 0, 0, 0, 1, 0, 2, 1, 2, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,
        1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1,
        2, 1, 0, 1, 1, 0, 0, 0])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],


tensor([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 2, 0, 1, 1,
        0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0,
        1, 1, 0, 0, 1, 2, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1,
        1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,
        1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1,
        0, 1, 0, 1, 0, 1, 0, 1])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
       

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
      

tensor([1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 2, 1, 0, 1,
        0, 1, 1, 1, 1, 0, 1, 1, 1, 2, 1, 1, 2, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1,
        1, 2, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0,
        1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0,
        0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 2, 0, 1, 0, 0, 1, 1, 0, 1, 0,
        1, 1, 1, 2, 0, 1, 0, 1])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
       

tensor([1, 0, 1, 1, 0, 1, 1, 2, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1,
        0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0,
        1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1,
        1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 2, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0,
        1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1,
        1, 1, 1, 0, 0, 1, 0, 1])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
       

tensor([0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 0, 2, 0, 0, 1, 1, 0, 0,
        1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 2, 1,
        1])
tensor([[0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.]])
Epoch 3 LR = [0.023549703260262806] Loss = 0.010126363407055448
tensor([0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 2, 0, 1, 0, 1, 1, 1, 0, 

tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,
        0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1,
        0, 0, 1, 1, 1, 1, 2, 0, 1, 1, 1, 2, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1,
        1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 2,
        1, 1, 1, 0, 0, 2, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0,
        1, 1, 1, 0, 1, 0, 1, 0])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
       

tensor([1, 2, 1, 1, 0, 0, 0, 0, 0, 1, 2, 1, 2, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1,
        1, 1, 2, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1,
        1, 2, 1, 0, 1, 1, 0, 1, 1, 2, 0, 1, 1, 1, 1, 2, 0, 1, 1, 2, 2, 1, 2, 1,
        1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1,
        1, 0, 1, 0, 1, 1, 1, 1, 0, 2, 0, 1, 1, 1, 1, 0, 2, 1, 0, 1, 1, 2, 2, 0,
        1, 0, 2, 0, 1, 0, 1, 0])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
       

tensor([1, 2, 0, 0, 1, 2, 0, 1, 1, 1, 1, 1, 1, 0, 1, 2, 1, 0, 0, 1, 1, 1, 1, 0,
        1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,
        0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 2, 0, 1, 0, 1, 1, 2, 1, 1, 1, 1, 1,
        0, 1, 1, 1, 1, 2, 0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 0, 1, 1,
        1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 2, 0, 0, 2, 1, 1, 1, 1,
        0, 0, 1, 2, 0, 1, 1, 0])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
       

tensor([1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0, 0, 1,
        0, 0, 2, 1, 1, 0, 1, 1, 0, 1, 2, 1, 1, 0, 1, 0, 2, 0, 0, 0, 0, 0, 1, 1,
        0, 2, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0,
        1, 1, 0, 1, 1, 2, 1, 1, 1, 0, 0, 1, 0, 1, 2, 1, 1, 1, 1, 0, 1, 1, 1, 0,
        0, 1, 2, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 0, 0, 1, 2])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
       

tensor([2, 2, 0, 0, 2, 1, 2, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1,
        1, 2, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 0, 1, 0, 1, 1, 1,
        1, 0, 1, 1, 1, 1, 1, 1, 0, 2, 1, 2, 2, 0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 0,
        1, 1, 2, 0, 0, 1, 1, 1, 1, 1, 0, 1, 2, 0, 2, 0, 0, 0, 1, 1, 1, 0, 0, 1,
        1, 0, 1, 0, 1, 2, 1, 0, 0, 1, 0, 2, 2, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0,
        1, 0, 1, 1, 1, 2, 1, 1])
tensor([[0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
       

tensor([1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 2, 1, 1, 0, 1, 0, 1, 0, 2, 1, 1, 1, 2,
        2, 1, 1, 1, 2, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1,
        0, 0, 2, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 2, 2, 2, 0, 1, 0, 1, 0, 1, 1, 1,
        1, 1, 0, 1, 1, 0, 1, 2, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1,
        1, 2, 1, 0, 1, 1, 1, 1, 0, 0, 2, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0,
        0, 0, 1, 2, 0, 0, 1, 0])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
       

tensor([0, 0, 1, 0, 0, 1, 2, 1, 1, 2, 0, 2, 1, 0, 0, 1, 2, 1, 1, 1, 0, 1, 0, 1,
        0, 2, 1, 2, 1, 2, 2, 2, 2, 0, 0, 0, 2, 0, 1, 1, 1, 2, 1, 0, 1, 0, 2, 1,
        1, 1, 1, 1, 1, 1, 0, 2, 0, 2, 0, 1, 1, 1, 1, 2, 1, 0, 2, 1, 2, 2, 0, 0,
        0, 0, 0, 0, 1, 1, 0, 1, 1, 2, 1, 1, 1, 1, 0, 1, 2, 1, 0, 1, 1, 1, 1, 1,
        1, 0, 0, 1, 1, 1, 1, 2, 2, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 2, 0, 0, 2, 0,
        1, 2, 1, 1, 1, 1, 1, 0])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
       

tensor([1, 1, 1, 0, 0, 2, 1, 1, 0, 2, 0, 0, 1, 1, 0, 2, 1, 0, 0, 2, 1, 0, 1, 2,
        1, 0, 2, 0, 1, 2, 2, 1, 0, 0, 2, 2, 0, 2, 0, 1, 2, 1, 1, 1, 0, 2, 0, 2,
        1, 2, 0, 0, 0, 2, 1, 1, 0, 0, 1, 1, 1, 1, 0, 2, 1, 0, 1, 0, 1, 1, 1, 0,
        0, 1, 0, 0, 1, 1, 1, 0, 1, 2, 1, 1, 2, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1,
        2, 0, 1, 2, 0, 1, 0, 2, 0, 2, 0, 2, 1, 2, 1, 0, 1, 1, 2, 1, 0, 1, 2, 0,
        1, 1, 2, 1, 1, 1, 2, 0])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
       

tensor([2, 1, 0, 2, 1, 1, 1, 1, 0, 2, 0, 2, 1, 0, 1, 0, 0, 1, 1, 2, 1, 1, 1, 1,
        0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 2, 0, 0, 2, 0, 1, 1, 0,
        0, 1, 1, 0, 1, 2, 2, 1, 1, 2, 1, 0, 2, 0, 1, 2, 0, 0, 1, 2, 0, 2, 0, 0,
        0, 1, 2, 1, 1, 0, 2, 1, 2, 1, 1, 1, 2, 1, 0, 1, 2, 2, 2, 1, 1, 0, 2, 0,
        2, 1, 2, 0, 0, 1, 1, 1, 2, 0, 0, 2, 2, 1, 1, 1, 1, 0, 0, 1, 1, 0, 2, 1,
        1, 0, 2, 0, 1, 2, 1, 1])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
       

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
      

tensor([1, 1, 0, 0, 1, 1, 0, 0, 0, 2, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1,
        2, 1, 0, 1, 1, 2, 2, 1, 0, 2, 2, 1, 0, 2, 0, 2, 0, 2, 2, 1, 0, 2, 2, 1,
        1, 1, 1, 0, 0, 2, 1, 1, 1, 0, 0, 0, 1, 2, 2, 0, 1, 2, 0, 1, 0, 0, 1, 1,
        0, 0, 1, 1, 2, 0, 1, 1, 1, 1, 2, 0, 1, 1, 1, 2, 0, 1, 1, 1, 0, 1, 1, 2,
        1, 1, 1, 1, 2, 1, 1, 2, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 2, 0, 1, 1, 1, 0,
        0, 0, 1, 1, 1, 1, 1, 2])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
       

tensor([[0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
      

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
      

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
      

tensor([2, 0, 1, 1, 1, 1, 1, 2, 1, 1, 0, 1, 0, 1, 1, 0, 2, 0, 1, 1, 1, 0, 1, 0,
        1, 0, 0, 2, 1, 2, 1, 1, 1, 0, 0, 2, 2, 1, 0, 0, 0, 0, 1, 1, 1, 2, 1, 2,
        2])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [2.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]])
Epoch 9 LR = [0.01844946170846621] Loss = 0.006809518013532008
tensor([2, 1, 0, 0, 1, 1, 1, 2, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 2, 1

tensor([2, 0, 2, 2, 1, 1, 1, 2, 0, 1, 0, 0, 1, 0, 2, 1, 2, 0, 2, 2, 2, 1, 1, 1,
        1, 1, 1, 0, 0, 2, 2, 0, 1, 1, 1, 0, 2, 2, 1, 2, 2, 1, 1, 1, 0, 1, 2, 0,
        0, 1, 1, 0, 0, 1, 2, 1, 1, 0, 0, 0, 1, 1, 0, 2, 1, 1, 2, 1, 2, 0, 2, 1,
        0, 2, 1, 2, 0, 1, 2, 2, 2, 0, 2, 0, 1, 2, 1, 1, 2, 0, 1, 1, 1, 2, 1, 2,
        0, 0, 0, 1, 1, 1, 2, 1, 2, 0, 0, 0, 1, 2, 2, 1, 0, 1, 1, 1, 2, 2, 1, 1,
        2, 1, 2, 1, 1, 1, 1, 1])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
       

tensor([1, 0, 1, 0, 1, 0, 1, 2, 0, 0, 0, 0, 1, 2, 0, 0, 1, 0, 0, 1, 0, 0, 2, 0,
        0, 1, 1, 2, 0, 0, 1, 1, 2, 0, 1, 0, 1, 1, 1, 0, 2, 1, 1, 2, 1, 0, 0, 1,
        2, 1, 1, 0, 0, 1, 1, 2, 2, 1, 0, 1, 1, 1, 2, 0, 1, 1, 0, 1, 0, 1, 2, 1,
        1, 0, 0, 0, 2, 0, 0, 0, 1, 0, 2, 1, 0, 0, 2, 2, 1, 0, 1, 0, 0, 0, 1, 0,
        1, 2, 0, 2, 0, 1, 1, 0, 0, 0, 0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 0, 2, 1,
        0, 1, 1, 0, 0, 0, 0, 1])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
       