In [1]:
import torch
import numpy as np
import seaborn
import yaml
import sklearn
#import devinterp
import torch.nn as nn
import torchvision
import pandas as pd
import matplotlib.pyplot as plt
import copy
from torch import optim
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.decomposition import PCA

In [2]:
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

else:
    mps_device = torch.device("mps")

device = 'mps' # 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# Let's define all the hyperparameters
use_adam_op = True
augmented = True
use_label_noise = True
pytorch_default_resnet = False
on_colab = False
model_width = 64
num_classes = 10
noise_levels = [0.10, 0.15, 0.20]
batch_size = 500
lr = 0.0001
epochs = 2000
model_seed = 42
data_seed = 42

In [4]:
## ResNet18 for CIFAR
## Based on: https://gitlab.com/harvard-machine-learning/double-descent/-/blob/master/models/resnet18k.py

class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, **kwargs):
        super(PreActBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                          nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = self.bn1(x)
        out = self.relu(out)
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out += shortcut
        return out

class PreActResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes, init_channels):
        super(PreActResNet, self).__init__()
        self.in_planes = init_channels
        c = init_channels

        self.conv1 = nn.Conv2d(3, c, kernel_size=7,
                               stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(c)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        self.layer1 = self._make_layer(block, c, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 2*c, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 4*c, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 8*c, num_blocks[3], stride=2)
        self.avpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.linear = nn.Linear(8*c*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        # eg: [2, 1, 1, ..., 1]. Only the first one downsamples.
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def make_resnet18k(k, num_classes) -> PreActResNet:
    ''' Returns a ResNet18 with width parameter k. (k=64 is standard ResNet18)'''
    return PreActResNet(PreActBlock, [2, 2, 2, 2], num_classes=num_classes, init_channels=k)

In [5]:
# Let's import the CIFAR10 dataset from torchvision
transform = transforms.Compose([transforms.ToTensor()]) if not augmented else transforms.Compose([transforms.ToTensor(), transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

if data_seed is not None:
    torch.manual_seed(data_seed)

train_set = datasets.CIFAR10(root='./data',
                             train=True,
                             download=True,
                             transform=transform)
trainloader = DataLoader(train_set,
                         batch_size=batch_size,
                         shuffle=True
                         )

test_set = datasets.CIFAR10(root='./data',
                            train=False,
                            download=True,
                            transform=transform)
testloader = DataLoader(test_set,
                        shuffle=False,
                        batch_size=batch_size
                        )

Files already downloaded and verified
Files already downloaded and verified


In [6]:
checkpoints_path = '/Users/sienkadounia/lab/ai-futures/Project/ewdd/'
label_noise_path = '/Users/sienkadounia/lab/ai-futures/Project/label_noise/'
rlcts_path = '/Users/sienkadounia/lab/checkpoints/rlcts/ewdd/'

In [12]:
pca_logits = []
to_save = list(range(0, 135)) + list(range(130, 1000, 10))+ list(range(1000, 2000, 100)) + [1999]

for point in to_save:
  logits = []
  model = make_resnet18k(model_width, num_classes)
  checkpoint= torch.load('ewdd/noise_20' +'checkpoint-with-noise'+str(point)+'.pth', map_location=device)
  model.load_state_dict(checkpoint['model_state'])
  with torch.no_grad():
      for images, labels in testloader:
          features = model(images)
          logits.extend(features.numpy())
  # Apply PCA to reduce dimensionality to 3
  pca = PCA(n_components=3)
  pca_logit = pca.fit_transform(logits)
  pca_logits.append(pca_logit)
torch.save(pca_logits, 'ewdd/pca_logits')

In [16]:
! pip install /Users/sienkadounia/lab/ai-futures/Project/devinterp
import devinterp

Processing ./devinterp
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: devinterp
  Building wheel for devinterp (pyproject.toml) ... [?25ldone
[?25h  Created wheel for devinterp: filename=devinterp-0.2.0-py3-none-any.whl size=36054 sha256=60214053dd85c3d3a9232489ba2899680a0ccd9af8b36337e87ece2c4012e9cf
  Stored in directory: /private/var/folders/z1/4ncpvz_901x1gm_4cv8jlcwr0000gn/T/pip-ephem-wheel-cache-y3aeaz0w/wheels/d1/77/c6/79d956127053bf417b2be69c96cf13206452a78ec4d772f07c
Successfully built devinterp
Installing collected packages: devinterp
  Attempting uninstall: devinterp
    Found existing installation: devinterp 0.2.0
    Uninstalling devinterp-0.2.0:
      Successfully uninstalled devinterp-0.2.0
Successfully installed devinterp-0.2.0


In [17]:
from devinterp.slt.forms import *

In [18]:
pca_logits = torch.load('ewdd/pca_logits')

In [19]:
cat_pca_logits = np.concatenate(pca_logits)

In [20]:
from devinterp.utils import *

In [21]:
def sigma_helper(z, sigma_early, sigma_late, sigma_interp_end, interp_range=0.2):
    sigma_interp_start = interp_range * sigma_interp_end
    if z < sigma_interp_start:
        return sigma_early
    elif z > sigma_interp_end:
        return sigma_late
    else:
        return sigma_early + (sigma_late - sigma_early) / (
            sigma_interp_end - sigma_interp_start
        ) * (z - sigma_interp_start)

def get_smoothed_pcs(
    transformed_samples,
    num_pca_components,
    early_smoothing,
    late_smoothing,
    late_smoothing_from,
):
    smoothed_pcs = []
    for pca_component_i in range(0, num_pca_components):
        print(f"Processing smoothing for PC{pca_component_i+1}")
        smoothed_pc = np.copy(transformed_samples[:, 0])
        for z in range(len(transformed_samples)):
            sigma = sigma_helper(
                z, early_smoothing, late_smoothing, late_smoothing_from
            )
            smoothed_pc[z] = gaussian_filter1d(
                transformed_samples[:, pca_component_i], sigma
            )[z]
        smoothed_pcs.append(smoothed_pc)
    return smoothed_pcs

In [22]:
n_pca_components = 3
transformed_samples = cat_pca_logits 

In [112]:
smoothed_pcs = get_smoothed_pcs(
    transformed_samples,
    n_pca_components,
    early_smoothing=1,
    late_smoothing=1,
    late_smoothing_from=1,
)

Processing smoothing for PC1
Processing smoothing for PC2
Processing smoothing for PC3


In [23]:
smoothed_pcs = torch.load('ewdd/smoothed_pcs.pt')

In [24]:
TRANSITIONS = [  # TODO automate
    (0, 38, "First Descent"),
    (38, 128, "Increase"),
    (128, 1_999, "Second Descent"),
]

In [25]:
fig = plot_essential_dynamics_grid(
    transformed_samples,
    smoothed_pcs,
    transitions=TRANSITIONS,
    marked_cusp_data=[],
    num_plotted_pca_comps=n_pca_components,
    plot_vertex_influence=True,
    plot_caustic=True,
    figsize=(8, 8 / 10 * 6),
    num_sharp_points=5,
    num_vertices=5,
    osculate_start=1,
    osculate_end_offset=5,
    osculate_skip=1,
)
plt.show()

len(colors) != len(transitions), using rainbow palette.
Number of samples: 2330000
Plotting PC1 vs PC2




: 

In [23]:
len(smoothed_pcs), len(n_pca_components)

TypeError: object of type 'int' has no len()

In [44]:
fig = plot_essential_dynamics_grid(
    transformed_samples,
    smoothed_pcs=3,
    transitions=TRANSITIONS,
    # marked_cusp_data=marked_cusp_data,
    num_plotted_pca_comps=n_pca_components,
    plot_vertex_influence=True,
    plot_caustic=True,
    figsize=(8, 8 / 10 * 6),
    num_sharp_points=5,
    num_vertices=5,
    osculate_start=1,
    osculate_end_offset=2000,
    osculate_skip=8,
)
plt.show()

len(colors) != len(transitions), using rainbow palette.


TypeError: object of type 'int' has no len()

In [48]:
pca_logits[0].shape

(10000, 3)