In [78]:
!pip install torch==1.0.0 torchvision==0.2.2 -f https://download.pytorch.org/whl/cu90/torch_stable.html
%load_ext autoreload
%autoreload 2


import argparse
import os
import torch
import capsulenet

# setting the hyper parameters
parser = argparse.ArgumentParser(description="Capsule Network on MNIST.")

parser.add_argument('--epochs', default=50, type=int)
parser.add_argument('--batch_size', default=100, type=int)
parser.add_argument('--lr', default=0.001, type=float,
                    help="Initial learning rate")
parser.add_argument('--lr_decay', default=0.9, type=float,
                    help="The value multiplied by lr at each epoch. Set a larger value for larger epochs")
parser.add_argument('--lam_recon', default=0.0005 * 784, type=float,
                    help="The coefficient for the loss of decoder")
parser.add_argument('-r', '--routings', default=3, type=int,
                    help="Number of iterations used in routing algorithm. should > 0")  # num_routing should > 0
parser.add_argument('--shift_pixels', default=2, type=int,
                    help="Number of pixels to shift at most in each direction.")
parser.add_argument('--data_dir', default='./data',
                    help="Directory of data. If no data, use \'--download\' flag to download it")
parser.add_argument('--download', action='store_true',
                    help="Download the required data.")
parser.add_argument('--save_dir', default='./result')
parser.add_argument('-t', '--testing', action='store_true',
                    help="Test the trained model on testing dataset")
parser.add_argument('-w', '--weights', default=None,
                    help="The path of the saved weights. Should be specified when testing")
args = parser.parse_args(args=[])
print(args)

if not os.path.exists(args.save_dir):
    os.makedirs(args.save_dir)

# load data
train_loader, test_loader = capsulenet.load_mnist(args.data_dir, download=True, batch_size=args.batch_size)

# define model
model = capsulenet.CapsuleNet(input_size=[2, 28, 28], classes=10, routings=3)
if torch.cuda.is_available():
  model.cuda()
print(model)


Looking in links: https://download.pytorch.org/whl/cu90/torch_stable.html
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Namespace(batch_size=100, data_dir='./data', download=False, epochs=50, lam_recon=0.392, lr=0.001, lr_decay=0.9, routings=3, save_dir='./result', shift_pixels=2, testing=False, weights=None)
CapsuleNet(
  (conv1): Conv2d(2, 256, kernel_size=(9, 9), stride=(1, 1))
  (primarycaps): PrimaryCapsule(
    (conv2d): Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2))
  )
  (digitcaps): DenseCapsule()
  (decoder): Sequential(
    (0): Linear(in_features=160, out_features=512, bias=True)
    (1): ReLU(inplace)
    (2): Linear(in_features=512, out_features=1024, bias=True)
    (3): ReLU(inplace)
    (4): Linear(in_features=1024, out_features=1568, bias=True)
    (5): Sigmoid()
  )
  (relu): ReLU()
)


In [80]:

# train or test
if args.weights is not None:  # init the model weights with provided one
    model.load_state_dict(torch.load(args.weights))
if not args.testing:
    capsulenet.train(model, train_loader, test_loader, args)
else:  # testing
    if args.weights is None:
        print('No weights are provided. Will test using random initialized weights.')
    test_loss, test_acc = test(model=model, test_loader=test_loader, args=args)
    print('test acc = %.4f, test loss = %.5f' % (test_acc, test_loss))
    #show_reconstruction(model, test_loader, 50, args)

Begin Training----------------------------------------------------------------------
0
y.size(0) 
200
y.view(-1, 1) 
torch.Size([800, 1])


RuntimeError: ignored

In [10]:
def caps_loss(y_true, y_pred, x, x_recon, lam_recon):
    L = y_true * torch.clamp(0.9 - y_pred, min=0.) ** 2 + \
        0.5 * (1 - y_true) * torch.clamp(y_pred - 0.1, min=0.) ** 2
    L_margin = L.sum(dim=1).mean()

    L_recon = nn.MSELoss()(x_recon, x)

    return L_margin + lam_recon * L_recon


NameError: ignored

In [17]:
import torch
from torch.utils.data import TensorDataset,DataLoader
import numpy as np
import os
from scipy import stats

os.chdir(os.getcwd())
path = 'data/'
fieldname = '_19790101-20190228.npy'
x1_arr = np.load(path+'z1000'+fieldname) # geopotential height data (9*9 resolution)
x2_arr = np.load(path+'pv300'+fieldname) # potential vorticity data (9*9 resolution)
x3_arr = np.load(path+'z1000'+fieldname) # geopotential height data (9*9 resolution)

x1_arr_flat = stats.zscore(x1_arr.reshape([x1_arr.shape[0],x1_arr.shape[1]*x1_arr.shape[2]])) # normalize and flatten
x2_arr_flat = stats.zscore(x2_arr.reshape([x2_arr.shape[0],x2_arr.shape[1]*x2_arr.shape[2]]))
x3_arr_flat = stats.zscore(x3_arr.reshape([x3_arr.shape[0],x3_arr.shape[1]*x3_arr.shape[2]]))
y_arr = np.load(path+'rain_basin_19790101-20190228.npy') # rain data

tensor_x = torch.Tensor(np.concatenate([x1_arr_flat,x2_arr_flat,x3_arr_flat],axis=1)) # join z and pv data
tensor_y = torch.Tensor(y_arr)

forecast_dataset = TensorDataset(tensor_x,tensor_y) # creates a dataset based on tensors
forecast_dataset2 = []


for j in forecast_dataset:
  forecast_dataset2.append( ( j[0].reshape((3, 9, 9)) , j[1] ) ) # Rotem: ask about order of 162
training_ds, validation_ds = torch.utils.data.random_split(forecast_dataset2, [2195,1464])
training_dataloader = DataLoader(training_ds,batch_size=200,shuffle=True)
valid_dataloader = DataLoader(validation_ds,batch_size=200)

for batch_x,batch_y in training_dataloader:
  print(batch_x.shape, batch_y.shape)

torch.Size([200, 3, 9, 9]) torch.Size([200, 4])
torch.Size([200, 3, 9, 9]) torch.Size([200, 4])
torch.Size([200, 3, 9, 9]) torch.Size([200, 4])
torch.Size([200, 3, 9, 9]) torch.Size([200, 4])
torch.Size([200, 3, 9, 9]) torch.Size([200, 4])
torch.Size([200, 3, 9, 9]) torch.Size([200, 4])
torch.Size([200, 3, 9, 9]) torch.Size([200, 4])
torch.Size([200, 3, 9, 9]) torch.Size([200, 4])
torch.Size([200, 3, 9, 9]) torch.Size([200, 4])
torch.Size([200, 3, 9, 9]) torch.Size([200, 4])
torch.Size([195, 3, 9, 9]) torch.Size([195, 4])


In [21]:
  loss = caps_loss(batch_y, batch_y, batch_x,batch_x, args.lam_recon)

  loss

tensor(-7271.9624)