In [None]:
SEED_EXPERIMENT=0

In [None]:
import numpy as np
import pandas as pd
import tqdm.notebook as tqdm

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import models
from models.utils import save_model

import pytorch_lightning as pl

from transforms.transform import base_transform, augmented_transform, IdentityTransform, simsiam_representation_transform

from models.encoders import get_resnet18_encoder, get_shufflenetv2_encoder, get_alexnet_encoder, SpectrumCNN

from models.lit_models import CNN_classifier, SimSiam
import os
import random


os.environ['PYTHONHASHSEED'] = str(SEED_EXPERIMENT)
# Torch RNG
torch.manual_seed(SEED_EXPERIMENT)
torch.cuda.manual_seed(SEED_EXPERIMENT)
torch.cuda.manual_seed_all(SEED_EXPERIMENT)
# Python RNG
np.random.seed(SEED_EXPERIMENT)
random.seed(SEED_EXPERIMENT)

batch_size = 64

# 1. Choose the dataset!

In [None]:
DATASET_NAME = "AMP_PHASE"
train_dataset_amp_phase = torchvision.datasets.ImageFolder('data/images_amp_phase/train', transform=base_transform)
train_dataloader = DataLoader(train_dataset_amp_phase, shuffle=True, batch_size=batch_size, num_workers=1)

In [None]:
DATASET_NAME = "SPECTROGRAM"
train_dataset_spectrogram = torchvision.datasets.ImageFolder('data/images_spectrogram/train', transform=base_transform)
train_dataloader = DataLoader(train_dataset_spectrogram, shuffle=True, batch_size=batch_size, num_workers=1)

# 2.Choose the model!

In [None]:
MODEL_NAME="RESNET_18"
encoder = get_resnet18_encoder()

In [None]:
MODEL_NAME="SHUFFLENET"
encoder = get_shufflenetv2_encoder()

In [None]:
MODEL_NAME="ALEXNET"
encoder = get_alexnet_encoder()

In [None]:
MODEL_NAME="SPECTRUM_CNN"
encoder=SpectrumCNN()

# 3. Run everything below!

In [None]:
transform = simsiam_representation_transform
lit_model = SimSiam(encoder=encoder, transforms=transform)

In [None]:
pretrainer = pl.Trainer(
    auto_lr_find=True,
    gpus=1, max_epochs=10,
    enable_checkpointing=True,
    log_every_n_steps=10,
    flush_logs_every_n_steps=50
)
# Run learning rate finder
lr_finder = pretrainer.tuner.lr_find(lit_model, train_dataloaders=train_dataloader)

# Results can be found in
lr_finder.results

# Plot with
fig = lr_finder.plot(suggest=True)
fig.show()
print(lr_finder.suggestion())

In [None]:
# Pick point based on plot, or get suggestion
new_lr = lr_finder.suggestion()
if new_lr is None:
    new_lr = 1e-3

# update hparams of the model
lit_model.hparams.lr = new_lr

# # Fit model
pretrainer.fit(lit_model, train_dataloaders=train_dataloader)

In [None]:
save_model(model=lit_model.encoder, save_path=f"pretrained_encoders/{DATASET_NAME}__{MODEL_NAME}.pth",)

# Learning Rate used:
- RESNET18 (0, 1): amp=2.7542287033381663e-05, spec=4.786300923226385e-05
- SHUFFLENET (2, 3): amp=failed(1e-3), spec=failed(1e-3)
- AlexNet(4, 6): amp=failed(1e-3), spec=failed(1e-3)
- SpectrumCNN(7, 8): amp=failed(1e-3), spec=failed(1e-3)