In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np


import matplotlib.pyplot as plt
import torchvision

device = torch.device('cuda:0')

transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])


trainset = torchvision.datasets.MNIST(root='./data/', train=True,  download=False, transform=transforms)
train_loader = torch.utils.data.DataLoader(trainset, drop_last=True, batch_size=2, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data/', train=False,  download=False, transform=transforms)
test_loader = torch.utils.data.DataLoader(trainset, drop_last=True, batch_size=2, shuffle=False)

class ReshapeLayer(nn.Module):
    
    def __init__(self, shape):
        super(ReshapeLayer, self).__init__()
        self.shape = shape
        
    def forward(self, x):
        return x.reshape(self.shape)

In [None]:
latent_space_size = 10

    
transform = nn.Sequential(
                        nn.Conv2d(1, 64, 5), nn.BatchNorm2d(64), nn.ReLU(),
                        nn.Conv2d(64,64, 5), nn.BatchNorm2d(64), nn.ReLU(),
                        nn.Conv2d(64,64, 5), nn.BatchNorm2d(64), nn.ReLU(),
                        nn.Conv2d(64,64, 5), nn.BatchNorm2d(64), nn.ReLU(),
                        nn.Flatten(), nn.Linear(64 * (28 - 4*4)**2, latent_space_size),
                        nn.Linear(latent_space_size, 64 * (28 - 4*4)**2), nn.ReLU(), ReshapeLayer([-1, 64,(28 - 4*4),(28 - 4*4)]),
                        nn.ConvTranspose2d(64,64,5), nn.BatchNorm2d(64), nn.ReLU(),
                        nn.ConvTranspose2d(64,64,5), nn.BatchNorm2d(64), nn.ReLU(),
                        nn.ConvTranspose2d(64,64,5), nn.BatchNorm2d(64), nn.ReLU(),
                        nn.ConvTranspose2d(64, 1,5)).to(device)


optim = torch.optim.Adam(transform.parameters())

for epoch in range(2):
    for it, (x, _) in enumerate(train_loader):
        x = x.to(device)
        prediction = transform(x)
        loss = F.mse_loss(prediction,torch.flip(x, dims=[0]))
        loss.backward()
        
        optim.step()
        optim.zero_grad()
        
        if not it % 100:
            print(loss)
         
        if not it % 1000:
            fig, axis = plt.subplots(ncols=4)
            axis[0].imshow(x[0,0].detach().cpu().numpy())
            axis[0].set_title('input 1')
            axis[1].imshow(prediction[0,0].detach().cpu().numpy())
            axis[1].set_title('transformed 1')
            axis[2].imshow(x[1,0].detach().cpu().numpy())
            axis[2].set_title('input 2')
            axis[3].imshow(prediction[1,0].detach().cpu().numpy())
            axis[3].set_title('transformed 2')

            plt.show()



In [None]:

for it, (x, _) in enumerate(train_loader):
    transform.train(True)

    x = x.to(device)
    train_prediction = transform(x)[:,0].detach().cpu().numpy()

    transform.train(False)
    test_prediction = transform(x)[:,0].detach().cpu().numpy()
    
    x = x.detach().cpu().numpy()[:,0]
    
    fig, axis = plt.subplots(nrows=2, ncols=3, figsize=(10,5))
    axis[0,0].imshow(x[0])
    axis[0,0].set_xticks(ticks=[])
    axis[0,0].set_yticks(ticks=[])
    axis[0,0].set_ylabel('Element 1')
    axis[0,0].set_title('Input')

    
    axis[0,1].imshow(train_prediction[0])
    axis[0,1].set_xticks(ticks=[])
    axis[0,1].set_yticks(ticks=[])
    axis[0,1].set_title('Reconstruction in training mode')

    
    axis[1,0].imshow(x[1])
    axis[1,0].set_xticks(ticks=[])
    axis[1,0].set_yticks(ticks=[])
    axis[1,0].set_ylabel('Element 2')

    
    axis[1,1].imshow(train_prediction[1])
    axis[1,1].set_xticks(ticks=[])
    axis[1,1].set_yticks(ticks=[])
    
    axis[0,2].imshow(test_prediction[0])
    axis[0,2].set_xticks(ticks=[])
    axis[0,2].set_yticks(ticks=[])
    axis[0,2].set_title('Reconstruction in test mode')
    
    axis[1,2].imshow(test_prediction[1])
    axis[1,2].set_xticks(ticks=[])
    axis[1,2].set_yticks(ticks=[])

    plt.savefig('example'+str(it)+'.svg', bbox_inches='tight')
    plt.show()
    
    if it == 3:
        break

In [None]:
batches = 0
sum_train_losses = 0.0
sum_test_losses = 0.0

with torch.no_grad():
    for it, (x, _) in enumerate(test_loader):
        transform.train(True)

        x = x.to(device)
        train_prediction = transform(x)

        transform.train(False)
        test_prediction = transform(x)

        batches +=1
        
        sum_train_losses += F.mse_loss(train_prediction, torch.flip(x, dims=[0])).detach().cpu().numpy()
        sum_test_losses += F.mse_loss(test_prediction, torch.flip(x, dims=[0])).detach().cpu().numpy()
        
        if not it % 100
        
print('Train', sum_train_losses/batches , 'test', sum_test_losses / batches)

    

In [None]:
x = np.linspace(0, 4*np.pi, 100)
fx = np.sin(x) + 2 * np.cos(0.5*x)

plt.plot(x, fx)
plt.savefig('function.svg')
plt.show()

In [None]:


class SimpleContrast(torch.utils.data.Dataset):
    
    def __init__(self, resolution, sequence_length):
        
        x = np.linspace(0, 4*np.pi, resolution)
        self.data = np.sin(x) + 2 * np.cos(0.5*x)
        

        self.sequence_length = sequence_length
        
        self._len = resolution // sequence_length -1

    def __len__(self):
        return self._len

    def __getitem__(self, idx):
        start = idx * self.sequence_length
        
        input_length = self.sequence_length - 1
        elements = []
        targets = []
        for i in range(self.batch_size-1):
            elements.append(self.data[start+i:start+i + input_length])
            targets.append(self.data[start+i + input_length])
            
        return torch.from_numpy(np.stack(elements, axis=0)[:,None,:]).type(torch.float32), torch.tensor(targets).type(torch.float32)
    

In [None]:
results = []
steps= 1000.0
for sequence_length in range(4,21):
    ds = SimpleContrast(1000, sequence_length)
    for batch_size in range(1,7):
        print(sequence_length, batch_size)
        print('Training')
        train_loader = torch.utils.data.DataLoader(ds, drop_last=True, batch_size=batch_size, shuffle=True)

        model = nn.Sequential(
            nn.Conv1d(1 , 32, 3, 1, 1), nn.BatchNorm1d(32), nn.ReLU(),
            nn.Conv1d(32, 32, 3, 1, 1), nn.BatchNorm1d(32), nn.ReLU(),
            nn.Conv1d(32, 32, 3, 1, 1), nn.BatchNorm1d(32), nn.ReLU(),
            nn.Conv1d(32, 32, 3, 1, 1), nn.BatchNorm1d(32), nn.ReLU(),
            nn.Conv1d(32, 32, 3, 1, 1), nn.BatchNorm1d(32), nn.ReLU(),
            nn.Conv1d(32, 32, 3, 1, 1), nn.BatchNorm1d(32), nn.ReLU(),
            nn.Conv1d(32, 1, sequence_length - 1), nn.Flatten()).to(device)

        optim = torch.optim.Adam(model.parameters())


        for epoch in range(int((steps/(len(ds)/batch_size)) + 0.9999)):
            for it, (x,y) in enumerate(train_loader):
                x = x.reshape([-1, 1, sequence_length - 1]).to(device)
                y = y.reshape([-1, 1]).to(device)

                loss = F.mse_loss(model(x),y)
                loss.backward()

                optim.step()
                optim.zero_grad()
                
                
        test_loader = torch.utils.data.DataLoader(ds, drop_last=True, batch_size=1, shuffle=False)
        
        sum_train_loss = 0.0
        sum_test_loss = 0.0

        batches = 0.0
        
        print('testing')
        for it, (x,y) in enumerate(train_loader):

            x = x.reshape([-1, 1, sequence_length - 1]).to(device)
            y = y.reshape([-1, 1]).to(device)

            model.train(True)
            sum_train_loss += F.mse_loss(model(x),y).detach().cpu().numpy()

            model.train(False)
            sum_test_loss += F.mse_loss(model(x),y).detach().cpu().numpy()

            batches += 1

        results.append((sequence_length, batch_size, sum_train_loss/ batches, sum_test_loss/batches))
        print(results[-1])
results

In [None]:


a_cheat = [t[2] for t in results if t[2] <  t[3]]
b_cheat = [t[3] for t in results if t[2] < t[3]]


a_fair = [t[2] for t in results if t[2] >= t[3]]
b_fair = [t[3] for t in results if t[2] >= t[3]]

plt.scatter(a_cheat,b_cheat, label='cheating')
plt.scatter(a_fair, b_fair, label='fair')
plt.plot([0, 2], [0,2])
plt.ylim([0,1])
plt.xlim([0,1])
plt.ylabel('test mode performance')
plt.xlabel('train mode performance')
plt.legend()
plt.show()

In [None]:
lengths_cheat = [t[0] for t in results if t[2] < t[3]]
batch_size_cheat = [t[1] for t in results if t[2] < t[3]]


lengths_fair = [t[0] for t in results if t[2] >= t[3]]
batch_size_fair = [t[1] for t in results if t[2] >= t[3]]


plt.scatter(lengths_cheat, batch_size_cheat, label='cheat')
plt.scatter(lengths_fair, batch_size_fair, label='fair')
plt.plot([3.5,20.5], [1,6], color='black')
plt.legend()
plt.ylabel('batch_size')
plt.xlabel('sequence length')
plt.savefig('cheatfair.svg')
plt.show()