In [None]:
# 1.a Train VIs via OSVM (RBF) based on 50000 signatures
from MNIST_models import *
import pickle 
import warnings
warnings.filterwarnings('ignore')

# Get dataset (V1X to V4X) for training VIs 
num_of_X = 50000
prefixs = ['store_zero/', 'store_one/', 'store_two/', 'store_three/', 'store_four/', 
        'store_five/', 'store_six/', 'store_seven/', 'store_eight/', 'store_nine/']

class reduced_recurrent_model(nn.Module):
    def __init__(self):
        super(reduced_recurrent_model, self).__init__()
        self.out1 = nn.Linear(16*24*24, 10)
        self.out2 = nn.Linear(16*10*10, 10)
        self.out3 = nn.Linear(32*3*3, 10)
        self.out4 = nn.Linear(64, 10)
        
        self.fc12 = nn.Linear(16*24*24, 16*10*10)
        self.fc23 = nn.Linear(16*10*10, 32*3*3)
        self.fc34 = nn.Linear(32*3*3, 64)

        self.softmax = nn.Softmax()
        self.relu = nn.ReLU()
    
    def forward(self, f1, f2, f3, f4):
        f1, f2, f3, f4 = f1.view(-1, 16*24*24), f2.view(-1, 16*10*10), f3.view(-1, 32*3*3), f4.view(-1, 64)
        out1 = self.softmax(self.out1(f1))
        
        h2 = torch.add(f2, self.relu(self.fc12(f1)))
        out2 = self.softmax(self.out2(h2))
        
        h3 = torch.add(f3, self.relu(self.fc23(h2)))
        out3 = self.softmax(self.out3(h3))
        
        h4 = torch.add(f4, self.relu(self.fc34(h3)))
        out4 = self.softmax(self.out4(h4))
        
        return out1, out2, out3, out4

# create model here 
r_c_model = reduced_recurrent_model()
loss_func, optimizer = nn.CrossEntropyLoss(), torch.optim.Adam(r_c_model.parameters(), lr=1e-3)

for _ in range(100):
    total_loss = None
    for i in range(num_of_X):
#         print(i, num_of_X)
        if (i+1) % 1000 == 0: print(i+1)
        for prefix_i, prefix in enumerate(prefixs):

            fn_name = 'store_subs_fadv/'+prefix+'normal'+'_'+str(i+1)+'.txt'
            try: fp = open(fn_name, 'rb')
            except: continue

            signatures = pickle.load(fp)
            f1, f2, f3, f4 = preprocess(signatures)
            y = prefix_i
            label = torch.from_numpy(np.array([y]).astype(np.int64))
            fp.close()

            # Forwarding
            out1, out2, out3, out4 = r_c_model.forward(f1, f2, f3, f4)
            loss1 = loss_func(out1, label)
            loss2 = loss_func(out2, label)
            loss3 = loss_func(out3, label)
            loss4 = loss_func(out4, label)
            loss = loss1+loss2+loss3+loss4

            if total_loss is None: total_loss = loss
            else: total_loss += loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    print(total_loss)
    acc, total = 0, 0
    for i in range(num_of_X):
        for prefix_i, prefix in enumerate(prefixs):

            fn_name = 'store_subs_fadv/'+prefix+'normal'+'_'+str(i+1)+'.txt'
            try: fp = open(fn_name, 'rb')
            except: continue

            signatures = pickle.load(fp)
            f1, f2, f3, f4 = preprocess(signatures)
            y = prefix_i
            label = torch.from_numpy(np.array([y]).astype(np.int64))
            fp.close()

            # Forwarding
            out1, out2, out3, out4 = r_c_model.forward(f1, f2, f3, f4)
            prediction1 = torch.argmax(out1, dim=1)[0].item()
            prediction2 = torch.argmax(out2, dim=1)[0].item()
            prediction3 = torch.argmax(out3, dim=1)[0].item()
            prediction4 = torch.argmax(out4, dim=1)[0].item()
            predictions = [prediction1, prediction2, prediction3, prediction4]
            if len(list(set(predictions))) == 1 and prediction1 == y:
                acc += 1
            total += 1
            
    print(acc, total, round(acc/total, 3))

