In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from torchvision import datasets, transforms
from tqdm import tqdm
from collections import OrderedDict, deque
import matplotlib.pyplot as plt

In [None]:
batch_size = 256
n_epochs = 5
device = 'mps'


dataset = 'cifar10'

if dataset == 'mnist':
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ]
    )    
    dataset_train = datasets.MNIST('./mnist-dataset', train=True, download=True, transform=transform)
    dataset_val = datasets.MNIST('./mnist-dataset', train=False, transform=transform)
    
elif dataset == 'cifar10':
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )        
    dataset_train = datasets.CIFAR10('./cifar10-dataset', train=True, download=True, transform=transform)
    dataset_val = datasets.CIFAR10('./cifar10-dataset', train=False, transform=transform)
else:
    raise ValueError('unsupported dataset')
    
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=False)


In [None]:

class ImageClassifier(nn.Module):
    def __init__(self, n_classes: int = 10):
        super(ImageClassifier, self).__init__()
        self.net = nn.Sequential(OrderedDict([
            ('conv1', nn.LazyConv2d(32, 3, 1)),
            ('relu1', nn.ReLU()),
            ('conv2', nn.Conv2d(32, 64, 3, 1)),
            ('relu2', nn.ReLU()),
            ('maxpool', nn.MaxPool2d(2)),           
            ('flatten', nn.Flatten()),
            ('linear1', nn.LazyLinear(128)),
            ('bn1', nn.LazyBatchNorm1d()),
            ('relu3', nn.ReLU()), 
            ('linear2', nn.LazyLinear(512)),
            ('bn2', nn.LazyBatchNorm1d()),
            ('relu4', nn.ReLU()),
            ('linear3', nn.LazyLinear(128)),
            ('relu5', nn.ReLU()), 
            ('linear4', nn.LazyLinear(64)),
            ('bn3', nn.LazyBatchNorm1d()),
            ('relu6', nn.ReLU()), 
            ('linear5', nn.LazyLinear(32)),
            ('bn4', nn.LazyBatchNorm1d()),
            ('relu7', nn.ReLU()), 
            ('linear8', nn.LazyLinear(n_classes)),
        ]))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class LinearClassifier(nn.Module):
    def __init__(self, n_classes: int = 10):
        super(LinearClassifier, self).__init__()
        self.net = nn.Sequential(OrderedDict([
            ('flatten', nn.Flatten()),
            ('linear1', nn.LazyLinear(64)),
            ('relu3', nn.ReLU()), 
            ('linear3', nn.LazyLinear(64)),
            ('relu5', nn.ReLU()), 
            ('linear4', nn.LazyLinear(32)),
            ('relu7', nn.ReLU()), 
            ('linear8', nn.LazyLinear(n_classes)),
        ]))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)




In [None]:
def train_epoch(exit_early: bool):
    img_classifier.train()
    is_finished = False
    recent_accs = deque(maxlen=20)
    t = tqdm(dataloader_train)
    for imgs, labels in t:
        imgs, labels = imgs.to(device=device), labels.to(device=device)
        
        ####### TODO: DELETE #######
        imgs = torch.randn_like(imgs)
        
        
        optimizer.zero_grad()
        out = img_classifier(imgs)
        loss = F.cross_entropy(out, labels)
        loss.backward()
        optimizer.step()
        acc = (torch.argmax(out, dim=1) == labels).float().mean().item()
        t.set_description(f'train - loss: {round(loss.item(), 2)}, acc: {round(acc, 2)}')
        recent_accs.append(acc)
        if exit_early and len(recent_accs) == 20 and np.mean(recent_accs) > 0.95:
            is_finished = True
            break
            
    return is_finished

def val_epoch():
    with torch.no_grad():
        img_classifier.eval()
        t = tqdm(dataloader_val)
        for imgs, labels in t:
            imgs, labels = imgs.to(device=device), labels.to(device=device)
            out = img_classifier(imgs)
            loss = F.cross_entropy(out, labels)
            acc = (torch.argmax(out, dim=1) == labels).float().mean().item()
            t.set_description(f'val - loss: {round(loss.item(), 2)}, acc: {round(acc, 2)}')            


In [None]:
nets = []
#5e-1, 
learning_rates = [1e-1, 5e-2, 1e-2, 5e-3, 1e-3,    5e-4, 1e-4, 5e-5, 1e-5]#, 5e-6, 1e-6, 5e-7]
for lr in learning_rates:
    torch.manual_seed(0)
    print('current lr:', lr)
#     img_classifier = ImageClassifier().to(device=device)
    ### TODO: REVERTR
    img_classifier = LinearClassifier().to(device=device)
    # optimizer = torch.optim.SGD(img_classifier.parameters(), lr=lr, momentum=0.9)
    optimizer = torch.optim.Adam(img_classifier.parameters(), lr=lr)

    for epoch in range(n_epochs):
        is_finished = train_epoch(exit_early=False)
        val_epoch()
        
        break  # TODO: DELETE RANDOMIZATION AS WELL
        
#         if is_finished:
    nets.append(img_classifier)
    #break

In [None]:
# torch.save(nets, f'nets_{dataset}.pth')

In [None]:
# nets = torch.load('nets.pth')

In [None]:
linear_layer_names = [name for name in list(nets[0].state_dict().keys()) if 'linear' in name and 'weight' in name]

In [None]:
# layer: {lr: mean}
layer_means = {layer_name: {} for layer_name in linear_layer_names} 
layer_stds = {layer_name: {} for layer_name in linear_layer_names} 

    
for i, net in enumerate(nets):
    for layer_name, param in net.named_parameters():
        if layer_name in linear_layer_names:
            layer_means[layer_name][learning_rates[i]] = param.data.mean().item()
            layer_stds[layer_name][learning_rates[i]] = param.data.std().item()        


In [None]:

fig, axes = plt.subplots(nrows=len(linear_layer_names), ncols=2, figsize=(14, 24))
for i in range(len(linear_layer_names)):
    ax = axes[i][0]
    layer_name = linear_layer_names[i]
    ax.set_title(layer_name)

    lr_to_mean = layer_means[layer_name]
    x = [str(lr) for lr in lr_to_mean.keys()]
    y = np.abs(list(lr_to_mean.values()))
    ax.plot(x, y, label='abs mean')
    ax.legend()
    

for i in range(len(linear_layer_names)):
    ax = axes[i][1]
    layer_name = linear_layer_names[i]
    ax.set_title(layer_name)
 
    lr_to_stds = layer_stds[layer_name]
    x = [str(lr) for lr in lr_to_stds.keys()]
    y = list(lr_to_stds.values())
    ax.plot(x, y, label='std')

    ax.legend()




In [None]:
sample = next(iter(dataloader_train))[0].to(device)#[0:2][0].unsqueeze(0)
nets[8].eval()
nets[8](sample).std(axis=1).mean()