In [2]:
import numpy as np
import torchvision

import torch.nn.functional as F
import torch.nn as nn
import torch
import torch.optim as optim

import math
from torch.nn import init
from torch.autograd import Variable
from tqdm import trange
from torch.distributions.categorical import Categorical
import scipy
import scipy.linalg
from collections import Counter


#from model import CnnActorCriticNetwork, RNDModel
from utils import global_grad_norm_
device = 'cuda'


### Dynamical Isomentry Check

In [158]:
def noise_sample(obs_batch):
    obs_batch = obs_batch.cpu().detach().numpy()
    sample = np.random.normal(size=obs_batch.shape)
    sample = torch.from_numpy(sample).float().cuda()# use .to(self.device) soon 
    return sample

def noise_sample_step(obs_batch, epsilon=1):
    obs_batch = obs_batch.cpu().detach().numpy()
    step = np.random.normal(size=obs_batch.shape)
    step = (step / np.linalg.norm(step)) * epsilon
    z_obs_batch = obs_batch + step
    z_obs_batch = torch.from_numpy(z_obs_batch).float().cuda()# use .to(self.device) soon 
    return z_obs_batch

In [142]:
torch.Tensor(np.random.randint(0,10),0)

tensor([], size=(9, 0))

In [155]:
torch.LongTensor(1).random_(0, 10)

tensor([6])

In [831]:
def init_weights(m):
    if type(m)==nn.Linear:
        init.orthogonal_(m.weight)

In [854]:
lnn =  nn.Sequential(nn.Linear(784, 512))

In [855]:
lnn.apply(init_weights)

Sequential(
  (0): Linear(in_features=784, out_features=512, bias=True)
)

In [866]:
w = lnn[0].weight.detach().cpu().numpy()
u,s,v = scipy.linalg.svd(w)
print(np.mean(s))

0.9999999


###  Plan 
#### Train model on 10 samples from class 0 - Done
#### Create Model with separated predictors - Done
#### Train/Test Model with separated predictors - Done
#### Debug Model with separated predictors - In progress

In [3]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./data/MNIST/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=1, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./data/MNIST/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=1, shuffle=True)

In [23]:
class RNDModel(nn.Module):
    def __init__(self, n_classes):
        super(RNDModel, self).__init__()
        
        self.activated_predictor = None
        
        self.target =  nn.Sequential(nn.Linear(784, 512))
        
        self.predictors = {}
        for c in range(n_classes):
            self.predictors['predictor_'+str(c)] = nn.Sequential(
                nn.Linear(784, 512),
               # nn.ReLu(),
                #nn.Linear(512, 512),
            )
        
        for p in self.modules():
            if isinstance(p, nn.Conv2d):
                init.orthogonal_(p.weight, np.sqrt(2))
                p.bias.data.zero_()

            if isinstance(p, nn.Linear):
                #init.orthogonal_(p.weight, np.sqrt(2))
                init.orthogonal_(p.weight)
                #init.orthogonal_(p.bias)
                #p.bias.data.zero_()

        for param in self.target.parameters():
            param.requires_grad = False
        for predictor in self.predictors:
            for param in self.predictors[predictor].parameters():
                param.requires_grad = False

                
    def cuda_predictors(self):
        for predictor in self.predictors:
            self.predictors[predictor].cuda()
                
                
    def activate_predictor(self, class_):
        self.activated_predictor = self.predictors['predictor_'+str(class_)]
        for param in self.activated_predictor.parameters():
            param.requires_grad = True
                
    def deactivate_predictor(self):
        for param in self.activated_predictor.parameters():
            param.requires_grad = False
            
            
    def predict(self, next_obs):
        predict_features = []
        target_feature = self.target(next_obs)
        for predictor in self.predictors:
            predict_features.append(self.predictors[predictor](next_obs))
        return predict_features, target_feature
            
            
    def forward(self, next_obs):
        target_feature = self.target(next_obs)
        predict_feature = self.activated_predictor(next_obs)

        return predict_feature, target_feature

In [174]:
rnd = RNDModel(10)
rnd.to(device)
rnd.cuda_predictors()
print(rnd)

params =[]
for _, predictor in rnd.predictors.items():
    params += list(predictor.parameters())

optimizer = optim.Adam(params,lr=0.001)
forward_mse = nn.MSELoss(reduction='none')

update_proportion = 0.25


#Batch size must be 1!
def train(epoch, rnd, train_loader, shots_num):
    for batch_idx, (data, y) in enumerate(train_loader):
        data = data.view(data.shape[0],-1 )
        rnd.activate_predictor(class_=y.cpu().numpy()[0])

        predict_next_state_feature, target_next_state_feature = rnd(Variable(data.to(device)))
        forward_loss = forward_mse(predict_next_state_feature, target_next_state_feature.detach()).mean(-1)
        forward_loss = forward_loss.sum()/len(forward_loss)

        #Some unknown rnd regularization!
        #mask = torch.rand(len(forward_loss)).to(device)
        #mask = (mask < update_proportion).type(torch.FloatTensor).to(device)
        #forward_loss = (forward_loss * mask).sum() / torch.max(mask.sum(), torch.Tensor([1]).to(device))
        
        #params =[]
        #for _, predictor in rnd.predictors.items():
        #    params += list(predictor.parameters())
        #print('Now using predictor number ', y.cpu().numpy()[0])
        #print('params before update for predictor 0:', params[0])
        #print('params before update for predictor 1:', params[4])
        
        optimizer = optim.Adam(list(rnd.activated_predictor.parameters()),lr=0.001)
        optimizer.zero_grad()
        loss = forward_loss
        loss.backward()
        global_grad_norm_(list(rnd.activated_predictor.parameters()))
        optimizer.step()
        
        #params =[]
        #for _, predictor in rnd.predictors.items():
        #    params += list(predictor.parameters())
        #print('params after update for predictor 0:', params[0])
        #print('params after update for predictor 1:', params[4])

        #rnd.deactivate_predictor()

        if batch_idx % 1000 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), shots_num,
            100. * batch_idx / shots_num, loss.item()))
          #train_losses.append(loss.item())


def pretrain(batch_idx, rnd, data, y):
    data = data.view(data.shape[0],-1 )
    rnd.activate_predictor(class_=y.cpu().numpy()[0])

    predict_next_state_feature, target_next_state_feature = rnd(Variable(data.to(device)))
    forward_loss = forward_mse(predict_next_state_feature, target_next_state_feature.detach()).mean(-1)
    forward_loss = forward_loss.sum()/len(forward_loss)

    optimizer = optim.Adam(list(rnd.activated_predictor.parameters()),lr=0.001)
    optimizer.zero_grad()
    loss = forward_loss
    loss.backward()
    global_grad_norm_(list(rnd.activated_predictor.parameters()))
    optimizer.step()

    if batch_idx % 1000 == 0:
        print('Loss: {:.6f}'.format(loss.item()))
        
        
def test(rnd, test_loader, shots_num=1000):
    rnd.eval()
    test_loss = 0
    correct = 0
    mses = []
    with torch.no_grad():
        for batch_idx, (data, y)  in enumerate(test_loader): 
            data = data.view(data.shape[0],-1 )
            predict_next_state_feature, target_next_state_feature = rnd.predict(Variable(data.to(device)))
            for predict in predict_next_state_feature:
                mses.append((target_next_state_feature - predict).pow(2).sum(1) / 2)
            min_mse = np.argmin(mses)
            #print('min_mse',min_mse)
            #print('y',y.cpu().numpy()[0])
            if min_mse==y.cpu().numpy()[0]:
                correct+=1
            mses = []
        print('Accuracy: {}/{} ({:.0f}%)\n'.format(correct, batch_idx+1, 100. * correct / (batch_idx+1)))
        #len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
    #return(test_loss)

RNDModel(
  (target): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
  )
)


#### Random pretraining

In [171]:
for batch_idx, (data, _) in enumerate(train_loader):
    y = torch.LongTensor(1).random_(0, 10)
    data = noise_sample(data)
    pretrain(batch_idx, rnd, data, y)

Loss: 1.213436
Loss: 1.278655
Loss: 1.123878
Loss: 1.080909
Loss: 0.991388
Loss: 1.086716
Loss: 0.769583
Loss: 0.946670
Loss: 0.915132
Loss: 0.743425
Loss: 0.844830
Loss: 0.670870
Loss: 0.781253
Loss: 0.684995
Loss: 0.689395
Loss: 0.631037
Loss: 0.578234
Loss: 0.595365
Loss: 0.597020
Loss: 0.533122
Loss: 0.545753
Loss: 0.501525
Loss: 0.532914
Loss: 0.531696
Loss: 0.447315
Loss: 0.535963
Loss: 0.532086
Loss: 0.430336
Loss: 0.493182
Loss: 0.481700
Loss: 0.369355
Loss: 0.344260
Loss: 0.420057
Loss: 0.414868
Loss: 0.445324
Loss: 0.463982
Loss: 0.364953
Loss: 0.405648
Loss: 0.436920
Loss: 0.384271
Loss: 0.433264
Loss: 0.343799
Loss: 0.416409
Loss: 0.355474
Loss: 0.363487
Loss: 0.371014
Loss: 0.391464
Loss: 0.350464
Loss: 0.367048
Loss: 0.406647
Loss: 0.369788
Loss: 0.363228
Loss: 0.407877
Loss: 0.381963
Loss: 0.393192
Loss: 0.385000
Loss: 0.391670
Loss: 0.356310
Loss: 0.394854
Loss: 0.333692


### Full MNIST

In [70]:
num_of_samples = 200
few_shot_dataset = []
few_shot_dataset_y = []
for batch_idx, (data, target) in enumerate(train_loader):
    few_shot_dataset.append(data)
    few_shot_dataset_y.append(target)
    if len(few_shot_dataset)>num_of_samples:
        break

In [175]:
num_of_shots = 6
break_trashold = num_of_shots*15
few_shot_dataset = []
few_shot_dataset_y = []
few_shot_dataset_y_np = list(range(0,10))
for batch_idx, (data, target) in enumerate(train_loader):
    num_of_samples = [x for x in Counter(few_shot_dataset_y_np).values()]
    pos_of_samples = [x for x in Counter(few_shot_dataset_y_np).keys()]
    if num_of_samples[pos_of_samples.index(target.cpu().numpy()[0])]<num_of_shots:
        few_shot_dataset.append(data)
        few_shot_dataset_y.append(target)
        few_shot_dataset_y_np.append(target.cpu().numpy()[0])
    if batch_idx>break_trashold:
        break

In [170]:
test(rnd, test_loader)

Accuracy: 1250/10000 (12%)



In [176]:
for epoch in range(1, 500 + 1):
    train(epoch, rnd, zip(few_shot_dataset, few_shot_dataset_y), len(few_shot_dataset))
    test(rnd, test_loader)

Accuracy: 6510/10000 (65%)

Accuracy: 6707/10000 (67%)

Accuracy: 6799/10000 (68%)

Accuracy: 6848/10000 (68%)

Accuracy: 6920/10000 (69%)

Accuracy: 6913/10000 (69%)

Accuracy: 6970/10000 (70%)

Accuracy: 6886/10000 (69%)

Accuracy: 6831/10000 (68%)



KeyboardInterrupt: 