In [1]:
%matplotlib inline

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from IPython.display import clear_output
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

from simple_movement import DataGen, SimpleMovement

In [2]:
class MMC(nn.Module):
    
    def __init__(self, beta=5, num_iter=20):
        super(MMC, self).__init__()
        self.df = 1. / beta
        self.num_iter = num_iter
        self.weights = nn.Parameter(
            data= torch.tensor([[1, self.df, 0], [-1, 0, 1], [0, 0, 1]]),
            requires_grad=False
        ).unsqueeze(0).expand(9, 3, 3)
            
    def forward(self, x):
#         iteration = 0
#         while iteration < self.num_iter:
#             for i in range(x.shape[0]):
#                 x[i] = torch.matmul(self.weights, x[i])
#             iteration += 1
        # x = torch.acos(x)
        # Can not set mmc target, alpha, alpha goals were angles, for image x, y coordinates
        x = torch.bmm(torch.pow(self.weights, self.num_iter), x.unsqueeze(2)).squeeze(2)
        return x

In [3]:
class Encoder(nn.Module):
    
    def __init__(self, num_channel):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=num_channel, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 8, 3, padding=1)
        self.conv3 = nn.Conv2d(8, 8, 3, padding=1)
        self.conv4 = nn.Conv2d(8, 8, 2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(200, 128)
        self.fc2 = nn.Linear(128, 3)
        # self.mmc = MMC(5, 20)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, 5*5*8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        print('before mmc: ', x.grad_fn, self.fc2.weight.grad_fn)
        # x = self.mmc(x)
        print('after mmc: ', x.grad_fn, self.fc2.weight.grad_fn)
        return x

In [5]:
# for p in Encoder(1).mmc.parameters():
#     print(p)

In [6]:
def get_model():
    
    net = Encoder(num_channel=1)
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    
    return net, optimizer

In [7]:
def to_tensor(x):
    return torch.from_numpy(x).float()

def get_data(height, width, num_channel, path, size, bs=9):
    
    dgen = DataGen(height, width, num_channel)
    x, y = dgen.get_feature_target_pairs(path, size)
    print('Data loaded...\nx:{}\ty:{}\n'.format(x.shape, y.shape))
    
    x = x / 255.
    x = x[:90]
    y = y[:90]
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, shuffle=True, random_state=331)
    x_train, x_test, y_train, y_test = map(to_tensor, (x_train, x_test, y_train, y_test))
    
    train_ds = TensorDataset(x_train, y_train)
    train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

    test_ds = TensorDataset(x_test, y_test)
    test_dl = DataLoader(test_ds, batch_size=bs)
    
    return train_dl, test_dl

In [8]:
def fit(net, optimizer, train_dl, test_dl, epochs=100):
    
    loss_function = nn.MSELoss()
    for epoch in range(epochs):
        net.train()
        for x, y in train_dl:
            pred = net(x)
            loss = loss_function(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        net.eval()
        with torch.no_grad():
            test_loss = sum(loss_function(net(x), y) for x, y in test_dl)
        print('epoch:{}, test_loss:{}'.format(epoch+1, test_loss/len(test_dl)))

In [9]:
net, optimizer = get_model()
train_dl, test_dl = get_data(100, 100, 1, './data/data_simple_movement_1/', 8)
fit(net, optimizer, train_dl, test_dl, epochs=200)

Data loaded...
x:(96, 1, 100, 100)	y:(96, 3)

before mmc:  <AddmmBackward object at 0x000001E29E9F06A0> None
after mmc:  <AddmmBackward object at 0x000001E29E9F06A0> None
before mmc:  <AddmmBackward object at 0x000001E29E916668> None
after mmc:  <AddmmBackward object at 0x000001E29E916668> None
before mmc:  <AddmmBackward object at 0x000001E29E916668> None
after mmc:  <AddmmBackward object at 0x000001E29E916668> None
before mmc:  <AddmmBackward object at 0x000001E29E9F07B8> None
after mmc:  <AddmmBackward object at 0x000001E29E9F07B8> None
before mmc:  <AddmmBackward object at 0x000001E29E9F07B8> None
after mmc:  <AddmmBackward object at 0x000001E29E9F07B8> None
before mmc:  <AddmmBackward object at 0x000001E29E916668> None
after mmc:  <AddmmBackward object at 0x000001E29E916668> None
before mmc:  <AddmmBackward object at 0x000001E29E916668> None
after mmc:  <AddmmBackward object at 0x000001E29E916668> None
before mmc:  <AddmmBackward object at 0x000001E29E916668> None
after mmc:  <Add

KeyboardInterrupt: 