In [None]:
import os
os.chdir('/content/drive/Shareddrives/')
os.getcwd()

In [None]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [None]:
class CustomDataset(Dataset):
    def __init__(self, filepath):
        self.dataframe = pd.read_csv(filepath, header = None)

    def __len__(self):
        return self.dataframe.shape[0]

    def __getitem__(self, idx):
        unsolved_sudoku = self.dataframe.iloc[idx,0]
        unsolved_sudoku = [int(cell) for cell in unsolved_sudoku]

        target = self.dataframe.iloc[idx, 1]
        target = [int(cell) for cell in target]
        
        return torch.tensor(unsolved_sudoku, dtype = torch.long), torch.tensor(target, dtype = torch.long)

train_dataset = CustomDataset('./sudoku-hard/train.csv')
test_dataset = CustomDataset('./sudoku-hard/test.csv')

In [None]:
def calc_row(row):
  return [9*row + i for i in range(9)]

def calc_col(col):
  return [col + 9*i for i in range(9)]

def calc_grid(gridx, gridy):
  lis = np.array([[0,1,2], [9,10,11], [18,19,20]])
  lis = np.array([27*gridx + x for x in lis])
  lis = np.array([3*gridy + x for x in lis])
  return list(lis.reshape(-1))

def sudoku_edges():
  src_ids = []
  dest_ids = []

  for i in range(81):
    src_id = [i]*20
    row, col = int(np.floor(i/9)), int(np.floor((i%9)))
    gridx, gridy = int(np.floor(row/3)), int(np.floor(col/3))

    rows = calc_row(row)
    cols = calc_col(col)
    grids = calc_grid(gridx, gridy)

    dest_id = list(set(rows + cols + grids).difference(set([i])))
    dest_id.sort()
    src_ids += src_id
    dest_ids += dest_id

  return torch.tensor(src_ids, dtype = torch.long), torch.tensor(dest_ids, dtype = torch.long)

In [None]:
def collate_fn(batch):
  digits = []
  rows = []
  cols = [] 
  src_edges = []
  dest_edges = []
  labels = []

  row = [0]*9 + [1]*9 + [2]*9 + [3]*9 + [4]*9 + [5]*9 + [6]*9 + [7]*9 + [8]*9
  col = [0,1,2,3,4,5,6,7,8]*9
  src_ids, dest_ids = sudoku_edges()

  for i, (x, y) in enumerate(batch):
    digits.append(x)

    rows.append(torch.tensor(row, dtype = torch.long))
    cols.append(torch.tensor(col, dtype = torch.long))
    
    src_edges.append(src_ids + 81*i)
    dest_edges.append(dest_ids + 81*i)

    labels.append(y)

  return torch.cat(digits).to(device), torch.cat(rows).to(device), torch.cat(cols).to(device), torch.cat(labels).to(device), torch.cat(src_edges).to(device), torch.cat(dest_edges).to(device)

batch_size = 16
train_dataloader = DataLoader(train_dataset, batch_size = batch_size, collate_fn = collate_fn, shuffle = True)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size, collate_fn = collate_fn, shuffle = True)

In [None]:
class RRN(nn.Module):
  def __init__(self):
    super().__init__()

    self.embedding_dim = 16
    self.message_size = 96
    self.node_hidden_state_size = 96
    self.n_iters = 32

    self.digit_embedding = nn.Embedding(10, self.embedding_dim)
    self.row_embedding = nn.Embedding(9, self.embedding_dim)
    self.col_embedding = nn.Embedding(9, self.embedding_dim)

    self.mlp1 = nn.Sequential(
        nn.Linear(3*self.embedding_dim, 96),
        nn.ReLU(),
        nn.Linear(96, 96),
        nn.ReLU(),
        nn.Linear(96, 96),
        nn.ReLU(),
        nn.Linear(96, self.node_hidden_state_size)
    )

    self.message_network = nn.Sequential(
        nn.Linear(2*self.node_hidden_state_size, 96),
        nn.ReLU(),
        nn.Linear(96, 96),
        nn.ReLU(),
        nn.Linear(96, 96),
        nn.ReLU(),
        nn.Linear(96, self.message_size)
    )

    self.gru = nn.GRU(2*self.message_size, self.node_hidden_state_size)

    self.linear = nn.Linear(self.node_hidden_state_size, 10)

  def forward(self, digits, rows, cols, src_edges, dest_edges):
    output_from_each_iter = []

    embedded_digits = self.digit_embedding(digits)
    embedded_rows = self.row_embedding(rows)
    embedded_cols = self.col_embedding(cols)
    combined_embedding = torch.cat((embedded_digits, embedded_rows, embedded_cols), dim = -1)

    x = self.mlp1(combined_embedding)

    initial_hidden_states = x.unsqueeze(0)

    for i in range(self.n_iters):
      src_encodings = torch.index_select(initial_hidden_states[0], dim = 0, index = src_edges)
      dest_encodings = torch.index_select(initial_hidden_states[0], dim = 0, index = dest_edges)
      
      input_to_message_network = torch.cat([src_encodings, dest_encodings], dim = -1)
      output_from_message_network = self.message_network(input_to_message_network)
      aggregated_messages = torch.zeros(x.shape[0], self.message_size).to(device)
      aggregated_messages.index_add_(0, dest_edges, output_from_message_network)

      input_to_gru = torch.cat((x, aggregated_messages), dim = -1).unsqueeze(0)
      output_from_gru, initial_hidden_states = self.gru(input_to_gru, initial_hidden_states)

      output_from_linear = self.linear(output_from_gru.squeeze(0))
      
      output_from_each_iter.append(output_from_linear)
    
    return output_from_each_iter

In [None]:
def evaluate(model):
  total_solved = 0
  total_correct = 0

  solved_givens = torch.zeros(18, device = device)
  correct_givens = torch.zeros(18, device = device)

  for (digits, rows, cols, target, src_edges, dest_edges) in test_dataloader:
    batch_size = digits.shape[0] // 81

    reshaped_digits = digits.reshape(-1, 81)
    is_not_zeros = (reshaped_digits != 0).sum(dim = -1)

    output_from_each_iter = model(digits, rows, cols, src_edges, dest_edges)
    output_from_last_iter = output_from_each_iter[-1]
    predictions = output_from_last_iter.argmax(dim = -1)

    predictions = predictions.view(batch_size, -1)
    target = target.view(batch_size, -1)

    is_correct = (predictions == target).all(dim = -1)

    for i in range(batch_size):
      solved_givens[is_not_zeros[i] - 17] += 1
      if(is_correct[i] == True):
        correct_givens[is_not_zeros[i] - 17] += 1

    total_solved += batch_size
    total_correct += is_correct.sum().item()
  
  return total_correct*100 / total_solved, solved_givens, correct_givens

def train(model, criterion, optimizer, epochs = 35, resuming = -1, print_every = len(train_dataloader) // 10):
  for epoch in range(resuming+1, epochs):
    loss_per_epoch = 0
    loss_per_print_every = 0
    count = 0
    for itr, (digits, rows, cols, target, src_edges, dest_edges) in enumerate(train_dataloader):
      output_from_each_iter = model(digits, rows, cols, src_edges, dest_edges)
      
      loss = 0
      for output_on_each_iter in output_from_each_iter:
        loss += criterion(output_on_each_iter, target)
      
      loss = loss/model.n_iters
      loss_per_epoch += loss
      loss_per_print_every += loss
      count += 1

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      if (itr+1)%print_every == 0:
        print(f"Epoch: {epoch} \t itr: {itr}/{len(train_dataloader)} \t loss: {loss_per_print_every / count}")
        loss_per_print_every = 0
        count = 0
    
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch
            }, './bigger-model/big' + str(epoch) + '.pth')
    
    print()
    print(f"Epoch: {epoch} \t loss: {loss_per_epoch / len(train_dataloader)}")
    model.eval()
    acc, solved_givens, correct_givens = evaluate(model)
    print("Testing Summary")
    print("Accuracy:", acc)
    print("Correct Givens: ", correct_givens)
    print("Solved Givens: ", solved_givens)
    model.train()
    print()

In [None]:
rrn = RRN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(params = rrn.parameters(), lr = 2e-4, weight_decay = 1e-4)

checkpoint = torch.load('./bigger-model/big10.pth')
rrn.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

In [None]:
train(rrn, criterion, optimizer, resuming = epoch)

Epoch: 7 	 itr: 1124/11250 	 loss: 0.26574835181236267
Epoch: 7 	 itr: 2249/11250 	 loss: 0.26590436697006226
Epoch: 7 	 itr: 3374/11250 	 loss: 0.2571168541908264
Epoch: 7 	 itr: 4499/11250 	 loss: 0.27118608355522156
Epoch: 7 	 itr: 5624/11250 	 loss: 0.26990416646003723
Epoch: 7 	 itr: 6749/11250 	 loss: 0.26391470432281494
Epoch: 7 	 itr: 7874/11250 	 loss: 0.2603660225868225
Epoch: 7 	 itr: 8999/11250 	 loss: 0.2732493281364441
Epoch: 7 	 itr: 10124/11250 	 loss: 0.26391980051994324
Epoch: 7 	 itr: 11249/11250 	 loss: 0.26939448714256287

Epoch: 7 	 loss: 0.26607006788253784
Testing Summary
Accuracy: 80.45555555555555
Correct Givens:  tensor([ 179.,  310.,  430.,  549.,  676.,  785.,  821.,  903.,  930.,  964.,
         971.,  985.,  991.,  992.,  998.,  999., 1000.,  999.],
       device='cuda:0')
Solved Givens:  tensor([1000., 1000., 1000., 1000., 1000., 1000., 1000., 1000., 1000., 1000.,
        1000., 1000., 1000., 1000., 1000., 1000., 1000., 1000.],
       device='cuda:0')

E

In [None]:
train(rrn, criterion, optimizer, resuming = epoch)

Epoch: 11 	 itr: 1124/11250 	 loss: 0.235174298286438
Epoch: 11 	 itr: 2249/11250 	 loss: 0.2259645015001297
Epoch: 11 	 itr: 3374/11250 	 loss: 0.2250165045261383
Epoch: 11 	 itr: 4499/11250 	 loss: 0.24912774562835693
Epoch: 11 	 itr: 5624/11250 	 loss: 0.22169968485832214
Epoch: 11 	 itr: 6749/11250 	 loss: 0.25477150082588196
Epoch: 11 	 itr: 7874/11250 	 loss: 0.22199347615242004
Epoch: 11 	 itr: 8999/11250 	 loss: 0.22347407042980194
Epoch: 11 	 itr: 10124/11250 	 loss: 0.2183961570262909
Epoch: 11 	 itr: 11249/11250 	 loss: 0.2401539534330368

Epoch: 11 	 loss: 0.23157723248004913
Testing Summary
Accuracy: 88.08333333333333
Correct Givens:  tensor([ 444.,  547.,  644.,  775.,  823.,  893.,  900.,  943.,  953.,  970.,
         988.,  987.,  992.,  996., 1000., 1000., 1000., 1000.],
       device='cuda:0')
Solved Givens:  tensor([1000., 1000., 1000., 1000., 1000., 1000., 1000., 1000., 1000., 1000.,
        1000., 1000., 1000., 1000., 1000., 1000., 1000., 1000.],
       device='cud