In [2]:
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch
import torchvision.transforms as transforms
import torch.optim as optim
from sklearn.metrics import r2_score, precision_score, recall_score, f1_score, confusion_matrix, accuracy_score
import numpy as np
import time
import sys
import matplotlib.image as mpimg
import os
import copy
import torch.utils.data as torch_data
import matplotlib.pyplot as plt

In [None]:
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)
torch.backends.cudnn.benchmark = True

In [3]:
class RRN(nn.Module):
    
    def __init__(self):
        super(RRN, self).__init__()
        
        self.input_encoder = nn.Sequential(
                                nn.Linear(25, 96), nn.ReLU(),
                                nn.Linear(96, 96), nn.ReLU(),
                                nn.Linear(96, 96), nn.ReLU(),
                                nn.Linear(96, 16)
                            )
        
        self.msg_encoder = nn.Sequential(
                                nn.Linear(32, 96), nn.ReLU(),
                                nn.Linear(96, 96), nn.ReLU(),
                                nn.Linear(96, 96), nn.ReLU(),
                                nn.Linear(96, 16)
                            )
        
        self.msg_combiner = nn.Sequential(
                                nn.Linear(32, 96), nn.ReLU(),
                                nn.Linear(96, 96), nn.ReLU(),
                                nn.Linear(96, 96), nn.ReLU(),
                                nn.Linear(96, 16)
                            )
        
        self.lstm_cell = nn.LSTMCell(16, 16)
        self.decoder = nn.Linear(16, 8)
        
        self.adj_mask = self.generate_mask()
    
    def generate_mask(self):
        mask = torch.zeros(64, 64)
        for i in range(8):
            for j in range(8):
                start = 8 * i + j
                for x in range(8):
                    end = 8 * i + x
                    mask[start][end] = 1
                    end = 8 * x + j
                    mask[start][end] = 1
                
                block_start_x = i // 2 * 2
                block_start_y = j // 4 * 4
                
                for x in range(2):
                    for y in range(4):
                        X, Y = block_start_x + x, block_start_y + y
                        end = 8 * X + Y
                        mask[start][end] = 1
        return mask > 0
        
    
    def forward(self, h_prev, s_prev, x, m):
        '''
            h_prev (B, 8, 8, 16)
            s_prev (B, 8, 8, 16)
            x (B, 8, 8, 16)
            m (B, 8, 8, 16)
        '''
        B = h_prev.shape[0]
        xm = self.msg_combiner(torch.cat((x, m), dim=3))
        h_next, s_next = self.lstm_cell(xm.flatten(0, 2), (h_prev.flatten(0, 2), s_prev.flatten(0, 2)))
        
        o = self.decoder(h_next.reshape(-1, h_next.shape[-1])) # (B * 64 * 64, 16)
        o = o.reshape(B, 8, 8, 8) # (B, 64, 64, 16)
        h_next, s_next = h_next.reshape(B, 8, 8, 16), s_next.reshape(B, 8, 8, 16)
        
        return o, h_next, s_next
    
    def message_passing(self, h):
        B = h.shape[0]
        # h (B, 8, 8, 16)
        h_ = h.flatten(1, 2) 
        #h_ (B, 4096, 16)
        
        M_all = torch.cat((h_[:,:,None,:].repeat(1, 1, 64, 1), h_[:,None,:,:].repeat(1, 64, 1, 1)), dim=3)
        M_all = M_all.flatten(1, 2)
        # M_all (B, 4096, 32)
        all_pairs = M_all[:,self.adj_mask.flatten(),:]
        # (B, #constraints, 32)
        
        msg_pairs = self.msg_encoder(all_pairs.flatten(0, 1)).reshape(B, all_pairs.shape[1], 16)
        # (B, #constraints, 16)
        
        all_msgs = torch.zeros(B, 4096, 16).cuda()
        all_msgs[:,self.adj_mask.flatten(),:] = msg_pairs
        
        all_msgs = all_msgs.reshape(B, 64, 64, 16)
        
        return all_msgs.sum(dim=2).reshape(B, 8, 8, 16)
    
    def encode_input(self, sudoku):
        # sudoku (B, 8, 8)
        B = sudoku.shape[0]
        col = torch.arange(0, 8, 1)
        col = col[None,:].repeat(8, 1)[None,:,:].repeat(B, 1, 1)
        row = col.transpose(1, 2)
        
        row_one_hot = F.one_hot(row, num_classes=8).cuda()
        col_one_hot = F.one_hot(col, num_classes=8).cuda()
        val_one_hot = F.one_hot(sudoku, num_classes=9).cuda()
        
        input_ = torch.cat((row_one_hot, col_one_hot, val_one_hot), dim=3)
        #print(input_)
        return self.input_encoder(input_.float())
        #return input_
    
        
        
        

In [None]:
class SudokuDataset(torch_data.Dataset):
    def __init__(self):
        self.query_pred = torch.tensor(torch.load('Assignment 2/visual_sudoku/query_pred.pt'))
        self.target_pred = torch.tensor(torch.load('Assignment 2/visual_sudoku/target_pred.pt'))
        
    def __len__(self):
        return self.query_pred.shape[0] // 64
    
    def __getitem__(self, i):
        query = self.query_pred[64*i : 64*(i + 1)].reshape(8, 8)
        target = self.target_pred[64*i : 64*(i + 1)].reshape(8, 8)
        return query, target

In [None]:
dataset = SudokuDataset()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')    
print(device)
net = RRN().to(device)
#net.load_state_dict(torch.load('step_32_bn.pth'))

In [None]:
tic = time.time()
from tqdm.auto import tqdm
num_epochs = 40
num_steps = 20

lmbda = lambda epoch: 0.90
optimizer = optim.Adam(net.parameters(), lr=5e-4)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lmbda)

criterion = nn.CrossEntropyLoss()

split = int(0.8 * len(dataset))
train_data, val_data = torch.utils.data.random_split(dataset, [split, len(dataset) - split])

train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=256, shuffle=False)

running_loss = 0.0
for epoch in range(num_epochs):
    net.train()
    print("Starting epoch " + str(epoch) + ", time: ", time.time() - tic)
    for i, data in enumerate(train_loader, 0):
        
        
        if i == 0:
            net.eval()
            with torch.no_grad():
                total, correct = 0, 0
                total_z, correct_z = 0, 0
                for j, data in enumerate(val_loader):
                    
                    query, target = data
                    query, target = query.to(device), target.to(device) - 1
                    B = query.shape[0]
                    
                    x = net.encode_input(query)
                    s = torch.zeros(B, 8, 8, 16).to(device)

                    h = x
                    for steps in range(num_steps):
                        m = net.message_passing(h)
                        o, h, s = net(h, s, x, m)
                    del h, s
                    
                    pred = o.argmax(dim=3)
                    
                    total += query.nelement()
                    correct += (pred == target).sum()
                    
                    total_z += (query == 0).sum()
                    correct_z += ((pred == target) * (query == 0)).sum()
                    
                    #print(query.flatten(), "\n", target.flatten(), "\n", pred.flatten())
                    
                    #print(query.nelement(), (pred == target).sum().item(), (query == 0).sum().item(), ((pred == target) * (query == 0)).sum().item())
                    
                print("%(z):", (correct_z.item() / total_z.item()) * 100.0,
                     "%(t):", (correct.item() / total) * 100.0)
        
        net.train()
        optimizer.zero_grad()
        query, target = data
        query, target = query.to(device), target.to(device) - 1
        B = query.shape[0]
        
        x = net.encode_input(query) #B, 8, 8, 16
        s = torch.zeros(B, 8, 8, 16).to(device)
        h = x
        loss = 0.0
        for j in range(num_steps):
            m = net.message_passing(h)
            o, h, s = net(h, s, x, m)
            curr_loss = criterion(o.flatten(0, 2), target.flatten(0, 2))
            loss += curr_loss
            #print(curr_loss.item())
        
        del o, h, s
        torch.cuda.empty_cache()
        loss = loss / num_steps
        loss.backward()
        #max_grad = -float('inf')
        
        #for param in net.named_parameters():
        #    if param[1].grad.norm().item() > max_grad:
        #        max_grad = max(param[1].grad.norm().item(), max_grad)
        #        param_name = param[0]
        #print(max_grad, param_name, loss.item())
                
        nn.utils.clip_grad_value_(net.parameters(), 1)
            
        optimizer.step()
        running_loss += loss.item()
        
        #print(loss.item())
        if i % 10 == 9:
            print("epoch: ", epoch + 1, "iter: ", i + 1, "loss: ", running_loss / 10, "time: ", time.time() - tic)
            running_loss = 0
            
    scheduler.step()
                
                
            
    running_loss = 0.0
    

In [14]:
net.eval()
with torch.no_grad():
    total, correct = 0, 0
    total_z, correct_z = 0, 0
    for j, data in enumerate(train_loader):

        query, target = data
        query, target = query.to(device), target.to(device) - 1
        B = query.shape[0]

        x = net.encode_input(query)
        s = torch.zeros(B, 8, 8, 16).to(device)

        h = x
        for steps in range(num_steps):
            m = net.message_passing(x)
            o, h, s = net(h, s, x, m)
        del h, s
        pred = o.argmax(dim=3)

        total += query.nelement()
        correct += (pred == target).sum()

        total_z += (query == 0).sum()
        correct_z += ((pred == target) * (query == 0)).sum()

    print("%(z):", (correct_z.item() / total_z.item()) * 100.0,
         "%(t):", (correct.item() / total) * 100.0)

%(z): 55.39162910292662 %(t): 77.9072265625


In [28]:
count = 0
total = 0
givens = []
train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False)

for j, data in enumerate(train_loader):
    q, t = data
    total += t.nelement()
    count += t.nelement() - (t == q).sum()
    givens.append((t == q).sum().item())

print(total, count.item() / total)
print(givens[0:100])

512000 0.493873046875
[34, 25, 24, 29, 34, 27, 38, 36, 29, 24, 29, 35, 38, 27, 37, 27, 33, 24, 23, 29, 36, 29, 26, 36, 39, 25, 31, 40, 29, 32, 25, 38, 38, 24, 27, 27, 40, 25, 38, 34, 35, 35, 37, 33, 32, 33, 25, 26, 25, 36, 38, 34, 31, 39, 36, 40, 31, 26, 29, 26, 39, 22, 34, 37, 31, 27, 27, 40, 40, 35, 32, 37, 33, 39, 30, 28, 29, 32, 37, 39, 29, 32, 34, 24, 26, 34, 26, 28, 35, 38, 27, 31, 36, 33, 31, 35, 36, 33, 35, 39]


In [66]:
idx = 20
print(train_data.__getitem__(idx)[0])
x = net.encode_input(train_data.__getitem__(idx)[0].unsqueeze(0).cuda())
s = torch.zeros(1, 8, 8, 16).to(device)


output = train_data.__getitem__(idx)[1].unsqueeze(0).cuda()
num_steps = 32
h = x
print(64 - (train_data.__getitem__(idx)[0].unsqueeze(0).cuda() == 0).sum())

for steps in range(num_steps):
    m = net.message_passing(x)
    o, h, s = net(h, s, x, m)
    #print((o.argmax(dim=3) + 1).cpu().float() * mask.cpu().float())
    print(steps)
    print(((o.argmax(dim=3) + 1) == output).sum())
del h, s
pred = o.argmax(dim=3)
print(pred + 1)

tensor([[0, 0, 0, 0, 4, 0, 2, 3],
        [0, 0, 6, 0, 7, 0, 0, 5],
        [8, 0, 1, 6, 0, 7, 0, 0],
        [0, 0, 5, 0, 0, 0, 6, 0],
        [0, 0, 7, 3, 0, 0, 0, 0],
        [4, 1, 0, 0, 6, 5, 3, 7],
        [1, 6, 0, 0, 0, 2, 7, 4],
        [2, 0, 4, 7, 5, 3, 1, 6]])
tensor(32, device='cuda:0')
0
tensor(53, device='cuda:0')
1
tensor(53, device='cuda:0')
2
tensor(53, device='cuda:0')
3
tensor(53, device='cuda:0')
4
tensor(53, device='cuda:0')
5
tensor(53, device='cuda:0')
6
tensor(53, device='cuda:0')
7
tensor(54, device='cuda:0')
8
tensor(54, device='cuda:0')
9
tensor(54, device='cuda:0')
10
tensor(54, device='cuda:0')
11
tensor(54, device='cuda:0')
12
tensor(54, device='cuda:0')
13
tensor(54, device='cuda:0')
14
tensor(54, device='cuda:0')
15
tensor(54, device='cuda:0')
16
tensor(54, device='cuda:0')
17
tensor(54, device='cuda:0')
18
tensor(53, device='cuda:0')
19
tensor(53, device='cuda:0')
20
tensor(53, device='cuda:0')
21
tensor(53, device='cuda:0')
22
tensor(53, device='cuda:

In [5]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')    
net = RRN().to(device)
net.load_state_dict(torch.load('step_25_correct.pth'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [35]:
import time
net.eval()

sudoku = torch.tensor([[
    [4, 0, 0, 2, 1, 0, 0, 7], 
    [0, 0, 5, 0, 6, 0, 0, 0], 
    [0, 3, 0, 0, 0, 0, 7, 0], 
    [0, 4, 0, 0, 0, 6, 0, 0],
    [0, 0, 0, 0, 3, 4, 0, 1],
    [0, 0, 0, 0, 0, 0, 0, 2],
    [1, 0, 0, 4, 0, 0, 8, 0],
    [2, 0, 0, 0, 0, 0, 0, 5]
]])

tic = time.time()
print(64 - (sudoku == 0).sum())
x = net.encode_input(sudoku)
s = torch.zeros(1, 8, 8, 16).to(device)

h = x
for steps in range(25):
    m = net.message_passing(h)
    o, h, s = net(h, s, x, m)
del h, s

pred = o.argmax(dim=3) + 1
toc = time.time()
print(pred)
print(constraint_violation(pred.squeeze()))
print(toc - tic)

tensor(18)
tensor([[[4, 6, 8, 2, 1, 5, 3, 7],
         [3, 1, 5, 1, 6, 2, 4, 8],
         [6, 3, 1, 8, 5, 7, 2, 4],
         [7, 4, 2, 8, 8, 6, 2, 3],
         [8, 2, 6, 7, 3, 4, 7, 1],
         [8, 1, 4, 3, 7, 8, 6, 2],
         [1, 5, 7, 4, 2, 3, 8, 6],
         [2, 8, 3, 6, 4, 1, 7, 5]]], device='cuda:0')
True
0.05385398864746094


In [17]:
def constraint_violation(sudoku):
    for i in range(8):
        a = set()
        for j in range(8):
            elem = sudoku[i][j].item()
            a.add(elem)
        #print("row", a)
        if not len(a) == 8:
            return True
        
    for i in range(8):
        a = set()
        for j in range(8):
            elem = sudoku[j][i].item()
            a.add(elem)
        #print("col", a)
        if not len(a) == 8:
            return True
    
    for i in range(0, 8, 2):
        for j in range(0, 8, 4):
            a = set()
            for x in range(0, 2):
                for y in range(0, 4):
                    elem = sudoku[i + x][j + y].item()
                    a.add(elem)
            #print("block", a)      
            if not len(a) == 8:
                return True
    return False