In [None]:
import os
os.chdir('/content/drive/MyDrive/')
os.getcwd()

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [None]:
X = torch.load('./sudoku/features.pt')
Y = torch.load('./sudoku/labels.pt')

In [None]:
def generate_train_test(X, Y):
  X = torch.flatten(X, start_dim = -3, end_dim = -2)
  Y = torch.argmax(Y, dim = -1).flatten(start_dim = -2)

  X_train = X[:9000]
  Y_train = Y[:9000]

  X_test = X[9000:]
  Y_test = Y[9000:]

  training_list = [(X_train[i],Y_train[i]) for i in range(X_train.shape[0])]
  testing_list = [(X_test[i],Y_test[i]) for i in range(X_test.shape[0])]

  return training_list, testing_list

training_set, testing_set = generate_train_test(X, Y)

In [None]:
def calc_row(row):
  return [9*row + i for i in range(9)]

def calc_col(col):
  return [col + 9*i for i in range(9)]

def calc_grid(gridx, gridy):
  lis = np.array([[0,1,2], [9,10,11], [18,19,20]])
  lis = np.array([27*gridx + x for x in lis])
  lis = np.array([3*gridy + x for x in lis])
  return list(lis.reshape(-1))

def sudoku_edges():
  src_ids = []
  dest_ids = []

  for i in range(81):
    src_id = [i]*20
    row, col = int(np.floor(i/9)), int(np.floor((i%9)))
    gridx, gridy = int(np.floor(row/3)), int(np.floor(col/3))

    rows = calc_row(row)
    cols = calc_col(col)
    grids = calc_grid(gridx, gridy)

    dest_id = list(set(rows + cols + grids).difference(set([i])))
    dest_id.sort()
    src_ids += src_id
    dest_ids += dest_id

  return torch.tensor(src_ids, dtype = torch.long), torch.tensor(dest_ids, dtype = torch.long)

In [None]:
def collate_fn(batch):
  features = []
  labels = []
  src_edges = []
  dest_edges = []

  src_ids, dest_ids = sudoku_edges()

  for i, (x, y) in enumerate(batch):
    features.append(x)
    labels.append(y)
    src_edges.append(src_ids + 81*i)
    dest_edges.append(dest_ids + 81*i)
  
  return torch.cat(features, dim = 0).to(device), torch.cat(labels).to(device), torch.cat(src_edges).to(device), torch.cat(dest_edges).to(device)

batch_size = 16
train_dataloader = DataLoader(training_set, batch_size = batch_size, collate_fn = collate_fn)
test_dataloader = DataLoader(testing_set, batch_size = batch_size, collate_fn = collate_fn)

In [None]:
class RRN(nn.Module):
  def __init__(self, node_hidden_state_size, message_size, n_iters):
    super(RRN, self).__init__()
    self.node_hidden_state_size = node_hidden_state_size
    self.message_size = message_size
    self.n_iters = n_iters

    self.message_network = nn.Sequential(
        nn.Linear(2*self.node_hidden_state_size, 96),
        nn.ReLU(),
        nn.Linear(96, 96),
        nn.ReLU(),
        nn.Linear(96, 96),
        nn.ReLU(),
        nn.Linear(96, self.message_size)
    )

    self.gru = nn.GRU(9+self.message_size, self.node_hidden_state_size)

    self.linear = nn.Linear(self.node_hidden_state_size, 9)

  def forward(self, features, src_edges, dest_edges):
    output_from_each_iter = []
    initial_hidden_states = torch.zeros(1, features.shape[0], self.node_hidden_state_size).to(device)

    for i in range(self.n_iters):
      src_encodings = torch.index_select(initial_hidden_states[0], dim = 0, index = src_edges)
      dest_encodings = torch.index_select(initial_hidden_states[0], dim = 0, index = dest_edges)
      
      input_to_message_network = torch.cat([src_encodings, dest_encodings], dim = -1)
      output_from_message_network = self.message_network(input_to_message_network)
      aggregated_messages = torch.zeros(features.shape[0], self.message_size).to(device)
      aggregated_messages.index_add_(0, dest_edges, output_from_message_network)

      input_to_gru = torch.cat((features, aggregated_messages), dim = -1).unsqueeze(0)
      output_from_gru, initial_hidden_states = self.gru(input_to_gru, initial_hidden_states)

      output_from_linear = self.linear(output_from_gru.squeeze(0))
      
      output_from_each_iter.append(output_from_linear)
    
    return output_from_each_iter

In [None]:
def evaluate(model):
  total_solved = 0
  total_correct = 0

  for (features, target, src_edges, dest_edges) in test_dataloader:
    batch_size = features.shape[0] // 81
    output_from_each_iter = model(features, src_edges, dest_edges)
    output_from_last_iter = output_from_each_iter[-1]
    predictions = output_from_last_iter.argmax(dim = -1)

    predictions = predictions.view(batch_size, -1)
    target = target.view(batch_size, -1)

    is_correct = (predictions == target).all(dim = -1)
    total_solved += batch_size
    total_correct += is_correct.sum().item()
  
  return total_correct / total_solved

def train(model, epochs, print_every):
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(params = model.parameters())

  for epoch in range(epochs):
    loss_per_epoch = 0
    loss_per_print_every = 0
    count = 0
    for itr, (features, target, src_edges, dest_edges) in enumerate(train_dataloader):
      output_from_each_iter = model(features, src_edges, dest_edges)
      
      loss = 0
      for output_on_each_iter in output_from_each_iter:
        loss += criterion(output_on_each_iter, target)
      
      loss = loss/model.n_iters
      loss_per_epoch += loss
      loss_per_print_every += loss
      count += 1

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      if (itr+1)%print_every == 0:
        print(f"Epoch: {epoch} \t itr: {itr}/{len(train_dataloader)} \t loss: {loss_per_print_every / count}")
        loss_per_print_every = 0
        count = 0
    
    print()
    print(f"Epoch: {epoch} \t loss: {loss_per_epoch / len(train_dataloader)}")
    model.eval()
    print("Test set: ", evaluate(model))
    model.train()
    print()

rrn = RRN(10, 11, 7).to(device)
train(rrn, 30, len(train_dataloader) // 10)

Epoch: 0 	 itr: 55/563 	 loss: 2.02996826171875
Epoch: 0 	 itr: 111/563 	 loss: 1.6108664274215698
Epoch: 0 	 itr: 167/563 	 loss: 1.4488059282302856
Epoch: 0 	 itr: 223/563 	 loss: 1.3424004316329956
Epoch: 0 	 itr: 279/563 	 loss: 1.2526931762695312
Epoch: 0 	 itr: 335/563 	 loss: 1.1896731853485107
Epoch: 0 	 itr: 391/563 	 loss: 1.1018798351287842
Epoch: 0 	 itr: 447/563 	 loss: 0.9866963028907776
Epoch: 0 	 itr: 503/563 	 loss: 0.8844743371009827
Epoch: 0 	 itr: 559/563 	 loss: 0.8334901928901672

Epoch: 0 	 loss: 1.2654423713684082
Test set:  0.0

Epoch: 1 	 itr: 55/563 	 loss: 0.7123574614524841
Epoch: 1 	 itr: 111/563 	 loss: 0.6369412541389465
Epoch: 1 	 itr: 167/563 	 loss: 0.590373694896698
Epoch: 1 	 itr: 223/563 	 loss: 0.551932692527771
Epoch: 1 	 itr: 279/563 	 loss: 0.5230867266654968
Epoch: 1 	 itr: 335/563 	 loss: 0.48592665791511536
Epoch: 1 	 itr: 391/563 	 loss: 0.4646262228488922
Epoch: 1 	 itr: 447/563 	 loss: 0.4578646123409271
Epoch: 1 	 itr: 503/563 	 loss: 0.