In [1]:
import torch
import torchvision
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from torchvision.datasets import CIFAR10
import torchvision.transforms as T
from torchvision.transforms.functional import rotate
import matplotlib.pyplot as plt
import torch.optim as optim
import torchvision.models as models
import torch.nn as nn

In [2]:
B = 64
NUM_TRAIN = 50000
NUM_SAMPLES = 100
EPS = 1e-5
NUM_EPOCHS = 20
SIGMA = 18
FORMAT = {'device': torch.device('cuda:0'), 'dtype': torch.float}

In [3]:
angles = torch.distributions.uniform.Uniform(0, 360).sample((NUM_TRAIN,)).to(**FORMAT)
noised_angles = angles + torch.normal(0, SIGMA, (NUM_TRAIN,), **FORMAT)
predicate = (noised_angles >= 0) & (noised_angles <= 180)
real_labels = angles[predicate]
labels = noised_angles[predicate]
print(labels.shape)

torch.Size([23984])


In [4]:
transform = T.Compose([T.ToTensor(),
                       T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
dset = CIFAR10(root='.', download=True, train=True, transform=transform)
data = torch.tensor(dset.data, **FORMAT).permute(0, 3, 1, 2).div_(255)
images = torch.zeros_like(data[predicate])
for (i, (img, a)) in enumerate(zip(data[predicate], real_labels)):
  images[i] = rotate(img, a.item())
print(images.shape)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./cifar-10-python.tar.gz to .

torch.Size([23984, 3, 32, 32])


In [5]:
loader = DataLoader(list(zip(images, labels)), batch_size=B, sampler=sampler.SubsetRandomSampler(range(labels.shape[0])))

In [6]:
class Model(nn.Module):
  def __init__(self):
    super().__init__()
    self.resnet50 = models.resnet50(pretrained=True).to(**FORMAT)
    self.resnet50.fc = nn.Linear(self.resnet50.fc.in_features, 1)
  def forward(self, x):
    return self.resnet50(x)

In [7]:
class TruncatedMSE(torch.autograd.Function):
  @staticmethod
  def forward(ctx, pred, targ):
    ctx.save_for_backward(pred, targ)
    return 0.5 * (pred.float() - targ.float()).pow(2).mean()

  @staticmethod
  def backward(ctx, grad_output):
    pred, targ = ctx.saved_tensors
    # Make args . num_samples copies of pred , N x B x 1
    stacked = pred[None, ...].repeat(NUM_SAMPLES, 1, 1)
    # Add random noise to each copy
    noised = stacked + torch.randn_like(stacked)
    #noised = stacked + torch.normal(0, SIGMA, stacked.shape, **FORMAT)
    # Filter out the copies where pred is in bounds
    filtered = (noised >= 0) * (noised <= 180)
    # Average over truncated indices
    out = (noised * filtered).sum(dim=0) / (filtered.sum(dim=0) + EPS)
    grad = torch.where(out > 0, out, targ) - targ
    return grad / pred.shape[0], (targ - pred) / pred.shape[0]

In [8]:
model = Model().to(**FORMAT)
criterion = TruncatedMSE()
#criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=5e-6, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, 8)

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth


HBox(children=(FloatProgress(value=0.0, max=102502400.0), HTML(value='')))




In [9]:
for epoch in range(NUM_EPOCHS):
  running_loss = 0.0
  for t, (x, y) in enumerate(loader):
    optimizer.zero_grad()
    loss = criterion.apply(model(x), y.reshape(-1, 1))
    loss.backward()
    optimizer.step()

    # print statistics
    running_loss += loss.item()
    if t % 100 == 99:    # print every 2000 mini-batches
        print('[%d, %5d] loss: %.3f' % (epoch + 1, t + 1, running_loss / 100))
        running_loss = 0.0
  model.eval()
  loss = 0
  loader_test = DataLoader(list(zip(images, labels)), batch_size=5000)
  with torch.no_grad():
    for t, (img, a) in enumerate(loader_test):
      loss += (model(img) - a.reshape(-1, 1)).pow(2).sum()
  loss /= images.shape[0]
  print('Real loss: %f' % loss)
  model.train()
    
  scheduler.step()

[1,   100] loss: 2353.261
[1,   200] loss: 923.523
[1,   300] loss: 863.486


KeyboardInterrupt: ignored