In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

In [2]:
batch_size = 128
steps = 200*1000
# tensorboard_writer = SummaryWriter()
T_0 =3
T_mult=2
eta_min = 1e-6
epsilon = 0.01
inv_base_matrix = [[[ 2., -2.,  1.],[-1.,  3., -2.],[ 0., -1.,  1.]]]

In [3]:


class BaseModel(nn.Module):
    def __init__(self):
        super(BaseModel, self).__init__()
        self.fc1 = nn.Linear(9,72)
        self.fc3 = nn.Linear(72, 9)
        
    def forward(self, x):
        x = x.view(-1, 9)   # reshape Variable
        x = F.relu(self.fc1(x))
#         x = F.dropout(x, 0.1)
        x = self.fc3(x)
        return x
    
model = BaseModel()
model = model.to(torch.double)
model = model.to('cuda') 
model.train()
model

BaseModel(
  (fc1): Linear(in_features=9, out_features=72, bias=True)
  (fc3): Linear(in_features=72, out_features=9, bias=True)
)

In [4]:
class CustomDataset(Dataset):

    def __init__(self, root_dir):
        self.dataset = np.load(root_dir)
        print('number of data points', self.dataset.shape[0])

    def __len__(self):
        return self.dataset.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        x = self.dataset[idx, :,:,0]
        y = self.dataset[idx, :,:,1]
        return x,y

In [5]:
train_set = CustomDataset('train_set.npy')
val_set = CustomDataset('val_set.npy')
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size=batch_size)

number of data points 100000
number of data points 10000


In [6]:
optimizer = optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-7)
scheduler = CosineAnnealingWarmRestarts(optimizer,T_0,T_mult,eta_min)

In [7]:

train_accu = []
i = 1
for epoch in range(steps//len(train_loader)):
    for data, target in train_loader:
#         target = target - torch.tensor(inv_base_matrix)
#         print(target)
#         target = target/epsilon
        data, target = Variable(data), Variable(target)
        data = data.to('cuda')
        target = target.to('cuda')
        optimizer.zero_grad()
        output = model(data)
        loss = F.mse_loss(output, target.view(-1,9))
        loss.backward()
        
        mse_loss = loss.item()
        optimizer.step()
        
#         if i % 10 == 0:
#             tensorboard_writer.add_scalar("Loss/step", loss, i)
        if i % 100 == 0:
            print('\rTrain Step: %d, Loss: %.4f, lr: %.8f'%(i, mse_loss, scheduler.get_lr()[0]), end="")
        i += 1
    scheduler.step()
    print('\n')



Train Step: 700, Loss: 1.7295, lr: 0.00005000

Train Step: 1500, Loss: 0.9072, lr: 0.00003775

Train Step: 2300, Loss: 0.6374, lr: 0.00001325

Train Step: 3100, Loss: 0.2161, lr: 0.00005000

Train Step: 3900, Loss: 0.0765, lr: 0.00004672

Train Step: 4600, Loss: 0.0370, lr: 0.00003775

Train Step: 5400, Loss: 0.0330, lr: 0.00002550

Train Step: 6200, Loss: 0.0308, lr: 0.00001325

Train Step: 7000, Loss: 0.0269, lr: 0.00000428

Train Step: 7800, Loss: 0.0192, lr: 0.00005000

Train Step: 8600, Loss: 0.0201, lr: 0.00004917

Train Step: 9300, Loss: 0.0147, lr: 0.00004672

Train Step: 10100, Loss: 0.0102, lr: 0.00004282

Train Step: 10900, Loss: 0.0085, lr: 0.00003775

Train Step: 11700, Loss: 0.0061, lr: 0.00003184

Train Step: 12500, Loss: 0.0051, lr: 0.00002550

Train Step: 13200, Loss: 0.0038, lr: 0.00001916

Train Step: 14000, Loss: 0.0042, lr: 0.00001325

Train Step: 14800, Loss: 0.0038, lr: 0.00000818

Train Step: 15600, Loss: 0.0032, lr: 0.00000428

Train Step: 16400, Loss: 0.0027, 

In [8]:
torch.save(model.state_dict(), 'direct_3.pth')

In [9]:
from tqdm import tqdm
train_accu = 0
i = 1
model = model.eval()
total_error = 0.0
total_number = 0
for data, target in tqdm(val_loader):
    target = target - torch.tensor(inv_base_matrix)
#     target = target/epsilon
    data, target = Variable(data), Variable(target)
    data = data.to('cuda')
#     target = target.to('cuda')
    output = model(data)
    output = output.detach().to('cpu') - torch.tensor(inv_base_matrix).view(-1,9)
#     output = output/epsilon
    total_error += torch.sum(torch.abs(output[:,:] - target.view(-1,9)[:,:]))
    total_number += output.shape[0]*output.shape[1]
#     print(total_error/total_number)
#     print(output[:5,:])
#     print(target.view(-1,4)[:5,:])
#     break

print(total_error.numpy()/total_number)

100%|██████████| 79/79 [00:00<00:00, 559.61it/s]

9.856039268487429e-05





####

model 1 test error: 8.96243563785012e-05
model 2 test error: 7.503371569420155e-05
model 3 test error: 9.856039268487429e-05

In [10]:
temp = [8.96243563785012e-05,7.503371569420155e-05, 9.856039268487429e-05]
print(np.mean(temp))
print(np.var(temp))

8.773948825252567e-05
9.402711896340092e-11


In [12]:
print(target.shape)

torch.Size([16, 3, 3])


In [16]:
print(output[:5,:])
print(target.view(-1,9)[:5,:])

tensor([[ 0.0404, -0.0266,  0.0062, -0.0268, -0.0193,  0.0274,  0.0051,  0.0159,
         -0.0153],
        [ 0.0962, -0.1912,  0.1210, -0.1133,  0.2512, -0.1669,  0.0342, -0.0854,
          0.0588],
        [ 0.0088, -0.0102,  0.0102, -0.0217,  0.0052, -0.0009,  0.0118, -0.0003,
         -0.0025],
        [-0.0076, -0.0255,  0.0227, -0.0051,  0.0137, -0.0075, -0.0006,  0.0065,
         -0.0057],
        [-0.0153, -0.0034,  0.0076,  0.0070,  0.0320, -0.0300, -0.0003, -0.0134,
          0.0123]], dtype=torch.float64)
tensor([[ 0.0404, -0.0266,  0.0062, -0.0268, -0.0194,  0.0276,  0.0052,  0.0160,
         -0.0155],
        [ 0.0958, -0.1915,  0.1213, -0.1131,  0.2515, -0.1670,  0.0342, -0.0852,
          0.0586],
        [ 0.0086, -0.0099,  0.0101, -0.0217,  0.0050, -0.0008,  0.0118, -0.0003,
         -0.0026],
        [-0.0076, -0.0255,  0.0227, -0.0051,  0.0137, -0.0075, -0.0005,  0.0066,
         -0.0057],
        [-0.0154, -0.0034,  0.0076,  0.0070,  0.0319, -0.0300, -0.0003, -0.013