In [None]:
from collections import OrderedDict
from robustness import datasets
from robustness.tools.breeds_helpers import make_living17
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import sampler
import os
import torchvision.datasets as dset
import torchvision.transforms as T
import pandas as pd
# for plotting
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

from trainer_functions.cifartrainer import evaluate_cifar, train_cifar, build_cifar10
from models.ResNet import ResNetCifar as ResNet
from trainer_functions.cifarCloader import CIFARCDataset
from trainer_functions.imagenetc import ImageNetC
from trainer_functions.imagenetctrainer import build_resnet50, evaluate_imagenetc, train_imagenetc

In [None]:
# empty dataframe
df = pd.DataFrame(columns=['accuracy', "lr", "state", "n_subset"])

data_dir = '/Users/rada/.mxnet/datasets/imagenet'
info_dir = '/Users/rada/Documents/GitHub/BREEDS-Benchmarks/imagenet_class_hierarchy/modified'

n_subset = 850
lrs = [0.0005, 1e-4, 1e-5]

In [None]:
for lr in lrs:
    
    ret = make_living17(info_dir, split="rand")
    superclasses, subclass_split, label_map = ret
    train_subclasses, test_subclasses = subclass_split

    dataset_target = datasets.CustomImageNet(data_dir, test_subclasses)

    indices = torch.randperm(len(dataset_target))

    l17train_target = torch.utils.data.Subset(dataset_target, indices[:n_subset])
    l17valid_target = torch.utils.data.Subset(dataset_target, indices[n_subset:n_subset+2000])
    l17test_target = torch.utils.data.Subset(dataset_target, indices[n_subset+2000:n_subset+4000])

    l17trainload_target = DataLoader(l17train_target, batch_size=64, shuffle=True)
    l17validload_target = DataLoader(l17valid_target, batch_size=64, shuffle=True)
    l17testload_target = DataLoader(l17test_target, batch_size=64, shuffle=True)

    device = torch.device("mps")
    tune_net = build_resnet50(device)

    # tune_net.layer1 to tune the first layer, tune_net to tune all layers, tune_net.fc to tune the last year
    # add linear probing here
    optimizer = optim.Adam(tune_net.fc.parameters(), lr=lr)

    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

    # checkpoint = torch.load('checkpoints/ckpt-2.pth')
    # tune_net.load_state_dict(checkpoint['net'])

    acc = train_imagenetc(tune_net, ictrainload, icvalidload, optimizer, scheduler, device=device, epochs=10)

    # adding results to dataframe
    dict = {}
    dict['accuracy'] = acc
    dict['lr'] = lr
    dict['state'] = "first layer"  # change when run
    dict['n_subset'] = n_subset
    df_temp = pd.DataFrame(dict, index=[0])
    df = pd.concat([df, df_temp])

# change when run
df.to_csv("imageNetC_layer1.csv")
