# Fine-tuning on the reference dataset
In this notebook, we'll demonstrate fine-tuning a pre-trained CNN on the 30-isolate classification task shown in Figure 2. In this example, fine-tuning serves to update the CNN to new measurement parameters. This code illustrates the procedure described in the `CNN architecture & training details` section in the Methods. Note that for speed and clarity, this demo only trains on a single randomly selected train and validation split.

## Loading data
The first step is to load the fine-tuning dataset.

In [1]:
from time import time
t00 = time()
import numpy as np

In [2]:
X_fn = 'F:/Datasets/RAMAN_data/X_finetune.npy'
y_fn = 'F:/Datasets/RAMAN_data/y_finetune.npy'
X = np.load(X_fn)
y = np.load(y_fn)
print(X.shape, y.shape)

(3000, 1000) (3000,)


## Loading pre-trained CNN
Now we set up a ResNet CNN and load weights that we previously trained for the 30-isolate task using the full training dataset. 

In [3]:
from resnet import ResNet
import os
import torch

In [4]:
# CNN parameters
layers = 6
hidden_size = 100
block_size = 2
hidden_sizes = [hidden_size] * layers
num_blocks = [block_size] * layers
input_dim = 1000
in_channels = 64
n_classes = 30
os.environ['CUDA_VISIBLE_DEVICES'] = '{}'.format(0)
cuda = torch.cuda.is_available()

In [5]:
# Load trained weights for demo
cnn = ResNet(hidden_sizes, num_blocks, input_dim=input_dim,
                in_channels=in_channels, n_classes=n_classes)
if cuda: cnn.cuda()
cnn.load_state_dict(torch.load(
    './pretrained_model.ckpt', map_location=lambda storage, loc: storage))

<All keys matched successfully>

## Fine-tuning
Now we can fine-tune the pre-trained CNN on the new fine-tuning dataset. In the experiments reported in the paper, we fine-tune across 5 randomly selected train and validation splits, but here we show just one split for clarity. We also only train for one epoch here in the interest of time. To train the CNN to convergence, we recommend setting the number of epochs to ~30.

In [6]:
from datasets import spectral_dataloader
from training import run_epoch
from torch import optim

### Train/val split
We split the fine-tuning dataset into train and validation sets. We randomly sample 10% of the dataset to use as a validation set.

In [7]:
p_val = 0.1
n_val = int(3000 * p_val)
idx_tr = list(range(3000))
np.random.shuffle(idx_tr)
idx_val = idx_tr[:n_val]
idx_tr = idx_tr[n_val:]

In [9]:
# Fine-tune CNN
epochs = 30 # Change this number to ~30 for full training
batch_size = 10
t0 = time()
# Set up Adam optimizer
optimizer = optim.Adam(cnn.parameters(), lr=1e-3, betas=(0.5, 0.999))
# Set up dataloaders
dl_tr = spectral_dataloader(X, y, idxs=idx_tr,
    batch_size=batch_size, shuffle=True)
dl_val = spectral_dataloader(X, y, idxs=idx_val,
    batch_size=batch_size, shuffle=False)
# Fine-tune CNN for first fold
best_val = 0
no_improvement = 0
max_no_improvement = 5
print('Starting fine-tuning!')
for epoch in range(epochs):
    print(' Epoch {}: {:0.2f}s'.format(epoch+1, time()-t0))
    # Train
    acc_tr, loss_tr = run_epoch(epoch, cnn, dl_tr, cuda,
        training=True, optimizer=optimizer)
    print('  Train acc: {:0.2f}'.format(acc_tr))
    # Val
    acc_val, loss_val = run_epoch(epoch, cnn, dl_val, cuda,
        training=False, optimizer=optimizer)
    print('  Val acc  : {:0.2f}'.format(acc_val))
    # Check performance for early stopping
    if acc_val > best_val or epoch == 0:
        best_val = acc_val
        no_improvement = 0
    else:
        no_improvement += 1
    if no_improvement >= max_no_improvement:
        print('Finished after {} epochs!'.format(epoch+1))
        break

print('\n This demo was completed in: {:0.2f}s'.format(time()-t00))

Starting fine-tuning!
 Epoch 1: 0.00s


Loss: 0.309 | Acc: 89.9%: 100%|██████████| 270/270 [00:12<00:00, 21.27it/s]


  Train acc: 89.93


Loss: 0.270 | Acc: 87.3%: 100%|██████████| 30/30 [00:00<00:00, 41.21it/s]


  Val acc  : 87.33
 Epoch 2: 27.57s


Loss: 0.172 | Acc: 94.0%: 100%|██████████| 270/270 [00:12<00:00, 21.19it/s]


  Train acc: 94.00


Loss: 0.215 | Acc: 90.3%: 100%|██████████| 30/30 [00:00<00:00, 42.02it/s]


  Val acc  : 90.33
 Epoch 3: 55.05s


Loss: 0.099 | Acc: 96.7%: 100%|██████████| 270/270 [00:12<00:00, 21.08it/s]


  Train acc: 96.70


Loss: 0.202 | Acc: 91.0%: 100%|██████████| 30/30 [00:00<00:00, 40.76it/s]


  Val acc  : 91.00
 Epoch 4: 82.57s


Loss: 0.059 | Acc: 98.0%: 100%|██████████| 270/270 [00:12<00:00, 20.88it/s]


  Train acc: 97.96


Loss: 0.212 | Acc: 91.0%: 100%|██████████| 30/30 [00:00<00:00, 40.00it/s]


  Val acc  : 91.00
 Epoch 5: 110.39s


Loss: 0.036 | Acc: 99.1%: 100%|██████████| 270/270 [00:12<00:00, 21.03it/s]


  Train acc: 99.15


Loss: 0.198 | Acc: 91.3%: 100%|██████████| 30/30 [00:00<00:00, 41.67it/s]


  Val acc  : 91.33
 Epoch 6: 137.93s


Loss: 0.024 | Acc: 99.4%: 100%|██████████| 270/270 [00:12<00:00, 20.94it/s]


  Train acc: 99.41


Loss: 0.193 | Acc: 93.3%: 100%|██████████| 30/30 [00:00<00:00, 38.27it/s]


  Val acc  : 93.33
 Epoch 7: 165.81s


Loss: 0.019 | Acc: 99.5%: 100%|██████████| 270/270 [00:13<00:00, 20.04it/s] 


  Train acc: 99.52


Loss: 0.232 | Acc: 90.7%: 100%|██████████| 30/30 [00:00<00:00, 40.87it/s]


  Val acc  : 90.67
 Epoch 8: 194.93s


Loss: 0.013 | Acc: 99.8%: 100%|██████████| 270/270 [00:13<00:00, 20.37it/s] 


  Train acc: 99.78


Loss: 0.226 | Acc: 92.3%: 100%|██████████| 30/30 [00:00<00:00, 43.29it/s]


  Val acc  : 92.33
 Epoch 9: 223.18s


Loss: 0.013 | Acc: 99.7%: 100%|██████████| 270/270 [00:13<00:00, 20.11it/s] 


  Train acc: 99.70


Loss: 0.210 | Acc: 92.3%: 100%|██████████| 30/30 [00:00<00:00, 40.38it/s]


  Val acc  : 92.33
 Epoch 10: 253.45s


Loss: 0.018 | Acc: 99.6%: 100%|██████████| 270/270 [00:13<00:00, 20.00it/s]


  Train acc: 99.59


Loss: 0.192 | Acc: 92.7%: 100%|██████████| 30/30 [00:00<00:00, 39.74it/s]


  Val acc  : 92.67
 Epoch 11: 283.13s


Loss: 0.005 | Acc: 100.0%: 100%|██████████| 270/270 [00:13<00:00, 19.99it/s]


  Train acc: 99.96


Loss: 0.219 | Acc: 92.7%: 100%|██████████| 30/30 [00:00<00:00, 41.06it/s]

  Val acc  : 92.67
Finished after 11 epochs!

 This demo was completed in: 1450.66s





The accuracies seen here are not representative of the accuracies achieved when training on the full dataset until convergence. To do this, increase the number of epoches. This code demonstrates how a pre-trained CNN can be fine-tuned and evaluated using randomly selected train/validation splits.