In [1]:
# Re-implementing train.py based on PyTorch
# 
# link.
# https://github.com/tensorflow/privacy/blob/master/research/mi_lira_2021/train.py
#
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
# pyformat: disable

import functools
import os
import shutil
from typing import Callable
import json

import jax
import jax.numpy as jn

import numpy as np
import tensorflow as tf  # For data augmentation.
import tensorflow_datasets as tfds
from absl import app, flags

import objax
from objax.jaxboard import SummaryWriter, Summary
from objax.util import EasyDict
from objax.zoo import convnet, wide_resnet

from dataset import DataSet

############################# from MODEL ZOO
import torch

import torchvision.models as models
import torchvision.transforms as transforms

from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler, random_split

from torch.utils.data import Subset

from torch_ema import ExponentialMovingAverage

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import os
import numpy as np

2024-10-13 10:29:20.858212: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-13 10:29:20.870418: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-13 10:29:20.874083: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# FLAGS = flags.FLAGS

#
# a temporary class for handling attributes 
#### 
FLAGS_ = {
    'arch': 'resnet50',
    'dataset': 'imagenet',
    'epochs': 20,
    'logdir': 'exp/imagenet-1k',
    'expid': 3,
    'num_experiments': 16,
    'save_steps': 5,
    'lr': 0.1,
    'weight_decay': 0.0005,
    'batch': 256,
    'seed': None,
    'pkeep': 0.5,
    'augment': 'weak',
    'only_subset': None,
    'dataset_size': 40000,
    'eval_steps': 1,
    'abort_after_epoch': None,
    'patience': None,
    'tunename': False
}

class Flags:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        
FLAGS = Flags(**FLAGS_)

# TODO 

- [ ] runner interface (FLAG, ...) 
- [ ] dataset load / pre-process
- [ ] model definition
- [ ] optimization (loss function)
- [ ] training loop

In [3]:
def augment(x, shift: int, mirror=True):
    """
    Augmentation function used in training the model.
    """
    y = x['image']
    if mirror:
        y = tf.image.random_flip_left_right(y)
    y = tf.pad(y, [[shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT')
    y = tf.image.random_crop(y, tf.shape(x['image']))
    return dict(image=y, label=x['label'])

In [4]:
def get_data(seed):
    """
    This is the function to generate subsets of the data for training models.

    First, we get the training dataset either from the numpy cache
    or otherwise we load it from tensorflow datasets.

    Then, we compute the subset. This works in one of two ways.

    1. If we have a seed, then we just randomly choose examples based on
       a prng with that seed, keeping FLAGS.pkeep fraction of the data.

    2. Otherwise, if we have an experiment ID, then we do something fancier.
       If we run each experiment independently then even after a lot of trials
       there will still probably be some examples that were always included
       or always excluded. So instead, with experiment IDs, we guarantee that
       after FLAGS.num_experiments are done, each example is seen exactly half
       of the time in train, and half of the time not in train.
    """
    
        
    if os.path.exists(os.path.join(FLAGS.logdir, "x_train.npy")):
        print("Loading dataset from local... ")

        train_inputs = np.load(os.path.join(FLAGS.logdir, "x_train.npy"))
        train_labels = np.load(os.path.join(FLAGS.logdir, "y_train.npy"))

        test_inputs = np.load(os.path.join(FLAGS.logdir, "x_test.npy"))
        test_labels = np.load(os.path.join(FLAGS.logdir, "y_test.npy"))
        
    else: 
        print("First time, creating dataset...")

        # TODO update to a relative ptah 
        DATA_DIR = '/serenity/scratch/psml/repo/psml/data/ILSVRC2012'

        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        imagenet_data = datasets.ImageNet(root=DATA_DIR, split='val', transform=transform)
        
        TRAIN_TEST_RATIO = 0.8
        dataset_size = len(imagenet_data)
        train_size = int(TRAIN_TEST_RATIO * dataset_size)
        validation_size = dataset_size - train_size

        train_dataset, test_dataset = random_split(imagenet_data, [train_size, validation_size])

        train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

        train_inputs_, train_labels_ = [], [] 
        test_inputs_, test_labels_ = [], []

        for images_, labels_ in tqdm(train_loader, desc="Transforming Tensor to Ndarray (Train) ..."): 
            train_inputs_.append(images_)
            train_labels_.append(labels_)
        
        for images_, labels_ in tqdm(test_loader, desc="Transforming Tensor to Ndarray (Test) ..."): 
            test_inputs_.append(images_)
            test_labels_.append(labels_)

        train_inputs, train_labels = np.concatenate(train_inputs_, axis=0), np.concatenate(train_labels_, axis=0)
        test_inputs, test_labels = np.concatenate(test_inputs_, axis=0), np.concatenate(test_labels_, axis=0)
                
        train_inputs = tf.transpose(train_inputs, perm=[0, 2, 3, 1])
        test_inputs = tf.transpose(test_inputs, perm=[0, 2, 3, 1])

        print("Saving *train.npy...")
        np.save(os.path.join(FLAGS.logdir, "x_train.npy"), train_inputs)
        np.save(os.path.join(FLAGS.logdir, "y_train.npy"), train_labels)

        print("Saving *test.npy...")
        np.save(os.path.join(FLAGS.logdir, "x_test.npy"), test_inputs)
        np.save(os.path.join(FLAGS.logdir, "y_test.npy"), test_labels)
        
##
    nclass = np.max(train_labels)+1
    
    np.random.seed(seed)
    if FLAGS.num_experiments is not None:
        np.random.seed(0)
        keep = np.random.uniform(0,1,size=(FLAGS.num_experiments, FLAGS.dataset_size))
        order = keep.argsort(0)
        keep = order < int(FLAGS.pkeep * FLAGS.num_experiments)
        keep = np.array(keep[FLAGS.expid], dtype=bool)
    else:
        keep = np.random.uniform(0, 1, size=FLAGS.dataset_size) <= FLAGS.pkeep

    if FLAGS.only_subset is not None:
        keep[FLAGS.only_subset:] = 0
    
    xs = train_inputs[keep]
    ys = train_labels[keep]
    
    if FLAGS.augment == 'weak':
        aug = lambda x: augment(x, 4)
    elif FLAGS.augment == 'mirror':
        aug = lambda x: augment(x, 0)
    elif FLAGS.augment == 'none':
        aug = lambda x: augment(x, 0, mirror=False)
    else:
        raise
    
    train = DataSet.from_arrays(xs, ys, augment_fn=aug)
    train = train.cache().shuffle(8192).repeat().parse().augment().batch(FLAGS.batch)
    train = train.nchw().one_hot(nclass).prefetch(16)
    
    test = DataSet.from_arrays(test_inputs, test_labels)
    test = test.cache().parse().batch(FLAGS.batch).nchw().prefetch(16)
    
    return train, test, xs, ys, keep, nclass

In [5]:
class Trainer:
    def __init__(self, model, train_loader, test_loader, device='cuda', lr=0.01, weight_decay=1e-4, logdir='logs'):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device

        self.lr = lr
        self.weight_decay = weight_decay
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9, weight_decay=self.weight_decay)
        self.ema_model = ExponentialMovingAverage(model.parameters(), decay=0.999)
#         self.ema_decay = 0.999
        
        self.writer = SummaryWriter(FLAGS.logdir)
        
        # TODO 
        # global parameter: num_epochs
        
        
    def adjust_learning_rate(self, progress):
        lr = self.initial_lr * torch.cos(progress * (7 * torch.pi) / (2 * 8))
        lr = lr * torch.clamp(progress * 100, 0, 1)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr.item()
        return lr.item()

    def train_one_epoch(self, epoch, num_epochs):
        self.model.train()
        
        running_loss = 0.0
        
#         number_of_training_examples // batch_size +1  .
#         total_batches = self.dataset_size // FLAGS.batch + 1 
        
        for batch_idx, data in enumerate(self.train_loader):
            images, labels = data['image'].to(self.device), data['label'].to(self.device)
            progress = (batch_idx + epoch * FLAGS.dataset_size) / (num_epochs * FLAGS.dataset_size)

            self.optimizer.zero_grad()
            logits = self.model(images)
            
            loss_xe = F.cross_entropy(logits, labels)
#             loss_wd = 0.5 * sum((p ** 2).sum() for p in self.model.parameters())
#             loss = loss_xe + self.weight_decay * loss_wd
            loss.backward()
            self.optimizer.step()

            # EMA update
            self.ema_model.update()

            running_loss += loss_xe.item()
            if batch_idx % 100 == 99:  # log every 100 mini-batches
                print(f'[Epoch {epoch+1}, Batch {batch_idx+1}] loss: {running_loss/100:.3f}')
                self.writer.add_scalar('training loss', running_loss / 100, epoch * len(self.train_loader) + batch_idx)
                running_loss = 0.0

    def evaluate(self, epoch):
        self.model.eval()
        
        correct = 0
        total = 0
        with torch.no_grad():
            for data in self.test_loader:
                images, labels = data['image'].to(self.device), data['label'].to(self.device)
                logits = self.model(images)
                _, predicted = torch.max(logits.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        self.writer.add_scalar('eval/accuracy', accuracy, epoch)
        print(f'Accuracy of the model on test images: {accuracy:.2f}%')
        
        return accuracy

    def save_checkpoint(self, epoch, best=False):
        checkpoint_dir = 'checkpoints'
        os.makedirs(checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch+1}.pth')
        if best:
            checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')
        torch.save(self.model.state_dict(), checkpoint_path)

    def fit(self, num_epochs, patience=None, save_steps=1):
        best_acc = 0.0
        best_epoch = -1
        patience_counter = 0

        for epoch in range(num_epochs):
            # Train for one epoch
            self.train_one_epoch(epoch, num_epochs)

            # Evaluate the model
            accuracy = self.evaluate(epoch)

            # Save the best model
            if accuracy > best_acc:
                best_acc = accuracy
                best_epoch = epoch
                self.save_checkpoint(epoch, best=True)
            else:
                patience_counter += 1

            # Early stopping check
            if patience is not None and patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

            # Save checkpoints periodically
            if epoch % save_steps == save_steps - 1:
                self.save_checkpoint(epoch)

        print(f'Best Accuracy: {best_acc:.2f}% at epoch {best_epoch+1}')
        self.writer.close()


# Usage
# Assuming `model`, `train_loader`, and `test_loader` are defined elsewhere


In [6]:
def network(arch: str):
    # https://pytorch.org/vision/stable/models.html
    TORCHVISION_MODELS = ['resnet18', 'resnet50', 'resnet101', 'vgg16', 'vgg19', 'densenet121', 
                          'wide_resnet50_2', 'wide_resnet101_2'
                          'densenet201', 'mobilenet_v2', 'inception_v3', 
                          'efficientnet_b0', 'efficientnet_b7', 
                          'squeezenet1_0', 'alexnet', 'googlenet', 'shufflenet_v2_x1_0']
    
    # https://github.com/huggingface/pytorch-image-models
    PYTORCH_IMAGE_MODELS = ['vit_base_patch16_224', 'vit_large_patch16_224', 'deit_base_patch16_224',
                        'convnext_base', 'convnext_large']
    
    if arch in TORCHVISION_MODELS:
        return models.__dict__[arch](pretrained=True)
    elif arch in PYTORCH_IMAGE_MODELS:
        return timm.create_model(arch, pretrained=True)
    else:
        raise ValueError(f"Model {model_name} not available.")

In [7]:
# loss function

In [8]:
train_loader, test_loader, xs, ys, keep, nclass = get_data(1024)

Loading dataset from local... 


In [11]:
train_loader.data.batch_size

AttributeError: '_PrefetchDataset' object has no attribute 'batch_size'

In [9]:
m = network('resnet50')



In [10]:
trainer = Trainer(m, train_loader, test_loader, lr=0.01, weight_decay=1e-4, logdir='logs')
trainer.fit(num_epochs=20, patience=5, save_steps=5)

TypeError: object of type 'DataSet' has no len()

In [14]:
type(train_loader.data)

tensorflow.python.data.ops.prefetch_op._PrefetchDataset

In [15]:
iterator = iter(train_loader)

In [16]:
data = iterator.get_next()
data

{'image': <tf.Tensor: shape=(256, 3, 224, 224), dtype=float32, numpy=
 array([[[[-0.4054286 , -0.57667613, -0.7650484 , ..., -0.67942464,
           -0.69654936, -0.9362959 ],
          [-0.5253019 , -0.5424266 , -0.55955136, ..., -0.81642264,
           -0.9362959 , -0.69654936],
          [-0.4054286 , -0.57667613, -0.7650484 , ..., -0.67942464,
           -0.69654936, -0.9362959 ],
          ...,
          [-1.0561693 , -1.0219197 , -0.95342064, ..., -0.6451751 ,
           -0.42255333, -0.25130582],
          [-1.141793  , -1.004795  , -1.0219197 , ..., -0.7307989 ,
           -0.6109256 , -0.59380084],
          [-1.0561693 , -0.9020464 , -0.6280504 , ..., -0.79929787,
           -1.0561693 , -1.0219197 ]],
 
         [[-0.32002798, -0.565126  , -0.705182  , ..., -0.65266097,
           -0.722689  , -0.862745  ],
          [-0.565126  , -0.565126  , -0.495098  , ..., -0.792717  ,
           -0.897759  , -0.65266097],
          [-0.32002798, -0.565126  , -0.705182  , ..., -0.652660

In [17]:
len(train_loader.data)

TypeError: The dataset is infinite.

In [80]:
logits = m(torch.from_numpy(images.numpy()))

In [21]:
train_loader.image_shape

(224, 224, 3)

In [25]:
len(xs)

19917

In [81]:
logits

tensor([[-4.4196e+00, -1.2758e+00, -3.9712e+00,  ..., -1.0529e+00,
         -1.3773e+00,  6.4069e+00],
        [-2.1180e+00, -2.7036e-01, -5.4485e-03,  ..., -5.5878e-01,
          9.9498e-02, -4.4625e-01],
        [ 2.7040e+00, -8.7923e-01, -6.9519e-01,  ...,  1.2780e+00,
          2.2677e-01, -6.8925e-01],
        ...,
        [-5.6048e-01,  9.8632e-01,  5.2000e+00,  ..., -1.7108e+00,
          2.4272e+00,  1.3268e+00],
        [-2.5208e+00, -3.2628e+00,  1.6351e+00,  ...,  9.9490e-01,
          3.1441e+00,  7.1338e-01],
        [ 5.2069e-01, -5.4838e+00, -2.8739e+00,  ..., -1.2691e+00,
          2.9342e+00,  2.0972e+00]], grad_fn=<AddmmBackward0>)

In [83]:
print(type(logits))
print(logits.shape)

<class 'torch.Tensor'>
torch.Size([256, 1000])


In [29]:
# train, test, xs, ys, keep, nclass = get_data(1024)

In [65]:
m = network('resnet18')



In [68]:
type(m)

torchvision.models.resnet.ResNet

In [None]:
if __name__ == '__main__':

    flags.DEFINE_float('lr', 0.1, 'Learning rate.')
    flags.DEFINE_float('weight_decay', 0.0005, 'Weight decay ratio.')
    flags.DEFINE_integer('batch', 256, 'Batch size')
    flags.DEFINE_integer('seed', None, 'Training seed.')
    flags.DEFINE_float('pkeep', .5, 'Probability to keep examples.')
    flags.DEFINE_string('augment', 'weak', 'Strong or weak augmentation')
    flags.DEFINE_integer('only_subset', None, 'Only train on a subset of images.')
    flags.DEFINE_integer('dataset_size', 50000, 'number of examples to keep.')
    flags.DEFINE_integer('eval_steps', 1, 'how often to get eval accuracy.')
    flags.DEFINE_integer('abort_after_epoch', None, 'stop trainin early at an epoch')
    flags.DEFINE_integer('patience', None, 'Early stopping after this many epochs without progress')
    flags.DEFINE_bool('tunename', False, 'Use tune name?')

    ### override 
    # https://github.com/tensorflow/privacy/tree/d965556ebb67bd62626830339478e9ebab7ab9bd/research/mi_lira_2021/scripts  
    #     CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0

    flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
    flags.DEFINE_string('arch', 'wrn28-2', 'Model architecture.')

    flags.DEFINE_integer('epochs', 20, 'Training duration in number of epochs.')
    flags.DEFINE_integer('save_steps', 5, 'how often to get save model.')
    flags.DEFINE_integer('num_experiments', 16, 'Number of experiments')
    flags.DEFINE_integer('expid', 3, 'Experiment ID')
    flags.DEFINE_string('logdir', 'experiments', 'Directory where to save checkpoints and tensorboard data.')