In [None]:
proj_path = '/home/ajhnam/projects/hidden_singles_public/'

In [None]:
import sys
sys.path.append(proj_path + 'python/')

import random
import numpy as np
import itertools
import pandas as pd
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader as DataLoader
from tqdm.auto import tqdm


from hiddensingles.misc import torch_utils as tu
from hiddensingles.misc import utils, TensorDict, TensorDictDataset, nnModule, MLP
from hiddensingles.experiment.sudoku_hs_service import create_tutorial, create_phase1, create_phase2

In [None]:
device = 3

In [None]:
def get_neighbors(dim_x, dim_y, device='cpu'):
    """
    Returns a boolean tensor of shape [(dim_x * dim_y)^2, (dim_x * dim_y)^2]
    indicating whether the ith cell is a neighbor of the jth cell
    """
    max_digit = dim_x * dim_y
    num_cells = max_digit**2
    coords = utils.get_combinations(range(max_digit), range(max_digit))
    neighbors = np.zeros((num_cells, num_cells), dtype=bool)
    for i, j in itertools.product(range(len(coords)), range(len(coords))):
        x1, y1 = coords[i]
        x2, y2 = coords[j]
        if x1 == x2 and y1 == y2:
            continue
        if x1 == x2 or y1 == y2:
            neighbors[i, j] = True
        elif x1 // dim_x == x2 // dim_x and y1 // dim_y == y2 // dim_y:
            neighbors[i, j] = True

    return torch.tensor(neighbors, device=device)

In [None]:
class RRN(nnModule):
    
    def __init__(self,
                 dim_x=3,
                 dim_y=3,
                 digit_embed_size=10,
                 hidden_vector_size=64,
                 message_size=64):
        super().__init__()
        self.dim_x = dim_x
        self.dim_y = dim_y
        self.digit_embed_size = digit_embed_size
        self.hidden_vector_size = hidden_vector_size
        self.message_size = message_size
        
        self.num_embed = nn.Embedding(1 + self.max_digit, digit_embed_size)
        self.message_linear = nn.Linear(hidden_vector_size+hidden_vector_size, message_size)
        self.update_node_linear = nn.Linear(digit_embed_size+message_size, hidden_vector_size)
        self.update_node_lstm = nn.LSTM(hidden_vector_size, hidden_vector_size)
        self.output_linear = nn.Linear(hidden_vector_size, self.max_digit)
        
        # reusable tensors that can be pre-computed
        neighbors = get_neighbors(dim_x, dim_y)
        self.neighbors = torch.where(neighbors)[1].view((dim_x*dim_y)**2, -1)
        
        self.grid_x, self.grid_y = torch.meshgrid(torch.arange(self.max_digit),
                                                  torch.arange(self.max_digit))
        
    @property
    def max_digit(self):
        return self.dim_x * self.dim_y
    
    @property
    def num_cells(self):
        return self.max_digit**2
    
    @property
    def num_neighbors(self):
        return self.neighbors.shape[-1]
    
    def get_neighbor_embeds(self, state):
        """
        state: tensor of shape [batch_size, num_cells, hidden_vector_size]
        return: tensor of shape [batch_size, num_cells, num_neighbors, hidden_vector_size]
        """
        batch_size = state.shape[0]
        self.neighbors = self.neighbors.to(state.device)
        
        neighbors = tu.prepend_shape(self.neighbors, batch_size)
        state = tu.expand_along_dim(state, 1, self.num_cells)
        return tu.select_subtensors_at(state, neighbors)
    
    def get_message_vectors(self, state):
        """
        state: tensor of shape [batch_size, num_cells, hidden_vector_size]
        return: tensor of shape [batch_size, num_cells, message_size]
        """
        neighbors = self.get_neighbor_embeds(state)
        state = tu.expand_along_dim(state, 2, self.num_neighbors)
        messages = self.message_linear(torch.cat([state, neighbors], dim=-1))
        return messages.sum(dim=2)
    
    def get_input_embedding(self, grids):
        input_embed = self.num_embed(grids).view(-1, self.num_cells, self.digit_embed_size)
        return input_embed
        
    def forward(self, grids, num_steps=16):
        batch_size = grids.shape[0]
        device = grids.device
        batch_size = len(grids)
        outputs = []
        
        input_embed = self.get_input_embedding(grids)
        
        lstm_ch = None
        messages = torch.zeros(batch_size, self.num_cells, self.message_size, device=device)
        for i in range(num_steps):
            lstm_inputs = self.update_node_linear(torch.cat([input_embed, messages], dim=-1)) # [batch_size, num_cells, hidden_vector_size]
            lstm_inputs = lstm_inputs.view(1, batch_size*self.num_cells, self.hidden_vector_size)
            state, lstm_ch = self.update_node_lstm(lstm_inputs, lstm_ch)
            state = state.view(batch_size, self.num_cells, self.hidden_vector_size)
            
            if i < num_steps - 1: # if not the last step
                messages = self.get_message_vectors(state)
                
            output = self.output_linear(state)
            outputs.append(output)
        
        outputs = torch.stack(outputs, dim=1)
        outputs = outputs.view(batch_size, num_steps, self.max_digit, self.max_digit, self.max_digit)
        return outputs

In [None]:
def get_results(model, dataset, num_steps=8):
    outputs = model(dataset.inputs, num_steps=num_steps)
    
    goals = tu.expand_along_dim(dataset.goals, 1, num_steps)
    goal_outputs = tu.select(outputs, goals, select_dims=1)
    
    targets = tu.expand_along_dim(dataset.targets, 1, num_steps)
    goal_loss = tu.cross_entropy(goal_outputs, targets)
    goal_probs = tu.select(goal_outputs.softmax(-1), dataset.targets)
    goal_td = TensorDict(loss=goal_loss,
                         probs=goal_probs,
                         outputs=goal_outputs)
    
    coords = tu.expand_along_dim(dataset.coords, 1, num_steps)
    out_exp = tu.expand_along_dim(outputs, 2, 9)
    coord_outputs = tu.select(out_exp, coords, select_dims=1)
    coord_targets = tu.expand_along_dim(dataset.coord_targets, 1, num_steps)
    coord_loss = tu.cross_entropy(coord_outputs, coord_targets)
    coord_probs = tu.select_subtensors(coord_outputs.softmax(-1), coord_targets)
    coord_td = TensorDict(loss=coord_loss,
                          probs=coord_probs,
                          outputs=coord_outputs)

    loss = goal_loss + coord_loss
    return TensorDict(loss=loss,
                      outputs=outputs,
                      goal=goal_td,
                      coord=coord_td)

In [None]:
def get_phase2_conditions(phase2):
    ht = [p.condition.house_type for p in phase2]
    hi = [p.condition.house_index for p in phase2]
    ci = [p.condition.cell_index for p in phase2]
    ds = [p.condition.digit_set for p in phase2]
    conditions = pd.DataFrame(np.array([ht, hi, ci, ds]).T,
                              columns=['house_type', 'house_index', 'cell_index', 'digit_set'])
    return conditions

def hidden_singles_to_tensordict(list_of_hidden_singles):
    grids = torch.tensor([a.grid.array for a in list_of_hidden_singles], device=device)
    goals = [p.coordinates['goal'] for p in list_of_hidden_singles]
    goals = torch.tensor([[g.x, g.y] for g in goals], device=device)
    targets = torch.tensor([a.digits['target'] for a in list_of_hidden_singles], device=device) - 1 # make it 0-8
    coords = grids.nonzero()[:,1:].view(len(list_of_hidden_singles), -1, 2)
    coord_targets = tu.select(tu.expand_along_dim(grids, 1, 9), coords) - 1 # make it 0-8
    
    return TensorDict(inputs=grids,
                      goals=goals,
                      targets=targets,
                      coords=coords,
                      coord_targets=coord_targets)

def create_dataset(num_train, num_valid):
    digit_set1 = set(random.sample(set(range(1, 10)), 4))
    digit_set2 = set(random.sample(set(range(1, 10)) - digit_set1, 4))
    tutorial = create_tutorial(digit_set1)
    phase1 = create_phase1(tutorial, num_train + num_valid)
    phase2 = create_phase2(tutorial, digit_set1, digit_set2)
    conditions = get_phase2_conditions(phase2)

    phase1 = hidden_singles_to_tensordict(phase1)
    phase2 = hidden_singles_to_tensordict([p.hidden_single for p in phase2])
    
    dataset = TensorDict(train=phase1[:num_train],
                         valid=phase1[num_train:],
                         test=phase2)
    return dataset, conditions

def train_model(model, dataset, num_epochs=100, record_epoch=1, verbose=False):
    num_steps = 8
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    dataloader = DataLoader(TensorDictDataset(dataset.train), batch_size=100, shuffle=True)
    
    iterator = range(num_epochs + 1)
    if verbose:
        iterator = tqdm(iterator)

    all_results = []
    
    for i in iterator:
        if i%record_epoch == 0:
            with torch.no_grad():
                v_results = get_results(model, dataset.valid)
            predictions = v_results.goal.outputs[:,-1].argmax(-1)
            accuracy = (predictions == dataset.valid.targets).float().mean().item()
            probability = v_results.goal.probs[:,-1].mean().item()
            
        goal_loss = []
        for dset in dataloader:
            dset = TensorDict(**dset)
            optimizer.zero_grad()
            results = get_results(model, dset, num_steps=num_steps)
            goal_loss.append(results.goal.loss.item())
            results.loss.backward()
            optimizer.step()
            
        if i%record_epoch == 0:
            row = {'epoch': i,
                   'loss': np.mean(goal_loss),
                   'probability': probability,
                   'accuracy': accuracy}
            all_results.append(row)
            
            if verbose:
                utils.kv_print(**row)
                
    results = pd.DataFrame(all_results)
    return results

def get_test_results(model, dataset, conditions):
    with torch.no_grad():
        results = get_results(model, dataset.test)
        
    goal_probs = results.goal.probs[:,-1].cpu().numpy()
    predictions = results.goal.outputs[:,-1].argmax(-1)
    correct = (predictions == dataset.test.targets).cpu().numpy()
    
    test_results = conditions.copy()
    test_results['probability'] = goal_probs
    test_results['correct'] = correct
    return test_results

In [None]:
# Train models for different training set size (i.e. number of puzzles in Practice Phase)

all_train_results = []
all_test_results = []
for num_train in tqdm([25, 50, 100, 200, 300, 400, 500]):
    model = RRN(digit_embed_size=10,
                hidden_vector_size=48,
                message_size=48).to(device)
    dataset, conditions = create_dataset(num_train=num_train, num_valid=100)
    train_results = train_model(model, dataset, num_epochs=1000, record_epoch=10, verbose=False)
    test_results = get_test_results(model, dataset, conditions)
    train_results['train_size'] = num_train
    test_results['train_size'] = num_train
    
    all_train_results.append(train_results)
    all_test_results.append(test_results)
    
train_results = pd.concat(all_train_results)
test_results = pd.concat(all_test_results)

train_results = train_results[['train_size', 'epoch', 'loss', 'accuracy', 'probability']]
test_results = test_results[['train_size', 'house_type', 'house_index', 'cell_index',
                             'digit_set', 'correct', 'probability']]

In [None]:
# Save results

dirpath = proj_path + 'data/rrn/'
utils.mkdir(dirpath)

train_results.to_csv(dirpath + "train_results.tsv", sep='\t', index=False)
test_results.to_csv(dirpath + "test_results.tsv", sep='\t', index=False)