In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Create a persistent working directory in your Drive
import os
drive_dir = '/content/drive/MyDrive/colab_dnn_mode_connectivity'
os.makedirs(drive_dir, exist_ok=True)
drive_dir


In [None]:
# Always clone into the VM runtime (fast, clean)
!git clone https://github.com/xenakistheo/dnn-mode-connectivity.git

import os
os.chdir("dnn-mode-connectivity")

# Optional: pull latest changes if the folder already existed (rare)
!git pull


In [None]:
# learning rate   = 0.1
# momentum        = 0.9
# weight decay    = 3e-4
# epochs          = 80

# ResNets: 8, 26, 38, 62, 116.

Use the following names for directories
- Endpoints [model_name]\_EP\_[seednr.]. E.g. ResNet80_EP_1
- Connecting path [model_name]\_CP\_[pathtype]\_[seednr.1]\_[seednr.2]. E.g. ResNet8_CP_Bezier_1_2
- Path evaluation [model_name]\_EV\_[pathtype]\_[seednr.1]\_[seednr.2]. E.g. ResNet8_EV_Bezier_1_2
- Hessian Analysis [model_name]\_HA\_[pathtype]\_[seednr.1]\_[seednr.2]. E.g. ResNet8_HA_Bezier_1_2

In [None]:
### Train Endpoint 1
# Remember to
# 1. Set correct directory
# 2. Set correct model
!python3 train.py --dir=/content/drive/MyDrive/mode_connectivity_runs/ResNet8_EP_1 \
  --dataset=CIFAR10 --data_path=./data --transform=ResNet \
  --model=ResNet8 \
  --epochs=80 --lr=0.1 --wd=3e-4 --save_freq=40 --use_test \
  --seed=1

In [None]:
### Train Endpoint 2
# Remember to
# 1. Set correct directory
# 2. Set correct model
!python3 train.py --dir=/content/drive/MyDrive/mode_connectivity_runs/ResNet8_EP_2 \
        --dataset=CIFAR10 --data_path=./data --transform=ResNet \
        --model=ResNet8 \
        --epochs=80 --lr=0.1 --wd=3e-4 --save_freq=40 --use_test \
        --seed=2

In [None]:
### Find Connecting Path - Bezier
# Remember to
# 1. Set correct directory
# 2. Set correct models
# Set correct endpoint paths
!python3 train.py --dir=/content/drive/MyDrive/mode_connectivity_runs/ResNet8_CP_Bezier_1_2 \
        --dataset=CIFAR10 --use_test --transform=ResNet --data_path=./data \
        --model=ResNet8 \
        --curve=Bezier \
        --num_bends=3 \
        --init_start=/content/drive/MyDrive/mode_connectivity_runs/ResNet8_EP_1/checkpoint-80.pt \
        --init_end=/content/drive/MyDrive/mode_connectivity_runs/ResNet8_EP_2/checkpoint-80.pt \
        --fix_start --fix_end --epochs=50 --lr=0.1 --wd=3e-4

In [None]:
### Find Connecting Path - PolyChain
# Remember to
# 1. Set correct directory
# 2. Set correct models
# Set correct endpoint paths/content/drive/MyDrive/mode_connectivity_runs/PreResNet_e
!python3 train.py --dir=/content/drive/MyDrive/mode_connectivity_runs/ResNet8_CP_PolyChain_1_2 \
        --dataset=CIFAR10 --use_test --transform=ResNet --data_path=./data \
        --model=ResNet8 \
        --curve=PolyChain \
        --num_bends=3 \
        --init_start=/content/drive/MyDrive/mode_connectivity_runs/ResNet8_EP_1/checkpoint-80.pt \
        --init_end=/content/drive/MyDrive/mode_connectivity_runs/ResNet8_EP_2/checkpoint-80.pt \
        --fix_start --fix_end --epochs=50 --lr=0.1 --wd=3e-4

In [None]:
### Evaluate Path - Bezier
# Remember to
# 1. Set correct directory
# 2. Set correct models
# Set correct endpoint paths
!python3 eval_curve.py --dir=/content/drive/MyDrive/mode_connectivity_runs/ResNet8_EV_Bezier_1_2 \
        --dataset=CIFAR10 --data_path=./data --transform=ResNet \
        --model=ResNet8 \
        --wd=3e-4 \
        --curve=Bezier \
        --num_bends=3 \
        --ckpt=/content/drive/MyDrive/mode_connectivity_runs/ResNet8_CP_Bezier_1_2/checkpoint-50.pt\
        --num_points=61 --use_test

In [None]:
### Evaluate Path - PolyChain
# Remember to
# 1. Set correct directory
# 2. Set correct models
# Set correct endpoint paths
!python3 eval_curve.py --dir=/content/drive/MyDrive/mode_connectivity_runs/ResNet8_EV_PolyChain_1_2 \
        --dataset=CIFAR10 --data_path=./data --transform=ResNet \
        --model=ResNet8 \
        --wd=3e-4 \
        --curve=PolyChain \
        --num_bends=3 \
        --ckpt=/content/drive/MyDrive/mode_connectivity_runs/ResNet8_CP_PolyChain_1_2/checkpoint-50.pt \
        --num_points=61 --use_test

In [None]:
### Perform Hessian Analysis
# Remember to
# 1. Set correct directory
# 2. Set correct models
# Set correct endpoint paths

!python3 hessian_curve_analysis.py --curve_dir=/content/drive/MyDrive/mode_connectivity_runs/ResNet8_HA_Bezier_1_2 \
        --curve_ckpt=/content/drive/MyDrive/mode_connectivity_runs/ResNet8_CP_Bezier_1_2/checkpoint-50.pt \
        --dataset=CIFAR10 --data_path=./data \
        --model=ResNet8 \
        --curve=Bezier \
        --num_bends=3 --batch_size=128 --use_test


----

In [None]:
### Convex Combination Evaluation 
## - Measure performance along the linear path between the two endpoints
# Remember to
# 1. Set correct model
# 2. Set correct endpoint paths
# 3. Set correct output csv file name

!python convex_combo.py \
    --ckpt_a=./runs/ResNet8_EP_1/checkpoint-80.pt \
    --ckpt_b=./runs/ResNet8_EP_2/checkpoint-80.pt \
    --model=ResNet8 \
    --steps=31 \
    --recompute_bn \
    --data_path=./data \
    --save_csv ResNet8_convex.csv