In [1]:
import numpy as np
import torch, torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import make_grid

from datetime import datetime
from torch.utils.tensorboard import SummaryWriter

import time
from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

import os

# !pip install einops
# !wget -O utils.py https://drive.google.com/uc?id=1O3QWBKrqA7s8nIGzhKMIz-YNK1-jzwml
# import utils
# from utils import plot_loss_and_accuracy

In [2]:
# Detect if we are in Google Colaboratory
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

from pathlib import Path
# Determine the locations of auxiliary libraries and datasets.
# `AUX_DATA_ROOT` is where 'notmnist.py', 'animation.py' and 'tiny-imagenet-2020.zip' are.
if IN_COLAB:
    google.colab.drive.mount("/content/drive")
    
    # Change this if you created the shortcut in a different location
    # AUX_DATA_ROOT = Path("/content/drive/My Drive/mds20_cohortney")
    
    # assert AUX_DATA_ROOT.is_dir(), "Have you forgot to 'Add a shortcut to Drive'?"
# else:
AUX_DATA_ROOT = Path(".")

Mounted at /content/drive


In [91]:
 def make_model():
    model = nn.Sequential(
        nn.Conv1d(in_channels=30, out_channels=50, kernel_size=3, padding=1), # [1, 50, 40]
        nn.BatchNorm1d(50),
        nn.ReLU(),
        nn.MaxPool1d(kernel_size=2),   # [1, 50, 20]

        nn.Flatten(),                  # [1, 1000]
        nn.Linear(1000, 500),          # [1, 500]
        nn.Dropout(0.3),
        nn.Linear(500, 10),            # [1, 10]
        nn.LogSoftmax(dim=1)
    )
    return model

In [71]:
def compute_accuracy(logits, y_true, device='cuda:0'):
    y_pred = torch.argmax(logits, dim=1)
    return (y_pred == y_true.to(device)).float().mean()

In [72]:
def accuracy(model, images, labels):
    logits = model.forward(images)
    y_pred = torch.argmax(logits, dim=1)
    return (y_pred == labels).float().mean()

In [73]:
def iterate_minibatches(inputs, targets, batchsize, shuffle=False):
    assert len(inputs) == len(targets)
    if shuffle:
        indices = np.random.permutation(len(inputs))
    for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
        if shuffle:
            excerpt = indices[start_idx:start_idx + batchsize]
        else:
            excerpt = slice(start_idx, start_idx + batchsize)
        yield inputs[excerpt], targets[excerpt], excerpt

In [98]:
def sinkhorn(P, r, c, lmbd=0.05, eps=1e-3):
    P = -P
    Q = torch.exp(- lmbd * P).T
    Q /= torch.sum(Q)
    K, N = Q.shape
    u = torch.zeros(K)
    res = []
    res.append(-torch.sum(torch.trace(Q.T @ P.T)))
    while True:
        Q_prev = Q.clone()
        u = torch.sum(Q, dim=1)
        Q *= (c / u).unsqueeze(1)
        Q *= (r / torch.sum(Q, dim=0)).unsqueeze(0)
        res.append(-torch.sum(torch.trace(Q.T @ P.T)))
        if torch.sum(((Q-Q_prev)/Q_prev)**2) <eps:
            break
    return Q.T, res

def opt_sk(model, epoch):
    PS = np.zeros((N, K))
    
    for batch_idx, (data, _, _selected) in enumerate(iterate_minibatches(X_train, y_train, batch_size, shuffle=True)):
        data = data.to(device) # [100, 30, 40]
        p = nn.functional.softmax(model(data), 1) # [100, 10]
        PS[_selected, :] = p.detach().cpu().numpy()
    
    r = torch.ones(K)/K
    c = torch.ones(N)/N
    selflabels = sinkhorn(torch.from_numpy(PS).T, r, c)

    return selflabels

In [99]:
def train(model, selflabels, num_epochs=2, lr=0.1, weight_decay=0, exp_name='my_network'):
    # writer = SummaryWriter(f'logs/{exp_name}')
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=7, gamma=0.1)

    train_loss = []
    val_loss = []
    train_accuracy = []
    val_accuracy = []

    for epoch in range(num_epochs):
        start_time = time.time()
        model.train(True)

        train_accuracy_batch = []

        for batch_idx, (inputs, _, indexes) in enumerate(iterate_minibatches(X_train, y_train, batch_size, shuffle=True)):
            inputs = inputs.float().to(device)
            niter = epoch * N // batch_size + batch_idx
            
            with torch.no_grad():
                selflabels = opt_sk(model, epoch)
                
            inputs = inputs.to(device)
            opt.zero_grad()

            outputs = model(inputs)
            # print (selflabels)
            print (outputs.shape, selflabels)
            
            loss = nn.CrossEntropyLoss(outputs, selflabels[indexes])
           
            loss.backward()
            opt.step()
                                                    
        train_accuracy_overall = np.mean(train_accuracy_batch) * 100 # mean train accuracy over 1 epoch in %
        train_accuracy.append(train_accuracy_overall.item())

    return train_accuracy_overall, 0

In [100]:
N = 1000
K = 10

X_train = torch.Tensor(N, 30, 40) # [samples, n_steps, features]
selflabels = torch.randint(K, (N,))
y_train = torch.zeros(N) # dummy var

X_train.shape, y_train.shape, selflabels.shape, torch.unique(selflabels)

(torch.Size([1000, 30, 40]),
 torch.Size([1000]),
 torch.Size([1000]),
 tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))

In [101]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device

'cuda:0'

In [102]:
model = make_model()
model = model.to(device)

In [103]:
train_accuracy = -1
val_accuracy = -1

DO_TRAIN = True
if DO_TRAIN:
    %%time
    lr=0.01
    exp_name = datetime.now().isoformat(timespec='seconds') + f'exp{lr}'
    batch_size = 100
    train_accuracy, val_accuracy = train(model, selflabels, num_epochs=2, weight_decay=0, lr=lr, exp_name=exp_name)

    torch.save(model.state_dict(), "checkpoint.pth")
    print (train_accuracy, val_accuracy)

CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 4.53 µs
torch.Size([100, 10]) (tensor([[9.9980e-05, 9.9983e-05, 9.9979e-05,  ..., 9.9979e-05, 9.9983e-05,
         9.9977e-05],
        [1.0010e-04, 1.0009e-04, 1.0009e-04,  ..., 1.0010e-04, 1.0009e-04,
         1.0009e-04],
        [9.9943e-05, 9.9943e-05, 9.9943e-05,  ..., 9.9938e-05, 9.9942e-05,
         9.9937e-05],
        ...,
        [1.0004e-04, 1.0003e-04, 1.0005e-04,  ..., 1.0003e-04, 1.0004e-04,
         1.0004e-04],
        [1.0002e-04, 1.0004e-04, 1.0003e-04,  ..., 1.0003e-04, 1.0003e-04,
         1.0003e-04],
        [1.0002e-04, 1.0001e-04, 1.0001e-04,  ..., 1.0001e-04, 1.0001e-04,
         1.0001e-04]], dtype=torch.float64), [tensor(0.1000, dtype=torch.float64), tensor(0.1000, dtype=torch.float64), tensor(0.1000, dtype=torch.float64)])


TypeError: ignored