In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import matplotlib.pyplot as plt

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# repo

In [5]:
def compute_cp(pred, true): # pred and true are of shape = batch,81
    # number of exactly correct predictions
    return torch.sum(torch.sum(pred==true,dim=1)==81)

def compute_micro_score(pred, true):
    # this finds percentage of correct predicited digits and then averaged over batch
    temp = 100.*torch.sum(pred==true,dim=1)/81.
    return torch.sum(temp)/temp.shape[0]

def get_edges():
    def cross(a):
        return [(i, j) for i in a.flatten() for j in a.flatten() if not i == j]

    idx = np.arange(81).reshape(9, 9)
    rows, columns, squares = [], [], []
    for i in range(9):
        rows += cross(idx[i, :])
        columns += cross(idx[:, i])
    for i in range(3):
        for j in range(3):
            squares += cross(idx[i * 3:(i + 1) * 3, j * 3:(j + 1) * 3])
    
    edges_base = list(set(rows + columns + squares))
    batched_edges = [(i + (b * 81), j + (b * 81)) for b in range(BATCH_SIZE) for i, j in edges_base]
    return torch.Tensor(batched_edges).long()

def get_start_embeds(embed, X):
    X = embed(X.long(), EMB_SIZE).float()
    return X


def message_passing(nodes, edges, message_fn):
    n_nodes = nodes.shape[0]
    n_edges = edges.shape[0]
    n_embed = nodes.shape[1]

    message_inputs = nodes[edges]
    message_inputs = message_inputs.view(n_edges, 2*n_embed)
    messages = message_fn(message_inputs)

    updates = torch.zeros(n_nodes, n_embed).to(device)
    idx_j = edges[:, 1].to(device)
    updates = updates.index_add(0, idx_j, messages)
    return updates

class MLP(nn.Module):
    def __init__(self, input_size):
        super(MLP, self).__init__()
        self.fc_in = nn.Linear(input_size, HIDDEN_SIZE)
        self.fc = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE)
        self.fc_out = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE)
    def forward(self, x):
        x = F.relu(self.fc_in(x))
        x = F.relu(self.fc(x))
        x = self.fc_out(x)
        return x

class Pred(nn.Module):
    def __init__(self, input_size):
        super(Pred, self).__init__()
        self.fc1 = nn.Linear(input_size, 10)
    def forward(self, x):
        x = self.fc1(x)
        return x

def check_val():
    with torch.no_grad():
        almost_correct = 0
        correct = 0
        total = 0
        for batch_id, (X_batched, Y_batched) in enumerate(testloader):
            if X_batched.shape[0] != BATCH_SIZE:
                continue
            X = X_batched.flatten()

            X = get_start_embeds(embed, X)
            X = X.to(device)
            Y_batched = Y_batched.to(device)
            X = mlp1(X)
            H = X.detach().clone().to(device)

            CellState = torch.zeros(X.shape).to(device)
            HiddenState = torch.zeros(X.shape).to(device)
            for i in range(32):
                H = message_passing(H, edges, mlp2) # message_fn
                H = mlp3(torch.cat([H, X], dim=1))
                HiddenState, CellState = lstm(H, (HiddenState, CellState))
                H = CellState
                pred = r(H)

            pred = torch.argmax(pred, dim=1)
            pred = pred.view(-1, 81).cpu()
            Y_batched = Y_batched.view(32, -1).cpu()
            amam = torch.sum(pred == Y_batched, dim=1)

            correct += torch.sum(torch.sum(pred == Y_batched, dim=1) == 81)
            almost_correct += torch.sum(torch.sum(pred == Y_batched, dim=1) >= 60)
            total += Y_batched.shape[0]

            if batch_id == 0:
                print("predicted:",pred[0])
                print("true:",Y_batched[0])
        

        print("Correctly solved: {}, out of: {}".format(correct, total))
        print("Almost correctly solved: {}, out of: {}".format(almost_correct, total))

In [6]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import sys

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
EMB_SIZE = 16
HIDDEN_SIZE = 96
BATCH_SIZE = 32

data_X=np.load('drive/My Drive/Colab Notebooks/COL870/sample_X.npy')[:1500]
data_Y=np.load('drive/My Drive/Colab Notebooks/COL870/sample_Y.npy')[:1500]
batch_size=32

dataset=TensorDataset(torch.tensor(data_X[:1024]),torch.tensor(data_Y[:1024]))
trainloader=DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

dataset=TensorDataset(torch.tensor(data_X[1024:]),torch.tensor(data_Y[1024:]))
testloader=DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

mlp1 = MLP(EMB_SIZE).to(device)
mlp2 = MLP(2*HIDDEN_SIZE).to(device)
mlp3 = MLP(2*HIDDEN_SIZE).to(device)
r = Pred(HIDDEN_SIZE).to(device)
lstm = nn.LSTMCell(HIDDEN_SIZE, HIDDEN_SIZE).to(device)
embed = torch.nn.functional.one_hot

optimizer_mlp1 = torch.optim.Adam(mlp1.parameters(), lr=2e-4, weight_decay=1e-4)
optimizer_mlp2 = torch.optim.Adam(mlp2.parameters(), lr=2e-4, weight_decay=1e-4)
optimizer_mlp3 = torch.optim.Adam(mlp3.parameters(), lr=2e-4, weight_decay=1e-4)
optimizer_r = torch.optim.Adam(r.parameters(), lr=2e-4, weight_decay=1e-4)
optimizer_lstm = torch.optim.Adam(lstm.parameters(), lr=2e-4, weight_decay=1e-4)

optimizers = [optimizer_mlp1, optimizer_mlp2, optimizer_mlp3, optimizer_r, optimizer_lstm]
criterion = nn.CrossEntropyLoss()

edges = get_edges()

for epoch in range(100):
    lss = 0
    print("Started epoch: ", epoch)
    total, correct, micro_score = 0, 0, 0
    for batch_id, (X, Y) in enumerate(trainloader):
        if X.shape[0] != BATCH_SIZE:
            continue

        X = X.flatten()
        Y = Y.flatten()
        X = get_start_embeds(embed, X)
        X = X.to(device)
        Y = Y.to(device)
        X = mlp1(X)
        H = X.detach().clone().to(device)

        for optimizer in optimizers:
            optimizer.zero_grad()
        
        loss = 0
        CellState = torch.zeros(X.shape).to(device)
        HiddenState = torch.zeros(X.shape).to(device)
        for i in range(32):
            H = message_passing(H, edges, mlp2) # message_fn
            H = mlp3(torch.cat([H, X], dim=1))
            HiddenState, CellState = lstm(H, (HiddenState, CellState))
            H = HiddenState
            pred = r(H)
            loss += criterion(pred, Y.long())
        
        loss /= BATCH_SIZE
        loss.backward()
        for optimizer in optimizers:
            optimizer.step()
    # check_val()
         
        lss += loss.item()

        pred = pred.argmax(dim=1).view(-1,81)
        Y = Y.view(-1,81)

        correct_predictions = compute_cp(pred.cpu(),Y.cpu()) # number of exactly correct predictions
        correct += correct_predictions
        total += Y.shape[0]

        micro_correct_digits = compute_micro_score(pred.cpu(),Y.cpu()) # this finds percentage of correct predicited digits and then averaged over batch
        micro_score += micro_correct_digits
        
    micro_score /= batch_id
    lss /= batch_id

    print("epoch:",epoch,"|loss:",lss,"| Completely correct predictions:",100.*correct/total,"| Percentage of Correctly predicted digits:",micro_score)

Started epoch:  0


  cpuset_checked))


epoch: 0 |loss: 2.3368322695455244 | Completely correct predictions: tensor(0.) | Percentage of Correctly predicted digits: tensor(11.4695)
Started epoch:  1
epoch: 1 |loss: 2.275004425356465 | Completely correct predictions: tensor(0.) | Percentage of Correctly predicted digits: tensor(11.6288)
Started epoch:  2
epoch: 2 |loss: 2.269856306814378 | Completely correct predictions: tensor(0.) | Percentage of Correctly predicted digits: tensor(11.7707)
Started epoch:  3
epoch: 3 |loss: 2.2691740451320523 | Completely correct predictions: tensor(0.) | Percentage of Correctly predicted digits: tensor(11.6151)
Started epoch:  4
epoch: 4 |loss: 2.268906062649142 | Completely correct predictions: tensor(0.) | Percentage of Correctly predicted digits: tensor(11.9437)
Started epoch:  5
epoch: 5 |loss: 2.268739915663196 | Completely correct predictions: tensor(0.) | Percentage of Correctly predicted digits: tensor(11.6711)
Started epoch:  6
epoch: 6 |loss: 2.26833854183074 | Completely correct pr

KeyboardInterrupt: ignored

# Copying from repo

In [7]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import matplotlib.pyplot as plt

class MLP_for_RRN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MLP_for_RRN, self).__init__()
        self.fc1=nn.Linear(input_dim, output_dim)
        self.fc2=nn.Linear(output_dim, output_dim)
        self.fc3=nn.Linear(output_dim, output_dim)
        # self.fc4=nn.Linear(output_dim, output_dim)
    
    def forward(self, inp):
        out = F.relu(self.fc1(inp))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        # out = self.fc4(out)
        return out

def compute_cp(pred, true): # pred and true are of shape = batch,81
    # number of exactly correct predictions
    return torch.sum(torch.sum(pred==true,dim=1)==81)

def compute_micro_score(pred, true):
    # this finds percentage of correct predicited digits and then averaged over batch
    temp = 100.*torch.sum(pred==true,dim=1)/81.
    return torch.sum(temp)/temp.shape[0]

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
embed_dim=16
sudoku_cells=9
hidden_dim=96
num_steps=32
batch_size=32


data_X=np.load('drive/My Drive/Colab Notebooks/COL870/sample_X.npy')[:1500]
data_Y=np.load('drive/My Drive/Colab Notebooks/COL870/sample_Y.npy')[:1500]

dataset=TensorDataset(torch.tensor(data_X[:1024]),torch.tensor(data_Y[:1024]))
trainloader=DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

dataset=TensorDataset(torch.tensor(data_X[1024:]),torch.tensor(data_Y[1024:]))
testloader=DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)


embeds_to_x = MLP_for_RRN(3*embed_dim, hidden_dim).to(device)
message_mlp = MLP_for_RRN(2*hidden_dim, hidden_dim).to(device)
mlp_for_lstm_inp = MLP_for_RRN(2*hidden_dim, hidden_dim).to(device)
r_to_o_mlp = nn.Linear(hidden_dim, sudoku_cells+1).to(device) # only one linear layer as given in architecture details
mlps = [embeds_to_x, message_mlp, mlp_for_lstm_inp, r_to_o_mlp]
LSTM = nn.LSTMCell(input_size=hidden_dim, hidden_size=hidden_dim).to(device)# since x and m will be concatentated and fed into lstm; x and m are of shape : batch_size*9*9, hidden_dim


embed = torch.nn.functional.one_hot ## ALERT

# embed_1 = nn.Linear(sudoku_cells+1,embed_dim)

# embed_1 = nn.Embedding.from_pretrained(torch.rand(sudoku_cells+1, embed_dim).to(device)).to(device)


optimizers = []
for mlp in mlps:
    optimizers.append(optim.Adam(mlp.parameters(), lr=2e-4, weight_decay=1e-4))
optimizers.append(optim.Adam(LSTM.parameters(), lr=2e-4, weight_decay=1e-4))

# optimizers.append(optim.Adam(embed_1.parameters(), lr=2e-4, weight_decay=1e-4))


loss_fn = nn.CrossEntropyLoss()


############################ get edges
# find the required edges in the graph to have communication of message signals
indices_of_cells=np.arange(0,sudoku_cells*sudoku_cells).reshape((sudoku_cells,sudoku_cells))
edges_row, edges_col, edges_in_3x3=[],[],[]
for i in range(9):
    vector = indices_of_cells[i,:]
    edges_row += [(i,j) for i in vector for j in vector if i!=j]
    vector = indices_of_cells[:,i]
    edges_col += [(i,j) for i in vector for j in vector if i!=j]
for i in range(3):
    for j in range(3):
        vector = indices_of_cells[3*i:3*(i+1),3*j:3*(j+1)].reshape(-1)
        edges_in_3x3 += [(i,j) for i in vector for j in vector if i!=j]

edges = list(set(edges_row + edges_col + edges_in_3x3))
# edges = [ (i + (b*81), j + (b*81)) for b in range(batch_size) for i,j in edges]
edges = torch.tensor(edges).long().to(device)
# self.edges contains all the possible pairs of communication between the cells of sudoku
############################



############################ FOR EMBEDDING ROW, COL INFORMATION
# create row and col labels for the cells of sudoku table
row_col = []
for i in range(sudoku_cells):
    for j in range(sudoku_cells):
        row_col.append((i,j))
row_col = torch.tensor(row_col).long()
############################


num_epochs=100
for epoch in range(num_epochs):
    lss = 0
        
    total, correct, micro_score = 0, 0, 0

    for batch_id, (X,Y) in enumerate(trainloader):
        if X.shape[0]!=batch_size:
            continue

        X = X.flatten()
        Y = Y.flatten()
        # if epoch == 0 and batch_id == 0:
        #     print('printing X shape: ', X.shape)
        
        X = embed(X.long(), embed_dim).float()
        row_col_batched = row_col.repeat(batch_size,1) # IMPORTANT - HERE REMOVE BATCH_SIZE VARIABLE LATER
        embedded_row = embed(row_col_batched[:,0].long(), embed_dim).float()
        embedded_col = embed(row_col_batched[:,1].long(), embed_dim).float()
        
        X = torch.cat([X,embedded_row,embedded_col],dim=1)
        # X = F.one_hot(X.long(), sudoku_cells+1).float()
        
        # if epoch == 0 and batch_id == 0:
        #     print('printing embedded X shape: ', X.shape)


        X = X.to(device)
        Y = Y.to(device)

        # X = embed_1(X)

        X = embeds_to_x(X)

        H = X.detach().clone().to(device)

        for optimizer in optimizers:
            optimizer.zero_grad()

        loss = 0
        HiddenState,CellState = torch.zeros(X.shape).to(device), torch.zeros(X.shape).to(device)

        for i in range(num_steps):

            n_nodes = H.shape[0]
            n_edges = edges.shape[0]
            n_embed = H.shape[1]
            assert n_embed == 96

            H = H.view(-1,81,96)

            assert H.shape[0] == 32

            message_inputs = H[:,edges]
            message_inputs = message_inputs.view(-1, 2*96)

            messages = message_mlp(message_inputs).view(H.shape[0],-1,96)

            updates = torch.zeros(H.shape).to(device)
            idx_j = edges[:, 1].to(device)
            H = updates.index_add(1, idx_j, messages)

            H = H.view(-1,96)

            H = mlp_for_lstm_inp(torch.cat([H, X], dim=1))
            HiddenState, CellState = LSTM(H, (HiddenState, CellState))

            H = HiddenState ## ALERT

            Y_pred = r_to_o_mlp(H)

            loss += loss_fn(Y_pred, Y.long())

        loss /= batch_size

        loss.backward()
        for optimizer in optimizers:
            optimizer.step()


        lss += loss.item()

        Y_pred = Y_pred.argmax(dim=1).view(-1,81)
        Y = Y.view(-1,81)

        correct_predictions = compute_cp(Y_pred.cpu(),Y.cpu()) # number of exactly correct predictions
        correct += correct_predictions
        total += Y.shape[0]

        micro_correct_digits = compute_micro_score(Y_pred.cpu(),Y.cpu()) # this finds percentage of correct predicited digits and then averaged over batch
        micro_score += micro_correct_digits
        
    micro_score /= batch_id
    lss /= batch_id

    print("epoch:",epoch,"|loss:",lss,"| Completely correct predictions:",100.*correct/total,"| Percentage of Correctly predicted digits:",micro_score)


  cpuset_checked))


epoch: 0 |loss: 2.332360367621145 | Completely correct predictions: tensor(0.) | Percentage of Correctly predicted digits: tensor(11.4708)
epoch: 1 |loss: 2.2722580432891846 | Completely correct predictions: tensor(0.) | Percentage of Correctly predicted digits: tensor(11.4695)
epoch: 2 |loss: 2.2695021860061155 | Completely correct predictions: tensor(0.) | Percentage of Correctly predicted digits: tensor(11.4695)
epoch: 3 |loss: 2.26906600306111 | Completely correct predictions: tensor(0.) | Percentage of Correctly predicted digits: tensor(11.4695)
epoch: 4 |loss: 2.2688424510340535 | Completely correct predictions: tensor(0.) | Percentage of Correctly predicted digits: tensor(11.4695)
epoch: 5 |loss: 2.268687140557074 | Completely correct predictions: tensor(0.) | Percentage of Correctly predicted digits: tensor(11.5405)
epoch: 6 |loss: 2.268586358716411 | Completely correct predictions: tensor(0.) | Percentage of Correctly predicted digits: tensor(11.5268)
epoch: 7 |loss: 2.2685368

Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):


KeyboardInterrupt: ignored

  File "/usr/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


# my implementation

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import matplotlib.pyplot as plt

class MLP_for_RRN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MLP_for_RRN, self).__init__()
        self.fc1=nn.Linear(input_dim, output_dim)
        self.fc2=nn.Linear(output_dim, output_dim)
        self.fc3=nn.Linear(output_dim, output_dim)
        # self.fc4=nn.Linear(output_dim, output_dim)
    
    def forward(self, inp):
        out = F.relu(self.fc1(inp))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        # out = self.fc4(out)
        return out

class RRN(nn.Module):
    def __init__(self, embed_dim=16, sudoku_cells=9, hidden_dim=96, num_steps=32, device='cpu'):
        # sudoku_cells means we will have sudoku_cells x sudoku_cells in the sudoku table
        super(RRN, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_steps = num_steps
        self.device = device
        

        ############################ FOR MESSAGE SIGNALS
        # find the required edges in the graph to have communication of message signals
        indices_of_cells=np.arange(0,sudoku_cells*sudoku_cells).reshape((sudoku_cells,sudoku_cells))
        edges_row, edges_col, edges_in_3x3=[],[],[]
        for i in range(9):
            vector = indices_of_cells[i,:]
            edges_row += [(i,j) for i in vector for j in vector if i!=j]
            vector = indices_of_cells[:,i]
            edges_col += [(i,j) for i in vector for j in vector if i!=j]
        for i in range(3):
            for j in range(3):
                vector = indices_of_cells[3*i:3*(i+1),3*j:3*(j+1)].reshape(-1)
                edges_in_3x3 += [(i,j) for i in vector for j in vector if i!=j]
        self.edges = torch.tensor(list(set(edges_row + edges_col + edges_in_3x3))).long().to(device)
        # self.edges contains all the possible pairs of communication between the cells of sudoku
        ############################
        
        

        ############################ FOR EMBEDDING ROW, COL INFORMATION
        # create row and col labels for the cells of sudoku table
        row_col = []
        for i in range(sudoku_cells):
            for j in range(sudoku_cells):
                row_col.append((i,j))
        self.row_col = torch.tensor(row_col).long().to(device)
        ############################
        
        

        ############################ EMBEDDING LAYERS
        # embedding the cell content {0,1,2,...,sudoku_cells}, row and column information for each cell in sudoku
        self.embed_dim = embed_dim
        # embed_1_init = torch.rand(sudoku_cells+1, self.embed_dim).to(device) #sudoku_cells+1 because possible digits in input : 0,1,2,3,...,sudoku_cells
        # self.embed_1 = nn.Linear(sudoku_cells+1, self.embed_dim)#nn.Embedding.from_pretrained(embed_1_init, freeze=False) 
        # embed_2_init = torch.rand(sudoku_cells, self.embed_dim).to(device)
        # self.embed_2 = nn.Linear(sudoku_cells, self.embed_dim)#nn.Embedding.from_pretrained(embed_2_init, freeze=False)
        # embed_3_init = torch.rand(sudoku_cells, self.embed_dim).to(device)
        # self.embed_3 = nn.Linear(sudoku_cells, self.embed_dim)#nn.Embedding.from_pretrained(embed_3_init, freeze=False)
        ############################


        ############################ MLPs
        self.embeds_to_x = MLP_for_RRN(3*embed_dim, hidden_dim)
        self.message_mlp = MLP_for_RRN(2*hidden_dim, hidden_dim)
        self.mlp_for_lstm_inp = MLP_for_RRN(2*hidden_dim, hidden_dim)
        self.r_to_o_mlp = nn.Linear(hidden_dim, sudoku_cells+1) # only one linear layer as given in architecture details
        ############################


        # LSTM for looping over time i.e. num_steps
        self.LSTM = nn.LSTMCell(input_size=hidden_dim, hidden_size=hidden_dim) # since x and m will be concatentated and fed into lstm; x and m are of shape : batch_size*9*9, hidden_dim
        
        
    def forward(self, inp): # inp.shape=batch_size,9*9
        bs = inp.shape[0]
        inp = inp.view(-1)


        # embed the cell content
        inp = F.one_hot(inp, self.embed_dim).float()
        embedded_inp = inp # batch_size*9*9, embed_dim
        
        # now also get row and column info of each cell embedded
        row_col = self.row_col.repeat(batch_size, 1)
        inp_row = F.one_hot(row_col[:,0], self.embed_dim).float()
        embedded_row = inp_row
        inp_col = F.one_hot(row_col[:,1], self.embed_dim).float()
        embedded_col = inp_col
        
        embedded_all = torch.cat([embedded_inp,embedded_row,embedded_col], dim=1)
        x = self.embeds_to_x(embedded_all) # batch_size*9*9, hidden_dim
        
        assert x.shape[1] == self.hidden_dim
        

        # x will be concatenated with m and then fed into LSTM
        # find message signals : over time i.e. num_steps
        # m_{i,j}^{t} = MLP(h_{i}^{t-1}, h_{j}^{t-1} 
        # since m^t requires h^{t-1}, maintain a list of h and c
        # cell state is also required since we will use LSTM cell and loop over LSTM cell num_steps times
        
        h_for_msgs = x.detach().clone().to(self.device)
        o_t = []

        for t in range(self.num_steps):

            h_for_msgs = h_for_msgs.view(-1, 81, self.hidden_dim)
            inp_for_msgs = h_for_msgs[:,self.edges].view(-1, 2*self.hidden_dim)
            msgs = self.message_mlp(inp_for_msgs).view(bs, -1, self.hidden_dim)
            

            # now sum up the message signals appropriately
            final_msgs = torch.zeros(h_for_msgs.shape).to(device)
            indices = self.edges[:,1].to(device)
            final_msgs = final_msgs.index_add(1, indices, msgs) # shape : batch_size, 81, self.hidden_dim
            final_msgs = final_msgs.view(-1, self.hidden_dim)
            
            # h_for_msgs = h_for_msgs.view(-1, self.hidden_dim) # required for input to lstm cell
            
            inp_to_lstm = self.mlp_for_lstm_inp(torch.cat([final_msgs,x],dim=1))
            h, c = self.LSTM(inp_to_lstm, (h,c)) if t!=0 else self.LSTM(inp_to_lstm, (h_for_msgs.view(-1, 96), torch.zeros(x.shape).to(device)))
            
            h_for_msgs = h

            o = self.r_to_o_mlp(h)
            o_t.append(o)
        
        out = torch.stack(o_t) # shape : num_steps, batch_size*9*9, sudoku_cells+1
        return out # out.shape = num_steps, batch_size*9*9, 10 : last dim is without-softmax over sudoku_cells(10)

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import matplotlib.pyplot as plt

def compute_cp(pred, true): # pred and true are of shape = batch,81
    # number of exactly correct predictions
    return torch.sum(torch.sum(pred==true,dim=1)==81)

def compute_micro_score(pred, true):
    # this finds percentage of correct predicited digits and then averaged over batch
    temp = 100.*torch.sum(pred==true,dim=1)/81.
    return torch.sum(temp)/temp.shape[0]

batch_size=64
embed_dim=16
sudoku_cells=9
hidden_dim=96
num_steps=32
device='cuda:0' if torch.cuda.is_available() else 'cpu'

data_X=np.load('drive/My Drive/Colab Notebooks/COL870/sample_X.npy')[:1024]
data_Y=np.load('drive/My Drive/Colab Notebooks/COL870/sample_Y.npy')[:1024]
dataset=TensorDataset(torch.tensor(data_X),torch.tensor(data_Y))
data_loader=DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

model = RRN(embed_dim=embed_dim, sudoku_cells=sudoku_cells, hidden_dim=hidden_dim, num_steps=num_steps, device=device)
model = model.to(device)
print(model)

optimizer=torch.optim.Adam(model.parameters(), lr=2e-4, weight_decay=1e-4)
loss_fn=nn.CrossEntropyLoss()

RRN(
  (embeds_to_x): MLP_for_RRN(
    (fc1): Linear(in_features=48, out_features=96, bias=True)
    (fc2): Linear(in_features=96, out_features=96, bias=True)
    (fc3): Linear(in_features=96, out_features=96, bias=True)
  )
  (message_mlp): MLP_for_RRN(
    (fc1): Linear(in_features=192, out_features=96, bias=True)
    (fc2): Linear(in_features=96, out_features=96, bias=True)
    (fc3): Linear(in_features=96, out_features=96, bias=True)
  )
  (mlp_for_lstm_inp): MLP_for_RRN(
    (fc1): Linear(in_features=192, out_features=96, bias=True)
    (fc2): Linear(in_features=96, out_features=96, bias=True)
    (fc3): Linear(in_features=96, out_features=96, bias=True)
  )
  (r_to_o_mlp): Linear(in_features=96, out_features=10, bias=True)
  (LSTM): LSTMCell(96, 96)
)


In [7]:
num_epochs=100
train_loss=[]
for epoch in range(num_epochs):
    lss=0

    total, correct, micro_score = 0, 0, 0

    for batch_id, (X,Y) in enumerate(data_loader):
        if X.shape[0] != batch_size:
            continue

        X, Y = X.to(device).long(), Y.to(device)
        Y = Y.view(-1)
        
        optimizer.zero_grad()

        Y_ = model(X)
        
        l=0
        for i in range(num_steps):
            ls=loss_fn(Y_[i],Y.long())
            l+=ls

        Y_pred = Y_[-1].argmax(dim=1)

        l /= batch_size
        l.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1) # clip gradient to 5
        optimizer.step()
        
        Y_pred = Y_pred.view(-1,81)
        Y = Y.view(-1,81)
        
        assert X.shape[0]==Y.shape[0]

        lss += l.item()
        correct_predictions = compute_cp(Y_pred.cpu(),Y.cpu()) # number of exactly correct predictions
        correct += correct_predictions
        total += Y.shape[0]

        micro_correct_digits = compute_micro_score(Y_pred.cpu(),Y.cpu()) # this finds percentage of correct predicited digits and then averaged over batch
        micro_score += micro_correct_digits
        
    lss /= batch_id
    micro_score /= batch_id
    print("epoch:",epoch,"|	loss:",lss,"| Completely correct predictions:",100.*correct/total,"| Percentage of Correctly predicted digits:",micro_score)
    
    # scheduler.step(lss)

torch.save(model.state_dict(),'RRN.pth')

RuntimeError: ignored