# Text Classification - Training a GAT


## $\color{blue}{Sections:}$

* Preamble
1.   Admin
2.   Dataset
3.   Model
4.   Train - Validate
5.   Training Loop

## $\color{blue}{Preamble:}$

We now train a GAT in basic PyTorch. The model will replace GCN node udates with attention based updates. Note that the GAT solution is immplemented as intended, but the results are sub-standard. There is no learning in the graph attention mechanism.

## $\color{blue}{Admin}$
* Install relevant Libraries
* Import relevant Libraries

In [None]:
import torch
import pandas as pd
from google.colab import drive
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
drive.mount("/content/drive")
%cd '/content/drive/MyDrive'

Mounted at /content/drive
/content/drive/MyDrive


## $\color{blue}{Data}$

* Connect to Drive
* Load the data
* Load adjacency matrices

In [None]:
path = 'class/datasets/'
df_train = pd.read_pickle(path + 'df_train')
df_dev = pd.read_pickle(path + 'df_dev')
df_test = pd.read_pickle(path + 'df_test')

In [None]:
path = 'class/tensors/adj_{}.pt'

# train
train_people = torch.load(path.format('train_people'))
train_locations = torch.load(path.format('train_locations'))
train_entities = torch.load(path.format('train_entities'))

# dev
dev_people = torch.load(path.format('dev_people'))
dev_locations = torch.load(path.format('dev_locations'))
dev_entities = torch.load(path.format('dev_entities'))

# val (contains the adjacency matrix for both the training and the development set)
val_people = torch.load(path.format('val_people'))
val_locations = torch.load(path.format('val_locations'))
val_entities = torch.load(path.format('val_entities'))

  train_people = torch.load(path.format('train_people'))
  train_locations = torch.load(path.format('train_locations'))
  train_entities = torch.load(path.format('train_entities'))
  dev_people = torch.load(path.format('dev_people'))
  dev_locations = torch.load(path.format('dev_locations'))
  dev_entities = torch.load(path.format('dev_entities'))
  val_people = torch.load(path.format('val_people'))
  val_locations = torch.load(path.format('val_locations'))
  val_entities = torch.load(path.format('val_entities'))


## $\color{blue}{Dataset:}$

In [None]:
from torch.utils.data import Dataset, DataLoader
from copy import deepcopy

def sample_neighborhood(A, inds, neighbor_max, branch_max, seed=None):
    # Set the random seed for deterministic responses
    if seed is not None:
        np.random.seed(seed)

    sampled_indices = set(inds)  # Initialize the set of sampled indices

    for ind in inds:  # Iterate through node in mini-batch
        break_to_outer = False
        neighbors = set()

        # Get the indices of all neighbors that idx links to
        disclude = set([ind]) | sampled_indices
        new_neighbors = [neighbor.item() for neighbor in (A[ind] > 0).nonzero(as_tuple=True)[0] if neighbor.item() not in disclude]
        neighbors.update(new_neighbors)

        if len(neighbors) >= neighbor_max:  # Check if we have too many neighbors
            # Take a random subset using np.random.choice
            neighbors = set(np.random.choice(list(neighbors), neighbor_max, replace=False))

        copy_neighbors = deepcopy(neighbors)
        for idx in copy_neighbors:
            if break_to_outer:
              break

            neighbors_neighbors = set()
            disclude = set([ind,idx]) | sampled_indices | neighbors
            new_neighbors_neighbors = [neighbor.item() for neighbor in (A[idx] > 0).nonzero(as_tuple=True)[0] if neighbor.item() not in disclude]
            if len(new_neighbors_neighbors) > neighbor_max:
              new_neighbors_neighbors = set(np.random.choice(list(new_neighbors_neighbors), neighbor_max, replace = False))
            neighbors_neighbors.update(new_neighbors_neighbors)


            if len(neighbors) + len(neighbors_neighbors) >= branch_max:

              neighbors_neighbors = set(np.random.choice(list(neighbors_neighbors), branch_max - len(neighbors), replace=False))

              neighbors.update(neighbors_neighbors)

              break_to_outer = True
              break

            neighbors.update(neighbors_neighbors)

        sampled_indices.update(neighbors)  # Add new neighbors

    return list(sampled_indices)

In [None]:
class GNNDataset(Dataset):
  def __init__(self, H, A, labels, meta_indices, neighbor_max=4, branch_max=16, seed=None):
    """Custom dataset with neighborhood sampling

    Args:
      H : torch.tensor
        input embeddings (n x d)

      A : torch.tensor
        adjacency matrix (n x n)

      labels : torch.LongTensor
        y

      meta_indices : torch.LongTensor
        index of datapoint to filter validation score

      neighbor_max : int
        max neighbors for each node in mini-batch

      batch_max : int
        max size of batch

    """
    # All inits must be tensors
    self.H = H.to(device)
    self.A = A.to(device)
    self.labels = labels.to(device)
    self.meta_indices = meta_indices
    self.neighbor_max = neighbor_max
    self.branch_max = branch_max
    self.seed = seed

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, inds):
    # Sample neighborhood

    # get inds in list
    inds = inds.tolist() if torch.is_tensor(inds) else (inds if isinstance(inds,list) else [inds])

    # return the required inds
    sampled_indices = sample_neighborhood(self.A, inds, self.neighbor_max, self.branch_max,seed=self.seed)

    # get the input for the required inds
    H_batch = self.H[sampled_indices]

    # get the adjacency matrix for the required inds
    A_batch = self.A[sampled_indices][:, sampled_indices]

    # get the labels for the required inds
    labels_batch = self.labels[sampled_indices]

    # get meta indices
    index_batch = self.meta_indices[sampled_indices]

    return H_batch, A_batch, labels_batch, index_batch

### Sampling Verificiation

In [None]:
H = torch.stack(list(df_train['vanilla_embedding.1']))
labels = torch.LongTensor(list(df_train['chapter_idx']))
A = train_entities
meta_indices = torch.LongTensor(list(range(df_train.shape[0])))

In [None]:
train_dataset = GNNDataset(H, A, labels, meta_indices, neighbor_max=4, branch_max=10)

# Prevent dataloader from calling a single index at a time
custom_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.RandomSampler(train_dataset),
    batch_size=6,
    drop_last=False)


train_loader = DataLoader(train_dataset, sampler = custom_sampler)


In [None]:
# Check batches

# Number of batches to inspect
num_batches_to_check = 2

for batch_idx, (inputs, adjacency, labels, indices) in enumerate(train_loader):
    print('\n##########################\n')
    print(f"Batch {batch_idx + 1}/{num_batches_to_check}:")
    print('-' * 10)
    print("Inputs:")
    print(f"  Type: {type(inputs)}")
    print(f"  Shape: {inputs.size()}")
    print('-' * 10)
    print("Adjacency:")
    print(f"  Type: {type(adjacency)}")
    print(f"  Shape: {adjacency[0].size()}")
    print('-' * 10)
    print("Indices:")
    print(f"  Type: {type(indices)}")
    print(f"  Shape: {indices.size()}")
    print(indices)
    print('-' * 10)
    print("Labels:")
    print(f"  Type: {type(labels)}")
    print(f"  Shape: {labels.size()}")
    print(labels)

    # Stop after inspecting the desired number of batches
    if batch_idx + 1 >= num_batches_to_check:
        break


##########################

Batch 1/2:
----------
Inputs:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 21, 768])
----------
Adjacency:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([21, 21])
----------
Indices:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 21])
tensor([[ 1473, 11203, 11972,  4997,   390, 11975, 11917,  1936,  1746,  9875,
         10770,  2196,  8665,  1562,  2735,  7536,  8112,  6517, 11765,  1019,
          8508]])
----------
Labels:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 21])
tensor([[14, 16, 14, 61,  6,  7,  9, 22, 67, 13,  6,  4,  0, 14, 14,  2, 14,  5,
         14, 14, 51]], device='cuda:0')

##########################

Batch 2/2:
----------
Inputs:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 22, 768])
----------
Adjacency:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([22, 22])
----------
Indices:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 22])
tensor([[ 2366,  4354,  5699,  9284, 10309,  7880,

In [None]:
H = torch.stack(list(df_train['vanilla_embedding.1']))
labels = torch.LongTensor(list(df_train['chapter_idx']))
A = train_entities
meta_indices = torch.LongTensor(list(range(df_train.shape[0])))

The dataloader seems to be working correctly, the implementation of custom sampling on all indices at once leads to DataLoaders collate function inserting a new dimension that it will stach against. Because all indices are dealt with at once, there is no stacking.

The simple solution will be to simply squeeze the tensors in the training loop. The validation loader eradicates randomness from the process.

## $\color{blue}{Model:}$

In [None]:
class GATVaswaniLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout_rate):
        super(GATVaswaniLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = nn.Dropout(dropout_rate)

        # define self projections
        self.Wq = nn.Parameter(torch.Tensor(in_features, out_features))
        self.Wk = nn.Parameter(torch.Tensor(in_features, out_features))
        self.Wv = nn.Parameter(torch.Tensor(in_features, out_features))
        nn.init.xavier_uniform_(self.Wq)
        nn.init.xavier_uniform_(self.Wk)
        nn.init.xavier_uniform_(self.Wv)

        # boost self
        self.self_bias = nn.Parameter(torch.Tensor(1))  # Bias term for self-loops
        nn.init.constant_(self.self_bias, 1.0)
        self.leakyrelu = nn.LeakyReLU()

        self.batch_norm = nn.BatchNorm1d(out_features)

    def forward(self, H, A):
        # Apply transformation of the input
        H_q = torch.matmul(H, self.Wq)  # n x d
        H_k = torch.matmul(H, self.Wk)  # n x d
        H_v = torch.matmul(H, self.Wv)  # n x d

        # Compute attention scores
        E = torch.matmul(H_q, H_k.transpose(0, 1))  # n x n

        # Scale identity matrix and add self-loops
        n = H.size()[0]
        I = self.self_bias.item() * torch.eye(n, device=H.device)

        # Apply the adjacency matrix A in sparse format
        A_sparse = A.to_sparse()
        attention = E + I

        # Apply the sparse mask from the adjacency matrix
        attention = attention * A_sparse.to_dense()

        # Create a large negative value for non-existing edges
        N = -9e15 * torch.ones_like(attention)

        # Apply the non-existing edges to negative values
        attention = torch.where(A_sparse.to_dense() > 0, attention, N)

        # Perform softmax over the rows
        attention = F.softmax(attention, dim=1)

        # Weight the projected input
        H_out = self.leakyrelu(torch.matmul(attention, H_v))
        H_out = self.dropout(H_out)
        H_out = self.batch_norm(H_out)

        return H_out

class GATModel(nn.Module):
    def __init__(self, d, h, c, num_layers=2, dropout_rate=0.3):
        super(GATModel, self).__init__()
        self.num_layers = num_layers
        self.gnn_layers = nn.ModuleList([GATVaswaniLayer(d, d, dropout_rate) for _ in range(num_layers)])
        self.fc1 = nn.Linear(d, h)
        self.batch_norm_fc1 = nn.BatchNorm1d(h)
        self.fc2 = nn.Linear(h, c)
        self.dropout = nn.Dropout(dropout_rate)
        self.relu = nn.ReLU()

    def forward(self, H, A):
        for layer in self.gnn_layers:
            H = layer(H, A)

        H = self.relu(self.batch_norm_fc1(self.dropout(self.fc1(H))))
        Output = self.fc2(H)
        return Output

    def forward_layer(self, H, A, layer_idx):
        """Forward pass for a specific layer."""
        H = self.gnn_layers[layer_idx](H, A)
        return H

In [None]:
import torch.optim as optim

d = 768
h = 400   # hidden dimension of fully connected layer
c = 70   # number of classes
num_relations = 1   # number of relationship types

# Model, Loss, Optimizer
model = GATModel(d, h, c)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)



In [None]:
def count_parameters_per_module(model):
    print("Module and parameter counts:")

    for name, module in model.named_modules():
        # Skip the top-level module (the model itself)
        if not isinstance(module, nn.Module) or name == "":
            continue

        param_count = sum(p.numel() for p in module.parameters() if p.requires_grad)

        if param_count > 0:  # Only print modules that have parameters
            print(f"{name}: {param_count} parameters")

In [None]:
count_parameters_per_module(model)

Module and parameter counts:
gnn_layers: 3542018 parameters
gnn_layers.0: 1771009 parameters
gnn_layers.0.batch_norm: 1536 parameters
gnn_layers.1: 1771009 parameters
gnn_layers.1.batch_norm: 1536 parameters
fc1: 307600 parameters
batch_norm_fc1: 800 parameters
fc2: 28070 parameters


## $\color{blue}{Train-Validate:}$

In [None]:
def accuracy(outputs, labels):
    # argmax to get predicted classes
    _, predicted = torch.max(outputs, 1)

    # count correct
    correct = (predicted == labels).sum().item()

    # get average
    acc = correct / labels.size(0)  # Total number of samples
    return acc

In [None]:
import numpy as np

def train(model, train_loader, criterion, optimizer):
    model.train()
    epoch_train_losses = []
    epoch_train_accuracy = []

    for batch_idx, (H, A, y, indices) in enumerate(train_loader):
        print('TTT',batch_idx)
        optimizer.zero_grad()

        H = H.squeeze(0)
        A = A.squeeze(0)
        y = y.squeeze(0)

        out = model(H,A)
        train_loss = criterion(out, y)
        train_accuracy = accuracy(out, y)


        epoch_train_losses.append(train_loss.item())
        epoch_train_accuracy.append(train_accuracy)

        # Backpropagation and optimization
        train_loss.backward()
        optimizer.step()

    return np.mean(epoch_train_losses), np.mean(epoch_train_accuracy)

In [None]:
def validate(model, dev_loader, criterion, threshold=12000):
    model.eval()
    epoch_dev_losses = []
    epoch_dev_accuracy = []
    pred_holder = []
    real_holder = []

    with torch.no_grad():
        for batch_idx, (H, A, y, indices) in enumerate(dev_loader):
            print('VVV', batch_idx)
            H = H.squeeze(0)
            A = A.squeeze(0)
            y = y.squeeze(0)
            indices = indices.squeeze(0)

            out = model(H, A)

            # Filter out training points
            mask = indices >= threshold
            filtered_out = out[mask]
            filtered_y = y[mask]

            # Calculate loss and accuracy only on filtered outputs
            if filtered_out.size(0) > 0:  # Ensure there are samples to evaluate
                dev_loss = criterion(filtered_out, filtered_y)
                dev_accuracy = accuracy(filtered_out, filtered_y)

                epoch_dev_losses.append(dev_loss.item())
                epoch_dev_accuracy.append(dev_accuracy)

    # Avoid division by zero if no validation points were processed
    return np.mean(epoch_dev_losses), np.mean(epoch_dev_accuracy)

## $\color{blue}{Training-Loop:}$

In [None]:
from collections import namedtuple
Stats = namedtuple('Stats', [
    'train_loss',
    'train_accuracy',
    'dev_loss',
    'dev_accuracy',
    'epoch',
    'lr',
    'alpha',
    'max_accuracy'
])

In [None]:
def gen_config(lr_low, lr_high, alpha_low, alpha_high):
  np.random.seed()
  lr = round(10**float(np.random.uniform(lr_low,lr_high)),6)
  alpha = round(10**float(np.random.uniform(alpha_low,alpha_high)),6)
  return lr, alpha

In [None]:
def gen_ranges( lr, lr_range, alpha, alpha_range):

  lr_center = lr
  lr_low = lr_center - lr_range/2
  lr_high = lr_center + lr_range/2
  lr_diff = lr_high - lr_low

  alpha_center = alpha
  alpha_low = alpha_center - alpha_range/2
  alpha_high = alpha_center + alpha_range/2
  alpha_diff = alpha_high - alpha_low

  return (lr_low, lr_high, alpha_low, alpha_high)

In [None]:
def search_stats(results):
  best_stats = None
  max_dev_accuracy = 0
  for i in range(len(results)):
    acc = results[i].dev_accuracy
    if acc > max_dev_accuracy:
      best_stats = results[i]
      max_dev_accuracy = acc
  return best_stats

In [None]:
def tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 0):
  """
  Runs a training setup
  verbose == 1 - print model results
  verbose == 2 -> print epoch and model results
  """
  print(f'lr {lr}, alpha {alpha}')
  model = model.to(device)
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=alpha)

  # Prepare data loaders
  train_loader = DataLoader(train_dataset, sampler = custom_train_sampler)
  dev_loader = DataLoader(validation_dataset, sampler = custom_validation_sampler)

  # Hold epoch stats
  train_losses = []
  train_accuracy = []
  dev_losses = []
  dev_accuracy = []
  epoch_holder = []

  # Break if no improvement
  current_best = 0
  no_improvement = 0


  # Run epochs
  for epoch in range(epochs):

    # break out of epochs
    if no_improvement >= 3:
      break

    # call training and validation functions
    train_loss, train_acc = train(model, train_loader, criterion, optimizer)
    dev_loss, dev_acc = validate(model, dev_loader, criterion)

    # Store epoch stats
    train_losses.append(train_loss)
    train_accuracy.append(train_acc)
    dev_losses.append(dev_loss)
    dev_accuracy.append(dev_acc)
    epoch_holder.append(epoch + 1)

    # check for improvement
    if dev_acc > current_best:
      current_best = dev_acc
      no_improvement = 0
    else:
      no_improvement += 1

    # save best model
    if dev_acc > max_accuracy:
      torch.save(model.state_dict(), path)
      max_accuracy = dev_acc


    # optionally print epoch results
    if verbose == 2:
      print(f'\n --------- \nEpoch: {epoch + 1}\n')
      print(f'Epoch {epoch + 1} train loss: {train_loss:.4f}')
      print(f'Epoch {epoch + 1} train accuracy: {train_acc:.4f}')
      print(f'Epoch {epoch + 1} dev loss: {dev_loss:.4f}')
      print(f'Epoch {epoch + 1} dev accuracy: {dev_acc:.4f}')

      # save best results
  max_ind = np.argmax(dev_accuracy)

  stats = Stats(
      train_losses[max_ind],
      train_accuracy[max_ind],
      dev_losses[max_ind],
      dev_accuracy[max_ind],
      epoch_holder[max_ind],
      lr, alpha,
      max_accuracy
  )

  # optionally print model results
  if verbose in [1,2]:
    print('\n ######## \n')
    print(f'lr:{stats.lr}, alpha:{stats.alpha} @ epoch {stats.epoch}.')
    print(f'TL:{stats.train_loss}, TA:{stats.train_accuracy}.')
    print(f'DL:{stats.dev_loss}, DA:{stats.dev_accuracy}')

  return stats

#### $\color{red}{Sanity-check:}$

In [None]:
# model
model = GATModel(d, h, c)

In [None]:
# training loader
H_train = torch.stack(list(df_train['vanilla_embedding.1']))
labels_train = torch.LongTensor(list(df_train['chapter_idx']))
A_train = train_entities
train_indices = torch.LongTensor(list(range(df_train.shape[0])))

train_dataset = GNNDataset(H_train, A_train, labels_train, train_indices, neighbor_max=4, branch_max=16)

# Prevent dataloader from calling a single index at a time
custom_train_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.RandomSampler(train_dataset),
    batch_size=2048,
    drop_last=False)


train_loader = DataLoader(train_dataset, sampler = custom_train_sampler)

# training loader
df1 = df_train[['vanilla_embedding.1', 'chapter_idx']]
df2 = df_dev[['vanilla_embedding.1', 'chapter_idx']]
df_val = pd.concat([df1, df2])
H_val = torch.stack(list(df_val['vanilla_embedding.1']))
labels_val = torch.LongTensor(list(df_val['chapter_idx']))
A_val = val_entities
val_indices = torch.LongTensor(list(range(df_val.shape[0])))



validation_dataset = GNNDataset(H_val, A_val, labels_val, val_indices, neighbor_max=4, branch_max=16, seed=42)

# Prevent dataloader from calling a single index at a time
custom_validation_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.SequentialSampler(validation_dataset),
    batch_size=2048,
    drop_last=False)


dev_loader = DataLoader(validation_dataset, sampler = custom_validation_sampler)



In [None]:
epochs = 10
lr = 0.0005
alpha = 0.0001
path = "class/models/GAT.pt"
max_accuracy = 0

In [None]:
tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 2)

lr 0.0005, alpha 0.0001
TTT 0
TTT 1
TTT 2
TTT 3
TTT 4
TTT 5
VVV 0
VVV 1
VVV 2
VVV 3
VVV 4
VVV 5
VVV 6

 --------- 
Epoch: 1

Epoch 1 train loss: 2.7746
Epoch 1 train accuracy: 0.3085
Epoch 1 dev loss: 4.0973
Epoch 1 dev accuracy: 0.0380
TTT 0
TTT 1
TTT 2
TTT 3
TTT 4
TTT 5
VVV 0
VVV 1
VVV 2
VVV 3
VVV 4
VVV 5
VVV 6

 --------- 
Epoch: 2

Epoch 2 train loss: 2.5677
Epoch 2 train accuracy: 0.3516
Epoch 2 dev loss: 4.1256
Epoch 2 dev accuracy: 0.0380
TTT 0
TTT 1
TTT 2
TTT 3
TTT 4
TTT 5
VVV 0
VVV 1
VVV 2
VVV 3
VVV 4
VVV 5
VVV 6

 --------- 
Epoch: 3

Epoch 3 train loss: 2.4397
Epoch 3 train accuracy: 0.3750
Epoch 3 dev loss: 4.1322
Epoch 3 dev accuracy: 0.0380
TTT 0
TTT 1
TTT 2
TTT 3
TTT 4
TTT 5
VVV 0
VVV 1
VVV 2
VVV 3
VVV 4
VVV 5
VVV 6

 --------- 
Epoch: 4

Epoch 4 train loss: 2.3502
Epoch 4 train accuracy: 0.3854
Epoch 4 dev loss: 4.1128
Epoch 4 dev accuracy: 0.0380

 ######## 

lr:0.0005, alpha:0.0001 @ epoch 1.
TL:2.7745629151662192, TA:0.3085403757639949.
DL:4.0972602026803155, DA:0.03

Stats(train_loss=2.7745629151662192, train_accuracy=0.3085403757639949, dev_loss=4.0972602026803155, dev_accuracy=0.03804008961005582, epoch=1, lr=0.0005, alpha=0.0001, max_accuracy=0.03804008961005582)

#### $\color{red}{Run:}$

In [None]:
"""
Main Admin
"""
epochs = 22
max_accuracy = 0
path = "class/models/GAT.pt"
results = []

"""
init random search
lr [10^-5 - 10^-1]
alpha [10^-5 - 10^-1]
bs [8, 32, 128]
"""
lr_low = -2
lr_high = -1
lr_range = lr_high - lr_low

alpha_low = -5
alpha_high = -3
alpha_range = alpha_high - alpha_low

d = 768
h = 400
c = 70
num_relations = 1

count = 0

"""
Hyperparameter Search
"""

for i in range(3):
  # debug
  print("\n################\n")
  print(f'round: {i}')
  # print(f'lr_low{lr_low}, lr_high{lr_high}, lr_range{lr_range}')
  # print(f'alpha_low{alpha_low}, lr_high{alpha_high}, lr_range{alpha_range}')
  print('max', max_accuracy)
  print("\n################\n")


  for j in range(3):
    count += 1
    print(count)

    # get config
    lr, alpha = gen_config(lr_low, lr_high, alpha_low, alpha_high)
    # define model
    model = GATModel(d, h, c)
    model = model.to(device)

    # run training
    res = tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 2)
    max_accuracy = res.max_accuracy
    results.append(res)

  # get best result of the round or even so far
  stats = search_stats(results)


  print(stats) # debug

  # reconfigure the new hypers
  lr = np.log10(stats.lr)
  lr_range = lr_range / 3

  alpha = np.log10(stats.alpha)
  alpha_range = alpha_range / 3

  config = gen_ranges(lr, lr_range, alpha, alpha_range)
  lr_low, lr_high, alpha_low, alpha_high = config
  lr_range = lr_high - lr_low
  alpha_range = alpha_high - alpha_low



################

round: 0
max 0

################

1
lr 0.013422, alpha 0.000411

 --------- 
Epoch: 1

Epoch 1 train loss: 3.6948
Epoch 1 train accuracy: 0.1359
Epoch 1 dev loss: 3.4371
Epoch 1 dev accuracy: 0.1587

 --------- 
Epoch: 2

Epoch 2 train loss: 3.5486
Epoch 2 train accuracy: 0.1550
Epoch 2 dev loss: 3.4074
Epoch 2 dev accuracy: 0.1777

 --------- 
Epoch: 3

Epoch 3 train loss: 3.5000
Epoch 3 train accuracy: 0.1602
Epoch 3 dev loss: 3.6082
Epoch 3 dev accuracy: 0.1561

 --------- 
Epoch: 4

Epoch 4 train loss: 3.5645
Epoch 4 train accuracy: 0.1534
Epoch 4 dev loss: 3.5864
Epoch 4 dev accuracy: 0.1516

 --------- 
Epoch: 5

Epoch 5 train loss: 3.4625
Epoch 5 train accuracy: 0.1655
Epoch 5 dev loss: 3.6374
Epoch 5 dev accuracy: 0.1399

 ######## 

lr:0.013422, alpha:0.000411 @ epoch 2.
TL:3.548618005434672, TA:0.15499645821744848.
DL:3.4073648167777733, DA:0.1777405110706763
2
lr 0.015027, alpha 3.2e-05

 --------- 
Epoch: 1

Epoch 1 train loss: 3.6556
Epoch 1 train accura

KeyboardInterrupt: 