In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from tqdm import tqdm
from typing import Union

In [2]:
hyperparams = {
        'n_tasks': 1000,
        'n_episode': 150,
        'n_timesteps': 10,
        'input_dim': 4, 
        'hidden_dim': 512,
        'output_dim': 2,
        'lr': 1e-4,
        'batch_size': 128,
        'n_epochs': 100,
        }

# Data Simulation

In [3]:
class TwoArmLinkDataset(Dataset):
	# `set_seed=None` means no seed fixed.  
	def __init__(self, n_tasks=hyperparams['n_tasks'], n_episode=hyperparams['n_episode'], 
				n_timesteps=hyperparams['n_timesteps'], set_seed: Union[int, None]=None):
		super().__init__()
		self.x_0 = []
		self.u = []
		self.x = []
		self.z = []
		if set_seed is not None:
			np.random.seed(set_seed)
		l_of_tasks = np.random.normal(loc=1, scale=0.3, size=(n_tasks,2))
		for i in tqdm(range(n_tasks)):
			l = l_of_tasks[i]
			for j in range(n_episode):
				x_0, u, x, z = TwoArmLinkDataset.generate_episode(l, n_timesteps)
				self.x_0.append(x_0)
				self.u.append(u)
				self.x.append(x)
				self.z.append(z)
		print("Two Arm Link Dataset Generation Finished!")
		
	def __getitem__(self, index):
		x_0 = self.x_0[index]
		u = self.u[index]
		x = self.x[index]
		z = self.z[index]
		sample = {
			"init_pos": torch.tensor(x_0),
			"control": torch.tensor(u),
			"true_pos": torch.tensor(x),
			"noisy_pos": torch.tensor(z),
		}
		return sample

	def __len__(self):
		return len(self.x_0)

	@staticmethod
	def generate_episode(l, n_timesteps):
		def x_of_q(q, l):
			# q: array (2,) containing the two angles
			# l: array (2,) containing the limb lengths 
			x = l[0] * np.cos(q[0]) + l[1] * np.cos(q[0] + q[1])
			y = l[0] * np.sin(q[0]) + l[1] * np.sin(q[0] + q[1])
			return np.array([x, y])
		q_0 = np.random.uniform(low= -np.pi, high= np.pi, size=2)
		x_0 = x_of_q(q_0, l=l)
		u = np.random.randn(2, n_timesteps)
		q = u.cumsum(axis=1) + q_0.reshape(-1,1)
		x = x_of_q(q, l=l)
		z = x + np.random.normal(0, scale=0.001, size=(2, n_timesteps))
		return x_0, u, x, z

In [4]:
def get_loader(batch_size=hyperparams['batch_size'], n_tasks=hyperparams['n_tasks'], set_seed=0):
    dataset = TwoArmLinkDataset(n_tasks=n_tasks, set_seed=set_seed)
    data_loader = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            pin_memory=True,
                            drop_last=True,
                            )
    return data_loader

# Build model

In [5]:
class DropoutLayer(nn.Module):
    def __init__(self, hidden_dim=hyperparams['hidden_dim']):
        super(DropoutLayer, self).__init__()
        self.mask = torch.ones(hidden_dim, dtype=torch.float32, requires_grad=False)
        self.training = False

    def forward(self, x):
        return x * self.mask

    def update(self, mask: np.ndarray):
        assert mask.shape == self.mask.shape, f"new mask shape should be {self.mask.shape} but giving {mask.shape}"
        self.mask = torch.Tensor(mask)
        
    def get(self):
        return self.mask.numpy()

In [6]:
class MyModel(nn.Module):
    def __init__(self, input_dim=hyperparams['input_dim'], hidden_dim=hyperparams['hidden_dim'], 
                output_dim=hyperparams['output_dim']):
        super().__init__()
        self.lin1 = nn.Linear(in_features=input_dim, out_features=hidden_dim)
        self.lin2 = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
        self.lin3 = nn.Linear(in_features=hidden_dim, out_features=output_dim)
        self.relu = nn.ReLU()
        # dropout for pretraining and finetunning
        self.dropout_for_GD = nn.Dropout(p=0.5)
        # dropout for smc
        self.dropout = DropoutLayer(hidden_dim=hidden_dim)

    def forward(self, x):
        x = self.lin1(x)
        x = self.relu(x)
        x = self.lin2(x)
        x = self.relu(x)
        if self.training:
            x = self.dropout_for_GD(x)
        else:
            x = self.dropout(x)
        x = self.lin3(x)
        return x
        
    def update_dropout_mask(self, mask: np.ndarray):
        self.dropout.update(mask)

    def get_dropout_mask(self):
        return self.dropout.get()

# Multi-Task Training

In [10]:
train_loader = get_loader(n_tasks=1000, set_seed=0)


100%|██████████| 1000/1000 [00:06<00:00, 154.75it/s]

Two Arm Link Dataset Generation Finished!





In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MyModel().to(device)
model = model.double()
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams['lr'])
criterion = nn.MSELoss()

In [11]:
def train(data_loader=train_loader, nb_epochs=hyperparams['n_epochs']):
    total_loss = []
    for epoch in range(1, nb_epochs + 1): 
    	# store the loss in each batch
        losses = []
        # use tqdm to better visualize the training process
        with tqdm(data_loader, unit="batch") as tepoch: # evaluate every batch
            for data in tepoch:
                tepoch.set_description(f"Epoch {epoch}") # custome the printed message in tqdm
                optimizer.zero_grad()
                control = data['control'].to(device)
                true_pos = data['true_pos'].to(device)
                loss = 0
                output = data['init_pos'] # initialisation
                for i in range(hyperparams['n_timesteps']):
                    input_ = torch.cat((output, control[:,:,i]), dim=1)
                    output = model.forward(input_)
                    loss += criterion(output, true_pos[:,:,i])
                loss /= hyperparams['n_timesteps']
                loss.backward() # backward propagation
                optimizer.step() # update parameters
                losses.append(loss.item())
                # custome what is printed in tqdm message
                tepoch.set_postfix(loss=sum(losses)/len(losses))
        total_loss.append(sum(losses)/len(losses))
    return total_loss

In [14]:
total_loss = train() # return a list of loss in different epochs

Epoch 1: 100%|██████████| 1171/1171 [01:25<00:00, 13.72batch/s, loss=0.799]
Epoch 2:  58%|█████▊    | 684/1171 [00:50<00:36, 13.48batch/s, loss=0.737]


KeyboardInterrupt: 

In [None]:
plt.plot(total_loss)

In [None]:
# save parameters and optimizer
torch.save({
    'state_dict': model.state_dict(),
    'loss': total_loss,
    'optimizer' : optimizer.state_dict(),
}, 'Task1_multiTask.pth.tar')

# Load model

In [None]:
model = MyModel().to(device)
checkpoint = torch.load('Task1_multiTask.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
model.eval()

MyModel(
  (lin1): Linear(in_features=4, out_features=512, bias=True)
  (lin2): Linear(in_features=512, out_features=512, bias=True)
  (lin3): Linear(in_features=512, out_features=2, bias=True)
  (relu): ReLU()
  (dropout_for_GD): Dropout(p=0.5, inplace=False)
  (dropout): DropoutLayer()
)